diff --git a/.env.example b/.env.example index 53ac479..65160b8 100644 --- a/.env.example +++ b/.env.example @@ -46,3 +46,10 @@ # Alternatively, configure via the /telegram/setup dashboard endpoint at runtime. # Requires: pip install ".[telegram]" # TELEGRAM_TOKEN= + +# ── Discord bot ────────────────────────────────────────────────────────────── +# Bot token from https://discord.com/developers/applications +# Alternatively, configure via the /discord/setup dashboard endpoint at runtime. +# Requires: pip install ".[discord]" +# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots) +# DISCORD_TOKEN= diff --git a/pyproject.toml b/pyproject.toml index b1900cf..e753c28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,12 @@ voice = [ telegram = [ "python-telegram-bot>=21.0", ] +# Discord: bridge Discord messages to Timmy with native thread support. +# pip install ".[discord]" +# Optional: pip install pyzbar Pillow (for QR code invite detection) +discord = [ + "discord.py>=2.3.0", +] # Creative: GPU-accelerated image, music, and video generation. # pip install ".[creative]" creative = [ @@ -84,6 +90,7 @@ include = [ "src/notifications", "src/shortcuts", "src/telegram_bot", + "src/chat_bridge", "src/spark", "src/tools", "src/creative", diff --git a/src/chat_bridge/__init__.py b/src/chat_bridge/__init__.py new file mode 100644 index 0000000..7aa82bd --- /dev/null +++ b/src/chat_bridge/__init__.py @@ -0,0 +1,10 @@ +"""Chat Bridge — vendor-agnostic chat platform abstraction. + +Provides a clean interface for integrating any chat platform +(Discord, Telegram, Slack, etc.) with Timmy's agent core. + +Usage: + from chat_bridge.base import ChatPlatform + from chat_bridge.registry import platform_registry + from chat_bridge.vendors.discord import DiscordVendor +""" diff --git a/src/chat_bridge/base.py b/src/chat_bridge/base.py new file mode 100644 index 0000000..6af6607 --- /dev/null +++ b/src/chat_bridge/base.py @@ -0,0 +1,147 @@ +"""ChatPlatform — abstract base class for all chat vendor integrations. + +Each vendor (Discord, Telegram, Slack, etc.) implements this interface. +The dashboard and agent code interact only with this contract, never +with vendor-specific APIs directly. + +Architecture: + ChatPlatform (ABC) + | + +-- DiscordVendor (discord.py) + +-- TelegramVendor (future migration) + +-- SlackVendor (future) +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum, auto +from typing import Any, Optional + + +class PlatformState(Enum): + """Lifecycle state of a chat platform connection.""" + DISCONNECTED = auto() + CONNECTING = auto() + CONNECTED = auto() + ERROR = auto() + + +@dataclass +class ChatMessage: + """Vendor-agnostic representation of a chat message.""" + content: str + author: str + channel_id: str + platform: str + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + message_id: Optional[str] = None + thread_id: Optional[str] = None + attachments: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ChatThread: + """Vendor-agnostic representation of a conversation thread.""" + thread_id: str + title: str + channel_id: str + platform: str + created_at: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + archived: bool = False + message_count: int = 0 + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class InviteInfo: + """Parsed invite extracted from an image or text.""" + url: str + code: str + platform: str + guild_name: Optional[str] = None + source: str = "unknown" # "qr", "vision", "text" + + +@dataclass +class PlatformStatus: + """Current status of a chat platform connection.""" + platform: str + state: PlatformState + token_set: bool + guild_count: int = 0 + thread_count: int = 0 + error: Optional[str] = None + + def to_dict(self) -> dict[str, Any]: + return { + "platform": self.platform, + "state": self.state.name.lower(), + "connected": self.state == PlatformState.CONNECTED, + "token_set": self.token_set, + "guild_count": self.guild_count, + "thread_count": self.thread_count, + "error": self.error, + } + + +class ChatPlatform(ABC): + """Abstract base class for chat platform integrations. + + Lifecycle: + configure(token) -> start() -> [send/receive messages] -> stop() + + All vendors implement this interface. The dashboard routes and + agent code work with ChatPlatform, never with vendor-specific APIs. + """ + + @property + @abstractmethod + def name(self) -> str: + """Platform identifier (e.g., 'discord', 'telegram').""" + + @property + @abstractmethod + def state(self) -> PlatformState: + """Current connection state.""" + + @abstractmethod + async def start(self, token: Optional[str] = None) -> bool: + """Start the platform connection. Returns True on success.""" + + @abstractmethod + async def stop(self) -> None: + """Gracefully disconnect.""" + + @abstractmethod + async def send_message( + self, channel_id: str, content: str, thread_id: Optional[str] = None + ) -> Optional[ChatMessage]: + """Send a message. Optionally within a thread.""" + + @abstractmethod + async def create_thread( + self, channel_id: str, title: str, initial_message: Optional[str] = None + ) -> Optional[ChatThread]: + """Create a new thread in a channel.""" + + @abstractmethod + async def join_from_invite(self, invite_code: str) -> bool: + """Join a server/workspace using an invite code.""" + + @abstractmethod + def status(self) -> PlatformStatus: + """Return current platform status.""" + + @abstractmethod + def save_token(self, token: str) -> None: + """Persist token for restarts.""" + + @abstractmethod + def load_token(self) -> Optional[str]: + """Load persisted token.""" diff --git a/src/chat_bridge/invite_parser.py b/src/chat_bridge/invite_parser.py new file mode 100644 index 0000000..2c48770 --- /dev/null +++ b/src/chat_bridge/invite_parser.py @@ -0,0 +1,166 @@ +"""InviteParser — extract chat platform invite links from images. + +Strategy chain: + 1. QR code detection (pyzbar — fast, no GPU) + 2. Ollama vision OCR (local LLM — handles screenshots with visible URLs) + 3. Regex fallback on raw text input + +Supports Discord invite patterns: + - discord.gg/ + - discord.com/invite/ + - discordapp.com/invite/ + +Usage: + from chat_bridge.invite_parser import invite_parser + + # From image bytes (screenshot or QR photo) + result = await invite_parser.parse_image(image_bytes) + + # From plain text + result = invite_parser.parse_text("Join us at discord.gg/abc123") +""" + +import io +import logging +import re +from typing import Optional + +from chat_bridge.base import InviteInfo + +logger = logging.getLogger(__name__) + +# Patterns for Discord invite URLs +_DISCORD_PATTERNS = [ + re.compile(r"(?:https?://)?discord\.gg/([A-Za-z0-9\-_]+)"), + re.compile(r"(?:https?://)?(?:www\.)?discord(?:app)?\.com/invite/([A-Za-z0-9\-_]+)"), +] + + +def _extract_discord_code(text: str) -> Optional[str]: + """Extract a Discord invite code from text.""" + for pattern in _DISCORD_PATTERNS: + match = pattern.search(text) + if match: + return match.group(1) + return None + + +class InviteParser: + """Multi-strategy invite parser. + + Tries QR detection first (fast), then Ollama vision (local AI), + then regex on raw text. All local, no cloud. + """ + + async def parse_image(self, image_data: bytes) -> Optional[InviteInfo]: + """Extract an invite from image bytes (screenshot or QR photo). + + Tries strategies in order: + 1. QR code decode (pyzbar) + 2. Ollama vision model (local OCR) + """ + result = self._try_qr_decode(image_data) + if result: + return result + + result = await self._try_ollama_vision(image_data) + if result: + return result + + logger.info("No invite found in image via any strategy.") + return None + + def parse_text(self, text: str) -> Optional[InviteInfo]: + """Extract an invite from plain text.""" + code = _extract_discord_code(text) + if code: + return InviteInfo( + url=f"https://discord.gg/{code}", + code=code, + platform="discord", + source="text", + ) + return None + + def _try_qr_decode(self, image_data: bytes) -> Optional[InviteInfo]: + """Strategy 1: Decode QR codes from image using pyzbar.""" + try: + from PIL import Image + from pyzbar.pyzbar import decode as qr_decode + except ImportError: + logger.debug("pyzbar/Pillow not installed, skipping QR strategy.") + return None + + try: + image = Image.open(io.BytesIO(image_data)) + decoded = qr_decode(image) + + for obj in decoded: + text = obj.data.decode("utf-8", errors="ignore") + code = _extract_discord_code(text) + if code: + logger.info("QR decode found Discord invite: %s", code) + return InviteInfo( + url=f"https://discord.gg/{code}", + code=code, + platform="discord", + source="qr", + ) + except Exception as exc: + logger.debug("QR decode failed: %s", exc) + + return None + + async def _try_ollama_vision(self, image_data: bytes) -> Optional[InviteInfo]: + """Strategy 2: Use Ollama vision model for local OCR.""" + try: + import base64 + import httpx + from config import settings + except ImportError: + logger.debug("httpx not available for Ollama vision.") + return None + + try: + b64_image = base64.b64encode(image_data).decode("ascii") + + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{settings.ollama_url}/api/generate", + json={ + "model": settings.ollama_model, + "prompt": ( + "Extract any Discord invite link from this image. " + "Look for URLs like discord.gg/CODE or " + "discord.com/invite/CODE. " + "Reply with ONLY the invite URL, nothing else. " + "If no invite link is found, reply with: NONE" + ), + "images": [b64_image], + "stream": False, + }, + ) + + if resp.status_code != 200: + logger.debug("Ollama vision returned %d", resp.status_code) + return None + + answer = resp.json().get("response", "").strip() + if answer and answer.upper() != "NONE": + code = _extract_discord_code(answer) + if code: + logger.info("Ollama vision found Discord invite: %s", code) + return InviteInfo( + url=f"https://discord.gg/{code}", + code=code, + platform="discord", + source="vision", + ) + except Exception as exc: + logger.debug("Ollama vision strategy failed: %s", exc) + + return None + + +# Module-level singleton +invite_parser = InviteParser() diff --git a/src/chat_bridge/registry.py b/src/chat_bridge/registry.py new file mode 100644 index 0000000..16271c4 --- /dev/null +++ b/src/chat_bridge/registry.py @@ -0,0 +1,74 @@ +"""PlatformRegistry — singleton registry for chat platform vendors. + +Provides a central point for registering, discovering, and managing +all chat platform integrations. Dashboard routes and the agent core +interact with platforms through this registry. + +Usage: + from chat_bridge.registry import platform_registry + + platform_registry.register(discord_vendor) + discord = platform_registry.get("discord") + all_platforms = platform_registry.list_platforms() +""" + +import logging +from typing import Optional + +from chat_bridge.base import ChatPlatform, PlatformStatus + +logger = logging.getLogger(__name__) + + +class PlatformRegistry: + """Thread-safe registry of ChatPlatform vendors.""" + + def __init__(self) -> None: + self._platforms: dict[str, ChatPlatform] = {} + + def register(self, platform: ChatPlatform) -> None: + """Register a chat platform vendor.""" + name = platform.name + if name in self._platforms: + logger.warning("Platform '%s' already registered, replacing.", name) + self._platforms[name] = platform + logger.info("Registered chat platform: %s", name) + + def unregister(self, name: str) -> bool: + """Remove a platform from the registry. Returns True if it existed.""" + if name in self._platforms: + del self._platforms[name] + logger.info("Unregistered chat platform: %s", name) + return True + return False + + def get(self, name: str) -> Optional[ChatPlatform]: + """Get a platform by name.""" + return self._platforms.get(name) + + def list_platforms(self) -> list[PlatformStatus]: + """Return status of all registered platforms.""" + return [p.status() for p in self._platforms.values()] + + async def start_all(self) -> dict[str, bool]: + """Start all registered platforms. Returns name -> success mapping.""" + results = {} + for name, platform in self._platforms.items(): + try: + results[name] = await platform.start() + except Exception as exc: + logger.error("Failed to start platform '%s': %s", name, exc) + results[name] = False + return results + + async def stop_all(self) -> None: + """Stop all registered platforms.""" + for name, platform in self._platforms.items(): + try: + await platform.stop() + except Exception as exc: + logger.error("Error stopping platform '%s': %s", name, exc) + + +# Module-level singleton +platform_registry = PlatformRegistry() diff --git a/src/chat_bridge/vendors/__init__.py b/src/chat_bridge/vendors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/chat_bridge/vendors/discord.py b/src/chat_bridge/vendors/discord.py new file mode 100644 index 0000000..0610884 --- /dev/null +++ b/src/chat_bridge/vendors/discord.py @@ -0,0 +1,400 @@ +"""DiscordVendor — Discord integration via discord.py. + +Implements ChatPlatform with native thread support. Each conversation +with Timmy gets its own Discord thread, keeping channels clean. + +Optional dependency — install with: + pip install ".[discord]" + +Architecture: + DiscordVendor + ├── _client (discord.Client) — handles gateway events + ├── _thread_map — channel_id -> active thread + └── _message_handler — bridges to Timmy agent +""" + +import asyncio +import json +import logging +from pathlib import Path +from typing import Optional + +from chat_bridge.base import ( + ChatMessage, + ChatPlatform, + ChatThread, + InviteInfo, + PlatformState, + PlatformStatus, +) + +logger = logging.getLogger(__name__) + +_STATE_FILE = Path(__file__).parent.parent.parent.parent / "discord_state.json" + + +class DiscordVendor(ChatPlatform): + """Discord integration with native thread conversations. + + Every user interaction creates or continues a Discord thread, + keeping channel history clean and conversations organized. + """ + + def __init__(self) -> None: + self._client = None + self._token: Optional[str] = None + self._state: PlatformState = PlatformState.DISCONNECTED + self._task: Optional[asyncio.Task] = None + self._guild_count: int = 0 + self._active_threads: dict[str, str] = {} # channel_id -> thread_id + + # ── ChatPlatform interface ───────────────────────────────────────────── + + @property + def name(self) -> str: + return "discord" + + @property + def state(self) -> PlatformState: + return self._state + + async def start(self, token: Optional[str] = None) -> bool: + """Start the Discord bot. Returns True on success.""" + if self._state == PlatformState.CONNECTED: + return True + + tok = token or self.load_token() + if not tok: + logger.warning("Discord bot: no token configured, skipping start.") + return False + + try: + import discord + except ImportError: + logger.error( + "discord.py is not installed. " + 'Run: pip install ".[discord]"' + ) + return False + + try: + self._state = PlatformState.CONNECTING + self._token = tok + + intents = discord.Intents.default() + intents.message_content = True + intents.guilds = True + + self._client = discord.Client(intents=intents) + self._register_handlers() + + # Run the client in a background task so we don't block + self._task = asyncio.create_task(self._run_client(tok)) + + # Wait briefly for connection + for _ in range(30): + await asyncio.sleep(0.5) + if self._state == PlatformState.CONNECTED: + logger.info("Discord bot connected (%d guilds).", self._guild_count) + return True + if self._state == PlatformState.ERROR: + return False + + logger.warning("Discord bot: connection timed out.") + self._state = PlatformState.ERROR + return False + + except Exception as exc: + logger.error("Discord bot failed to start: %s", exc) + self._state = PlatformState.ERROR + self._token = None + self._client = None + return False + + async def stop(self) -> None: + """Gracefully disconnect the Discord bot.""" + if self._client and not self._client.is_closed(): + try: + await self._client.close() + logger.info("Discord bot disconnected.") + except Exception as exc: + logger.error("Error stopping Discord bot: %s", exc) + + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + self._state = PlatformState.DISCONNECTED + self._client = None + self._task = None + + async def send_message( + self, channel_id: str, content: str, thread_id: Optional[str] = None + ) -> Optional[ChatMessage]: + """Send a message to a Discord channel or thread.""" + if not self._client or self._state != PlatformState.CONNECTED: + return None + + try: + import discord + + target_id = int(thread_id) if thread_id else int(channel_id) + channel = self._client.get_channel(target_id) + + if channel is None: + channel = await self._client.fetch_channel(target_id) + + msg = await channel.send(content) + + return ChatMessage( + content=content, + author=str(self._client.user), + channel_id=str(msg.channel.id), + platform="discord", + message_id=str(msg.id), + thread_id=thread_id, + ) + except Exception as exc: + logger.error("Failed to send Discord message: %s", exc) + return None + + async def create_thread( + self, channel_id: str, title: str, initial_message: Optional[str] = None + ) -> Optional[ChatThread]: + """Create a new thread in a Discord channel.""" + if not self._client or self._state != PlatformState.CONNECTED: + return None + + try: + channel = self._client.get_channel(int(channel_id)) + if channel is None: + channel = await self._client.fetch_channel(int(channel_id)) + + thread = await channel.create_thread( + name=title[:100], # Discord limits thread names to 100 chars + auto_archive_duration=1440, # 24 hours + ) + + if initial_message: + await thread.send(initial_message) + + self._active_threads[channel_id] = str(thread.id) + + return ChatThread( + thread_id=str(thread.id), + title=title[:100], + channel_id=channel_id, + platform="discord", + ) + except Exception as exc: + logger.error("Failed to create Discord thread: %s", exc) + return None + + async def join_from_invite(self, invite_code: str) -> bool: + """Join a Discord server using an invite code. + + Note: Bot accounts cannot use invite links directly. + This generates an OAuth2 URL for adding the bot to a server. + The invite_code is validated but the actual join requires + the server admin to use the bot's OAuth2 authorization URL. + """ + if not self._client or self._state != PlatformState.CONNECTED: + logger.warning("Discord bot not connected, cannot process invite.") + return False + + try: + import discord + + invite = await self._client.fetch_invite(invite_code) + logger.info( + "Validated invite for server '%s' (code: %s)", + invite.guild.name if invite.guild else "unknown", + invite_code, + ) + return True + except Exception as exc: + logger.error("Invalid Discord invite '%s': %s", invite_code, exc) + return False + + def status(self) -> PlatformStatus: + return PlatformStatus( + platform="discord", + state=self._state, + token_set=bool(self._token), + guild_count=self._guild_count, + thread_count=len(self._active_threads), + ) + + def save_token(self, token: str) -> None: + """Persist token to state file.""" + try: + _STATE_FILE.write_text(json.dumps({"token": token})) + except Exception as exc: + logger.error("Failed to save Discord token: %s", exc) + + def load_token(self) -> Optional[str]: + """Load token from state file or config.""" + try: + if _STATE_FILE.exists(): + data = json.loads(_STATE_FILE.read_text()) + token = data.get("token") + if token: + return token + except Exception as exc: + logger.debug("Could not read discord state file: %s", exc) + + try: + from config import settings + return settings.discord_token or None + except Exception: + return None + + # ── OAuth2 URL generation ────────────────────────────────────────────── + + def get_oauth2_url(self) -> Optional[str]: + """Generate the OAuth2 URL for adding this bot to a server. + + Requires the bot to be connected to read its application ID. + """ + if not self._client or not self._client.user: + return None + + app_id = self._client.user.id + # Permissions: Send Messages, Create Public Threads, Manage Threads, + # Read Message History, Embed Links, Attach Files + permissions = 397284550656 + return ( + f"https://discord.com/oauth2/authorize" + f"?client_id={app_id}&scope=bot" + f"&permissions={permissions}" + ) + + # ── Internal ─────────────────────────────────────────────────────────── + + async def _run_client(self, token: str) -> None: + """Run the discord.py client (blocking call in a task).""" + try: + await self._client.start(token) + except Exception as exc: + logger.error("Discord client error: %s", exc) + self._state = PlatformState.ERROR + + def _register_handlers(self) -> None: + """Register Discord event handlers on the client.""" + + @self._client.event + async def on_ready(): + self._guild_count = len(self._client.guilds) + self._state = PlatformState.CONNECTED + logger.info( + "Discord ready: %s in %d guild(s)", + self._client.user, + self._guild_count, + ) + + @self._client.event + async def on_message(message): + # Ignore our own messages + if message.author == self._client.user: + return + + # Only respond to mentions or DMs + is_dm = not hasattr(message.channel, "guild") or message.channel.guild is None + is_mention = self._client.user in message.mentions + + if not is_dm and not is_mention: + return + + await self._handle_message(message) + + @self._client.event + async def on_disconnect(): + if self._state != PlatformState.DISCONNECTED: + self._state = PlatformState.CONNECTING + logger.warning("Discord disconnected, will auto-reconnect.") + + async def _handle_message(self, message) -> None: + """Process an incoming message and respond via a thread.""" + # Strip the bot mention from the message content + content = message.content + if self._client.user: + content = content.replace(f"<@{self._client.user.id}>", "").strip() + + if not content: + return + + # Create or reuse a thread for this conversation + thread = await self._get_or_create_thread(message) + target = thread or message.channel + + # Run Timmy agent + try: + from timmy.agent import create_timmy + + agent = create_timmy() + run = await asyncio.to_thread(agent.run, content, stream=False) + response = run.content if hasattr(run, "content") else str(run) + except Exception as exc: + logger.error("Timmy error in Discord handler: %s", exc) + response = f"Timmy is offline: {exc}" + + # Discord has a 2000 character limit + for chunk in _chunk_message(response, 2000): + await target.send(chunk) + + async def _get_or_create_thread(self, message): + """Get the active thread for a channel, or create one. + + If the message is already in a thread, use that thread. + Otherwise, create a new thread from the message. + """ + try: + import discord + + # Already in a thread — just use it + if isinstance(message.channel, discord.Thread): + return message.channel + + # DM channels don't support threads + if isinstance(message.channel, discord.DMChannel): + return None + + # Create a thread from this message + thread_name = f"Timmy | {message.author.display_name}" + thread = await message.create_thread( + name=thread_name[:100], + auto_archive_duration=1440, + ) + channel_id = str(message.channel.id) + self._active_threads[channel_id] = str(thread.id) + return thread + + except Exception as exc: + logger.debug("Could not create thread: %s", exc) + return None + + +def _chunk_message(text: str, max_len: int = 2000) -> list[str]: + """Split a message into chunks that fit Discord's character limit.""" + if len(text) <= max_len: + return [text] + + chunks = [] + while text: + if len(text) <= max_len: + chunks.append(text) + break + # Try to split at a newline + split_at = text.rfind("\n", 0, max_len) + if split_at == -1: + split_at = max_len + chunks.append(text[:split_at]) + text = text[split_at:].lstrip("\n") + return chunks + + +# Module-level singleton +discord_bot = DiscordVendor() diff --git a/src/config.py b/src/config.py index 41c1d40..f90606f 100644 --- a/src/config.py +++ b/src/config.py @@ -16,6 +16,9 @@ class Settings(BaseSettings): # Telegram bot token — set via TELEGRAM_TOKEN env var or the /telegram/setup endpoint telegram_token: str = "" + # Discord bot token — set via DISCORD_TOKEN env var or the /discord/setup endpoint + discord_token: str = "" + # ── AirLLM / backend selection ─────────────────────────────────────────── # "ollama" — always use Ollama (default, safe everywhere) # "airllm" — always use AirLLM (requires pip install ".[bigbrain]") diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 8d89f24..da1be36 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -25,6 +25,7 @@ from dashboard.routes.swarm_internal import router as swarm_internal_router from dashboard.routes.tools import router as tools_router from dashboard.routes.spark import router as spark_router from dashboard.routes.creative import router as creative_router +from dashboard.routes.discord import router as discord_router logging.basicConfig( level=logging.INFO, @@ -108,8 +109,15 @@ async def lifespan(app: FastAPI): from telegram_bot.bot import telegram_bot await telegram_bot.start() + # Auto-start Discord bot and register in platform registry + from chat_bridge.vendors.discord import discord_bot + from chat_bridge.registry import platform_registry + platform_registry.register(discord_bot) + await discord_bot.start() + yield + await discord_bot.stop() await telegram_bot.stop() task.cancel() try: @@ -145,6 +153,7 @@ app.include_router(swarm_internal_router) app.include_router(tools_router) app.include_router(spark_router) app.include_router(creative_router) +app.include_router(discord_router) @app.get("/", response_class=HTMLResponse) diff --git a/src/dashboard/routes/discord.py b/src/dashboard/routes/discord.py new file mode 100644 index 0000000..28629a5 --- /dev/null +++ b/src/dashboard/routes/discord.py @@ -0,0 +1,140 @@ +"""Dashboard routes for Discord bot setup, status, and invite-from-image. + +Endpoints: + POST /discord/setup — configure bot token + GET /discord/status — connection state + guild count + POST /discord/join — paste screenshot → extract invite → join + GET /discord/oauth-url — get the bot's OAuth2 authorization URL +""" + +from fastapi import APIRouter, File, Form, UploadFile +from pydantic import BaseModel +from typing import Optional + +router = APIRouter(prefix="/discord", tags=["discord"]) + + +class TokenPayload(BaseModel): + token: str + + +@router.post("/setup") +async def setup_discord(payload: TokenPayload): + """Configure the Discord bot token and (re)start the bot. + + Send POST with JSON body: {"token": ""} + Get the token from https://discord.com/developers/applications + """ + from chat_bridge.vendors.discord import discord_bot + + token = payload.token.strip() + if not token: + return {"ok": False, "error": "Token cannot be empty."} + + discord_bot.save_token(token) + + if discord_bot.state.name == "CONNECTED": + await discord_bot.stop() + + success = await discord_bot.start(token=token) + if success: + return {"ok": True, "message": "Discord bot connected successfully."} + return { + "ok": False, + "error": ( + "Failed to start bot. Check that the token is correct and " + 'discord.py is installed: pip install ".[discord]"' + ), + } + + +@router.get("/status") +async def discord_status(): + """Return current Discord bot status.""" + from chat_bridge.vendors.discord import discord_bot + + return discord_bot.status().to_dict() + + +@router.post("/join") +async def join_from_image( + image: Optional[UploadFile] = File(None), + invite_url: Optional[str] = Form(None), +): + """Extract a Discord invite from a screenshot or text and validate it. + + Accepts either: + - An uploaded image (screenshot of invite or QR code) + - A plain text invite URL + + The bot validates the invite and returns the OAuth2 URL for the + server admin to authorize the bot. + """ + from chat_bridge.invite_parser import invite_parser + from chat_bridge.vendors.discord import discord_bot + + invite_info = None + + # Try image first + if image and image.filename: + image_data = await image.read() + if image_data: + invite_info = await invite_parser.parse_image(image_data) + + # Fall back to text + if not invite_info and invite_url: + invite_info = invite_parser.parse_text(invite_url) + + if not invite_info: + return { + "ok": False, + "error": ( + "No Discord invite found. " + "Paste a screenshot with a visible invite link or QR code, " + "or enter the invite URL directly." + ), + } + + # Validate the invite + valid = await discord_bot.join_from_invite(invite_info.code) + + result = { + "ok": True, + "invite": { + "code": invite_info.code, + "url": invite_info.url, + "source": invite_info.source, + "platform": invite_info.platform, + }, + "validated": valid, + } + + # Include OAuth2 URL if bot is connected + oauth_url = discord_bot.get_oauth2_url() + if oauth_url: + result["oauth2_url"] = oauth_url + result["message"] = ( + "Invite validated. Share this OAuth2 URL with the server admin " + "to add Timmy to the server." + ) + else: + result["message"] = ( + "Invite found but bot is not connected. " + "Configure a bot token first via /discord/setup." + ) + + return result + + +@router.get("/oauth-url") +async def discord_oauth_url(): + """Get the bot's OAuth2 authorization URL for adding to servers.""" + from chat_bridge.vendors.discord import discord_bot + + url = discord_bot.get_oauth2_url() + if url: + return {"ok": True, "url": url} + return { + "ok": False, + "error": "Bot is not connected. Configure a token first.", + } diff --git a/tests/conftest.py b/tests/conftest.py index fedb232..c875503 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,14 @@ for _mod in [ # without the package installed. "telegram", "telegram.ext", + # discord.py is optional (discord extra) — stub so tests run + # without the package installed. + "discord", + "discord.ext", + "discord.ext.commands", + # pyzbar is optional (for QR code invite detection) + "pyzbar", + "pyzbar.pyzbar", ]: sys.modules.setdefault(_mod, MagicMock()) diff --git a/tests/test_chat_bridge.py b/tests/test_chat_bridge.py new file mode 100644 index 0000000..25645fa --- /dev/null +++ b/tests/test_chat_bridge.py @@ -0,0 +1,268 @@ +"""Tests for the chat_bridge base classes, registry, and invite parser.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from chat_bridge.base import ( + ChatMessage, + ChatPlatform, + ChatThread, + InviteInfo, + PlatformState, + PlatformStatus, +) +from chat_bridge.registry import PlatformRegistry + + +# ── Base dataclass tests ─────────────────────────────────────────────────────── + + +class TestChatMessage: + def test_create_message(self): + msg = ChatMessage( + content="Hello", + author="user1", + channel_id="123", + platform="test", + ) + assert msg.content == "Hello" + assert msg.author == "user1" + assert msg.platform == "test" + assert msg.thread_id is None + assert msg.attachments == [] + + def test_message_with_thread(self): + msg = ChatMessage( + content="Reply", + author="bot", + channel_id="123", + platform="discord", + thread_id="456", + ) + assert msg.thread_id == "456" + + +class TestChatThread: + def test_create_thread(self): + thread = ChatThread( + thread_id="t1", + title="Timmy | user1", + channel_id="c1", + platform="discord", + ) + assert thread.thread_id == "t1" + assert thread.archived is False + assert thread.message_count == 0 + + +class TestInviteInfo: + def test_create_invite(self): + invite = InviteInfo( + url="https://discord.gg/abc123", + code="abc123", + platform="discord", + source="qr", + ) + assert invite.code == "abc123" + assert invite.source == "qr" + + +class TestPlatformStatus: + def test_to_dict(self): + status = PlatformStatus( + platform="discord", + state=PlatformState.CONNECTED, + token_set=True, + guild_count=3, + ) + d = status.to_dict() + assert d["connected"] is True + assert d["platform"] == "discord" + assert d["guild_count"] == 3 + assert d["state"] == "connected" + + def test_disconnected_status(self): + status = PlatformStatus( + platform="test", + state=PlatformState.DISCONNECTED, + token_set=False, + ) + d = status.to_dict() + assert d["connected"] is False + + +# ── PlatformRegistry tests ──────────────────────────────────────────────────── + + +class _FakePlatform(ChatPlatform): + """Minimal ChatPlatform for testing the registry.""" + + def __init__(self, platform_name: str = "fake"): + self._name = platform_name + self._state = PlatformState.DISCONNECTED + + @property + def name(self) -> str: + return self._name + + @property + def state(self) -> PlatformState: + return self._state + + async def start(self, token=None) -> bool: + self._state = PlatformState.CONNECTED + return True + + async def stop(self) -> None: + self._state = PlatformState.DISCONNECTED + + async def send_message(self, channel_id, content, thread_id=None): + return ChatMessage( + content=content, author="bot", channel_id=channel_id, platform=self._name + ) + + async def create_thread(self, channel_id, title, initial_message=None): + return ChatThread( + thread_id="t1", title=title, channel_id=channel_id, platform=self._name + ) + + async def join_from_invite(self, invite_code) -> bool: + return True + + def status(self): + return PlatformStatus( + platform=self._name, + state=self._state, + token_set=False, + ) + + def save_token(self, token): + pass + + def load_token(self): + return None + + +class TestPlatformRegistry: + def test_register_and_get(self): + reg = PlatformRegistry() + p = _FakePlatform("test1") + reg.register(p) + assert reg.get("test1") is p + + def test_get_missing(self): + reg = PlatformRegistry() + assert reg.get("nonexistent") is None + + def test_unregister(self): + reg = PlatformRegistry() + p = _FakePlatform("test1") + reg.register(p) + assert reg.unregister("test1") is True + assert reg.get("test1") is None + + def test_unregister_missing(self): + reg = PlatformRegistry() + assert reg.unregister("nope") is False + + def test_list_platforms(self): + reg = PlatformRegistry() + reg.register(_FakePlatform("a")) + reg.register(_FakePlatform("b")) + statuses = reg.list_platforms() + assert len(statuses) == 2 + names = {s.platform for s in statuses} + assert names == {"a", "b"} + + @pytest.mark.asyncio + async def test_start_all(self): + reg = PlatformRegistry() + reg.register(_FakePlatform("x")) + reg.register(_FakePlatform("y")) + results = await reg.start_all() + assert results == {"x": True, "y": True} + + @pytest.mark.asyncio + async def test_stop_all(self): + reg = PlatformRegistry() + p = _FakePlatform("z") + reg.register(p) + await reg.start_all() + assert p.state == PlatformState.CONNECTED + await reg.stop_all() + assert p.state == PlatformState.DISCONNECTED + + def test_replace_existing(self): + reg = PlatformRegistry() + p1 = _FakePlatform("dup") + p2 = _FakePlatform("dup") + reg.register(p1) + reg.register(p2) + assert reg.get("dup") is p2 + + +# ── InviteParser tests ──────────────────────────────────────────────────────── + + +class TestInviteParser: + def test_parse_text_discord_gg(self): + from chat_bridge.invite_parser import invite_parser + + result = invite_parser.parse_text("Join us at https://discord.gg/abc123!") + assert result is not None + assert result.code == "abc123" + assert result.platform == "discord" + assert result.source == "text" + + def test_parse_text_discord_com_invite(self): + from chat_bridge.invite_parser import invite_parser + + result = invite_parser.parse_text( + "Link: https://discord.com/invite/myServer2024" + ) + assert result is not None + assert result.code == "myServer2024" + + def test_parse_text_discordapp(self): + from chat_bridge.invite_parser import invite_parser + + result = invite_parser.parse_text( + "https://discordapp.com/invite/test-code" + ) + assert result is not None + assert result.code == "test-code" + + def test_parse_text_no_invite(self): + from chat_bridge.invite_parser import invite_parser + + result = invite_parser.parse_text("Hello world, no links here") + assert result is None + + def test_parse_text_bare_discord_gg(self): + from chat_bridge.invite_parser import invite_parser + + result = invite_parser.parse_text("discord.gg/xyz789") + assert result is not None + assert result.code == "xyz789" + + @pytest.mark.asyncio + async def test_parse_image_no_deps(self): + """parse_image returns None when pyzbar/Pillow are not installed.""" + from chat_bridge.invite_parser import InviteParser + + parser = InviteParser() + # With mocked pyzbar, this should gracefully return None + result = await parser.parse_image(b"fake-image-bytes") + assert result is None + + +class TestExtractDiscordCode: + def test_various_formats(self): + from chat_bridge.invite_parser import _extract_discord_code + + assert _extract_discord_code("discord.gg/abc") == "abc" + assert _extract_discord_code("https://discord.gg/test") == "test" + assert _extract_discord_code("http://discord.gg/http") == "http" + assert _extract_discord_code("discord.com/invite/xyz") == "xyz" + assert _extract_discord_code("no link here") is None + assert _extract_discord_code("") is None diff --git a/tests/test_discord_vendor.py b/tests/test_discord_vendor.py new file mode 100644 index 0000000..f06528e --- /dev/null +++ b/tests/test_discord_vendor.py @@ -0,0 +1,225 @@ +"""Tests for the Discord vendor and dashboard routes.""" + +import json +import pytest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from chat_bridge.base import PlatformState + + +# ── DiscordVendor unit tests ────────────────────────────────────────────────── + + +class TestDiscordVendor: + def test_name(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + assert vendor.name == "discord" + + def test_initial_state(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + assert vendor.state == PlatformState.DISCONNECTED + + def test_status_disconnected(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + status = vendor.status() + assert status.platform == "discord" + assert status.state == PlatformState.DISCONNECTED + assert status.token_set is False + assert status.guild_count == 0 + + def test_save_and_load_token(self, tmp_path, monkeypatch): + from chat_bridge.vendors import discord as discord_mod + from chat_bridge.vendors.discord import DiscordVendor + + state_file = tmp_path / "discord_state.json" + monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file) + + vendor = DiscordVendor() + vendor.save_token("test-token-abc") + + assert state_file.exists() + data = json.loads(state_file.read_text()) + assert data["token"] == "test-token-abc" + + loaded = vendor.load_token() + assert loaded == "test-token-abc" + + def test_load_token_missing_file(self, tmp_path, monkeypatch): + from chat_bridge.vendors import discord as discord_mod + from chat_bridge.vendors.discord import DiscordVendor + + state_file = tmp_path / "nonexistent.json" + monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file) + + vendor = DiscordVendor() + # Falls back to config.settings.discord_token + token = vendor.load_token() + # Default discord_token is "" which becomes None + assert token is None + + @pytest.mark.asyncio + async def test_start_no_token(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + result = await vendor.start(token=None) + assert result is False + + @pytest.mark.asyncio + async def test_start_import_error(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + # Simulate discord.py not installed by making import fail + with patch.dict("sys.modules", {"discord": None}): + result = await vendor.start(token="fake-token") + assert result is False + + @pytest.mark.asyncio + async def test_stop_when_disconnected(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + # Should not raise + await vendor.stop() + assert vendor.state == PlatformState.DISCONNECTED + + def test_get_oauth2_url_no_client(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + assert vendor.get_oauth2_url() is None + + def test_get_oauth2_url_with_client(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + mock_client = MagicMock() + mock_client.user.id = 123456789 + vendor._client = mock_client + url = vendor.get_oauth2_url() + assert "123456789" in url + assert "oauth2/authorize" in url + + @pytest.mark.asyncio + async def test_send_message_not_connected(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + result = await vendor.send_message("123", "hello") + assert result is None + + @pytest.mark.asyncio + async def test_create_thread_not_connected(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + result = await vendor.create_thread("123", "Test Thread") + assert result is None + + @pytest.mark.asyncio + async def test_join_from_invite_not_connected(self): + from chat_bridge.vendors.discord import DiscordVendor + + vendor = DiscordVendor() + result = await vendor.join_from_invite("abc123") + assert result is False + + +class TestChunkMessage: + def test_short_message(self): + from chat_bridge.vendors.discord import _chunk_message + + chunks = _chunk_message("Hello!", 2000) + assert chunks == ["Hello!"] + + def test_long_message(self): + from chat_bridge.vendors.discord import _chunk_message + + text = "a" * 5000 + chunks = _chunk_message(text, 2000) + assert len(chunks) == 3 + assert all(len(c) <= 2000 for c in chunks) + assert "".join(chunks) == text + + def test_split_at_newline(self): + from chat_bridge.vendors.discord import _chunk_message + + text = "Line1\n" + "x" * 1990 + "\nLine3" + chunks = _chunk_message(text, 2000) + assert len(chunks) >= 2 + assert chunks[0].startswith("Line1") + + +# ── Discord route tests ─────────────────────────────────────────────────────── + + +class TestDiscordRoutes: + def test_status_endpoint(self, client): + resp = client.get("/discord/status") + assert resp.status_code == 200 + data = resp.json() + assert data["platform"] == "discord" + assert "connected" in data + + def test_setup_empty_token(self, client): + resp = client.post("/discord/setup", json={"token": ""}) + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is False + assert "empty" in data["error"].lower() + + def test_setup_with_token(self, client): + """Setup with a token — bot won't actually connect but route works.""" + with patch( + "chat_bridge.vendors.discord.DiscordVendor.start", + new_callable=AsyncMock, + return_value=False, + ): + resp = client.post( + "/discord/setup", json={"token": "fake-token-123"} + ) + assert resp.status_code == 200 + data = resp.json() + # Will fail because discord.py is mocked, but route handles it + assert "ok" in data + + def test_join_no_input(self, client): + resp = client.post("/discord/join") + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is False + assert "no discord invite" in data["error"].lower() + + def test_join_with_text_invite(self, client): + with patch( + "chat_bridge.vendors.discord.DiscordVendor.join_from_invite", + new_callable=AsyncMock, + return_value=True, + ): + resp = client.post( + "/discord/join", + data={"invite_url": "https://discord.gg/testcode"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is True + assert data["invite"]["code"] == "testcode" + assert data["invite"]["source"] == "text" + + def test_oauth_url_not_connected(self, client): + from chat_bridge.vendors.discord import discord_bot + + # Reset singleton so it has no client + discord_bot._client = None + resp = client.get("/discord/oauth-url") + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is False