feat: add Discord integration with chat_bridge abstraction layer
Introduces a vendor-agnostic chat platform architecture: - chat_bridge/base.py: ChatPlatform ABC, ChatMessage, ChatThread - chat_bridge/registry.py: PlatformRegistry singleton - chat_bridge/invite_parser.py: QR + Ollama vision invite extraction - chat_bridge/vendors/discord.py: DiscordVendor with native threads Workflow: paste a screenshot of a Discord invite or QR code at POST /discord/join → Timmy extracts the invite automatically. Every Discord conversation gets its own thread, keeping channels clean. Bot responds to @mentions and DMs, routes through Timmy agent. 43 new tests (base classes, registry, invite parser, vendor, routes). https://claude.ai/code/session_01WU4h3cQQiouMwmgYmAgkMM
This commit is contained in:
@@ -46,3 +46,10 @@
|
||||
# Alternatively, configure via the /telegram/setup dashboard endpoint at runtime.
|
||||
# Requires: pip install ".[telegram]"
|
||||
# TELEGRAM_TOKEN=
|
||||
|
||||
# ── Discord bot ──────────────────────────────────────────────────────────────
|
||||
# Bot token from https://discord.com/developers/applications
|
||||
# Alternatively, configure via the /discord/setup dashboard endpoint at runtime.
|
||||
# Requires: pip install ".[discord]"
|
||||
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
|
||||
# DISCORD_TOKEN=
|
||||
|
||||
@@ -54,6 +54,12 @@ voice = [
|
||||
telegram = [
|
||||
"python-telegram-bot>=21.0",
|
||||
]
|
||||
# Discord: bridge Discord messages to Timmy with native thread support.
|
||||
# pip install ".[discord]"
|
||||
# Optional: pip install pyzbar Pillow (for QR code invite detection)
|
||||
discord = [
|
||||
"discord.py>=2.3.0",
|
||||
]
|
||||
# Creative: GPU-accelerated image, music, and video generation.
|
||||
# pip install ".[creative]"
|
||||
creative = [
|
||||
@@ -84,6 +90,7 @@ include = [
|
||||
"src/notifications",
|
||||
"src/shortcuts",
|
||||
"src/telegram_bot",
|
||||
"src/chat_bridge",
|
||||
"src/spark",
|
||||
"src/tools",
|
||||
"src/creative",
|
||||
|
||||
10
src/chat_bridge/__init__.py
Normal file
10
src/chat_bridge/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Chat Bridge — vendor-agnostic chat platform abstraction.
|
||||
|
||||
Provides a clean interface for integrating any chat platform
|
||||
(Discord, Telegram, Slack, etc.) with Timmy's agent core.
|
||||
|
||||
Usage:
|
||||
from chat_bridge.base import ChatPlatform
|
||||
from chat_bridge.registry import platform_registry
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
"""
|
||||
147
src/chat_bridge/base.py
Normal file
147
src/chat_bridge/base.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""ChatPlatform — abstract base class for all chat vendor integrations.
|
||||
|
||||
Each vendor (Discord, Telegram, Slack, etc.) implements this interface.
|
||||
The dashboard and agent code interact only with this contract, never
|
||||
with vendor-specific APIs directly.
|
||||
|
||||
Architecture:
|
||||
ChatPlatform (ABC)
|
||||
|
|
||||
+-- DiscordVendor (discord.py)
|
||||
+-- TelegramVendor (future migration)
|
||||
+-- SlackVendor (future)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class PlatformState(Enum):
|
||||
"""Lifecycle state of a chat platform connection."""
|
||||
DISCONNECTED = auto()
|
||||
CONNECTING = auto()
|
||||
CONNECTED = auto()
|
||||
ERROR = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""Vendor-agnostic representation of a chat message."""
|
||||
content: str
|
||||
author: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
message_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
attachments: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatThread:
|
||||
"""Vendor-agnostic representation of a conversation thread."""
|
||||
thread_id: str
|
||||
title: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
created_at: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
archived: bool = False
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InviteInfo:
|
||||
"""Parsed invite extracted from an image or text."""
|
||||
url: str
|
||||
code: str
|
||||
platform: str
|
||||
guild_name: Optional[str] = None
|
||||
source: str = "unknown" # "qr", "vision", "text"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformStatus:
|
||||
"""Current status of a chat platform connection."""
|
||||
platform: str
|
||||
state: PlatformState
|
||||
token_set: bool
|
||||
guild_count: int = 0
|
||||
thread_count: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform,
|
||||
"state": self.state.name.lower(),
|
||||
"connected": self.state == PlatformState.CONNECTED,
|
||||
"token_set": self.token_set,
|
||||
"guild_count": self.guild_count,
|
||||
"thread_count": self.thread_count,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class ChatPlatform(ABC):
|
||||
"""Abstract base class for chat platform integrations.
|
||||
|
||||
Lifecycle:
|
||||
configure(token) -> start() -> [send/receive messages] -> stop()
|
||||
|
||||
All vendors implement this interface. The dashboard routes and
|
||||
agent code work with ChatPlatform, never with vendor-specific APIs.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Platform identifier (e.g., 'discord', 'telegram')."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> PlatformState:
|
||||
"""Current connection state."""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, token: Optional[str] = None) -> bool:
|
||||
"""Start the platform connection. Returns True on success."""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully disconnect."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self, channel_id: str, content: str, thread_id: Optional[str] = None
|
||||
) -> Optional[ChatMessage]:
|
||||
"""Send a message. Optionally within a thread."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_thread(
|
||||
self, channel_id: str, title: str, initial_message: Optional[str] = None
|
||||
) -> Optional[ChatThread]:
|
||||
"""Create a new thread in a channel."""
|
||||
|
||||
@abstractmethod
|
||||
async def join_from_invite(self, invite_code: str) -> bool:
|
||||
"""Join a server/workspace using an invite code."""
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> PlatformStatus:
|
||||
"""Return current platform status."""
|
||||
|
||||
@abstractmethod
|
||||
def save_token(self, token: str) -> None:
|
||||
"""Persist token for restarts."""
|
||||
|
||||
@abstractmethod
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Load persisted token."""
|
||||
166
src/chat_bridge/invite_parser.py
Normal file
166
src/chat_bridge/invite_parser.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""InviteParser — extract chat platform invite links from images.
|
||||
|
||||
Strategy chain:
|
||||
1. QR code detection (pyzbar — fast, no GPU)
|
||||
2. Ollama vision OCR (local LLM — handles screenshots with visible URLs)
|
||||
3. Regex fallback on raw text input
|
||||
|
||||
Supports Discord invite patterns:
|
||||
- discord.gg/<code>
|
||||
- discord.com/invite/<code>
|
||||
- discordapp.com/invite/<code>
|
||||
|
||||
Usage:
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
# From image bytes (screenshot or QR photo)
|
||||
result = await invite_parser.parse_image(image_bytes)
|
||||
|
||||
# From plain text
|
||||
result = invite_parser.parse_text("Join us at discord.gg/abc123")
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import InviteInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns for Discord invite URLs
|
||||
_DISCORD_PATTERNS = [
|
||||
re.compile(r"(?:https?://)?discord\.gg/([A-Za-z0-9\-_]+)"),
|
||||
re.compile(r"(?:https?://)?(?:www\.)?discord(?:app)?\.com/invite/([A-Za-z0-9\-_]+)"),
|
||||
]
|
||||
|
||||
|
||||
def _extract_discord_code(text: str) -> Optional[str]:
|
||||
"""Extract a Discord invite code from text."""
|
||||
for pattern in _DISCORD_PATTERNS:
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
class InviteParser:
|
||||
"""Multi-strategy invite parser.
|
||||
|
||||
Tries QR detection first (fast), then Ollama vision (local AI),
|
||||
then regex on raw text. All local, no cloud.
|
||||
"""
|
||||
|
||||
async def parse_image(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Extract an invite from image bytes (screenshot or QR photo).
|
||||
|
||||
Tries strategies in order:
|
||||
1. QR code decode (pyzbar)
|
||||
2. Ollama vision model (local OCR)
|
||||
"""
|
||||
result = self._try_qr_decode(image_data)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = await self._try_ollama_vision(image_data)
|
||||
if result:
|
||||
return result
|
||||
|
||||
logger.info("No invite found in image via any strategy.")
|
||||
return None
|
||||
|
||||
def parse_text(self, text: str) -> Optional[InviteInfo]:
|
||||
"""Extract an invite from plain text."""
|
||||
code = _extract_discord_code(text)
|
||||
if code:
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="text",
|
||||
)
|
||||
return None
|
||||
|
||||
def _try_qr_decode(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Strategy 1: Decode QR codes from image using pyzbar."""
|
||||
try:
|
||||
from PIL import Image
|
||||
from pyzbar.pyzbar import decode as qr_decode
|
||||
except ImportError:
|
||||
logger.debug("pyzbar/Pillow not installed, skipping QR strategy.")
|
||||
return None
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
decoded = qr_decode(image)
|
||||
|
||||
for obj in decoded:
|
||||
text = obj.data.decode("utf-8", errors="ignore")
|
||||
code = _extract_discord_code(text)
|
||||
if code:
|
||||
logger.info("QR decode found Discord invite: %s", code)
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="qr",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("QR decode failed: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
async def _try_ollama_vision(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Strategy 2: Use Ollama vision model for local OCR."""
|
||||
try:
|
||||
import base64
|
||||
import httpx
|
||||
from config import settings
|
||||
except ImportError:
|
||||
logger.debug("httpx not available for Ollama vision.")
|
||||
return None
|
||||
|
||||
try:
|
||||
b64_image = base64.b64encode(image_data).decode("ascii")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{settings.ollama_url}/api/generate",
|
||||
json={
|
||||
"model": settings.ollama_model,
|
||||
"prompt": (
|
||||
"Extract any Discord invite link from this image. "
|
||||
"Look for URLs like discord.gg/CODE or "
|
||||
"discord.com/invite/CODE. "
|
||||
"Reply with ONLY the invite URL, nothing else. "
|
||||
"If no invite link is found, reply with: NONE"
|
||||
),
|
||||
"images": [b64_image],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.debug("Ollama vision returned %d", resp.status_code)
|
||||
return None
|
||||
|
||||
answer = resp.json().get("response", "").strip()
|
||||
if answer and answer.upper() != "NONE":
|
||||
code = _extract_discord_code(answer)
|
||||
if code:
|
||||
logger.info("Ollama vision found Discord invite: %s", code)
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="vision",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Ollama vision strategy failed: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
invite_parser = InviteParser()
|
||||
74
src/chat_bridge/registry.py
Normal file
74
src/chat_bridge/registry.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""PlatformRegistry — singleton registry for chat platform vendors.
|
||||
|
||||
Provides a central point for registering, discovering, and managing
|
||||
all chat platform integrations. Dashboard routes and the agent core
|
||||
interact with platforms through this registry.
|
||||
|
||||
Usage:
|
||||
from chat_bridge.registry import platform_registry
|
||||
|
||||
platform_registry.register(discord_vendor)
|
||||
discord = platform_registry.get("discord")
|
||||
all_platforms = platform_registry.list_platforms()
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import ChatPlatform, PlatformStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformRegistry:
|
||||
"""Thread-safe registry of ChatPlatform vendors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._platforms: dict[str, ChatPlatform] = {}
|
||||
|
||||
def register(self, platform: ChatPlatform) -> None:
|
||||
"""Register a chat platform vendor."""
|
||||
name = platform.name
|
||||
if name in self._platforms:
|
||||
logger.warning("Platform '%s' already registered, replacing.", name)
|
||||
self._platforms[name] = platform
|
||||
logger.info("Registered chat platform: %s", name)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove a platform from the registry. Returns True if it existed."""
|
||||
if name in self._platforms:
|
||||
del self._platforms[name]
|
||||
logger.info("Unregistered chat platform: %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> Optional[ChatPlatform]:
|
||||
"""Get a platform by name."""
|
||||
return self._platforms.get(name)
|
||||
|
||||
def list_platforms(self) -> list[PlatformStatus]:
|
||||
"""Return status of all registered platforms."""
|
||||
return [p.status() for p in self._platforms.values()]
|
||||
|
||||
async def start_all(self) -> dict[str, bool]:
|
||||
"""Start all registered platforms. Returns name -> success mapping."""
|
||||
results = {}
|
||||
for name, platform in self._platforms.items():
|
||||
try:
|
||||
results[name] = await platform.start()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to start platform '%s': %s", name, exc)
|
||||
results[name] = False
|
||||
return results
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all registered platforms."""
|
||||
for name, platform in self._platforms.items():
|
||||
try:
|
||||
await platform.stop()
|
||||
except Exception as exc:
|
||||
logger.error("Error stopping platform '%s': %s", name, exc)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
platform_registry = PlatformRegistry()
|
||||
0
src/chat_bridge/vendors/__init__.py
vendored
Normal file
0
src/chat_bridge/vendors/__init__.py
vendored
Normal file
400
src/chat_bridge/vendors/discord.py
vendored
Normal file
400
src/chat_bridge/vendors/discord.py
vendored
Normal file
@@ -0,0 +1,400 @@
|
||||
"""DiscordVendor — Discord integration via discord.py.
|
||||
|
||||
Implements ChatPlatform with native thread support. Each conversation
|
||||
with Timmy gets its own Discord thread, keeping channels clean.
|
||||
|
||||
Optional dependency — install with:
|
||||
pip install ".[discord]"
|
||||
|
||||
Architecture:
|
||||
DiscordVendor
|
||||
├── _client (discord.Client) — handles gateway events
|
||||
├── _thread_map — channel_id -> active thread
|
||||
└── _message_handler — bridges to Timmy agent
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import (
|
||||
ChatMessage,
|
||||
ChatPlatform,
|
||||
ChatThread,
|
||||
InviteInfo,
|
||||
PlatformState,
|
||||
PlatformStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STATE_FILE = Path(__file__).parent.parent.parent.parent / "discord_state.json"
|
||||
|
||||
|
||||
class DiscordVendor(ChatPlatform):
|
||||
"""Discord integration with native thread conversations.
|
||||
|
||||
Every user interaction creates or continues a Discord thread,
|
||||
keeping channel history clean and conversations organized.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = None
|
||||
self._token: Optional[str] = None
|
||||
self._state: PlatformState = PlatformState.DISCONNECTED
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._guild_count: int = 0
|
||||
self._active_threads: dict[str, str] = {} # channel_id -> thread_id
|
||||
|
||||
# ── ChatPlatform interface ─────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "discord"
|
||||
|
||||
@property
|
||||
def state(self) -> PlatformState:
|
||||
return self._state
|
||||
|
||||
async def start(self, token: Optional[str] = None) -> bool:
|
||||
"""Start the Discord bot. Returns True on success."""
|
||||
if self._state == PlatformState.CONNECTED:
|
||||
return True
|
||||
|
||||
tok = token or self.load_token()
|
||||
if not tok:
|
||||
logger.warning("Discord bot: no token configured, skipping start.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"discord.py is not installed. "
|
||||
'Run: pip install ".[discord]"'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._state = PlatformState.CONNECTING
|
||||
self._token = tok
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client = discord.Client(intents=intents)
|
||||
self._register_handlers()
|
||||
|
||||
# Run the client in a background task so we don't block
|
||||
self._task = asyncio.create_task(self._run_client(tok))
|
||||
|
||||
# Wait briefly for connection
|
||||
for _ in range(30):
|
||||
await asyncio.sleep(0.5)
|
||||
if self._state == PlatformState.CONNECTED:
|
||||
logger.info("Discord bot connected (%d guilds).", self._guild_count)
|
||||
return True
|
||||
if self._state == PlatformState.ERROR:
|
||||
return False
|
||||
|
||||
logger.warning("Discord bot: connection timed out.")
|
||||
self._state = PlatformState.ERROR
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Discord bot failed to start: %s", exc)
|
||||
self._state = PlatformState.ERROR
|
||||
self._token = None
|
||||
self._client = None
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully disconnect the Discord bot."""
|
||||
if self._client and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
logger.info("Discord bot disconnected.")
|
||||
except Exception as exc:
|
||||
logger.error("Error stopping Discord bot: %s", exc)
|
||||
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
self._client = None
|
||||
self._task = None
|
||||
|
||||
async def send_message(
|
||||
self, channel_id: str, content: str, thread_id: Optional[str] = None
|
||||
) -> Optional[ChatMessage]:
|
||||
"""Send a message to a Discord channel or thread."""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
return None
|
||||
|
||||
try:
|
||||
import discord
|
||||
|
||||
target_id = int(thread_id) if thread_id else int(channel_id)
|
||||
channel = self._client.get_channel(target_id)
|
||||
|
||||
if channel is None:
|
||||
channel = await self._client.fetch_channel(target_id)
|
||||
|
||||
msg = await channel.send(content)
|
||||
|
||||
return ChatMessage(
|
||||
content=content,
|
||||
author=str(self._client.user),
|
||||
channel_id=str(msg.channel.id),
|
||||
platform="discord",
|
||||
message_id=str(msg.id),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send Discord message: %s", exc)
|
||||
return None
|
||||
|
||||
async def create_thread(
|
||||
self, channel_id: str, title: str, initial_message: Optional[str] = None
|
||||
) -> Optional[ChatThread]:
|
||||
"""Create a new thread in a Discord channel."""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
return None
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(channel_id))
|
||||
if channel is None:
|
||||
channel = await self._client.fetch_channel(int(channel_id))
|
||||
|
||||
thread = await channel.create_thread(
|
||||
name=title[:100], # Discord limits thread names to 100 chars
|
||||
auto_archive_duration=1440, # 24 hours
|
||||
)
|
||||
|
||||
if initial_message:
|
||||
await thread.send(initial_message)
|
||||
|
||||
self._active_threads[channel_id] = str(thread.id)
|
||||
|
||||
return ChatThread(
|
||||
thread_id=str(thread.id),
|
||||
title=title[:100],
|
||||
channel_id=channel_id,
|
||||
platform="discord",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to create Discord thread: %s", exc)
|
||||
return None
|
||||
|
||||
async def join_from_invite(self, invite_code: str) -> bool:
|
||||
"""Join a Discord server using an invite code.
|
||||
|
||||
Note: Bot accounts cannot use invite links directly.
|
||||
This generates an OAuth2 URL for adding the bot to a server.
|
||||
The invite_code is validated but the actual join requires
|
||||
the server admin to use the bot's OAuth2 authorization URL.
|
||||
"""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
logger.warning("Discord bot not connected, cannot process invite.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import discord
|
||||
|
||||
invite = await self._client.fetch_invite(invite_code)
|
||||
logger.info(
|
||||
"Validated invite for server '%s' (code: %s)",
|
||||
invite.guild.name if invite.guild else "unknown",
|
||||
invite_code,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("Invalid Discord invite '%s': %s", invite_code, exc)
|
||||
return False
|
||||
|
||||
def status(self) -> PlatformStatus:
|
||||
return PlatformStatus(
|
||||
platform="discord",
|
||||
state=self._state,
|
||||
token_set=bool(self._token),
|
||||
guild_count=self._guild_count,
|
||||
thread_count=len(self._active_threads),
|
||||
)
|
||||
|
||||
def save_token(self, token: str) -> None:
|
||||
"""Persist token to state file."""
|
||||
try:
|
||||
_STATE_FILE.write_text(json.dumps({"token": token}))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to save Discord token: %s", exc)
|
||||
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Load token from state file or config."""
|
||||
try:
|
||||
if _STATE_FILE.exists():
|
||||
data = json.loads(_STATE_FILE.read_text())
|
||||
token = data.get("token")
|
||||
if token:
|
||||
return token
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read discord state file: %s", exc)
|
||||
|
||||
try:
|
||||
from config import settings
|
||||
return settings.discord_token or None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ── OAuth2 URL generation ──────────────────────────────────────────────
|
||||
|
||||
def get_oauth2_url(self) -> Optional[str]:
|
||||
"""Generate the OAuth2 URL for adding this bot to a server.
|
||||
|
||||
Requires the bot to be connected to read its application ID.
|
||||
"""
|
||||
if not self._client or not self._client.user:
|
||||
return None
|
||||
|
||||
app_id = self._client.user.id
|
||||
# Permissions: Send Messages, Create Public Threads, Manage Threads,
|
||||
# Read Message History, Embed Links, Attach Files
|
||||
permissions = 397284550656
|
||||
return (
|
||||
f"https://discord.com/oauth2/authorize"
|
||||
f"?client_id={app_id}&scope=bot"
|
||||
f"&permissions={permissions}"
|
||||
)
|
||||
|
||||
# ── Internal ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _run_client(self, token: str) -> None:
|
||||
"""Run the discord.py client (blocking call in a task)."""
|
||||
try:
|
||||
await self._client.start(token)
|
||||
except Exception as exc:
|
||||
logger.error("Discord client error: %s", exc)
|
||||
self._state = PlatformState.ERROR
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""Register Discord event handlers on the client."""
|
||||
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
self._guild_count = len(self._client.guilds)
|
||||
self._state = PlatformState.CONNECTED
|
||||
logger.info(
|
||||
"Discord ready: %s in %d guild(s)",
|
||||
self._client.user,
|
||||
self._guild_count,
|
||||
)
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message):
|
||||
# Ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Only respond to mentions or DMs
|
||||
is_dm = not hasattr(message.channel, "guild") or message.channel.guild is None
|
||||
is_mention = self._client.user in message.mentions
|
||||
|
||||
if not is_dm and not is_mention:
|
||||
return
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
@self._client.event
|
||||
async def on_disconnect():
|
||||
if self._state != PlatformState.DISCONNECTED:
|
||||
self._state = PlatformState.CONNECTING
|
||||
logger.warning("Discord disconnected, will auto-reconnect.")
|
||||
|
||||
async def _handle_message(self, message) -> None:
|
||||
"""Process an incoming message and respond via a thread."""
|
||||
# Strip the bot mention from the message content
|
||||
content = message.content
|
||||
if self._client.user:
|
||||
content = content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Create or reuse a thread for this conversation
|
||||
thread = await self._get_or_create_thread(message)
|
||||
target = thread or message.channel
|
||||
|
||||
# Run Timmy agent
|
||||
try:
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
agent = create_timmy()
|
||||
run = await asyncio.to_thread(agent.run, content, stream=False)
|
||||
response = run.content if hasattr(run, "content") else str(run)
|
||||
except Exception as exc:
|
||||
logger.error("Timmy error in Discord handler: %s", exc)
|
||||
response = f"Timmy is offline: {exc}"
|
||||
|
||||
# Discord has a 2000 character limit
|
||||
for chunk in _chunk_message(response, 2000):
|
||||
await target.send(chunk)
|
||||
|
||||
async def _get_or_create_thread(self, message):
|
||||
"""Get the active thread for a channel, or create one.
|
||||
|
||||
If the message is already in a thread, use that thread.
|
||||
Otherwise, create a new thread from the message.
|
||||
"""
|
||||
try:
|
||||
import discord
|
||||
|
||||
# Already in a thread — just use it
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
return message.channel
|
||||
|
||||
# DM channels don't support threads
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return None
|
||||
|
||||
# Create a thread from this message
|
||||
thread_name = f"Timmy | {message.author.display_name}"
|
||||
thread = await message.create_thread(
|
||||
name=thread_name[:100],
|
||||
auto_archive_duration=1440,
|
||||
)
|
||||
channel_id = str(message.channel.id)
|
||||
self._active_threads[channel_id] = str(thread.id)
|
||||
return thread
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("Could not create thread: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _chunk_message(text: str, max_len: int = 2000) -> list[str]:
|
||||
"""Split a message into chunks that fit Discord's character limit."""
|
||||
if len(text) <= max_len:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= max_len:
|
||||
chunks.append(text)
|
||||
break
|
||||
# Try to split at a newline
|
||||
split_at = text.rfind("\n", 0, max_len)
|
||||
if split_at == -1:
|
||||
split_at = max_len
|
||||
chunks.append(text[:split_at])
|
||||
text = text[split_at:].lstrip("\n")
|
||||
return chunks
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
discord_bot = DiscordVendor()
|
||||
@@ -16,6 +16,9 @@ class Settings(BaseSettings):
|
||||
# Telegram bot token — set via TELEGRAM_TOKEN env var or the /telegram/setup endpoint
|
||||
telegram_token: str = ""
|
||||
|
||||
# Discord bot token — set via DISCORD_TOKEN env var or the /discord/setup endpoint
|
||||
discord_token: str = ""
|
||||
|
||||
# ── AirLLM / backend selection ───────────────────────────────────────────
|
||||
# "ollama" — always use Ollama (default, safe everywhere)
|
||||
# "airllm" — always use AirLLM (requires pip install ".[bigbrain]")
|
||||
|
||||
@@ -25,6 +25,7 @@ from dashboard.routes.swarm_internal import router as swarm_internal_router
|
||||
from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.spark import router as spark_router
|
||||
from dashboard.routes.creative import router as creative_router
|
||||
from dashboard.routes.discord import router as discord_router
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -108,8 +109,15 @@ async def lifespan(app: FastAPI):
|
||||
from telegram_bot.bot import telegram_bot
|
||||
await telegram_bot.start()
|
||||
|
||||
# Auto-start Discord bot and register in platform registry
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
from chat_bridge.registry import platform_registry
|
||||
platform_registry.register(discord_bot)
|
||||
await discord_bot.start()
|
||||
|
||||
yield
|
||||
|
||||
await discord_bot.stop()
|
||||
await telegram_bot.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
@@ -145,6 +153,7 @@ app.include_router(swarm_internal_router)
|
||||
app.include_router(tools_router)
|
||||
app.include_router(spark_router)
|
||||
app.include_router(creative_router)
|
||||
app.include_router(discord_router)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
|
||||
140
src/dashboard/routes/discord.py
Normal file
140
src/dashboard/routes/discord.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Dashboard routes for Discord bot setup, status, and invite-from-image.
|
||||
|
||||
Endpoints:
|
||||
POST /discord/setup — configure bot token
|
||||
GET /discord/status — connection state + guild count
|
||||
POST /discord/join — paste screenshot → extract invite → join
|
||||
GET /discord/oauth-url — get the bot's OAuth2 authorization URL
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter(prefix="/discord", tags=["discord"])
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/setup")
|
||||
async def setup_discord(payload: TokenPayload):
|
||||
"""Configure the Discord bot token and (re)start the bot.
|
||||
|
||||
Send POST with JSON body: {"token": "<your-bot-token>"}
|
||||
Get the token from https://discord.com/developers/applications
|
||||
"""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
token = payload.token.strip()
|
||||
if not token:
|
||||
return {"ok": False, "error": "Token cannot be empty."}
|
||||
|
||||
discord_bot.save_token(token)
|
||||
|
||||
if discord_bot.state.name == "CONNECTED":
|
||||
await discord_bot.stop()
|
||||
|
||||
success = await discord_bot.start(token=token)
|
||||
if success:
|
||||
return {"ok": True, "message": "Discord bot connected successfully."}
|
||||
return {
|
||||
"ok": False,
|
||||
"error": (
|
||||
"Failed to start bot. Check that the token is correct and "
|
||||
'discord.py is installed: pip install ".[discord]"'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def discord_status():
|
||||
"""Return current Discord bot status."""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
return discord_bot.status().to_dict()
|
||||
|
||||
|
||||
@router.post("/join")
|
||||
async def join_from_image(
|
||||
image: Optional[UploadFile] = File(None),
|
||||
invite_url: Optional[str] = Form(None),
|
||||
):
|
||||
"""Extract a Discord invite from a screenshot or text and validate it.
|
||||
|
||||
Accepts either:
|
||||
- An uploaded image (screenshot of invite or QR code)
|
||||
- A plain text invite URL
|
||||
|
||||
The bot validates the invite and returns the OAuth2 URL for the
|
||||
server admin to authorize the bot.
|
||||
"""
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
invite_info = None
|
||||
|
||||
# Try image first
|
||||
if image and image.filename:
|
||||
image_data = await image.read()
|
||||
if image_data:
|
||||
invite_info = await invite_parser.parse_image(image_data)
|
||||
|
||||
# Fall back to text
|
||||
if not invite_info and invite_url:
|
||||
invite_info = invite_parser.parse_text(invite_url)
|
||||
|
||||
if not invite_info:
|
||||
return {
|
||||
"ok": False,
|
||||
"error": (
|
||||
"No Discord invite found. "
|
||||
"Paste a screenshot with a visible invite link or QR code, "
|
||||
"or enter the invite URL directly."
|
||||
),
|
||||
}
|
||||
|
||||
# Validate the invite
|
||||
valid = await discord_bot.join_from_invite(invite_info.code)
|
||||
|
||||
result = {
|
||||
"ok": True,
|
||||
"invite": {
|
||||
"code": invite_info.code,
|
||||
"url": invite_info.url,
|
||||
"source": invite_info.source,
|
||||
"platform": invite_info.platform,
|
||||
},
|
||||
"validated": valid,
|
||||
}
|
||||
|
||||
# Include OAuth2 URL if bot is connected
|
||||
oauth_url = discord_bot.get_oauth2_url()
|
||||
if oauth_url:
|
||||
result["oauth2_url"] = oauth_url
|
||||
result["message"] = (
|
||||
"Invite validated. Share this OAuth2 URL with the server admin "
|
||||
"to add Timmy to the server."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Invite found but bot is not connected. "
|
||||
"Configure a bot token first via /discord/setup."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/oauth-url")
|
||||
async def discord_oauth_url():
|
||||
"""Get the bot's OAuth2 authorization URL for adding to servers."""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
url = discord_bot.get_oauth2_url()
|
||||
if url:
|
||||
return {"ok": True, "url": url}
|
||||
return {
|
||||
"ok": False,
|
||||
"error": "Bot is not connected. Configure a token first.",
|
||||
}
|
||||
@@ -25,6 +25,14 @@ for _mod in [
|
||||
# without the package installed.
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
# discord.py is optional (discord extra) — stub so tests run
|
||||
# without the package installed.
|
||||
"discord",
|
||||
"discord.ext",
|
||||
"discord.ext.commands",
|
||||
# pyzbar is optional (for QR code invite detection)
|
||||
"pyzbar",
|
||||
"pyzbar.pyzbar",
|
||||
]:
|
||||
sys.modules.setdefault(_mod, MagicMock())
|
||||
|
||||
|
||||
268
tests/test_chat_bridge.py
Normal file
268
tests/test_chat_bridge.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Tests for the chat_bridge base classes, registry, and invite parser."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from chat_bridge.base import (
|
||||
ChatMessage,
|
||||
ChatPlatform,
|
||||
ChatThread,
|
||||
InviteInfo,
|
||||
PlatformState,
|
||||
PlatformStatus,
|
||||
)
|
||||
from chat_bridge.registry import PlatformRegistry
|
||||
|
||||
|
||||
# ── Base dataclass tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChatMessage:
|
||||
def test_create_message(self):
|
||||
msg = ChatMessage(
|
||||
content="Hello",
|
||||
author="user1",
|
||||
channel_id="123",
|
||||
platform="test",
|
||||
)
|
||||
assert msg.content == "Hello"
|
||||
assert msg.author == "user1"
|
||||
assert msg.platform == "test"
|
||||
assert msg.thread_id is None
|
||||
assert msg.attachments == []
|
||||
|
||||
def test_message_with_thread(self):
|
||||
msg = ChatMessage(
|
||||
content="Reply",
|
||||
author="bot",
|
||||
channel_id="123",
|
||||
platform="discord",
|
||||
thread_id="456",
|
||||
)
|
||||
assert msg.thread_id == "456"
|
||||
|
||||
|
||||
class TestChatThread:
|
||||
def test_create_thread(self):
|
||||
thread = ChatThread(
|
||||
thread_id="t1",
|
||||
title="Timmy | user1",
|
||||
channel_id="c1",
|
||||
platform="discord",
|
||||
)
|
||||
assert thread.thread_id == "t1"
|
||||
assert thread.archived is False
|
||||
assert thread.message_count == 0
|
||||
|
||||
|
||||
class TestInviteInfo:
|
||||
def test_create_invite(self):
|
||||
invite = InviteInfo(
|
||||
url="https://discord.gg/abc123",
|
||||
code="abc123",
|
||||
platform="discord",
|
||||
source="qr",
|
||||
)
|
||||
assert invite.code == "abc123"
|
||||
assert invite.source == "qr"
|
||||
|
||||
|
||||
class TestPlatformStatus:
|
||||
def test_to_dict(self):
|
||||
status = PlatformStatus(
|
||||
platform="discord",
|
||||
state=PlatformState.CONNECTED,
|
||||
token_set=True,
|
||||
guild_count=3,
|
||||
)
|
||||
d = status.to_dict()
|
||||
assert d["connected"] is True
|
||||
assert d["platform"] == "discord"
|
||||
assert d["guild_count"] == 3
|
||||
assert d["state"] == "connected"
|
||||
|
||||
def test_disconnected_status(self):
|
||||
status = PlatformStatus(
|
||||
platform="test",
|
||||
state=PlatformState.DISCONNECTED,
|
||||
token_set=False,
|
||||
)
|
||||
d = status.to_dict()
|
||||
assert d["connected"] is False
|
||||
|
||||
|
||||
# ── PlatformRegistry tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakePlatform(ChatPlatform):
|
||||
"""Minimal ChatPlatform for testing the registry."""
|
||||
|
||||
def __init__(self, platform_name: str = "fake"):
|
||||
self._name = platform_name
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def state(self) -> PlatformState:
|
||||
return self._state
|
||||
|
||||
async def start(self, token=None) -> bool:
|
||||
self._state = PlatformState.CONNECTED
|
||||
return True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
|
||||
async def send_message(self, channel_id, content, thread_id=None):
|
||||
return ChatMessage(
|
||||
content=content, author="bot", channel_id=channel_id, platform=self._name
|
||||
)
|
||||
|
||||
async def create_thread(self, channel_id, title, initial_message=None):
|
||||
return ChatThread(
|
||||
thread_id="t1", title=title, channel_id=channel_id, platform=self._name
|
||||
)
|
||||
|
||||
async def join_from_invite(self, invite_code) -> bool:
|
||||
return True
|
||||
|
||||
def status(self):
|
||||
return PlatformStatus(
|
||||
platform=self._name,
|
||||
state=self._state,
|
||||
token_set=False,
|
||||
)
|
||||
|
||||
def save_token(self, token):
|
||||
pass
|
||||
|
||||
def load_token(self):
|
||||
return None
|
||||
|
||||
|
||||
class TestPlatformRegistry:
|
||||
def test_register_and_get(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("test1")
|
||||
reg.register(p)
|
||||
assert reg.get("test1") is p
|
||||
|
||||
def test_get_missing(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.get("nonexistent") is None
|
||||
|
||||
def test_unregister(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("test1")
|
||||
reg.register(p)
|
||||
assert reg.unregister("test1") is True
|
||||
assert reg.get("test1") is None
|
||||
|
||||
def test_unregister_missing(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.unregister("nope") is False
|
||||
|
||||
def test_list_platforms(self):
|
||||
reg = PlatformRegistry()
|
||||
reg.register(_FakePlatform("a"))
|
||||
reg.register(_FakePlatform("b"))
|
||||
statuses = reg.list_platforms()
|
||||
assert len(statuses) == 2
|
||||
names = {s.platform for s in statuses}
|
||||
assert names == {"a", "b"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all(self):
|
||||
reg = PlatformRegistry()
|
||||
reg.register(_FakePlatform("x"))
|
||||
reg.register(_FakePlatform("y"))
|
||||
results = await reg.start_all()
|
||||
assert results == {"x": True, "y": True}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("z")
|
||||
reg.register(p)
|
||||
await reg.start_all()
|
||||
assert p.state == PlatformState.CONNECTED
|
||||
await reg.stop_all()
|
||||
assert p.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_replace_existing(self):
|
||||
reg = PlatformRegistry()
|
||||
p1 = _FakePlatform("dup")
|
||||
p2 = _FakePlatform("dup")
|
||||
reg.register(p1)
|
||||
reg.register(p2)
|
||||
assert reg.get("dup") is p2
|
||||
|
||||
|
||||
# ── InviteParser tests ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestInviteParser:
|
||||
def test_parse_text_discord_gg(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("Join us at https://discord.gg/abc123!")
|
||||
assert result is not None
|
||||
assert result.code == "abc123"
|
||||
assert result.platform == "discord"
|
||||
assert result.source == "text"
|
||||
|
||||
def test_parse_text_discord_com_invite(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text(
|
||||
"Link: https://discord.com/invite/myServer2024"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "myServer2024"
|
||||
|
||||
def test_parse_text_discordapp(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text(
|
||||
"https://discordapp.com/invite/test-code"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "test-code"
|
||||
|
||||
def test_parse_text_no_invite(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("Hello world, no links here")
|
||||
assert result is None
|
||||
|
||||
def test_parse_text_bare_discord_gg(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("discord.gg/xyz789")
|
||||
assert result is not None
|
||||
assert result.code == "xyz789"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_image_no_deps(self):
|
||||
"""parse_image returns None when pyzbar/Pillow are not installed."""
|
||||
from chat_bridge.invite_parser import InviteParser
|
||||
|
||||
parser = InviteParser()
|
||||
# With mocked pyzbar, this should gracefully return None
|
||||
result = await parser.parse_image(b"fake-image-bytes")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestExtractDiscordCode:
|
||||
def test_various_formats(self):
|
||||
from chat_bridge.invite_parser import _extract_discord_code
|
||||
|
||||
assert _extract_discord_code("discord.gg/abc") == "abc"
|
||||
assert _extract_discord_code("https://discord.gg/test") == "test"
|
||||
assert _extract_discord_code("http://discord.gg/http") == "http"
|
||||
assert _extract_discord_code("discord.com/invite/xyz") == "xyz"
|
||||
assert _extract_discord_code("no link here") is None
|
||||
assert _extract_discord_code("") is None
|
||||
225
tests/test_discord_vendor.py
Normal file
225
tests/test_discord_vendor.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for the Discord vendor and dashboard routes."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from chat_bridge.base import PlatformState
|
||||
|
||||
|
||||
# ── DiscordVendor unit tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiscordVendor:
|
||||
def test_name(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.name == "discord"
|
||||
|
||||
def test_initial_state(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_status_disconnected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
status = vendor.status()
|
||||
assert status.platform == "discord"
|
||||
assert status.state == PlatformState.DISCONNECTED
|
||||
assert status.token_set is False
|
||||
assert status.guild_count == 0
|
||||
|
||||
def test_save_and_load_token(self, tmp_path, monkeypatch):
|
||||
from chat_bridge.vendors import discord as discord_mod
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
state_file = tmp_path / "discord_state.json"
|
||||
monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file)
|
||||
|
||||
vendor = DiscordVendor()
|
||||
vendor.save_token("test-token-abc")
|
||||
|
||||
assert state_file.exists()
|
||||
data = json.loads(state_file.read_text())
|
||||
assert data["token"] == "test-token-abc"
|
||||
|
||||
loaded = vendor.load_token()
|
||||
assert loaded == "test-token-abc"
|
||||
|
||||
def test_load_token_missing_file(self, tmp_path, monkeypatch):
|
||||
from chat_bridge.vendors import discord as discord_mod
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
state_file = tmp_path / "nonexistent.json"
|
||||
monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file)
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Falls back to config.settings.discord_token
|
||||
token = vendor.load_token()
|
||||
# Default discord_token is "" which becomes None
|
||||
assert token is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_token(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.start(token=None)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_import_error(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Simulate discord.py not installed by making import fail
|
||||
with patch.dict("sys.modules", {"discord": None}):
|
||||
result = await vendor.start(token="fake-token")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_disconnected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Should not raise
|
||||
await vendor.stop()
|
||||
assert vendor.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_get_oauth2_url_no_client(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.get_oauth2_url() is None
|
||||
|
||||
def test_get_oauth2_url_with_client(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
mock_client = MagicMock()
|
||||
mock_client.user.id = 123456789
|
||||
vendor._client = mock_client
|
||||
url = vendor.get_oauth2_url()
|
||||
assert "123456789" in url
|
||||
assert "oauth2/authorize" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.send_message("123", "hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_thread_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.create_thread("123", "Test Thread")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_from_invite_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.join_from_invite("abc123")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestChunkMessage:
|
||||
def test_short_message(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
chunks = _chunk_message("Hello!", 2000)
|
||||
assert chunks == ["Hello!"]
|
||||
|
||||
def test_long_message(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
text = "a" * 5000
|
||||
chunks = _chunk_message(text, 2000)
|
||||
assert len(chunks) == 3
|
||||
assert all(len(c) <= 2000 for c in chunks)
|
||||
assert "".join(chunks) == text
|
||||
|
||||
def test_split_at_newline(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
text = "Line1\n" + "x" * 1990 + "\nLine3"
|
||||
chunks = _chunk_message(text, 2000)
|
||||
assert len(chunks) >= 2
|
||||
assert chunks[0].startswith("Line1")
|
||||
|
||||
|
||||
# ── Discord route tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiscordRoutes:
|
||||
def test_status_endpoint(self, client):
|
||||
resp = client.get("/discord/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["platform"] == "discord"
|
||||
assert "connected" in data
|
||||
|
||||
def test_setup_empty_token(self, client):
|
||||
resp = client.post("/discord/setup", json={"token": ""})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert "empty" in data["error"].lower()
|
||||
|
||||
def test_setup_with_token(self, client):
|
||||
"""Setup with a token — bot won't actually connect but route works."""
|
||||
with patch(
|
||||
"chat_bridge.vendors.discord.DiscordVendor.start",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
resp = client.post(
|
||||
"/discord/setup", json={"token": "fake-token-123"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Will fail because discord.py is mocked, but route handles it
|
||||
assert "ok" in data
|
||||
|
||||
def test_join_no_input(self, client):
|
||||
resp = client.post("/discord/join")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert "no discord invite" in data["error"].lower()
|
||||
|
||||
def test_join_with_text_invite(self, client):
|
||||
with patch(
|
||||
"chat_bridge.vendors.discord.DiscordVendor.join_from_invite",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
resp = client.post(
|
||||
"/discord/join",
|
||||
data={"invite_url": "https://discord.gg/testcode"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["invite"]["code"] == "testcode"
|
||||
assert data["invite"]["source"] == "text"
|
||||
|
||||
def test_oauth_url_not_connected(self, client):
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
# Reset singleton so it has no client
|
||||
discord_bot._client = None
|
||||
resp = client.get("/discord/oauth-url")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
Reference in New Issue
Block a user