From ae3bb1cc217bbc32cda89f1af8601205b3e2f7c3 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone <8633216+AlexanderWhitestone@users.noreply.github.com> Date: Sun, 8 Mar 2026 12:50:44 -0400 Subject: [PATCH] feat: code quality audit + autoresearch integration + infra hardening (#150) --- .env.example | 20 + .github/workflows/tests.yml | 40 +- .pre-commit-config.yaml | 10 +- CLAUDE.md | 2 +- Dockerfile | 15 +- Makefile | 5 + docker-compose.prod.yml | 56 +++ pyproject.toml | 6 + src/brain/__init__.py | 2 +- src/brain/client.py | 259 ++++++------ src/brain/embeddings.py | 27 +- src/brain/memory.py | 25 +- src/brain/schema.py | 8 +- src/brain/worker.py | 145 ++++--- src/config.py | 9 + src/dashboard/app.py | 80 ++-- src/dashboard/middleware/__init__.py | 2 +- src/dashboard/middleware/csrf.py | 115 +++--- src/dashboard/middleware/request_logging.py | 75 ++-- src/dashboard/middleware/security_headers.py | 75 ++-- src/dashboard/models/calm.py | 28 +- src/dashboard/models/database.py | 7 +- src/dashboard/routes/agents.py | 10 +- src/dashboard/routes/briefing.py | 5 +- src/dashboard/routes/calm.py | 75 ++-- src/dashboard/routes/chat_api.py | 9 +- src/dashboard/routes/discord.py | 3 +- src/dashboard/routes/experiments.py | 77 ++++ src/dashboard/routes/grok.py | 39 +- src/dashboard/routes/health.py | 29 +- src/dashboard/routes/marketplace.py | 9 +- src/dashboard/routes/memory.py | 14 +- src/dashboard/routes/models.py | 17 +- src/dashboard/routes/router.py | 10 +- src/dashboard/routes/spark.py | 35 +- src/dashboard/routes/swarm.py | 26 +- src/dashboard/routes/system.py | 66 +++- src/dashboard/routes/tasks.py | 130 +++--- src/dashboard/routes/thinking.py | 2 +- src/dashboard/routes/tools.py | 6 +- src/dashboard/routes/voice.py | 6 +- src/dashboard/routes/work_orders.py | 50 ++- src/dashboard/store.py | 6 +- src/dashboard/templates/experiments.html | 90 +++++ src/infrastructure/error_capture.py | 4 +- src/infrastructure/events/broadcaster.py | 29 +- src/infrastructure/events/bus.py | 68 ++-- src/infrastructure/hands/__init__.py | 2 +- src/infrastructure/hands/git.py | 26 +- src/infrastructure/hands/shell.py | 24 +- src/infrastructure/hands/tools.py | 5 +- src/infrastructure/models/__init__.py | 14 +- src/infrastructure/models/multimodal.py | 369 ++++++++++++------ src/infrastructure/models/registry.py | 39 +- src/infrastructure/notifications/push.py | 19 +- src/infrastructure/openfang/client.py | 20 +- src/infrastructure/openfang/tools.py | 2 + src/infrastructure/router/__init__.py | 2 +- src/infrastructure/router/api.py | 48 ++- src/infrastructure/router/cascade.py | 294 +++++++------- src/infrastructure/ws_manager/handler.py | 61 +-- src/integrations/chat_bridge/base.py | 13 +- src/integrations/chat_bridge/invite_parser.py | 2 + .../chat_bridge/vendors/discord.py | 15 +- src/integrations/paperclip/models.py | 1 - src/integrations/paperclip/task_runner.py | 12 +- src/integrations/shortcuts/siri.py | 1 + src/integrations/telegram_bot/bot.py | 7 +- src/integrations/voice/nlu.py | 87 +++-- src/spark/advisor.py | 225 ++++++----- src/spark/eidos.py | 28 +- src/spark/engine.py | 36 +- src/spark/memory.py | 38 +- src/swarm/event_log.py | 38 +- src/swarm/task_queue/models.py | 10 +- src/timmy/agent.py | 84 ++-- src/timmy/agent_core/interface.py | 182 +++++---- src/timmy/agent_core/ollama_adapter.py | 111 +++--- src/timmy/agentic_loop.py | 95 +++-- src/timmy/agents/base.py | 52 +-- src/timmy/agents/timmy.py | 192 +++++---- src/timmy/approvals.py | 23 +- src/timmy/autoresearch.py | 214 ++++++++++ src/timmy/backends.py | 41 +- src/timmy/briefing.py | 25 +- src/timmy/cascade_adapter.py | 37 +- src/timmy/conversation.py | 152 ++++++-- src/timmy/memory/vector_store.py | 157 ++++---- src/timmy/memory_system.py | 147 +++---- src/timmy/semantic_memory.py | 136 ++++--- src/timmy/session.py | 12 +- src/timmy/session_logger.py | 2 +- src/timmy/thinking.py | 42 +- src/timmy/tools.py | 95 ++++- src/timmy/tools_delegation/__init__.py | 8 +- src/timmy/tools_intro/__init__.py | 24 +- src/timmy_serve/app.py | 16 +- src/timmy_serve/cli.py | 1 + src/timmy_serve/inter_agent.py | 9 +- src/timmy_serve/voice_tts.py | 1 + tests/brain/test_brain_client.py | 67 ++-- tests/brain/test_brain_worker.py | 43 +- tests/brain/test_unified_memory.py | 16 +- tests/conftest.py | 19 +- tests/conftest_markers.py | 10 +- tests/dashboard/middleware/test_csrf.py | 98 +++-- .../dashboard/middleware/test_csrf_bypass.py | 30 +- .../test_csrf_bypass_vulnerability.py | 43 +- .../middleware/test_csrf_traversal.py | 15 +- .../middleware/test_request_logging.py | 51 +-- .../middleware/test_security_headers.py | 26 +- tests/dashboard/test_briefing.py | 24 +- tests/dashboard/test_calm.py | 35 +- tests/dashboard/test_chat_api.py | 1 - tests/dashboard/test_dashboard.py | 1 - tests/dashboard/test_experiments_route.py | 41 ++ tests/dashboard/test_input_validation.py | 32 +- tests/dashboard/test_local_models.py | 16 +- tests/dashboard/test_memory_api.py | 2 + tests/dashboard/test_middleware_migration.py | 6 + tests/dashboard/test_mission_control.py | 3 +- tests/dashboard/test_mobile_scenarios.py | 59 +-- tests/dashboard/test_paperclip_routes.py | 5 +- tests/dashboard/test_round4_fixes.py | 48 ++- tests/dashboard/test_security_headers.py | 22 +- tests/dashboard/test_tasks_api.py | 26 +- tests/dashboard/test_work_orders_api.py | 50 ++- tests/e2e/test_agentic_chain.py | 86 ++-- tests/e2e/test_ollama_integration.py | 40 +- tests/fixtures/media.py | 24 +- tests/functional/conftest.py | 48 ++- tests/functional/test_cli.py | 1 - tests/functional/test_docker_swarm.py | 6 +- tests/functional/test_fast_e2e.py | 62 +-- tests/functional/test_ollama_chat.py | 98 +++-- tests/functional/test_setup_prod.py | 63 ++- tests/functional/test_ui_selenium.py | 43 +- tests/infrastructure/test_error_capture.py | 13 +- .../infrastructure/test_event_broadcaster.py | 17 +- tests/infrastructure/test_event_bus.py | 4 +- .../infrastructure/test_functional_router.py | 84 ++-- tests/infrastructure/test_model_registry.py | 11 +- tests/infrastructure/test_models_api.py | 77 +--- tests/infrastructure/test_router_api.py | 270 +++++++------ tests/integrations/test_chat_bridge.py | 16 +- tests/integrations/test_discord_vendor.py | 10 +- tests/integrations/test_paperclip_bridge.py | 2 +- .../test_paperclip_task_runner.py | 189 ++++++--- tests/integrations/test_shortcuts.py | 2 +- tests/integrations/test_telegram_bot.py | 39 +- tests/integrations/test_voice_nlu.py | 4 +- .../integrations/test_voice_tts_functional.py | 3 +- tests/integrations/test_websocket.py | 1 + tests/integrations/test_websocket_extended.py | 5 +- tests/security/test_security_fixes_xss.py | 47 ++- tests/security/test_security_regression.py | 4 +- tests/security/test_xss_vulnerabilities.py | 7 +- tests/spark/test_spark.py | 96 +++-- tests/spark/test_spark_tools_creative.py | 2 +- tests/test_agentic_loop.py | 145 ++++--- tests/test_hands_git.py | 7 +- tests/test_hands_shell.py | 7 +- tests/test_openfang_client.py | 21 +- tests/test_setup_script.py | 26 +- tests/test_smoke.py | 88 +++-- tests/timmy/test_agent.py | 123 +++--- tests/timmy/test_agent_core.py | 120 ++++-- tests/timmy/test_agents_timmy.py | 23 +- tests/timmy/test_api_rate_limiting.py | 20 +- tests/timmy/test_approvals.py | 19 +- tests/timmy/test_autoresearch.py | 179 +++++++++ tests/timmy/test_backends.py | 44 ++- tests/timmy/test_conversation.py | 1 + tests/timmy/test_grok_backend.py | 28 +- tests/timmy/test_introspection.py | 4 +- tests/timmy/test_ollama_timeout.py | 24 +- tests/timmy/test_prompts.py | 2 +- tests/timmy/test_semantic_memory.py | 39 +- tests/timmy/test_session.py | 25 +- tests/timmy/test_session_logging.py | 5 +- tests/timmy/test_thinking.py | 49 ++- tests/timmy/test_timmy_tools.py | 16 +- tests/timmy/test_tools_extended.py | 17 +- tests/timmy/test_vector_store.py | 91 ++--- tests/timmy_serve/test_inter_agent.py | 6 +- tox.ini | 4 +- 186 files changed, 5129 insertions(+), 3289 deletions(-) create mode 100644 docker-compose.prod.yml create mode 100644 src/dashboard/routes/experiments.py create mode 100644 src/dashboard/templates/experiments.html create mode 100644 src/timmy/autoresearch.py create mode 100644 tests/dashboard/test_experiments_route.py create mode 100644 tests/timmy/test_autoresearch.py diff --git a/.env.example b/.env.example index 47ca04d..c8b346e 100644 --- a/.env.example +++ b/.env.example @@ -71,3 +71,23 @@ # Requires: pip install ".[discord]" # Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots) # DISCORD_TOKEN= + +# ── Autoresearch — autonomous ML experiment loops ──────────────────────────── +# Enable autonomous experiment loops (Karpathy autoresearch pattern). +# AUTORESEARCH_ENABLED=false +# AUTORESEARCH_WORKSPACE=data/experiments +# AUTORESEARCH_TIME_BUDGET=300 +# AUTORESEARCH_MAX_ITERATIONS=100 +# AUTORESEARCH_METRIC=val_bpb + +# ── Docker Production ──────────────────────────────────────────────────────── +# When deploying with docker-compose.prod.yml: +# - Containers run as non-root user "timmy" (defined in Dockerfile) +# - No source bind mounts — code is baked into the image +# - Set TIMMY_ENV=production to enforce security checks +# - All secrets below MUST be set before production deployment +# +# Taskosaur secrets (change from dev defaults): +# TASKOSAUR_JWT_SECRET= +# TASKOSAUR_JWT_REFRESH_SECRET= +# TASKOSAUR_ENCRYPTION_KEY= diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 202a878..1371459 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,8 +7,30 @@ on: branches: ["**"] jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install linters + run: pip install black==23.12.1 isort==5.13.2 bandit==1.7.5 + + - name: Check formatting (black) + run: black --check --line-length 100 src/ tests/ + + - name: Check import order (isort) + run: isort --check --profile black --line-length 100 src/ tests/ + + - name: Security scan (bandit) + run: bandit -r src/ -ll -s B101,B104,B307,B310,B324,B601,B608 -q + test: runs-on: ubuntu-latest + needs: lint # Required for publish-unit-test-result-action to post check runs and PR comments permissions: @@ -22,7 +44,15 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.11" - cache: "pip" + + - name: Cache Poetry virtualenv + uses: actions/cache@v4 + with: + path: | + ~/.cache/pypoetry + ~/.cache/pip + key: poetry-${{ hashFiles('poetry.lock') }} + restore-keys: poetry- - name: Install dependencies run: | @@ -60,3 +90,11 @@ jobs: name: coverage-report path: reports/coverage.xml retention-days: 14 + + docker-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Build Docker image + run: DOCKER_BUILDKIT=1 docker build -t timmy-time:ci . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 596cf77..f7baeb5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,12 +51,12 @@ repos: exclude: ^tests/ stages: [manual] - # Full test suite with 30-second wall-clock limit. - # Current baseline: ~18s. If tests get slow, this blocks the commit. + # Unit tests only with 30-second wall-clock limit. + # Runs only fast unit tests on commit; full suite runs in CI. - repo: local hooks: - id: pytest-fast - name: pytest (30s limit) + name: pytest unit (30s limit) entry: timeout 30 poetry run pytest language: system types: [python] @@ -68,4 +68,8 @@ repos: - -q - --tb=short - --timeout=10 + - -m + - unit + - -p + - no:xdist verbose: true diff --git a/CLAUDE.md b/CLAUDE.md index ac86d95..4252417 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -56,7 +56,7 @@ make test-cov # With coverage (term-missing + XML) - **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest - **FastAPI testing:** Use the `client` fixture - **Async:** `asyncio_mode = "auto"` — async tests detected automatically -- **Coverage threshold:** 60% (`fail_under` in `pyproject.toml`) +- **Coverage threshold:** 73% (`fail_under` in `pyproject.toml`) --- diff --git a/Dockerfile b/Dockerfile index 4969337..1542baf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ # timmy-time:latest \ # python -m swarm.agent_runner --agent-id w1 --name Worker-1 -# ── Stage 1: Builder — export deps via Poetry, install via pip ────────────── +# ── Stage 1: Builder — install deps via Poetry ────────────────────────────── FROM python:3.12-slim AS builder RUN apt-get update && apt-get install -y --no-install-recommends \ @@ -20,18 +20,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ WORKDIR /build -# Install Poetry + export plugin (only needed for export, not in runtime) -RUN pip install --no-cache-dir poetry poetry-plugin-export +# Install Poetry (only needed to resolve deps, not in runtime) +RUN pip install --no-cache-dir poetry # Copy dependency files only (layer caching) COPY pyproject.toml poetry.lock ./ -# Export pinned requirements and install with pip cache mount -RUN poetry export --extras swarm --extras telegram --extras discord --without-hashes \ - -f requirements.txt -o requirements.txt - -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --no-cache-dir -r requirements.txt +# Install deps directly from lock file (no virtualenv, no export plugin needed) +RUN poetry config virtualenvs.create false && \ + poetry install --only main --extras telegram --extras discord --no-interaction # ── Stage 2: Runtime ─────────────────────────────────────────────────────── FROM python:3.12-slim AS base diff --git a/Makefile b/Makefile index a7c0365..c7d3f71 100644 --- a/Makefile +++ b/Makefile @@ -210,6 +210,11 @@ docker-up: mkdir -p data docker compose up -d dashboard +docker-prod: + mkdir -p data + DOCKER_BUILDKIT=1 docker build -t timmy-time:latest . + docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d dashboard + docker-down: docker compose down diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml new file mode 100644 index 0000000..5c92fe4 --- /dev/null +++ b/docker-compose.prod.yml @@ -0,0 +1,56 @@ +# ── Production Compose Overlay ───────────────────────────────────────────────── +# +# Usage: +# make docker-prod # build + start with prod settings +# docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d +# +# Differences from dev: +# - Runs as non-root user (timmy) from Dockerfile +# - No bind mounts — uses image-baked source only +# - Named volumes only (no host path dependencies) +# - Read-only root filesystem with tmpfs for /tmp +# - Resource limits enforced +# - Secrets passed via environment variables (set in .env) +# +# Security note: Set all secrets in .env before deploying. +# Required: L402_HMAC_SECRET, L402_MACAROON_SECRET +# Recommended: TASKOSAUR_JWT_SECRET, TASKOSAUR_ENCRYPTION_KEY + +services: + + dashboard: + # Remove dev-only root user override — use Dockerfile's USER timmy + user: "" + read_only: true + tmpfs: + - /tmp:size=100M + volumes: + # Override: named volume only, no host bind mounts + - timmy-data:/app/data + # Remove ./src and ./static bind mounts (use baked-in image files) + environment: + DEBUG: "false" + TIMMY_ENV: "production" + deploy: + resources: + limits: + cpus: "2.0" + memory: 2G + + celery-worker: + user: "" + read_only: true + tmpfs: + - /tmp:size=100M + volumes: + - timmy-data:/app/data + deploy: + resources: + limits: + cpus: "1.0" + memory: 1G + +# Override timmy-data to use a simple named volume (no host bind) +volumes: + timmy-data: + driver: local diff --git a/pyproject.toml b/pyproject.toml index a03ed86..421a2dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,12 @@ markers = [ "skip_ci: Skip in CI environment (local development only)", ] +[tool.isort] +profile = "black" +line_length = 100 +src_paths = ["src", "tests"] +known_first_party = ["brain", "config", "dashboard", "infrastructure", "integrations", "spark", "swarm", "timmy", "timmy_serve"] + [tool.coverage.run] source = ["src"] omit = [ diff --git a/src/brain/__init__.py b/src/brain/__init__.py index 5555b48..c66ca0e 100644 --- a/src/brain/__init__.py +++ b/src/brain/__init__.py @@ -11,9 +11,9 @@ upgrade to distributed rqlite over Tailscale — same API, replicated. """ from brain.client import BrainClient -from brain.worker import DistributedWorker from brain.embeddings import LocalEmbedder from brain.memory import UnifiedMemory, get_memory +from brain.worker import DistributedWorker __all__ = [ "BrainClient", diff --git a/src/brain/client.py b/src/brain/client.py index 6131168..8c7ad98 100644 --- a/src/brain/client.py +++ b/src/brain/client.py @@ -21,52 +21,54 @@ DEFAULT_RQLITE_URL = "http://localhost:4001" class BrainClient: """Client for distributed brain (rqlite). - + Connects to local rqlite instance, which handles replication. All writes go to leader, reads can come from local node. """ - + def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None): from config import settings + self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}" self.source = self._detect_source() self._client = httpx.AsyncClient(timeout=30) - + def _detect_source(self) -> str: """Detect what component is using the brain.""" # Could be 'timmy', 'zeroclaw', 'worker', etc. # For now, infer from context or env from config import settings + return settings.brain_source - + # ────────────────────────────────────────────────────────────────────────── # Memory Operations # ────────────────────────────────────────────────────────────────────────── - + async def remember( self, content: str, tags: Optional[List[str]] = None, source: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Store a memory with embedding. - + Args: content: Text content to remember tags: Optional list of tags (e.g., ['shell', 'result']) source: Source identifier (defaults to self.source) metadata: Additional JSON-serializable metadata - + Returns: Dict with 'id' and 'status' """ from brain.embeddings import get_embedder - + embedder = get_embedder() embedding_bytes = embedder.encode_single(content) - + query = """ INSERT INTO memories (content, embedding, source, tags, metadata, created_at) VALUES (?, ?, ?, ?, ?, ?) @@ -77,100 +79,90 @@ class BrainClient: source or self.source, json.dumps(tags or []), json.dumps(metadata or {}), - datetime.utcnow().isoformat() + datetime.utcnow().isoformat(), ] - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/execute", - json=[query, params] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params]) resp.raise_for_status() result = resp.json() - + # Extract inserted ID last_id = None if "results" in result and result["results"]: last_id = result["results"][0].get("last_insert_id") - + logger.debug(f"Stored memory {last_id}: {content[:50]}...") return {"id": last_id, "status": "stored"} - + except Exception as e: logger.error(f"Failed to store memory: {e}") raise - + async def recall( - self, - query: str, - limit: int = 5, - sources: Optional[List[str]] = None + self, query: str, limit: int = 5, sources: Optional[List[str]] = None ) -> List[str]: """Semantic search for memories. - + Args: query: Search query text limit: Max results to return sources: Filter by source(s) (e.g., ['timmy', 'user']) - + Returns: List of memory content strings """ from brain.embeddings import get_embedder - + embedder = get_embedder() query_emb = embedder.encode_single(query) - + # rqlite with sqlite-vec extension for vector search sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?" params = [query_emb] - + if sources: placeholders = ",".join(["?"] * len(sources)) sql += f" AND source IN ({placeholders})" params.extend(sources) - + sql += " ORDER BY distance LIMIT ?" params.append(limit) - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/query", - json=[sql, params] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params]) resp.raise_for_status() result = resp.json() - + results = [] if "results" in result and result["results"]: for row in result["results"][0].get("rows", []): - results.append({ - "content": row[0], - "source": row[1], - "metadata": json.loads(row[2]) if row[2] else {}, - "distance": row[3] - }) - + results.append( + { + "content": row[0], + "source": row[1], + "metadata": json.loads(row[2]) if row[2] else {}, + "distance": row[3], + } + ) + return results - + except Exception as e: logger.error(f"Failed to search memories: {e}") # Graceful fallback - return empty list return [] - + async def get_recent( - self, - hours: int = 24, - limit: int = 20, - sources: Optional[List[str]] = None + self, hours: int = 24, limit: int = 20, sources: Optional[List[str]] = None ) -> List[Dict[str, Any]]: """Get recent memories by time. - + Args: hours: Look back this many hours limit: Max results sources: Optional source filter - + Returns: List of memory dicts """ @@ -180,84 +172,83 @@ class BrainClient: WHERE created_at > datetime('now', ?) """ params = [f"-{hours} hours"] - + if sources: placeholders = ",".join(["?"] * len(sources)) sql += f" AND source IN ({placeholders})" params.extend(sources) - + sql += " ORDER BY created_at DESC LIMIT ?" params.append(limit) - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/query", - json=[sql, params] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params]) resp.raise_for_status() result = resp.json() - + memories = [] if "results" in result and result["results"]: for row in result["results"][0].get("rows", []): - memories.append({ - "id": row[0], - "content": row[1], - "source": row[2], - "tags": json.loads(row[3]) if row[3] else [], - "metadata": json.loads(row[4]) if row[4] else {}, - "created_at": row[5] - }) - + memories.append( + { + "id": row[0], + "content": row[1], + "source": row[2], + "tags": json.loads(row[3]) if row[3] else [], + "metadata": json.loads(row[4]) if row[4] else {}, + "created_at": row[5], + } + ) + return memories - + except Exception as e: logger.error(f"Failed to get recent memories: {e}") return [] - + async def get_context(self, query: str) -> str: """Get formatted context for system prompt. - + Combines recent memories + relevant memories. - + Args: query: Current user query to find relevant context - + Returns: Formatted context string for prompt injection """ recent = await self.get_recent(hours=24, limit=10) relevant = await self.recall(query, limit=5) - + lines = ["Recent activity:"] for m in recent[:5]: lines.append(f"- {m['content'][:100]}") - + lines.append("\nRelevant memories:") for r in relevant[:5]: lines.append(f"- {r['content'][:100]}") - + return "\n".join(lines) - + # ────────────────────────────────────────────────────────────────────────── # Task Queue Operations # ────────────────────────────────────────────────────────────────────────── - + async def submit_task( self, content: str, task_type: str = "general", priority: int = 0, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Submit a task to the distributed queue. - + Args: content: Task description/prompt task_type: Type of task (shell, creative, code, research, general) priority: Higher = processed first metadata: Additional task data - + Returns: Dict with task 'id' """ @@ -270,50 +261,45 @@ class BrainClient: task_type, priority, json.dumps(metadata or {}), - datetime.utcnow().isoformat() + datetime.utcnow().isoformat(), ] - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/execute", - json=[query, params] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params]) resp.raise_for_status() result = resp.json() - + last_id = None if "results" in result and result["results"]: last_id = result["results"][0].get("last_insert_id") - + logger.info(f"Submitted task {last_id}: {content[:50]}...") return {"id": last_id, "status": "queued"} - + except Exception as e: logger.error(f"Failed to submit task: {e}") raise - + async def claim_task( - self, - capabilities: List[str], - node_id: Optional[str] = None + self, capabilities: List[str], node_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: """Atomically claim next available task. - + Uses UPDATE ... RETURNING pattern for atomic claim. - + Args: capabilities: List of capabilities this node has node_id: Identifier for claiming node - + Returns: Task dict or None if no tasks available """ claimer = node_id or self.node_id - + # Try to claim a matching task atomically # This works because rqlite uses Raft consensus - only one node wins placeholders = ",".join(["?"] * len(capabilities)) - + query = f""" UPDATE tasks SET status = 'claimed', @@ -330,15 +316,12 @@ class BrainClient: RETURNING id, content, task_type, priority, metadata """ params = [claimer, datetime.utcnow().isoformat()] + capabilities - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/execute", - json=[query, params] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params]) resp.raise_for_status() result = resp.json() - + if "results" in result and result["results"]: rows = result["results"][0].get("rows", []) if rows: @@ -348,24 +331,20 @@ class BrainClient: "content": row[1], "type": row[2], "priority": row[3], - "metadata": json.loads(row[4]) if row[4] else {} + "metadata": json.loads(row[4]) if row[4] else {}, } - + return None - + except Exception as e: logger.error(f"Failed to claim task: {e}") return None - + async def complete_task( - self, - task_id: int, - success: bool, - result: Optional[str] = None, - error: Optional[str] = None + self, task_id: int, success: bool, result: Optional[str] = None, error: Optional[str] = None ) -> None: """Mark task as completed or failed. - + Args: task_id: Task ID success: True if task succeeded @@ -373,7 +352,7 @@ class BrainClient: error: Error message if failed """ status = "done" if success else "failed" - + query = """ UPDATE tasks SET status = ?, @@ -383,23 +362,20 @@ class BrainClient: WHERE id = ? """ params = [status, result, error, datetime.utcnow().isoformat(), task_id] - + try: - await self._client.post( - f"{self.rqlite_url}/db/execute", - json=[query, params] - ) + await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params]) logger.debug(f"Task {task_id} marked {status}") - + except Exception as e: logger.error(f"Failed to complete task {task_id}: {e}") - + async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]: """Get list of pending tasks (for dashboard/monitoring). - + Args: limit: Max tasks to return - + Returns: List of pending task dicts """ @@ -410,33 +386,32 @@ class BrainClient: ORDER BY priority DESC, created_at ASC LIMIT ? """ - + try: - resp = await self._client.post( - f"{self.rqlite_url}/db/query", - json=[sql, [limit]] - ) + resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, [limit]]) resp.raise_for_status() result = resp.json() - + tasks = [] if "results" in result and result["results"]: for row in result["results"][0].get("rows", []): - tasks.append({ - "id": row[0], - "content": row[1], - "type": row[2], - "priority": row[3], - "metadata": json.loads(row[4]) if row[4] else {}, - "created_at": row[5] - }) - + tasks.append( + { + "id": row[0], + "content": row[1], + "type": row[2], + "priority": row[3], + "metadata": json.loads(row[4]) if row[4] else {}, + "created_at": row[5], + } + ) + return tasks - + except Exception as e: logger.error(f"Failed to get pending tasks: {e}") return [] - + async def close(self): """Close HTTP client.""" await self._client.aclose() diff --git a/src/brain/embeddings.py b/src/brain/embeddings.py index 1a09ed9..1988e93 100644 --- a/src/brain/embeddings.py +++ b/src/brain/embeddings.py @@ -18,48 +18,51 @@ _dimensions = 384 class LocalEmbedder: """Local sentence transformer for embeddings. - + Uses all-MiniLM-L6-v2 (80MB download, runs on CPU). 384-dimensional embeddings, good enough for semantic search. """ - + def __init__(self, model_name: str = _model_name): self.model_name = model_name self._model = None self._dimensions = _dimensions - + def _load_model(self): """Lazy load the model.""" global _model if _model is not None: self._model = _model return - + try: from sentence_transformers import SentenceTransformer + logger.info(f"Loading embedding model: {self.model_name}") _model = SentenceTransformer(self.model_name) self._model = _model logger.info(f"Embedding model loaded ({self._dimensions} dims)") except ImportError: - logger.error("sentence-transformers not installed. Run: pip install sentence-transformers") + logger.error( + "sentence-transformers not installed. Run: pip install sentence-transformers" + ) raise - + def encode(self, text: Union[str, List[str]]): """Encode text to embedding vector(s). - + Args: text: String or list of strings to encode - + Returns: Numpy array of shape (dims,) for single string or (n, dims) for list """ if self._model is None: self._load_model() - + # Normalize embeddings for cosine similarity return self._model.encode(text, normalize_embeddings=True) - + def encode_single(self, text: str) -> bytes: """Encode single text to bytes for SQLite storage. @@ -67,17 +70,19 @@ class LocalEmbedder: Float32 bytes """ import numpy as np + embedding = self.encode(text) if len(embedding.shape) > 1: embedding = embedding[0] return embedding.astype(np.float32).tobytes() - + def similarity(self, a, b) -> float: """Compute cosine similarity between two vectors. Vectors should already be normalized from encode(). """ import numpy as np + return float(np.dot(a, b)) diff --git a/src/brain/memory.py b/src/brain/memory.py index ba5c15e..9524a66 100644 --- a/src/brain/memory.py +++ b/src/brain/memory.py @@ -48,6 +48,7 @@ _SCHEMA_VERSION = 1 def _get_db_path() -> Path: """Get the brain database path from env or default.""" from config import settings + if settings.brain_db_path: return Path(settings.brain_db_path) return _DEFAULT_DB_PATH @@ -75,6 +76,7 @@ class UnifiedMemory: # Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite if use_rqlite is None: from config import settings as _settings + use_rqlite = bool(_settings.rqlite_url) self._use_rqlite = use_rqlite @@ -107,10 +109,12 @@ class UnifiedMemory: """Lazy-load the embedding model.""" if self._embedder is None: from config import settings as _settings + if _settings.timmy_skip_embeddings: return None try: from brain.embeddings import LocalEmbedder + self._embedder = LocalEmbedder() except ImportError: logger.warning("sentence-transformers not available — semantic search disabled") @@ -125,6 +129,7 @@ class UnifiedMemory: """Lazy-load the rqlite BrainClient.""" if self._rqlite_client is None: from brain.client import BrainClient + self._rqlite_client = BrainClient() return self._rqlite_client @@ -292,15 +297,17 @@ class UnifiedMemory: results = [] for score, row in scored[:limit]: - results.append({ - "id": row["id"], - "content": row["content"], - "source": row["source"], - "tags": json.loads(row["tags"]) if row["tags"] else [], - "metadata": json.loads(row["metadata"]) if row["metadata"] else {}, - "score": score, - "created_at": row["created_at"], - }) + results.append( + { + "id": row["id"], + "content": row["content"], + "source": row["source"], + "tags": json.loads(row["tags"]) if row["tags"] else [], + "metadata": json.loads(row["metadata"]) if row["metadata"] else {}, + "score": score, + "created_at": row["created_at"], + } + ) return results finally: diff --git a/src/brain/schema.py b/src/brain/schema.py index d504ae9..94620e1 100644 --- a/src/brain/schema.py +++ b/src/brain/schema.py @@ -84,11 +84,13 @@ def get_migration_sql(from_version: int, to_version: int) -> str: """Get SQL to migrate between versions.""" if to_version <= from_version: return "" - + sql_parts = [] for v in range(from_version + 1, to_version + 1): if v in MIGRATIONS: sql_parts.append(MIGRATIONS[v]) - sql_parts.append(f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');") - + sql_parts.append( + f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');" + ) + return "\n".join(sql_parts) diff --git a/src/brain/worker.py b/src/brain/worker.py index 10db77e..f7700de 100644 --- a/src/brain/worker.py +++ b/src/brain/worker.py @@ -21,11 +21,11 @@ logger = logging.getLogger(__name__) class DistributedWorker: """Continuous task processor for the distributed brain. - + Runs on every device, claims tasks matching its capabilities, executes them immediately, stores results. """ - + def __init__(self, brain_client: Optional[BrainClient] = None): self.brain = brain_client or BrainClient() self.node_id = f"{socket.gethostname()}-{os.getpid()}" @@ -33,30 +33,30 @@ class DistributedWorker: self.running = False self._handlers: Dict[str, Callable] = {} self._register_default_handlers() - + def _detect_capabilities(self) -> List[str]: """Detect what this node can do.""" caps = ["general", "shell", "file_ops", "git"] - + # Check for GPU if self._has_gpu(): caps.append("gpu") caps.append("creative") caps.append("image_gen") caps.append("video_gen") - + # Check for internet if self._has_internet(): caps.append("web") caps.append("research") - + # Check memory mem_gb = self._get_memory_gb() if mem_gb > 16: caps.append("large_model") if mem_gb > 32: caps.append("huge_model") - + # Check for specific tools if self._has_command("ollama"): caps.append("ollama") @@ -64,17 +64,15 @@ class DistributedWorker: caps.append("docker") if self._has_command("cargo"): caps.append("rust") - + logger.info(f"Worker capabilities: {caps}") return caps - + def _has_gpu(self) -> bool: """Check for NVIDIA or AMD GPU.""" try: # Check for nvidia-smi - result = subprocess.run( - ["nvidia-smi"], capture_output=True, timeout=5 - ) + result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=5) if result.returncode == 0: return True except (OSError, subprocess.SubprocessError): @@ -83,13 +81,15 @@ class DistributedWorker: # Check for ROCm if os.path.exists("/opt/rocm"): return True - + # Check for Apple Silicon Metal if os.uname().sysname == "Darwin": try: result = subprocess.run( ["system_profiler", "SPDisplaysDataType"], - capture_output=True, text=True, timeout=5 + capture_output=True, + text=True, + timeout=5, ) if "Metal" in result.stdout: return True @@ -102,8 +102,7 @@ class DistributedWorker: """Check if we have internet connectivity.""" try: result = subprocess.run( - ["curl", "-s", "--max-time", "3", "https://1.1.1.1"], - capture_output=True, timeout=5 + ["curl", "-s", "--max-time", "3", "https://1.1.1.1"], capture_output=True, timeout=5 ) return result.returncode == 0 except (OSError, subprocess.SubprocessError): @@ -114,8 +113,7 @@ class DistributedWorker: try: if os.uname().sysname == "Darwin": result = subprocess.run( - ["sysctl", "-n", "hw.memsize"], - capture_output=True, text=True + ["sysctl", "-n", "hw.memsize"], capture_output=True, text=True ) bytes_mem = int(result.stdout.strip()) return bytes_mem / (1024**3) @@ -128,13 +126,11 @@ class DistributedWorker: except (OSError, ValueError): pass return 8.0 # Assume 8GB if we can't detect - + def _has_command(self, cmd: str) -> bool: """Check if command exists.""" try: - result = subprocess.run( - ["which", cmd], capture_output=True, timeout=5 - ) + result = subprocess.run(["which", cmd], capture_output=True, timeout=5) return result.returncode == 0 except (OSError, subprocess.SubprocessError): return False @@ -148,10 +144,10 @@ class DistributedWorker: "research": self._handle_research, "general": self._handle_general, } - + def register_handler(self, task_type: str, handler: Callable[[str], Any]): """Register a custom task handler. - + Args: task_type: Type of task this handler handles handler: Async function that takes task content and returns result @@ -159,11 +155,11 @@ class DistributedWorker: self._handlers[task_type] = handler if task_type not in self.capabilities: self.capabilities.append(task_type) - + # ────────────────────────────────────────────────────────────────────────── # Task Handlers # ────────────────────────────────────────────────────────────────────────── - + async def _handle_shell(self, command: str) -> str: """Execute shell command via ZeroClaw or direct subprocess.""" # Try ZeroClaw first if available @@ -171,156 +167,153 @@ class DistributedWorker: proc = await asyncio.create_subprocess_shell( f"zeroclaw exec --json '{command}'", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() - + # Store result in brain await self.brain.remember( content=f"Shell: {command}\nOutput: {stdout.decode()}", tags=["shell", "result"], source=self.node_id, - metadata={"command": command, "exit_code": proc.returncode} + metadata={"command": command, "exit_code": proc.returncode}, ) - + if proc.returncode != 0: raise Exception(f"Command failed: {stderr.decode()}") return stdout.decode() - + # Fallback to direct subprocess (less safe) proc = await asyncio.create_subprocess_shell( - command, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() - + if proc.returncode != 0: raise Exception(f"Command failed: {stderr.decode()}") return stdout.decode() - + async def _handle_creative(self, prompt: str) -> str: """Generate creative media (requires GPU).""" if "gpu" not in self.capabilities: raise Exception("GPU not available on this node") - + # This would call creative tools (Stable Diffusion, etc.) # For now, placeholder logger.info(f"Creative task: {prompt[:50]}...") - + # Store result result = f"Creative output for: {prompt}" await self.brain.remember( content=result, tags=["creative", "generated"], source=self.node_id, - metadata={"prompt": prompt} + metadata={"prompt": prompt}, ) - + return result - + async def _handle_code(self, description: str) -> str: """Code generation and modification.""" # Would use LLM to generate code # For now, placeholder logger.info(f"Code task: {description[:50]}...") return f"Code generated for: {description}" - + async def _handle_research(self, query: str) -> str: """Web research.""" if "web" not in self.capabilities: raise Exception("Internet not available on this node") - + # Would use browser automation or search logger.info(f"Research task: {query[:50]}...") return f"Research results for: {query}" - + async def _handle_general(self, prompt: str) -> str: """General LLM task via local Ollama.""" if "ollama" not in self.capabilities: raise Exception("Ollama not available on this node") - + # Call Ollama try: proc = await asyncio.create_subprocess_exec( - "curl", "-s", "http://localhost:11434/api/generate", - "-d", json.dumps({ - "model": "llama3.1:8b-instruct", - "prompt": prompt, - "stream": False - }), - stdout=asyncio.subprocess.PIPE + "curl", + "-s", + "http://localhost:11434/api/generate", + "-d", + json.dumps({"model": "llama3.1:8b-instruct", "prompt": prompt, "stream": False}), + stdout=asyncio.subprocess.PIPE, ) stdout, _ = await proc.communicate() - + response = json.loads(stdout.decode()) result = response.get("response", "No response") - + # Store in brain await self.brain.remember( content=f"Task: {prompt}\nResult: {result}", tags=["llm", "result"], source=self.node_id, - metadata={"model": "llama3.1:8b-instruct"} + metadata={"model": "llama3.1:8b-instruct"}, ) - + return result - + except Exception as e: raise Exception(f"LLM failed: {e}") - + # ────────────────────────────────────────────────────────────────────────── # Main Loop # ────────────────────────────────────────────────────────────────────────── - + async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]: """Execute a claimed task.""" task_type = task.get("type", "general") content = task.get("content", "") task_id = task.get("id") - + handler = self._handlers.get(task_type, self._handlers["general"]) - + try: logger.info(f"Executing task {task_id}: {task_type}") result = await handler(content) - + await self.brain.complete_task(task_id, success=True, result=result) logger.info(f"Task {task_id} completed") return {"success": True, "result": result} - + except Exception as e: error_msg = str(e) logger.error(f"Task {task_id} failed: {error_msg}") await self.brain.complete_task(task_id, success=False, error=error_msg) return {"success": False, "error": error_msg} - + async def run_once(self) -> bool: """Process one task if available. - + Returns: True if a task was processed, False if no tasks available """ task = await self.brain.claim_task(self.capabilities, self.node_id) - + if task: await self.execute_task(task) return True - + return False - + async def run(self): """Main loop — continuously process tasks.""" logger.info(f"Worker {self.node_id} started") logger.info(f"Capabilities: {self.capabilities}") - + self.running = True consecutive_empty = 0 - + while self.running: try: had_work = await self.run_once() - + if had_work: # Immediately check for more work consecutive_empty = 0 @@ -331,11 +324,11 @@ class DistributedWorker: # Sleep 0.5s, but up to 2s if consistently empty sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0) await asyncio.sleep(sleep_time) - + except Exception as e: logger.error(f"Worker error: {e}") await asyncio.sleep(1) - + def stop(self): """Stop the worker loop.""" self.running = False @@ -345,7 +338,7 @@ class DistributedWorker: async def main(): """CLI entry point for worker.""" import sys - + # Allow capability overrides from CLI if len(sys.argv) > 1: caps = sys.argv[1].split(",") @@ -354,12 +347,12 @@ async def main(): logger.info(f"Overriding capabilities: {caps}") else: worker = DistributedWorker() - + try: await worker.run() except KeyboardInterrupt: worker.stop() - print("\nWorker stopped.") + logger.info("Worker stopped.") if __name__ == "__main__": diff --git a/src/config.py b/src/config.py index b26e51e..d8d452f 100644 --- a/src/config.py +++ b/src/config.py @@ -213,6 +213,15 @@ class Settings(BaseSettings): # Timeout in seconds for OpenFang hand execution (some hands are slow). openfang_timeout: int = 120 + # ── Autoresearch — autonomous ML experiment loops ────────────────── + # Integrates Karpathy's autoresearch pattern: agents modify training + # code, run time-boxed experiments, evaluate metrics, and iterate. + autoresearch_enabled: bool = False + autoresearch_workspace: str = "data/experiments" + autoresearch_time_budget: int = 300 # seconds per experiment run + autoresearch_max_iterations: int = 100 + autoresearch_metric: str = "val_bpb" # metric to optimise (lower = better) + # ── Local Hands (Shell + Git) ────────────────────────────────────── # Enable local shell/git execution hands. hands_shell_enabled: bool = True diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 1f2084a..dd5b117 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -18,36 +18,38 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles + from config import settings -from dashboard.routes.agents import router as agents_router -from dashboard.routes.health import router as health_router -from dashboard.routes.marketplace import router as marketplace_router -from dashboard.routes.voice import router as voice_router -from dashboard.routes.mobile import router as mobile_router -from dashboard.routes.briefing import router as briefing_router -from dashboard.routes.telegram import router as telegram_router -from dashboard.routes.tools import router as tools_router -from dashboard.routes.spark import router as spark_router -from dashboard.routes.discord import router as discord_router -from dashboard.routes.memory import router as memory_router -from dashboard.routes.router import router as router_status_router -from dashboard.routes.grok import router as grok_router -from dashboard.routes.models import router as models_router -from dashboard.routes.models import api_router as models_api_router -from dashboard.routes.chat_api import router as chat_api_router -from dashboard.routes.thinking import router as thinking_router -from dashboard.routes.calm import router as calm_router -from dashboard.routes.swarm import router as swarm_router -from dashboard.routes.tasks import router as tasks_router -from dashboard.routes.work_orders import router as work_orders_router -from dashboard.routes.system import router as system_router -from dashboard.routes.paperclip import router as paperclip_router -from infrastructure.router.api import router as cascade_router # Import dedicated middleware from dashboard.middleware.csrf import CSRFMiddleware from dashboard.middleware.request_logging import RequestLoggingMiddleware from dashboard.middleware.security_headers import SecurityHeadersMiddleware +from dashboard.routes.agents import router as agents_router +from dashboard.routes.briefing import router as briefing_router +from dashboard.routes.calm import router as calm_router +from dashboard.routes.chat_api import router as chat_api_router +from dashboard.routes.discord import router as discord_router +from dashboard.routes.experiments import router as experiments_router +from dashboard.routes.grok import router as grok_router +from dashboard.routes.health import router as health_router +from dashboard.routes.marketplace import router as marketplace_router +from dashboard.routes.memory import router as memory_router +from dashboard.routes.mobile import router as mobile_router +from dashboard.routes.models import api_router as models_api_router +from dashboard.routes.models import router as models_router +from dashboard.routes.paperclip import router as paperclip_router +from dashboard.routes.router import router as router_status_router +from dashboard.routes.spark import router as spark_router +from dashboard.routes.swarm import router as swarm_router +from dashboard.routes.system import router as system_router +from dashboard.routes.tasks import router as tasks_router +from dashboard.routes.telegram import router as telegram_router +from dashboard.routes.thinking import router as thinking_router +from dashboard.routes.tools import router as tools_router +from dashboard.routes.voice import router as voice_router +from dashboard.routes.work_orders import router as work_orders_router +from infrastructure.router.api import router as cascade_router def _configure_logging() -> None: @@ -100,8 +102,8 @@ _BRIEFING_INTERVAL_HOURS = 6 async def _briefing_scheduler() -> None: """Background task: regenerate Timmy's briefing every 6 hours.""" - from timmy.briefing import engine as briefing_engine from infrastructure.notifications.push import notify_briefing_ready + from timmy.briefing import engine as briefing_engine await asyncio.sleep(2) @@ -121,9 +123,9 @@ async def _briefing_scheduler() -> None: async def _start_chat_integrations_background() -> None: """Background task: start chat integrations without blocking startup.""" - from integrations.telegram_bot.bot import telegram_bot - from integrations.chat_bridge.vendors.discord import discord_bot from integrations.chat_bridge.registry import platform_registry + from integrations.chat_bridge.vendors.discord import discord_bot + from integrations.telegram_bot.bot import telegram_bot await asyncio.sleep(0.5) @@ -164,9 +166,9 @@ async def _discord_token_watcher() -> None: if discord_bot.state.name == "CONNECTED": return # Already running — stop watching - # 1. Check live environment variable (intentionally uses os.environ, - # not settings, because this polls for runtime hot-reload changes) - token = os.environ.get("DISCORD_TOKEN", "") + # 1. Check settings (pydantic-settings reads env on instantiation; + # hot-reload is handled by re-reading .env below) + token = settings.discord_token # 2. Re-read .env file for hot-reload if not token: @@ -203,6 +205,7 @@ async def lifespan(app: FastAPI): # Initialize Spark Intelligence engine from spark.engine import spark_engine + if spark_engine.enabled: logger.info("Spark Intelligence active — event capture enabled") @@ -210,12 +213,17 @@ async def lifespan(app: FastAPI): if settings.memory_prune_days > 0: try: from timmy.memory.vector_store import prune_memories + pruned = prune_memories( older_than_days=settings.memory_prune_days, keep_facts=settings.memory_prune_keep_facts, ) if pruned: - logger.info("Memory auto-prune: removed %d entries older than %d days", pruned, settings.memory_prune_days) + logger.info( + "Memory auto-prune: removed %d entries older than %d days", + pruned, + settings.memory_prune_days, + ) except Exception as exc: logger.debug("Memory auto-prune skipped: %s", exc) @@ -229,7 +237,8 @@ async def lifespan(app: FastAPI): if total_mb > settings.memory_vault_max_mb: logger.warning( "Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes", - total_mb, settings.memory_vault_max_mb, + total_mb, + settings.memory_vault_max_mb, ) except Exception as exc: logger.debug("Vault size check skipped: %s", exc) @@ -284,10 +293,7 @@ def _get_cors_origins() -> list[str]: app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"]) # 2. Security Headers -app.add_middleware( - SecurityHeadersMiddleware, - production=not settings.debug -) +app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug) # 3. CSRF Protection app.add_middleware(CSRFMiddleware) @@ -314,7 +320,6 @@ if static_dir.exists(): # Shared templates instance from dashboard.templating import templates # noqa: E402 - # Include routers app.include_router(health_router) app.include_router(agents_router) @@ -339,6 +344,7 @@ app.include_router(tasks_router) app.include_router(work_orders_router) app.include_router(system_router) app.include_router(paperclip_router) +app.include_router(experiments_router) app.include_router(cascade_router) diff --git a/src/dashboard/middleware/__init__.py b/src/dashboard/middleware/__init__.py index b3682d2..24a85ff 100644 --- a/src/dashboard/middleware/__init__.py +++ b/src/dashboard/middleware/__init__.py @@ -1,8 +1,8 @@ """Dashboard middleware package.""" from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token -from .security_headers import SecurityHeadersMiddleware from .request_logging import RequestLoggingMiddleware +from .security_headers import SecurityHeadersMiddleware __all__ = [ "CSRFMiddleware", diff --git a/src/dashboard/middleware/csrf.py b/src/dashboard/middleware/csrf.py index dd52591..fa30ed5 100644 --- a/src/dashboard/middleware/csrf.py +++ b/src/dashboard/middleware/csrf.py @@ -4,16 +4,15 @@ Provides CSRF token generation, validation, and middleware integration to protect state-changing endpoints from cross-site request attacks. """ -import secrets -import hmac import hashlib -from typing import Callable, Optional +import hmac +import secrets from functools import wraps +from typing import Callable, Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response, JSONResponse - +from starlette.responses import JSONResponse, Response # Module-level set to track exempt routes _exempt_routes: set[str] = set() @@ -21,26 +20,27 @@ _exempt_routes: set[str] = set() def csrf_exempt(endpoint: Callable) -> Callable: """Decorator to mark an endpoint as exempt from CSRF validation. - + Usage: @app.post("/webhook") @csrf_exempt def webhook_endpoint(): ... """ + @wraps(endpoint) async def async_wrapper(*args, **kwargs): return await endpoint(*args, **kwargs) - + @wraps(endpoint) def sync_wrapper(*args, **kwargs): return endpoint(*args, **kwargs) - + # Mark the original function as exempt endpoint._csrf_exempt = True # type: ignore - + # Also mark the wrapper - if hasattr(endpoint, '__code__') and endpoint.__code__.co_flags & 0x80: + if hasattr(endpoint, "__code__") and endpoint.__code__.co_flags & 0x80: async_wrapper._csrf_exempt = True # type: ignore return async_wrapper else: @@ -50,12 +50,12 @@ def csrf_exempt(endpoint: Callable) -> Callable: def is_csrf_exempt(endpoint: Callable) -> bool: """Check if an endpoint is marked as CSRF exempt.""" - return getattr(endpoint, '_csrf_exempt', False) + return getattr(endpoint, "_csrf_exempt", False) def generate_csrf_token() -> str: """Generate a cryptographically secure CSRF token. - + Returns: A secure random token string. """ @@ -64,77 +64,78 @@ def generate_csrf_token() -> str: def validate_csrf_token(token: str, expected_token: str) -> bool: """Validate a CSRF token against the expected token. - + Uses constant-time comparison to prevent timing attacks. - + Args: token: The token provided by the client. expected_token: The expected token (from cookie/session). - + Returns: True if the token is valid, False otherwise. """ if not token or not expected_token: return False - + return hmac.compare_digest(token, expected_token) class CSRFMiddleware(BaseHTTPMiddleware): """Middleware to enforce CSRF protection on state-changing requests. - + Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens. State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token. - + The token is expected to be: - In the X-CSRF-Token header, or - In the request body as 'csrf_token', or - Matching the token in the csrf_token cookie - + Usage: app.add_middleware(CSRFMiddleware, secret="your-secret-key") - + Attributes: secret: Secret key for token signing (optional, for future use). cookie_name: Name of the CSRF cookie. header_name: Name of the CSRF header. safe_methods: HTTP methods that don't require CSRF tokens. """ - + SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} - + def __init__( self, app, secret: Optional[str] = None, cookie_name: str = "csrf_token", header_name: str = "X-CSRF-Token", - form_field: str = "csrf_token" + form_field: str = "csrf_token", ): super().__init__(app) self.secret = secret self.cookie_name = cookie_name self.header_name = header_name self.form_field = form_field - + async def dispatch(self, request: Request, call_next) -> Response: """Process the request and enforce CSRF protection. - + For safe methods: Set a CSRF token cookie if not present. For unsafe methods: Validate the CSRF token. """ # Bypass CSRF if explicitly disabled (e.g. in tests) from config import settings + if settings.timmy_disable_csrf: return await call_next(request) # Get existing CSRF token from cookie csrf_cookie = request.cookies.get(self.cookie_name) - + # For safe methods, just ensure a token exists if request.method in self.SAFE_METHODS: response = await call_next(request) - + # Set CSRF token cookie if not present if not csrf_cookie: new_token = generate_csrf_token() @@ -144,15 +145,15 @@ class CSRFMiddleware(BaseHTTPMiddleware): httponly=False, # Must be readable by JavaScript secure=settings.csrf_cookie_secure, samesite="Lax", - max_age=86400 # 24 hours + max_age=86400, # 24 hours ) - + return response - + # For unsafe methods, check if route is exempt first # Note: We need to let the request proceed and check at response time # since FastAPI routes are resolved after middleware - + # Try to validate token early if not await self._validate_request(request, csrf_cookie): # Check if this might be an exempt route by checking path patterns @@ -164,33 +165,34 @@ class CSRFMiddleware(BaseHTTPMiddleware): content={ "error": "CSRF validation failed", "code": "CSRF_INVALID", - "message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field." - } + "message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field.", + }, ) - + return await call_next(request) - + def _is_likely_exempt(self, path: str) -> bool: """Check if a path is likely to be CSRF exempt. - + Common patterns like webhooks, API endpoints, etc. Uses path normalization and exact/prefix matching to prevent bypasses. - + Args: path: The request path. - + Returns: True if the path is likely exempt. """ # 1. Normalize path to prevent /webhook/../ bypasses # Use posixpath for consistent behavior on all platforms import posixpath + normalized_path = posixpath.normpath(path) - + # Ensure it starts with / for comparison if not normalized_path.startswith("/"): normalized_path = "/" + normalized_path - + # Add back trailing slash if it was present in original path # to ensure prefix matching behaves as expected if path.endswith("/") and not normalized_path.endswith("/"): @@ -200,15 +202,15 @@ class CSRFMiddleware(BaseHTTPMiddleware): # Patterns ending with / are prefix-matched # Patterns NOT ending with / are exact-matched exempt_patterns = [ - "/webhook/", # Prefix match (e.g., /webhook/stripe) - "/webhook", # Exact match - "/api/v1/", # Prefix match - "/lightning/webhook/", # Prefix match + "/webhook/", # Prefix match (e.g., /webhook/stripe) + "/webhook", # Exact match + "/api/v1/", # Prefix match + "/lightning/webhook/", # Prefix match "/lightning/webhook", # Exact match - "/_internal/", # Prefix match - "/_internal", # Exact match + "/_internal/", # Prefix match + "/_internal", # Exact match ] - + for pattern in exempt_patterns: if pattern.endswith("/"): if normalized_path.startswith(pattern): @@ -216,20 +218,20 @@ class CSRFMiddleware(BaseHTTPMiddleware): else: if normalized_path == pattern: return True - + return False - + async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool: """Validate the CSRF token in the request. - + Checks for token in: 1. X-CSRF-Token header 2. csrf_token form field - + Args: request: The incoming request. csrf_cookie: The expected token from the cookie. - + Returns: True if the token is valid, False otherwise. """ @@ -241,11 +243,14 @@ class CSRFMiddleware(BaseHTTPMiddleware): header_token = request.headers.get(self.header_name) if header_token and validate_csrf_token(header_token, csrf_cookie): return True - + # If no header token, try form data (for non-JSON POSTs) # Check Content-Type to avoid hanging on non-form requests content_type = request.headers.get("Content-Type", "") - if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type: + if ( + "application/x-www-form-urlencoded" in content_type + or "multipart/form-data" in content_type + ): try: form_data = await request.form() form_token = form_data.get(self.form_field) @@ -254,5 +259,5 @@ class CSRFMiddleware(BaseHTTPMiddleware): except Exception: # Error parsing form data, treat as invalid pass - + return False diff --git a/src/dashboard/middleware/request_logging.py b/src/dashboard/middleware/request_logging.py index 69818ea..5f136da 100644 --- a/src/dashboard/middleware/request_logging.py +++ b/src/dashboard/middleware/request_logging.py @@ -4,22 +4,21 @@ Logs HTTP requests with timing, status codes, and client information for monitoring and debugging purposes. """ +import logging import time import uuid -import logging -from typing import Optional, List +from typing import List, Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response - logger = logging.getLogger("timmy.requests") class RequestLoggingMiddleware(BaseHTTPMiddleware): """Middleware to log all HTTP requests. - + Logs the following information for each request: - HTTP method and path - Response status code @@ -27,60 +26,55 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): - Client IP address - User-Agent header - Correlation ID for tracing - + Usage: app.add_middleware(RequestLoggingMiddleware) - + # Skip certain paths: app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"]) - + Attributes: skip_paths: List of URL paths to skip logging. log_level: Logging level for successful requests. """ - - def __init__( - self, - app, - skip_paths: Optional[List[str]] = None, - log_level: int = logging.INFO - ): + + def __init__(self, app, skip_paths: Optional[List[str]] = None, log_level: int = logging.INFO): super().__init__(app) self.skip_paths = set(skip_paths or []) self.log_level = log_level - + async def dispatch(self, request: Request, call_next) -> Response: """Log the request and response details. - + Args: request: The incoming request. call_next: Callable to get the response from downstream. - + Returns: The response from downstream. """ # Check if we should skip logging this path if request.url.path in self.skip_paths: return await call_next(request) - + # Generate correlation ID correlation_id = str(uuid.uuid4())[:8] request.state.correlation_id = correlation_id - + # Record start time start_time = time.time() - + # Get client info client_ip = self._get_client_ip(request) user_agent = request.headers.get("user-agent", "-") - + try: # Process the request response = await call_next(request) - + # Calculate duration duration_ms = (time.time() - start_time) * 1000 - + # Log the request self._log_request( method=request.method, @@ -89,14 +83,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): duration_ms=duration_ms, client_ip=client_ip, user_agent=user_agent, - correlation_id=correlation_id + correlation_id=correlation_id, ) - + # Add correlation ID to response headers response.headers["X-Correlation-ID"] = correlation_id - + return response - + except Exception as exc: # Calculate duration even for failed requests duration_ms = (time.time() - start_time) * 1000 @@ -110,6 +104,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): # Auto-escalate: create bug report task from unhandled exception try: from infrastructure.error_capture import capture_error + capture_error( exc, source="http", @@ -126,16 +121,16 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): # Re-raise the exception raise - + def _get_client_ip(self, request: Request) -> str: """Extract the client IP address from the request. - + Checks X-Forwarded-For and X-Real-IP headers first for proxied requests, falls back to the direct client IP. - + Args: request: The incoming request. - + Returns: Client IP address string. """ @@ -144,17 +139,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): if forwarded_for: # X-Forwarded-For can contain multiple IPs, take the first one return forwarded_for.split(",")[0].strip() - + real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip - + # Fall back to direct connection if request.client: return request.client.host - + return "-" - + def _log_request( self, method: str, @@ -163,10 +158,10 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): duration_ms: float, client_ip: str, user_agent: str, - correlation_id: str + correlation_id: str, ) -> None: """Format and log the request details. - + Args: method: HTTP method. path: Request path. @@ -182,14 +177,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware): level = logging.ERROR elif status_code >= 400: level = logging.WARNING - + message = ( f"[{correlation_id}] {method} {path} - {status_code} " f"- {duration_ms:.2f}ms - {client_ip}" ) - + # Add user agent for non-health requests if path not in self.skip_paths: message += f" - {user_agent[:50]}" - + logger.log(level, message) diff --git a/src/dashboard/middleware/security_headers.py b/src/dashboard/middleware/security_headers.py index 63dacab..403f3d5 100644 --- a/src/dashboard/middleware/security_headers.py +++ b/src/dashboard/middleware/security_headers.py @@ -4,6 +4,8 @@ Adds common security headers to all HTTP responses to improve application security posture against various attacks. """ +from typing import Optional + from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response @@ -11,7 +13,7 @@ from starlette.responses import Response class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware to add security headers to all responses. - + Adds the following headers: - X-Content-Type-Options: Prevents MIME type sniffing - X-Frame-Options: Prevents clickjacking @@ -20,41 +22,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): - Permissions-Policy: Restricts feature access - Content-Security-Policy: Mitigates XSS and data injection - Strict-Transport-Security: Enforces HTTPS (production only) - + Usage: app.add_middleware(SecurityHeadersMiddleware) - + # Or with production settings: app.add_middleware(SecurityHeadersMiddleware, production=True) - + Attributes: production: If True, adds HSTS header for HTTPS enforcement. csp_report_only: If True, sends CSP in report-only mode. """ - + def __init__( self, app, production: bool = False, csp_report_only: bool = False, - custom_csp: str = None + custom_csp: Optional[str] = None, ): super().__init__(app) self.production = production self.csp_report_only = csp_report_only - + # Build CSP directive self.csp_directive = custom_csp or self._build_csp() - + def _build_csp(self) -> str: """Build the Content-Security-Policy directive. - + Creates a restrictive default policy that allows: - Same-origin resources by default - Inline scripts/styles (needed for HTMX/Bootstrap) - Data URIs for images - WebSocket connections - + Returns: CSP directive string. """ @@ -73,25 +75,25 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): "form-action 'self'", ] return "; ".join(directives) - + def _add_security_headers(self, response: Response) -> None: """Add security headers to a response. - + Args: response: The response to add headers to. """ # Prevent MIME type sniffing response.headers["X-Content-Type-Options"] = "nosniff" - + # Prevent clickjacking response.headers["X-Frame-Options"] = "SAMEORIGIN" - + # Enable XSS protection (legacy browsers) response.headers["X-XSS-Protection"] = "1; mode=block" - + # Control referrer information response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" - + # Restrict browser features response.headers["Permissions-Policy"] = ( "camera=(), " @@ -103,38 +105,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): "gyroscope=(), " "accelerometer=()" ) - + # Content Security Policy - csp_header = "Content-Security-Policy-Report-Only" if self.csp_report_only else "Content-Security-Policy" + csp_header = ( + "Content-Security-Policy-Report-Only" + if self.csp_report_only + else "Content-Security-Policy" + ) response.headers[csp_header] = self.csp_directive - + # HTTPS enforcement (production only) if self.production: - response.headers["Strict-Transport-Security"] = ( - "max-age=31536000; includeSubDomains; preload" - ) - + response.headers[ + "Strict-Transport-Security" + ] = "max-age=31536000; includeSubDomains; preload" + async def dispatch(self, request: Request, call_next) -> Response: """Add security headers to the response. - + Args: request: The incoming request. call_next: Callable to get the response from downstream. - + Returns: Response with security headers added. """ try: response = await call_next(request) - self._add_security_headers(response) - return response except Exception: - # Create a response for the error with security headers - from starlette.responses import PlainTextResponse - response = PlainTextResponse( - content="Internal Server Error", - status_code=500 + import logging + + logging.getLogger(__name__).debug( + "Upstream error in security headers middleware", exc_info=True ) - self._add_security_headers(response) - # Return the error response with headers (don't re-raise) - return response + from starlette.responses import PlainTextResponse + + response = PlainTextResponse("Internal Server Error", status_code=500) + self._add_security_headers(response) + return response diff --git a/src/dashboard/models/calm.py b/src/dashboard/models/calm.py index bc0cb67..7417fec 100644 --- a/src/dashboard/models/calm.py +++ b/src/dashboard/models/calm.py @@ -1,24 +1,27 @@ - -from datetime import datetime, date +from datetime import date, datetime from enum import Enum as PyEnum -from sqlalchemy import ( - Column, Integer, String, DateTime, Boolean, Enum as SQLEnum, - Date, ForeignKey, Index, JSON -) + +from sqlalchemy import JSON, Boolean, Column, Date, DateTime +from sqlalchemy import Enum as SQLEnum +from sqlalchemy import ForeignKey, Index, Integer, String from sqlalchemy.orm import relationship + from .database import Base # Assuming a shared Base in models/database.py + class TaskState(str, PyEnum): LATER = "LATER" NEXT = "NEXT" NOW = "NOW" DONE = "DONE" - DEFERRED = "DEFERRED" # Task pushed to tomorrow + DEFERRED = "DEFERRED" # Task pushed to tomorrow + class TaskCertainty(str, PyEnum): - FUZZY = "FUZZY" # An intention without a time - SOFT = "SOFT" # A flexible task with a time - HARD = "HARD" # A fixed meeting/appointment + FUZZY = "FUZZY" # An intention without a time + SOFT = "SOFT" # A flexible task with a time + HARD = "HARD" # A fixed meeting/appointment + class Task(Base): __tablename__ = "tasks" @@ -29,7 +32,7 @@ class Task(Base): state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True) certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False) - is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day + is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day sort_order = Column(Integer, default=0, nullable=False) @@ -42,7 +45,8 @@ class Task(Base): created_at = Column(DateTime, default=datetime.utcnow, nullable=False) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) - __table_args__ = (Index('ix_task_state_order', 'state', 'sort_order'),) + __table_args__ = (Index("ix_task_state_order", "state", "sort_order"),) + class JournalEntry(Base): __tablename__ = "journal_entries" diff --git a/src/dashboard/models/database.py b/src/dashboard/models/database.py index 0994996..8a5b914 100644 --- a/src/dashboard/models/database.py +++ b/src/dashboard/models/database.py @@ -1,17 +1,16 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db" -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} -) +engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() + def create_tables(): """Create all tables defined by models that have imported Base.""" Base.metadata.create_all(bind=engine) diff --git a/src/dashboard/routes/agents.py b/src/dashboard/routes/agents.py index 55489d8..6e4e5e4 100644 --- a/src/dashboard/routes/agents.py +++ b/src/dashboard/routes/agents.py @@ -5,9 +5,9 @@ from datetime import datetime from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse -from timmy.session import chat as agent_chat from dashboard.store import message_log from dashboard.templating import templates +from timmy.session import chat as agent_chat logger = logging.getLogger(__name__) @@ -38,9 +38,7 @@ async def list_agents(): @router.get("/default/panel", response_class=HTMLResponse) async def agent_panel(request: Request): """Chat panel — for HTMX main-panel swaps.""" - return templates.TemplateResponse( - request, "partials/agent_panel_chat.html", {"agent": None} - ) + return templates.TemplateResponse(request, "partials/agent_panel_chat.html", {"agent": None}) @router.get("/default/history", response_class=HTMLResponse) @@ -77,7 +75,9 @@ async def chat_agent(request: Request, message: str = Form(...)): message_log.append(role="user", content=message, timestamp=timestamp, source="browser") if response_text is not None: - message_log.append(role="agent", content=response_text, timestamp=timestamp, source="browser") + message_log.append( + role="agent", content=response_text, timestamp=timestamp, source="browser" + ) elif error_text: message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser") diff --git a/src/dashboard/routes/briefing.py b/src/dashboard/routes/briefing.py index 15d0883..feff0ba 100644 --- a/src/dashboard/routes/briefing.py +++ b/src/dashboard/routes/briefing.py @@ -12,9 +12,10 @@ from datetime import datetime, timezone from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse, JSONResponse -from timmy.briefing import Briefing, engine as briefing_engine -from timmy import approvals as approval_store from dashboard.templating import templates +from timmy import approvals as approval_store +from timmy.briefing import Briefing +from timmy.briefing import engine as briefing_engine logger = logging.getLogger(__name__) diff --git a/src/dashboard/routes/calm.py b/src/dashboard/routes/calm.py index 46ebe07..b11229b 100644 --- a/src/dashboard/routes/calm.py +++ b/src/dashboard/routes/calm.py @@ -1,4 +1,3 @@ - import logging from datetime import date, datetime from typing import List, Optional @@ -8,7 +7,7 @@ from fastapi.responses import HTMLResponse from sqlalchemy.orm import Session from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState -from dashboard.models.database import SessionLocal, engine, get_db, create_tables +from dashboard.models.database import SessionLocal, create_tables, engine, get_db from dashboard.templating import templates # Ensure CALM tables exist (safe to call multiple times) @@ -23,11 +22,19 @@ router = APIRouter(tags=["calm"]) def get_now_task(db: Session) -> Optional[Task]: return db.query(Task).filter(Task.state == TaskState.NOW).first() + def get_next_task(db: Session) -> Optional[Task]: return db.query(Task).filter(Task.state == TaskState.NEXT).first() + def get_later_tasks(db: Session) -> List[Task]: - return db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() + return ( + db.query(Task) + .filter(Task.state == TaskState.LATER) + .order_by(Task.is_mit.desc(), Task.sort_order) + .all() + ) + def promote_tasks(db: Session): # Ensure only one NOW task exists. If multiple, demote extras to NEXT. @@ -38,7 +45,7 @@ def promote_tasks(db: Session): for task_to_demote in now_tasks[1:]: task_to_demote.state = TaskState.NEXT db.add(task_to_demote) - db.flush() # Make changes visible + db.flush() # Make changes visible # If no NOW task, promote NEXT to NOW current_now = db.query(Task).filter(Task.state == TaskState.NOW).first() @@ -47,12 +54,17 @@ def promote_tasks(db: Session): if next_task: next_task.state = TaskState.NOW db.add(next_task) - db.flush() # Make changes visible + db.flush() # Make changes visible # If no NEXT task, promote highest priority LATER to NEXT current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first() if not current_next: - later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() + later_tasks = ( + db.query(Task) + .filter(Task.state == TaskState.LATER) + .order_by(Task.is_mit.desc(), Task.sort_order) + .all() + ) if later_tasks: later_tasks[0].state = TaskState.NEXT db.add(later_tasks[0]) @@ -60,14 +72,17 @@ def promote_tasks(db: Session): db.commit() - # Endpoints @router.get("/calm", response_class=HTMLResponse) async def get_calm_view(request: Request, db: Session = Depends(get_db)): now_task = get_now_task(db) next_task = get_next_task(db) later_tasks_count = len(get_later_tasks(db)) - return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": now_task, + return templates.TemplateResponse( + request, + "calm/calm_view.html", + { + "now_task": now_task, "next_task": next_task, "later_tasks_count": later_tasks_count, }, @@ -101,7 +116,7 @@ async def post_morning_ritual( task = Task( title=mit_title, is_mit=True, - state=TaskState.LATER, # Initially LATER, will be promoted + state=TaskState.LATER, # Initially LATER, will be promoted certainty=TaskCertainty.SOFT, ) db.add(task) @@ -113,7 +128,7 @@ async def post_morning_ritual( db.add(journal_entry) # Create other tasks - for task_title in other_tasks.split('\n'): + for task_title in other_tasks.split("\n"): task_title = task_title.strip() if task_title: task = Task( @@ -128,20 +143,29 @@ async def post_morning_ritual( # Set initial NOW/NEXT states # Set initial NOW/NEXT states after all tasks are created if not get_now_task(db) and not get_next_task(db): - later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() + later_tasks = ( + db.query(Task) + .filter(Task.state == TaskState.LATER) + .order_by(Task.is_mit.desc(), Task.sort_order) + .all() + ) if later_tasks: # Set the highest priority LATER task to NOW later_tasks[0].state = TaskState.NOW db.add(later_tasks[0]) - db.flush() # Flush to make the change visible for the next query + db.flush() # Flush to make the change visible for the next query # Set the next highest priority LATER task to NEXT if len(later_tasks) > 1: later_tasks[1].state = TaskState.NEXT db.add(later_tasks[1]) - db.commit() # Commit changes after initial NOW/NEXT setup + db.commit() # Commit changes after initial NOW/NEXT setup - return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": get_now_task(db), + return templates.TemplateResponse( + request, + "calm/calm_view.html", + { + "now_task": get_now_task(db), "next_task": get_next_task(db), "later_tasks_count": len(get_later_tasks(db)), }, @@ -154,7 +178,8 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db if not journal_entry: raise HTTPException(status_code=404, detail="No journal entry for today") return templates.TemplateResponse( - "calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}) + "calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry} + ) @router.post("/calm/ritual/evening", response_class=HTMLResponse) @@ -175,9 +200,13 @@ async def post_evening_ritual( db.add(journal_entry) # Archive any remaining active tasks - active_tasks = db.query(Task).filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER])).all() + active_tasks = ( + db.query(Task) + .filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER])) + .all() + ) for task in active_tasks: - task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic + task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic task.deferred_at = datetime.utcnow() db.add(task) @@ -221,7 +250,7 @@ async def start_task( ): current_now_task = get_now_task(db) if current_now_task and current_now_task.id != task_id: - current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT + current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT db.add(current_now_task) task = db.query(Task).filter(Task.id == task_id).first() @@ -322,7 +351,7 @@ async def reorder_tasks( ): # Reorder LATER tasks if later_task_ids: - ids_in_order = [int(x.strip()) for x in later_task_ids.split(',') if x.strip()] + ids_in_order = [int(x.strip()) for x in later_task_ids.split(",") if x.strip()] for index, task_id in enumerate(ids_in_order): task = db.query(Task).filter(Task.id == task_id).first() if task and task.state == TaskState.LATER: @@ -332,16 +361,18 @@ async def reorder_tasks( # Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly) if next_task_id: task = db.query(Task).filter(Task.id == next_task_id).first() - if task and task.state == TaskState.LATER: # Only if it was a LATER task being promoted manually + if ( + task and task.state == TaskState.LATER + ): # Only if it was a LATER task being promoted manually # Demote current NEXT to LATER current_next = get_next_task(db) if current_next: current_next.state = TaskState.LATER - current_next.sort_order = len(get_later_tasks(db)) # Add to end of later + current_next.sort_order = len(get_later_tasks(db)) # Add to end of later db.add(current_next) task.state = TaskState.NEXT - task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency + task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency db.add(task) db.commit() diff --git a/src/dashboard/routes/chat_api.py b/src/dashboard/routes/chat_api.py index 4f855c8..04b16ca 100644 --- a/src/dashboard/routes/chat_api.py +++ b/src/dashboard/routes/chat_api.py @@ -27,12 +27,13 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["chat-api"]) -_UPLOAD_DIR = os.path.join("data", "chat-uploads") +_UPLOAD_DIR = str(Path(settings.repo_root) / "data" / "chat-uploads") _MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB # ── POST /api/chat ──────────────────────────────────────────────────────────── + @router.post("/chat") async def api_chat(request: Request): """Accept a JSON chat payload and return the agent's reply. @@ -65,7 +66,8 @@ async def api_chat(request: Request): # Handle multimodal content arrays — extract text parts if isinstance(content, list): text_parts = [ - p.get("text", "") for p in content + p.get("text", "") + for p in content if isinstance(p, dict) and p.get("type") == "text" ] last_user_msg = " ".join(text_parts).strip() @@ -109,6 +111,7 @@ async def api_chat(request: Request): # ── POST /api/upload ────────────────────────────────────────────────────────── + @router.post("/upload") async def api_upload(file: UploadFile = File(...)): """Accept a file upload and return its URL. @@ -147,6 +150,7 @@ async def api_upload(file: UploadFile = File(...)): # ── GET /api/chat/history ──────────────────────────────────────────────────── + @router.get("/chat/history") async def api_chat_history(): """Return the in-memory chat history as JSON.""" @@ -165,6 +169,7 @@ async def api_chat_history(): # ── DELETE /api/chat/history ────────────────────────────────────────────────── + @router.delete("/chat/history") async def api_clear_history(): """Clear the in-memory chat history.""" diff --git a/src/dashboard/routes/discord.py b/src/dashboard/routes/discord.py index 781f789..4608164 100644 --- a/src/dashboard/routes/discord.py +++ b/src/dashboard/routes/discord.py @@ -7,9 +7,10 @@ Endpoints: GET /discord/oauth-url — get the bot's OAuth2 authorization URL """ +from typing import Optional + from fastapi import APIRouter, File, Form, UploadFile from pydantic import BaseModel -from typing import Optional router = APIRouter(prefix="/discord", tags=["discord"]) diff --git a/src/dashboard/routes/experiments.py b/src/dashboard/routes/experiments.py new file mode 100644 index 0000000..37b0be3 --- /dev/null +++ b/src/dashboard/routes/experiments.py @@ -0,0 +1,77 @@ +"""Experiment dashboard routes — autoresearch experiment monitoring. + +Provides endpoints for viewing, starting, and monitoring autonomous +ML experiment loops powered by Karpathy's autoresearch pattern. +""" + +import logging +from pathlib import Path + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse + +from config import settings +from dashboard.templating import templates + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/experiments", tags=["experiments"]) + + +def _workspace() -> Path: + return Path(settings.repo_root) / settings.autoresearch_workspace + + +@router.get("", response_class=HTMLResponse) +async def experiments_page(request: Request): + """Experiment dashboard — lists past runs and allows starting new ones.""" + from timmy.autoresearch import get_experiment_history + + history = [] + try: + history = get_experiment_history(_workspace()) + except Exception: + logger.debug("Failed to load experiment history", exc_info=True) + + return templates.TemplateResponse( + request, + "experiments.html", + { + "page_title": "Experiments — Autoresearch", + "enabled": settings.autoresearch_enabled, + "history": history[:50], + "metric_name": settings.autoresearch_metric, + "time_budget": settings.autoresearch_time_budget, + "max_iterations": settings.autoresearch_max_iterations, + }, + ) + + +@router.post("/start", response_class=JSONResponse) +async def start_experiment(request: Request): + """Kick off an experiment loop in the background.""" + if not settings.autoresearch_enabled: + raise HTTPException( + status_code=403, + detail="Autoresearch is disabled. Set AUTORESEARCH_ENABLED=true.", + ) + + from timmy.autoresearch import prepare_experiment + + workspace = _workspace() + status = prepare_experiment(workspace) + + return {"status": "started", "workspace": str(workspace), "prepare": status} + + +@router.get("/{run_id}", response_class=JSONResponse) +async def experiment_detail(run_id: str): + """Get details for a specific experiment run.""" + from timmy.autoresearch import get_experiment_history + + history = get_experiment_history(_workspace()) + for entry in history: + if entry.get("run_id") == run_id: + return entry + + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") diff --git a/src/dashboard/routes/grok.py b/src/dashboard/routes/grok.py index 2856dae..9dd439b 100644 --- a/src/dashboard/routes/grok.py +++ b/src/dashboard/routes/grok.py @@ -43,6 +43,7 @@ async def grok_status(request: Request): stats = None try: from timmy.backends import get_grok_backend + backend = get_grok_backend() stats = { "total_requests": backend.stats.total_requests, @@ -52,12 +53,16 @@ async def grok_status(request: Request): "errors": backend.stats.errors, } except Exception: - pass + logger.debug("Failed to load Grok stats", exc_info=True) - return templates.TemplateResponse(request, "grok_status.html", { - "status": status, - "stats": stats, - }) + return templates.TemplateResponse( + request, + "grok_status.html", + { + "status": status, + "stats": stats, + }, + ) @router.post("/toggle") @@ -90,7 +95,7 @@ async def toggle_grok_mode(request: Request): success=True, ) except Exception: - pass + logger.debug("Failed to log Grok toggle to Spark", exc_info=True) return HTMLResponse( _render_toggle_card(_grok_mode_active), @@ -104,10 +109,13 @@ def _run_grok_query(message: str) -> dict: Returns: {"response": str | None, "error": str | None} """ - from timmy.backends import grok_available, get_grok_backend + from timmy.backends import get_grok_backend, grok_available if not grok_available(): - return {"response": None, "error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY."} + return { + "response": None, + "error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY.", + } backend = get_grok_backend() @@ -115,12 +123,13 @@ def _run_grok_query(message: str) -> dict: if not settings.grok_free: try: from lightning.factory import get_backend as get_ln_backend + ln = get_ln_backend() sats = min(settings.grok_max_sats_per_query, 100) ln.create_invoice(sats, f"Grok: {message[:50]}") invoice_note = f" | {sats} sats" except Exception: - pass + logger.debug("Lightning invoice creation failed", exc_info=True) try: result = backend.run(message) @@ -132,9 +141,10 @@ def _run_grok_query(message: str) -> dict: @router.post("/chat", response_class=HTMLResponse) async def grok_chat(request: Request, message: str = Form(...)): """Send a message directly to Grok and return HTMX chat partial.""" - from dashboard.store import message_log from datetime import datetime + from dashboard.store import message_log + timestamp = datetime.now().strftime("%H:%M:%S") result = _run_grok_query(message) @@ -142,9 +152,13 @@ async def grok_chat(request: Request, message: str = Form(...)): message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser") if result["response"]: - message_log.append(role="agent", content=result["response"], timestamp=timestamp, source="browser") + message_log.append( + role="agent", content=result["response"], timestamp=timestamp, source="browser" + ) else: - message_log.append(role="error", content=result["error"], timestamp=timestamp, source="browser") + message_log.append( + role="error", content=result["error"], timestamp=timestamp, source="browser" + ) return templates.TemplateResponse( request, @@ -185,6 +199,7 @@ async def grok_stats(): def _render_toggle_card(active: bool) -> str: """Render the Grok Mode toggle card HTML.""" import html + color = "#00ff88" if active else "#666" state = "ACTIVE" if active else "STANDBY" glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none" diff --git a/src/dashboard/routes/health.py b/src/dashboard/routes/health.py index b8a21f8..cb151e4 100644 --- a/src/dashboard/routes/health.py +++ b/src/dashboard/routes/health.py @@ -22,6 +22,7 @@ router = APIRouter(tags=["health"]) class DependencyStatus(BaseModel): """Status of a single dependency.""" + name: str status: str # "healthy", "degraded", "unavailable" sovereignty_score: int # 0-10 @@ -30,6 +31,7 @@ class DependencyStatus(BaseModel): class SovereigntyReport(BaseModel): """Full sovereignty audit report.""" + overall_score: float dependencies: list[DependencyStatus] timestamp: str @@ -38,6 +40,7 @@ class SovereigntyReport(BaseModel): class HealthStatus(BaseModel): """System health status.""" + status: str timestamp: str version: str @@ -52,6 +55,7 @@ def _check_ollama_sync() -> DependencyStatus: """Synchronous Ollama check — run via asyncio.to_thread().""" try: import urllib.request + url = settings.ollama_url.replace("localhost", "127.0.0.1") req = urllib.request.Request( f"{url}/api/tags", @@ -67,7 +71,7 @@ def _check_ollama_sync() -> DependencyStatus: details={"url": settings.ollama_url, "model": settings.ollama_model}, ) except Exception: - pass + logger.debug("Ollama health check failed", exc_info=True) return DependencyStatus( name="Ollama AI", @@ -142,7 +146,7 @@ def _calculate_overall_score(deps: list[DependencyStatus]) -> float: def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]: """Generate recommendations based on dependency status.""" recommendations = [] - + for dep in deps: if dep.status == "unavailable": recommendations.append(f"{dep.name} is unavailable - check configuration") @@ -151,25 +155,25 @@ def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]: recommendations.append( "Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND" ) - + if not recommendations: recommendations.append("System operating optimally - all dependencies healthy") - + return recommendations @router.get("/health") async def health_check(): """Basic health check endpoint. - + Returns legacy format for backward compatibility with existing tests, plus extended information for the Mission Control dashboard. """ uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds() - + # Legacy format for test compatibility ollama_ok = await check_ollama() - + agent_status = "idle" if ollama_ok else "offline" return { @@ -193,12 +197,13 @@ async def health_check(): async def health_status_panel(request: Request): """Simple HTML health status panel.""" ollama_ok = await check_ollama() - + status_text = "UP" if ollama_ok else "DOWN" status_color = "#10b981" if ollama_ok else "#ef4444" import html + model = html.escape(settings.ollama_model) # Include model for test compatibility - + html_content = f""" @@ -217,7 +222,7 @@ async def health_status_panel(request: Request): @router.get("/health/sovereignty", response_model=SovereigntyReport) async def sovereignty_check(): """Comprehensive sovereignty audit report. - + Returns the status of all external dependencies with sovereignty scores. Use this to verify the system is operating in a sovereign manner. """ @@ -226,10 +231,10 @@ async def sovereignty_check(): _check_lightning(), _check_sqlite(), ] - + overall = _calculate_overall_score(dependencies) recommendations = _generate_recommendations(dependencies) - + return SovereigntyReport( overall_score=overall, dependencies=dependencies, diff --git a/src/dashboard/routes/marketplace.py b/src/dashboard/routes/marketplace.py index 7afa684..583914d 100644 --- a/src/dashboard/routes/marketplace.py +++ b/src/dashboard/routes/marketplace.py @@ -19,8 +19,7 @@ AGENT_CATALOG = [ "name": "Orchestrator", "role": "Local AI", "description": ( - "Primary AI agent. Coordinates tasks, manages memory. " - "Uses distributed brain." + "Primary AI agent. Coordinates tasks, manages memory. " "Uses distributed brain." ), "capabilities": "chat,reasoning,coordination,memory", "rate_sats": 0, @@ -37,11 +36,11 @@ async def api_list_agents(): pending_tasks = len(await brain.get_pending_tasks(limit=1000)) except Exception: pending_tasks = 0 - + catalog = [dict(AGENT_CATALOG[0])] catalog[0]["pending_tasks"] = pending_tasks catalog[0]["status"] = "active" - + # Include 'total' for backward compatibility with tests return {"agents": catalog, "total": len(catalog)} @@ -82,7 +81,7 @@ async def marketplace_ui(request: Request): "page_title": "Agent Marketplace", "active_count": active, "planned_count": 0, - } + }, ) diff --git a/src/dashboard/routes/memory.py b/src/dashboard/routes/memory.py index 26dd13f..6b0629b 100644 --- a/src/dashboard/routes/memory.py +++ b/src/dashboard/routes/memory.py @@ -5,17 +5,17 @@ from typing import Optional from fastapi import APIRouter, Form, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse +from dashboard.templating import templates from timmy.memory.vector_store import ( - store_memory, - search_memories, + delete_memory, get_memory_stats, recall_personal_facts, recall_personal_facts_with_ids, + search_memories, + store_memory, store_personal_fact, update_personal_fact, - delete_memory, ) -from dashboard.templating import templates router = APIRouter(prefix="/memory", tags=["memory"]) @@ -36,10 +36,10 @@ async def memory_page( agent_id=agent_id, limit=20, ) - + stats = get_memory_stats() facts = recall_personal_facts_with_ids()[:10] - + return templates.TemplateResponse( request, "memory.html", @@ -67,7 +67,7 @@ async def memory_search( context_type=context_type, limit=20, ) - + # Return partial for HTMX return templates.TemplateResponse( request, diff --git a/src/dashboard/routes/models.py b/src/dashboard/routes/models.py index 5b63a9f..ed714aa 100644 --- a/src/dashboard/routes/models.py +++ b/src/dashboard/routes/models.py @@ -13,6 +13,7 @@ from fastapi.responses import HTMLResponse from pydantic import BaseModel from config import settings +from dashboard.templating import templates from infrastructure.models.registry import ( CustomModel, ModelFormat, @@ -20,7 +21,6 @@ from infrastructure.models.registry import ( ModelRole, model_registry, ) -from dashboard.templating import templates logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"]) class RegisterModelRequest(BaseModel): """Request body for model registration.""" + name: str format: str # gguf, safetensors, hf, ollama path: str @@ -45,12 +46,14 @@ class RegisterModelRequest(BaseModel): class AssignModelRequest(BaseModel): """Request body for assigning a model to an agent.""" + agent_id: str model_name: str class SetActiveRequest(BaseModel): """Request body for enabling/disabling a model.""" + active: bool @@ -92,15 +95,14 @@ async def register_model(request: RegisterModelRequest) -> dict[str, Any]: raise HTTPException( status_code=400, detail=f"Invalid format: {request.format}. " - f"Choose from: {[f.value for f in ModelFormat]}", + f"Choose from: {[f.value for f in ModelFormat]}", ) try: role = ModelRole(request.role) except ValueError: raise HTTPException( status_code=400, - detail=f"Invalid role: {request.role}. " - f"Choose from: {[r.value for r in ModelRole]}", + detail=f"Invalid role: {request.role}. " f"Choose from: {[r.value for r in ModelRole]}", ) # Validate path exists for non-Ollama formats @@ -163,9 +165,7 @@ async def unregister_model(model_name: str) -> dict[str, str]: @api_router.patch("/{model_name}/active") -async def set_model_active( - model_name: str, request: SetActiveRequest -) -> dict[str, str]: +async def set_model_active(model_name: str, request: SetActiveRequest) -> dict[str, str]: """Enable or disable a model.""" if not model_registry.set_active(model_name, request.active): raise HTTPException(status_code=404, detail=f"Model {model_name} not found") @@ -182,8 +182,7 @@ async def list_assignments() -> dict[str, Any]: assignments = model_registry.get_agent_assignments() return { "assignments": [ - {"agent_id": aid, "model_name": mname} - for aid, mname in assignments.items() + {"agent_id": aid, "model_name": mname} for aid, mname in assignments.items() ], "total": len(assignments), } diff --git a/src/dashboard/routes/router.py b/src/dashboard/routes/router.py index 4a833fc..5be982e 100644 --- a/src/dashboard/routes/router.py +++ b/src/dashboard/routes/router.py @@ -3,8 +3,8 @@ from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse -from timmy.cascade_adapter import get_cascade_adapter from dashboard.templating import templates +from timmy.cascade_adapter import get_cascade_adapter router = APIRouter(prefix="/router", tags=["router"]) @@ -13,19 +13,19 @@ router = APIRouter(prefix="/router", tags=["router"]) async def router_status_page(request: Request): """Cascade Router status dashboard.""" adapter = get_cascade_adapter() - + providers = adapter.get_provider_status() preferred = adapter.get_preferred_provider() - + # Calculate overall stats total_requests = sum(p["metrics"]["total"] for p in providers) total_success = sum(p["metrics"]["success"] for p in providers) total_failed = sum(p["metrics"]["failed"] for p in providers) - + avg_latency = 0.0 if providers: avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers) - + return templates.TemplateResponse( request, "router_status.html", diff --git a/src/dashboard/routes/spark.py b/src/dashboard/routes/spark.py index f542417..c7cbe45 100644 --- a/src/dashboard/routes/spark.py +++ b/src/dashboard/routes/spark.py @@ -13,8 +13,8 @@ import logging from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse -from spark.engine import spark_engine from dashboard.templating import templates +from spark.engine import spark_engine logger = logging.getLogger(__name__) @@ -86,23 +86,26 @@ async def spark_ui(request: Request): async def spark_status_json(): """Return Spark Intelligence status as JSON.""" from fastapi.responses import JSONResponse + status = spark_engine.status() advisories = spark_engine.get_advisories() - return JSONResponse({ - "status": status, - "advisories": [ - { - "category": a.category, - "priority": a.priority, - "title": a.title, - "detail": a.detail, - "suggested_action": a.suggested_action, - "subject": a.subject, - "evidence_count": a.evidence_count, - } - for a in advisories - ], - }) + return JSONResponse( + { + "status": status, + "advisories": [ + { + "category": a.category, + "priority": a.priority, + "title": a.title, + "detail": a.detail, + "suggested_action": a.suggested_action, + "subject": a.subject, + "evidence_count": a.evidence_count, + } + for a in advisories + ], + } + ) @router.get("/timeline", response_class=HTMLResponse) diff --git a/src/dashboard/routes/swarm.py b/src/dashboard/routes/swarm.py index 7f9d2d4..2e0c78b 100644 --- a/src/dashboard/routes/swarm.py +++ b/src/dashboard/routes/swarm.py @@ -7,9 +7,9 @@ from typing import Optional from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse -from spark.engine import spark_engine from dashboard.templating import templates from infrastructure.ws_manager.handler import ws_manager +from spark.engine import spark_engine logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ async def swarm_events( ): """Event log page.""" events = spark_engine.get_timeline(limit=100) - + # Filter if requested if task_id: events = [e for e in events if e.task_id == task_id] @@ -33,7 +33,7 @@ async def swarm_events( events = [e for e in events if e.agent_id == agent_id] if event_type: events = [e for e in events if e.event_type == event_type] - + # Prepare summary and event types for template summary = {} event_types = set() @@ -41,7 +41,7 @@ async def swarm_events( etype = e.event_type event_types.add(etype) summary[etype] = summary.get(etype, 0) + 1 - + return templates.TemplateResponse( request, "events.html", @@ -78,14 +78,16 @@ async def swarm_ws(websocket: WebSocket): await ws_manager.connect(websocket) try: # Send initial state so frontend can clear loading placeholders - await websocket.send_json({ - "type": "initial_state", - "data": { - "agents": {"total": 0, "active": 0, "list": []}, - "tasks": {"active": 0}, - "auctions": {"list": []}, - }, - }) + await websocket.send_json( + { + "type": "initial_state", + "data": { + "agents": {"total": 0, "active": 0, "list": []}, + "tasks": {"active": 0}, + "auctions": {"list": []}, + }, + } + ) while True: await websocket.receive_text() except WebSocketDisconnect: diff --git a/src/dashboard/routes/system.py b/src/dashboard/routes/system.py index f691e09..86ac817 100644 --- a/src/dashboard/routes/system.py +++ b/src/dashboard/routes/system.py @@ -25,26 +25,42 @@ async def lightning_ledger(request: Request): "pending_incoming_sats": 0, "pending_outgoing_sats": 0, } - + # Mock transactions from collections import namedtuple from enum import Enum - + class TxType(Enum): incoming = "incoming" outgoing = "outgoing" - + class TxStatus(Enum): completed = "completed" pending = "pending" - - Tx = namedtuple("Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"]) - + + Tx = namedtuple( + "Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"] + ) + transactions = [ - Tx(TxType.outgoing, TxStatus.completed, 50, "hash1", "Model inference", "2026-03-04 10:00:00"), - Tx(TxType.incoming, TxStatus.completed, 1000, "hash2", "Manual deposit", "2026-03-03 15:00:00"), + Tx( + TxType.outgoing, + TxStatus.completed, + 50, + "hash1", + "Model inference", + "2026-03-04 10:00:00", + ), + Tx( + TxType.incoming, + TxStatus.completed, + 1000, + "hash2", + "Manual deposit", + "2026-03-03 15:00:00", + ), ] - + return templates.TemplateResponse( request, "ledger.html", @@ -84,9 +100,16 @@ async def mission_control(request: Request): @router.get("/bugs", response_class=HTMLResponse) async def bugs_page(request: Request): - return templates.TemplateResponse(request, "bugs.html", { - "bugs": [], "total": 0, "stats": {}, "filter_status": None, - }) + return templates.TemplateResponse( + request, + "bugs.html", + { + "bugs": [], + "total": 0, + "stats": {}, + "filter_status": None, + }, + ) @router.get("/self-coding", response_class=HTMLResponse) @@ -109,14 +132,17 @@ async def api_notifications(): """Return recent system events for the notification dropdown.""" try: from spark.engine import spark_engine + events = spark_engine.get_timeline(limit=20) - return JSONResponse([ - { - "event_type": e.event_type, - "title": getattr(e, "description", e.event_type), - "timestamp": str(getattr(e, "timestamp", "")), - } - for e in events - ]) + return JSONResponse( + [ + { + "event_type": e.event_type, + "title": getattr(e, "description", e.event_type), + "timestamp": str(getattr(e, "timestamp", "")), + } + for e in events + ] + ) except Exception: return JSONResponse([]) diff --git a/src/dashboard/routes/tasks.py b/src/dashboard/routes/tasks.py index e39da87..ae33de6 100644 --- a/src/dashboard/routes/tasks.py +++ b/src/dashboard/routes/tasks.py @@ -7,9 +7,10 @@ from datetime import datetime from pathlib import Path from typing import Optional -from fastapi import APIRouter, HTTPException, Request, Form +from fastapi import APIRouter, Form, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse +from config import settings from dashboard.templating import templates logger = logging.getLogger(__name__) @@ -20,11 +21,17 @@ router = APIRouter(tags=["tasks"]) # Database helpers # --------------------------------------------------------------------------- -DB_PATH = Path("data/tasks.db") +DB_PATH = Path(settings.repo_root) / "data" / "tasks.db" VALID_STATUSES = { - "pending_approval", "approved", "running", "paused", - "completed", "vetoed", "failed", "backlogged", + "pending_approval", + "approved", + "running", + "paused", + "completed", + "vetoed", + "failed", + "backlogged", } VALID_PRIORITIES = {"low", "normal", "high", "urgent"} @@ -33,7 +40,8 @@ def _get_db() -> sqlite3.Connection: DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, title TEXT NOT NULL, @@ -46,7 +54,8 @@ def _get_db() -> sqlite3.Connection: created_at TEXT DEFAULT (datetime('now')), completed_at TEXT ) - """) + """ + ) conn.commit() return conn @@ -91,37 +100,52 @@ class _TaskView: # Page routes # --------------------------------------------------------------------------- + @router.get("/tasks", response_class=HTMLResponse) async def tasks_page(request: Request): """Render the main task queue page with 3-column layout.""" db = _get_db() try: - pending = [_TaskView(_row_to_dict(r)) for r in db.execute( - "SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC" - ).fetchall()] - active = [_TaskView(_row_to_dict(r)) for r in db.execute( - "SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC" - ).fetchall()] - completed = [_TaskView(_row_to_dict(r)) for r in db.execute( - "SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50" - ).fetchall()] + pending = [ + _TaskView(_row_to_dict(r)) + for r in db.execute( + "SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC" + ).fetchall() + ] + active = [ + _TaskView(_row_to_dict(r)) + for r in db.execute( + "SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC" + ).fetchall() + ] + completed = [ + _TaskView(_row_to_dict(r)) + for r in db.execute( + "SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50" + ).fetchall() + ] finally: db.close() - return templates.TemplateResponse(request, "tasks.html", { - "pending_count": len(pending), - "pending": pending, - "active": active, - "completed": completed, - "agents": [], # no agent roster wired yet - "pre_assign": "", - }) + return templates.TemplateResponse( + request, + "tasks.html", + { + "pending_count": len(pending), + "pending": pending, + "active": active, + "completed": completed, + "agents": [], # no agent roster wired yet + "pre_assign": "", + }, + ) # --------------------------------------------------------------------------- # HTMX partials (polled by the template) # --------------------------------------------------------------------------- + @router.get("/tasks/pending", response_class=HTMLResponse) async def tasks_pending(request: Request): db = _get_db() @@ -134,9 +158,11 @@ async def tasks_pending(request: Request): tasks = [_TaskView(_row_to_dict(r)) for r in rows] parts = [] for task in tasks: - parts.append(templates.TemplateResponse( - request, "partials/task_card.html", {"task": task} - ).body.decode()) + parts.append( + templates.TemplateResponse( + request, "partials/task_card.html", {"task": task} + ).body.decode() + ) if not parts: return HTMLResponse('
No pending tasks
') return HTMLResponse("".join(parts)) @@ -154,9 +180,11 @@ async def tasks_active(request: Request): tasks = [_TaskView(_row_to_dict(r)) for r in rows] parts = [] for task in tasks: - parts.append(templates.TemplateResponse( - request, "partials/task_card.html", {"task": task} - ).body.decode()) + parts.append( + templates.TemplateResponse( + request, "partials/task_card.html", {"task": task} + ).body.decode() + ) if not parts: return HTMLResponse('
No active tasks
') return HTMLResponse("".join(parts)) @@ -174,9 +202,11 @@ async def tasks_completed(request: Request): tasks = [_TaskView(_row_to_dict(r)) for r in rows] parts = [] for task in tasks: - parts.append(templates.TemplateResponse( - request, "partials/task_card.html", {"task": task} - ).body.decode()) + parts.append( + templates.TemplateResponse( + request, "partials/task_card.html", {"task": task} + ).body.decode() + ) if not parts: return HTMLResponse('
No completed tasks yet
') return HTMLResponse("".join(parts)) @@ -186,6 +216,7 @@ async def tasks_completed(request: Request): # Form-based create (used by the modal in tasks.html) # --------------------------------------------------------------------------- + @router.post("/tasks/create", response_class=HTMLResponse) async def create_task_form( request: Request, @@ -218,6 +249,7 @@ async def create_task_form( # Task action endpoints (approve, veto, modify, pause, cancel, retry) # --------------------------------------------------------------------------- + @router.post("/tasks/{task_id}/approve", response_class=HTMLResponse) async def approve_task(request: Request, task_id: str): return await _set_status(request, task_id, "approved") @@ -268,7 +300,9 @@ async def modify_task( async def _set_status(request: Request, task_id: str, new_status: str): """Helper to update status and return refreshed task card.""" - completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None + completed_at = ( + datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None + ) db = _get_db() try: db.execute( @@ -289,6 +323,7 @@ async def _set_status(request: Request, task_id: str, new_status: str): # JSON API (for programmatic access / Timmy's tool calls) # --------------------------------------------------------------------------- + @router.post("/api/tasks", response_class=JSONResponse, status_code=201) async def api_create_task(request: Request): """Create a task via JSON API.""" @@ -345,7 +380,9 @@ async def api_update_status(task_id: str, request: Request): if not new_status or new_status not in VALID_STATUSES: raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}") - completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None + completed_at = ( + datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None + ) db = _get_db() try: db.execute( @@ -379,6 +416,7 @@ async def api_delete_task(task_id: str): # Queue status (polled by the chat panel every 10 seconds) # --------------------------------------------------------------------------- + @router.get("/api/queue/status", response_class=JSONResponse) async def queue_status(assigned_to: str = "default"): """Return queue status for the chat panel's agent status indicator.""" @@ -396,14 +434,18 @@ async def queue_status(assigned_to: str = "default"): db.close() if running: - return JSONResponse({ - "is_working": True, - "current_task": {"id": running["id"], "title": running["title"]}, - "tasks_ahead": 0, - }) + return JSONResponse( + { + "is_working": True, + "current_task": {"id": running["id"], "title": running["title"]}, + "tasks_ahead": 0, + } + ) - return JSONResponse({ - "is_working": False, - "current_task": None, - "tasks_ahead": ahead["cnt"] if ahead else 0, - }) + return JSONResponse( + { + "is_working": False, + "current_task": None, + "tasks_ahead": ahead["cnt"] if ahead else 0, + } + ) diff --git a/src/dashboard/routes/thinking.py b/src/dashboard/routes/thinking.py index 563c2d7..a15b39c 100644 --- a/src/dashboard/routes/thinking.py +++ b/src/dashboard/routes/thinking.py @@ -10,8 +10,8 @@ import logging from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse, JSONResponse -from timmy.thinking import thinking_engine from dashboard.templating import templates +from timmy.thinking import thinking_engine logger = logging.getLogger(__name__) diff --git a/src/dashboard/routes/tools.py b/src/dashboard/routes/tools.py index 3cd52a8..d65ed9a 100644 --- a/src/dashboard/routes/tools.py +++ b/src/dashboard/routes/tools.py @@ -8,8 +8,8 @@ from collections import namedtuple from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse, JSONResponse -from timmy.tools import get_all_available_tools from dashboard.templating import templates +from timmy.tools import get_all_available_tools router = APIRouter(tags=["tools"]) @@ -29,9 +29,7 @@ def _build_agent_tools(): for name, fn in available.items() ] - return [ - _AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0)) - ] + return [_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))] @router.get("/tools", response_class=HTMLResponse) diff --git a/src/dashboard/routes/voice.py b/src/dashboard/routes/voice.py index 7482c10..08dcc6f 100644 --- a/src/dashboard/routes/voice.py +++ b/src/dashboard/routes/voice.py @@ -10,9 +10,9 @@ import logging from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse +from dashboard.templating import templates from integrations.voice.nlu import detect_intent, extract_command from timmy.agent import create_timmy -from dashboard.templating import templates logger = logging.getLogger(__name__) @@ -38,6 +38,7 @@ async def tts_status(): """Check TTS engine availability.""" try: from timmy_serve.voice_tts import voice_tts + return { "available": voice_tts.available, "voices": voice_tts.get_voices() if voice_tts.available else [], @@ -51,6 +52,7 @@ async def tts_speak(text: str = Form(...)): """Speak text aloud via TTS.""" try: from timmy_serve.voice_tts import voice_tts + if not voice_tts.available: return {"spoken": False, "reason": "TTS engine not available"} voice_tts.speak(text) @@ -86,6 +88,7 @@ async def voice_command(text: str = Form(...)): # ── Enhanced voice pipeline ────────────────────────────────────────────── + @router.post("/enhanced/process") async def process_voice_input( text: str = Form(...), @@ -133,6 +136,7 @@ async def process_voice_input( if speak_response and response_text: try: from timmy_serve.voice_tts import voice_tts + if voice_tts.available: voice_tts.speak(response_text) except Exception: diff --git a/src/dashboard/routes/work_orders.py b/src/dashboard/routes/work_orders.py index 1365f3e..a296840 100644 --- a/src/dashboard/routes/work_orders.py +++ b/src/dashboard/routes/work_orders.py @@ -6,7 +6,7 @@ import uuid from datetime import datetime from pathlib import Path -from fastapi import APIRouter, HTTPException, Request, Form +from fastapi import APIRouter, Form, HTTPException, Request from fastapi.responses import HTMLResponse, JSONResponse from dashboard.templating import templates @@ -26,7 +26,8 @@ def _get_db() -> sqlite3.Connection: DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS work_orders ( id TEXT PRIMARY KEY, title TEXT NOT NULL, @@ -41,7 +42,8 @@ def _get_db() -> sqlite3.Connection: created_at TEXT DEFAULT (datetime('now')), completed_at TEXT ) - """) + """ + ) conn.commit() return conn @@ -71,7 +73,9 @@ class _WOView: self.submitter = row.get("submitter", "dashboard") self.status = _EnumLike(row.get("status", "submitted")) raw_files = row.get("related_files", "") - self.related_files = [f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else [] + self.related_files = ( + [f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else [] + ) self.result = row.get("result", "") self.rejection_reason = row.get("rejection_reason", "") self.created_at = row.get("created_at", "") @@ -98,6 +102,7 @@ def _query_wos(db, statuses): # Page route # --------------------------------------------------------------------------- + @router.get("/work-orders/queue", response_class=HTMLResponse) async def work_orders_page(request: Request): db = _get_db() @@ -109,21 +114,26 @@ async def work_orders_page(request: Request): finally: db.close() - return templates.TemplateResponse(request, "work_orders.html", { - "pending_count": len(pending), - "pending": pending, - "active": active, - "completed": completed, - "rejected": rejected, - "priorities": PRIORITIES, - "categories": CATEGORIES, - }) + return templates.TemplateResponse( + request, + "work_orders.html", + { + "pending_count": len(pending), + "pending": pending, + "active": active, + "completed": completed, + "rejected": rejected, + "priorities": PRIORITIES, + "categories": CATEGORIES, + }, + ) # --------------------------------------------------------------------------- # Form submit # --------------------------------------------------------------------------- + @router.post("/work-orders/submit", response_class=HTMLResponse) async def submit_work_order( request: Request, @@ -159,6 +169,7 @@ async def submit_work_order( # HTMX partials # --------------------------------------------------------------------------- + @router.get("/work-orders/queue/pending", response_class=HTMLResponse) async def pending_partial(request: Request): db = _get_db() @@ -174,7 +185,9 @@ async def pending_partial(request: Request): parts = [] for wo in wos: parts.append( - templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode() + templates.TemplateResponse( + request, "partials/work_order_card.html", {"wo": wo} + ).body.decode() ) return HTMLResponse("".join(parts)) @@ -194,7 +207,9 @@ async def active_partial(request: Request): parts = [] for wo in wos: parts.append( - templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode() + templates.TemplateResponse( + request, "partials/work_order_card.html", {"wo": wo} + ).body.decode() ) return HTMLResponse("".join(parts)) @@ -203,8 +218,11 @@ async def active_partial(request: Request): # Action endpoints # --------------------------------------------------------------------------- + async def _update_status(request: Request, wo_id: str, new_status: str, **extra): - completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None + completed_at = ( + datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None + ) db = _get_db() try: sets = ["status=?", "completed_at=COALESCE(?, completed_at)"] diff --git a/src/dashboard/store.py b/src/dashboard/store.py index dd4bfcc..c79aa47 100644 --- a/src/dashboard/store.py +++ b/src/dashboard/store.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field @dataclass class Message: - role: str # "user" | "agent" | "error" + role: str # "user" | "agent" | "error" content: str timestamp: str source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system" @@ -16,7 +16,9 @@ class MessageLog: self._entries: list[Message] = [] def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None: - self._entries.append(Message(role=role, content=content, timestamp=timestamp, source=source)) + self._entries.append( + Message(role=role, content=content, timestamp=timestamp, source=source) + ) def all(self) -> list[Message]: return list(self._entries) diff --git a/src/dashboard/templates/experiments.html b/src/dashboard/templates/experiments.html new file mode 100644 index 0000000..d5fc330 --- /dev/null +++ b/src/dashboard/templates/experiments.html @@ -0,0 +1,90 @@ +{% extends "base.html" %} + +{% block title %}{{ page_title }}{% endblock %} + +{% block extra_styles %} + +{% endblock %} + +{% block content %} +
+
+
+
Autoresearch Experiments
+
Autonomous ML experiment loops — modify code, train, evaluate, iterate
+
+
+ {% if enabled %} + + {% else %} + +
Set AUTORESEARCH_ENABLED=true to enable
+ {% endif %} +
+
+ +
+ Metric: {{ metric_name }} + Budget: {{ time_budget }}s + Max iters: {{ max_iterations }} +
+ +
+ + {% if history %} + + + + + + + + + + + {% for run in history %} + + + + + + + {% endfor %} + +
#{{ metric_name }}DurationStatus
{{ loop.index }} + {% if run.metric is not none %} + {{ "%.4f"|format(run.metric) }} + {% else %} + — + {% endif %} + {{ run.get("duration_s", "—") }}s{% if run.get("success") %}OK{% else %}{{ run.get("error", "failed") }}{% endif %}
+ {% else %} +
+ No experiments yet. Start one to begin autonomous training. +
+ {% endif %} +
+{% endblock %} diff --git a/src/infrastructure/error_capture.py b/src/infrastructure/error_capture.py index b1de0bc..7b7c6bf 100644 --- a/src/infrastructure/error_capture.py +++ b/src/infrastructure/error_capture.py @@ -119,9 +119,7 @@ def capture_error( return None # Format the stack trace - tb_str = "".join( - traceback.format_exception(type(exc), exc, exc.__traceback__) - ) + tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) # Extract file/line from traceback tb_obj = exc.__traceback__ diff --git a/src/infrastructure/events/broadcaster.py b/src/infrastructure/events/broadcaster.py index e4f4623..fa53154 100644 --- a/src/infrastructure/events/broadcaster.py +++ b/src/infrastructure/events/broadcaster.py @@ -19,38 +19,39 @@ logger = logging.getLogger(__name__) class EventBroadcaster: """Broadcasts events to WebSocket clients. - + Usage: from infrastructure.events.broadcaster import event_broadcaster event_broadcaster.broadcast(event) """ - + def __init__(self) -> None: self._ws_manager: Optional = None - + def _get_ws_manager(self): """Lazy import to avoid circular deps.""" if self._ws_manager is None: try: from infrastructure.ws_manager.handler import ws_manager + self._ws_manager = ws_manager except Exception as exc: logger.debug("WebSocket manager not available: %s", exc) return self._ws_manager - + async def broadcast(self, event: EventLogEntry) -> int: """Broadcast an event to all connected WebSocket clients. - + Args: event: The event to broadcast - + Returns: Number of clients notified """ ws_manager = self._get_ws_manager() if not ws_manager: return 0 - + # Build message payload payload = { "type": "event", @@ -62,9 +63,9 @@ class EventBroadcaster: "agent_id": event.agent_id, "timestamp": event.timestamp, "data": event.data, - } + }, } - + try: # Broadcast to all connected clients count = await ws_manager.broadcast_json(payload) @@ -73,10 +74,10 @@ class EventBroadcaster: except Exception as exc: logger.error("Failed to broadcast event: %s", exc) return 0 - + def broadcast_sync(self, event: EventLogEntry) -> None: """Synchronous wrapper for broadcast. - + Use this from synchronous code - it schedules the async broadcast in the event loop if one is running. """ @@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str: def format_event_for_display(event: EventLogEntry) -> dict: """Format event for display in activity feed. - + Returns dict with display-friendly fields. """ data = event.data or {} - + # Build description based on event type description = "" if event.event_type.value == "task.created": @@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict: val = str(data[key]) description = val[:60] + "..." if len(val) > 60 else val break - + return { "id": event.id, "icon": get_event_icon(event.event_type.value), diff --git a/src/infrastructure/events/bus.py b/src/infrastructure/events/bus.py index 476a76f..a0a6492 100644 --- a/src/infrastructure/events/bus.py +++ b/src/infrastructure/events/bus.py @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) @dataclass class Event: """A typed event in the system.""" + type: str # e.g., "agent.task.assigned", "tool.execution.completed" source: str # Agent or component that emitted the event data: dict = field(default_factory=dict) @@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]] class EventBus: """Async event bus for publish/subscribe pattern. - + Usage: bus = EventBus() - + # Subscribe to events @bus.subscribe("agent.task.*") async def handle_task(event: Event): print(f"Task event: {event.data}") - + # Publish events await bus.publish(Event( type="agent.task.assigned", @@ -45,88 +46,89 @@ class EventBus: data={"task_id": "123", "agent": "forge"} )) """ - + def __init__(self) -> None: self._subscribers: dict[str, list[EventHandler]] = {} self._history: list[Event] = [] self._max_history = 1000 logger.info("EventBus initialized") - + def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]: """Decorator to subscribe to events matching a pattern. - + Patterns support wildcards: - "agent.task.assigned" — exact match - "agent.task.*" — any task event - "agent.*" — any agent event - "*" — all events """ + def decorator(handler: EventHandler) -> EventHandler: if event_pattern not in self._subscribers: self._subscribers[event_pattern] = [] self._subscribers[event_pattern].append(handler) logger.debug("Subscribed handler to '%s'", event_pattern) return handler + return decorator - + def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool: """Remove a handler from a subscription.""" if event_pattern not in self._subscribers: return False - + if handler in self._subscribers[event_pattern]: self._subscribers[event_pattern].remove(handler) logger.debug("Unsubscribed handler from '%s'", event_pattern) return True - + return False - + async def publish(self, event: Event) -> int: """Publish an event to all matching subscribers. - + Returns: Number of handlers invoked """ # Store in history self._history.append(event) if len(self._history) > self._max_history: - self._history = self._history[-self._max_history:] - + self._history = self._history[-self._max_history :] + # Find matching handlers handlers: list[EventHandler] = [] - + for pattern, pattern_handlers in self._subscribers.items(): if self._match_pattern(event.type, pattern): handlers.extend(pattern_handlers) - + # Invoke handlers concurrently if handlers: await asyncio.gather( - *[self._invoke_handler(h, event) for h in handlers], - return_exceptions=True + *[self._invoke_handler(h, event) for h in handlers], return_exceptions=True ) - + logger.debug("Published event '%s' to %d handlers", event.type, len(handlers)) return len(handlers) - + async def _invoke_handler(self, handler: EventHandler, event: Event) -> None: """Invoke a handler with error handling.""" try: await handler(event) except Exception as exc: logger.error("Event handler failed for '%s': %s", event.type, exc) - + def _match_pattern(self, event_type: str, pattern: str) -> bool: """Check if event type matches a wildcard pattern.""" if pattern == "*": return True - + if pattern.endswith(".*"): prefix = pattern[:-2] return event_type.startswith(prefix + ".") - + return event_type == pattern - + def get_history( self, event_type: str | None = None, @@ -135,15 +137,15 @@ class EventBus: ) -> list[Event]: """Get recent event history with optional filtering.""" events = self._history - + if event_type: events = [e for e in events if e.type == event_type] - + if source: events = [e for e in events if e.source == source] - + return events[-limit:] - + def clear_history(self) -> None: """Clear event history.""" self._history.clear() @@ -156,11 +158,13 @@ event_bus = EventBus() # Convenience functions async def emit(event_type: str, source: str, data: dict) -> int: """Quick emit an event.""" - return await event_bus.publish(Event( - type=event_type, - source=source, - data=data, - )) + return await event_bus.publish( + Event( + type=event_type, + source=source, + data=data, + ) + ) def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]: diff --git a/src/infrastructure/hands/__init__.py b/src/infrastructure/hands/__init__.py index 15309f7..0a3f744 100644 --- a/src/infrastructure/hands/__init__.py +++ b/src/infrastructure/hands/__init__.py @@ -11,7 +11,7 @@ Usage: result = await git_hand.run("status") """ -from infrastructure.hands.shell import shell_hand from infrastructure.hands.git import git_hand +from infrastructure.hands.shell import shell_hand __all__ = ["shell_hand", "git_hand"] diff --git a/src/infrastructure/hands/git.py b/src/infrastructure/hands/git.py index f1af5ef..d404ba1 100644 --- a/src/infrastructure/hands/git.py +++ b/src/infrastructure/hands/git.py @@ -25,16 +25,18 @@ from config import settings logger = logging.getLogger(__name__) # Operations that require explicit confirmation before execution -DESTRUCTIVE_OPS = frozenset({ - "push --force", - "push -f", - "reset --hard", - "clean -fd", - "clean -f", - "branch -D", - "checkout -- .", - "restore .", -}) +DESTRUCTIVE_OPS = frozenset( + { + "push --force", + "push -f", + "reset --hard", + "clean -fd", + "clean -f", + "branch -D", + "checkout -- .", + "restore .", + } +) @dataclass @@ -190,7 +192,9 @@ class GitHand: flag = "-b" if create else "" return await self.run(f"checkout {flag} {branch}".strip()) - async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult: + async def push( + self, remote: str = "origin", branch: str = "", force: bool = False + ) -> GitResult: """Push to remote. Force-push requires explicit opt-in.""" args = f"push -u {remote} {branch}".strip() if force: diff --git a/src/infrastructure/hands/shell.py b/src/infrastructure/hands/shell.py index c44c8f8..ee10c79 100644 --- a/src/infrastructure/hands/shell.py +++ b/src/infrastructure/hands/shell.py @@ -26,15 +26,17 @@ from config import settings logger = logging.getLogger(__name__) # Commands that are always blocked regardless of allow-list -_BLOCKED_COMMANDS = frozenset({ - "rm -rf /", - "rm -rf /*", - "mkfs", - "dd if=/dev/zero", - ":(){ :|:& };:", # fork bomb - "> /dev/sda", - "chmod -R 777 /", -}) +_BLOCKED_COMMANDS = frozenset( + { + "rm -rf /", + "rm -rf /*", + "mkfs", + "dd if=/dev/zero", + ":(){ :|:& };:", # fork bomb + "> /dev/sda", + "chmod -R 777 /", + } +) # Default allow-list: safe build/dev commands DEFAULT_ALLOWED_PREFIXES = ( @@ -199,9 +201,7 @@ class ShellHand: proc.kill() await proc.wait() latency = (time.time() - start) * 1000 - logger.warning( - "Shell command timed out after %ds: %s", effective_timeout, command - ) + logger.warning("Shell command timed out after %ds: %s", effective_timeout, command) return ShellResult( command=command, success=False, diff --git a/src/infrastructure/hands/tools.py b/src/infrastructure/hands/tools.py index a08860b..b2aaa96 100644 --- a/src/infrastructure/hands/tools.py +++ b/src/infrastructure/hands/tools.py @@ -11,15 +11,17 @@ the tool registry. import logging from typing import Any -from infrastructure.hands.shell import shell_hand from infrastructure.hands.git import git_hand +from infrastructure.hands.shell import shell_hand try: from mcp.schemas.base import create_tool_schema except ImportError: + def create_tool_schema(**kwargs): return kwargs + logger = logging.getLogger(__name__) # ── Tool schemas ───────────────────────────────────────────────────────────── @@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = { # ── Handlers ───────────────────────────────────────────────────────────────── + async def _handle_shell(**kwargs: Any) -> str: """Handler for the shell MCP tool.""" command = kwargs.get("command", "") diff --git a/src/infrastructure/models/__init__.py b/src/infrastructure/models/__init__.py index ee8fe04..2f42430 100644 --- a/src/infrastructure/models/__init__.py +++ b/src/infrastructure/models/__init__.py @@ -1,12 +1,5 @@ """Infrastructure models package.""" -from infrastructure.models.registry import ( - CustomModel, - ModelFormat, - ModelRegistry, - ModelRole, - model_registry, -) from infrastructure.models.multimodal import ( ModelCapability, ModelInfo, @@ -17,6 +10,13 @@ from infrastructure.models.multimodal import ( model_supports_vision, pull_model_with_fallback, ) +from infrastructure.models.registry import ( + CustomModel, + ModelFormat, + ModelRegistry, + ModelRole, + model_registry, +) __all__ = [ # Registry diff --git a/src/infrastructure/models/multimodal.py b/src/infrastructure/models/multimodal.py index b648397..d7220f6 100644 --- a/src/infrastructure/models/multimodal.py +++ b/src/infrastructure/models/multimodal.py @@ -21,39 +21,130 @@ logger = logging.getLogger(__name__) class ModelCapability(Enum): """Capabilities a model can have.""" - TEXT = auto() # Standard text completion - VISION = auto() # Image understanding - AUDIO = auto() # Audio/speech processing - TOOLS = auto() # Function calling / tool use - JSON = auto() # Structured output / JSON mode - STREAMING = auto() # Streaming responses + + TEXT = auto() # Standard text completion + VISION = auto() # Image understanding + AUDIO = auto() # Audio/speech processing + TOOLS = auto() # Function calling / tool use + JSON = auto() # Structured output / JSON mode + STREAMING = auto() # Streaming responses # Known model capabilities (local Ollama models) # These are used when we can't query the model directly KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = { # Llama 3.x series - "llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, + "llama3.1": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "llama3.1:8b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "llama3.1:8b-instruct": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "llama3.1:70b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "llama3.1:405b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "llama3.2": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, "llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, - "llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - "llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - "llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - + "llama3.2:3b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, + "llama3.2-vision": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, + "llama3.2-vision:11b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, # Qwen series - "qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - "qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - "qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, - + "qwen2.5": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "qwen2.5:7b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "qwen2.5:14b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "qwen2.5:32b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "qwen2.5:72b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "qwen2.5-vl": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, + "qwen2.5-vl:3b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, + "qwen2.5-vl:7b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + ModelCapability.VISION, + }, # DeepSeek series "deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, @@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = { "deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, - "deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - + "deepseek-v3": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, # Gemma series "gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, - # Mistral series - "mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - + "mistral": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "mistral:7b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "mistral-nemo": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "mistral-small": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "mistral-large": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, # Vision-specific models "llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, @@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = { "bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, - # Phi series "phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, - "phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - + "phi4": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, # Command R - "command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - + "command-r": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "command-r:35b": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "command-r-plus": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, # Granite (IBM) - "granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, - "granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, + "granite3-dense": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, + "granite3-moe": { + ModelCapability.TEXT, + ModelCapability.TOOLS, + ModelCapability.JSON, + ModelCapability.STREAMING, + }, } @@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = { # These are tried in order when the primary model doesn't support a capability DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = { ModelCapability.VISION: [ - "llama3.2:3b", # Fast vision model - "llava:7b", # Classic vision model - "qwen2.5-vl:3b", # Qwen vision - "moondream:1.8b", # Tiny vision model (last resort) + "llama3.2:3b", # Fast vision model + "llava:7b", # Classic vision model + "qwen2.5-vl:3b", # Qwen vision + "moondream:1.8b", # Tiny vision model (last resort) ], ModelCapability.TOOLS: [ "llama3.1:8b-instruct", # Best tool use - "llama3.2:3b", # Smaller but capable - "qwen2.5:7b", # Reliable fallback + "llama3.2:3b", # Smaller but capable + "qwen2.5:7b", # Reliable fallback ], ModelCapability.AUDIO: [ # Audio models are less common in Ollama @@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = { @dataclass class ModelInfo: """Information about a model's capabilities and availability.""" + name: str capabilities: set[ModelCapability] = field(default_factory=set) is_available: bool = False is_pulled: bool = False size_mb: Optional[int] = None description: str = "" - + def supports(self, capability: ModelCapability) -> bool: """Check if model supports a specific capability.""" return capability in self.capabilities @@ -142,26 +288,26 @@ class ModelInfo: class MultiModalManager: """Manages multi-modal model capabilities and fallback chains. - + This class: 1. Detects what capabilities each model has 2. Maintains fallback chains for different capabilities 3. Pulls models on-demand with automatic fallback 4. Routes requests to appropriate models based on content type """ - + def __init__(self, ollama_url: Optional[str] = None) -> None: self.ollama_url = ollama_url or settings.ollama_url self._available_models: dict[str, ModelInfo] = {} self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS) self._refresh_available_models() - + def _refresh_available_models(self) -> None: """Query Ollama for available models.""" try: - import urllib.request import json - + import urllib.request + url = self.ollama_url.replace("localhost", "127.0.0.1") req = urllib.request.Request( f"{url}/api/tags", @@ -170,7 +316,7 @@ class MultiModalManager: ) with urllib.request.urlopen(req, timeout=5) as response: data = json.loads(response.read().decode()) - + for model_data in data.get("models", []): name = model_data.get("name", "") self._available_models[name] = ModelInfo( @@ -181,58 +327,53 @@ class MultiModalManager: size_mb=model_data.get("size", 0) // (1024 * 1024), description=model_data.get("details", {}).get("family", ""), ) - + logger.info("Found %d models in Ollama", len(self._available_models)) - + except Exception as exc: logger.warning("Could not refresh available models: %s", exc) - + def _detect_capabilities(self, model_name: str) -> set[ModelCapability]: """Detect capabilities for a model based on known data.""" # Normalize model name (strip tags for lookup) base_name = model_name.split(":")[0] - + # Try exact match first if model_name in KNOWN_MODEL_CAPABILITIES: return set(KNOWN_MODEL_CAPABILITIES[model_name]) - + # Try base name match if base_name in KNOWN_MODEL_CAPABILITIES: return set(KNOWN_MODEL_CAPABILITIES[base_name]) - + # Default to text-only for unknown models logger.debug("Unknown model %s, defaulting to TEXT only", model_name) return {ModelCapability.TEXT, ModelCapability.STREAMING} - + def get_model_capabilities(self, model_name: str) -> set[ModelCapability]: """Get capabilities for a specific model.""" if model_name in self._available_models: return self._available_models[model_name].capabilities return self._detect_capabilities(model_name) - + def model_supports(self, model_name: str, capability: ModelCapability) -> bool: """Check if a model supports a specific capability.""" capabilities = self.get_model_capabilities(model_name) return capability in capabilities - + def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]: """Get all available models that support a capability.""" - return [ - info for info in self._available_models.values() - if capability in info.capabilities - ] - + return [info for info in self._available_models.values() if capability in info.capabilities] + def get_best_model_for( - self, - capability: ModelCapability, - preferred_model: Optional[str] = None + self, capability: ModelCapability, preferred_model: Optional[str] = None ) -> Optional[str]: """Get the best available model for a specific capability. - + Args: capability: The required capability preferred_model: Preferred model to use if available and capable - + Returns: Model name or None if no suitable model found """ @@ -243,25 +384,26 @@ class MultiModalManager: return preferred_model logger.debug( "Preferred model %s doesn't support %s, checking fallbacks", - preferred_model, capability.name + preferred_model, + capability.name, ) - + # Check fallback chain for this capability fallback_chain = self._fallback_chains.get(capability, []) for model_name in fallback_chain: if model_name in self._available_models: logger.debug("Using fallback model %s for %s", model_name, capability.name) return model_name - + # Find any available model with this capability capable_models = self.get_models_with_capability(capability) if capable_models: # Sort by size (prefer smaller/faster models as fallback) - capable_models.sort(key=lambda m: m.size_mb or float('inf')) + capable_models.sort(key=lambda m: m.size_mb or float("inf")) return capable_models[0].name - + return None - + def pull_model_with_fallback( self, primary_model: str, @@ -269,58 +411,58 @@ class MultiModalManager: auto_pull: bool = True, ) -> tuple[str, bool]: """Pull a model with automatic fallback if unavailable. - + Args: primary_model: The desired model to use capability: Required capability (for finding fallback) auto_pull: Whether to attempt pulling missing models - + Returns: Tuple of (model_name, is_fallback) """ # Check if primary model is already available if primary_model in self._available_models: return primary_model, False - + # Try to pull the primary model if auto_pull: if self._pull_model(primary_model): return primary_model, False - + # Need to find a fallback if capability: fallback = self.get_best_model_for(capability, primary_model) if fallback: logger.info( - "Primary model %s unavailable, using fallback %s", - primary_model, fallback + "Primary model %s unavailable, using fallback %s", primary_model, fallback ) return fallback, True - + # Last resort: use the configured default model default_model = settings.ollama_model if default_model in self._available_models: logger.warning( "Falling back to default model %s (primary: %s unavailable)", - default_model, primary_model + default_model, + primary_model, ) return default_model, True - + # Absolute last resort return primary_model, False - + def _pull_model(self, model_name: str) -> bool: """Attempt to pull a model from Ollama. - + Returns: True if successful or model already exists """ try: - import urllib.request import json - + import urllib.request + logger.info("Pulling model: %s", model_name) - + url = self.ollama_url.replace("localhost", "127.0.0.1") req = urllib.request.Request( f"{url}/api/pull", @@ -328,7 +470,7 @@ class MultiModalManager: headers={"Content-Type": "application/json"}, data=json.dumps({"name": model_name, "stream": False}).encode(), ) - + with urllib.request.urlopen(req, timeout=300) as response: if response.status == 200: logger.info("Successfully pulled model: %s", model_name) @@ -338,55 +480,51 @@ class MultiModalManager: else: logger.error("Failed to pull %s: HTTP %s", model_name, response.status) return False - + except Exception as exc: logger.error("Error pulling model %s: %s", model_name, exc) return False - - def configure_fallback_chain( - self, - capability: ModelCapability, - models: list[str] - ) -> None: + + def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None: """Configure a custom fallback chain for a capability.""" self._fallback_chains[capability] = models logger.info("Configured fallback chain for %s: %s", capability.name, models) - + def get_fallback_chain(self, capability: ModelCapability) -> list[str]: """Get the fallback chain for a capability.""" return list(self._fallback_chains.get(capability, [])) - + def list_available_models(self) -> list[ModelInfo]: """List all available models with their capabilities.""" return list(self._available_models.values()) - + def refresh(self) -> None: """Refresh the list of available models.""" self._refresh_available_models() - + def get_model_for_content( self, content_type: str, # "text", "image", "audio", "multimodal" preferred_model: Optional[str] = None, ) -> tuple[str, bool]: """Get appropriate model based on content type. - + Args: content_type: Type of content (text, image, audio, multimodal) preferred_model: User's preferred model - + Returns: Tuple of (model_name, is_fallback) """ content_type = content_type.lower() - + if content_type in ("image", "vision", "multimodal"): # For vision content, we need a vision-capable model return self.pull_model_with_fallback( preferred_model or "llava:7b", capability=ModelCapability.VISION, ) - + elif content_type == "audio": # Audio support is limited in Ollama # Would need specific audio models @@ -395,7 +533,7 @@ class MultiModalManager: preferred_model or settings.ollama_model, capability=ModelCapability.TEXT, ) - + else: # Standard text content return self.pull_model_with_fallback( @@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager: def get_model_for_capability( - capability: ModelCapability, - preferred_model: Optional[str] = None + capability: ModelCapability, preferred_model: Optional[str] = None ) -> Optional[str]: """Convenience function to get best model for a capability.""" return get_multimodal_manager().get_best_model_for(capability, preferred_model) @@ -430,9 +567,7 @@ def pull_model_with_fallback( auto_pull: bool = True, ) -> tuple[str, bool]: """Convenience function to pull model with fallback.""" - return get_multimodal_manager().pull_model_with_fallback( - primary_model, capability, auto_pull - ) + return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull) def model_supports_vision(model_name: str) -> bool: diff --git a/src/infrastructure/models/registry.py b/src/infrastructure/models/registry.py index b9a568c..34bedf5 100644 --- a/src/infrastructure/models/registry.py +++ b/src/infrastructure/models/registry.py @@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db") class ModelFormat(str, Enum): """Supported model weight formats.""" - GGUF = "gguf" # Ollama-compatible quantised weights - SAFETENSORS = "safetensors" # HuggingFace safetensors - HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory - OLLAMA = "ollama" # Already loaded in Ollama by name + + GGUF = "gguf" # Ollama-compatible quantised weights + SAFETENSORS = "safetensors" # HuggingFace safetensors + HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory + OLLAMA = "ollama" # Already loaded in Ollama by name class ModelRole(str, Enum): """Role a model can play in the system (OpenClaw-RL style).""" - GENERAL = "general" # Default agent inference - REWARD = "reward" # Process Reward Model (PRM) scoring - TEACHER = "teacher" # On-policy distillation teacher - JUDGE = "judge" # Output quality evaluation + + GENERAL = "general" # Default agent inference + REWARD = "reward" # Process Reward Model (PRM) scoring + TEACHER = "teacher" # On-policy distillation teacher + JUDGE = "judge" # Output quality evaluation @dataclass class CustomModel: """A registered custom model.""" + name: str format: ModelFormat - path: str # Absolute path or Ollama model name + path: str # Absolute path or Ollama model name role: ModelRole = ModelRole.GENERAL context_window: int = 4096 description: str = "" @@ -141,10 +144,16 @@ class ModelRegistry: VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( - model.name, model.format.value, model.path, - model.role.value, model.context_window, model.description, - model.registered_at, int(model.active), - model.default_temperature, model.max_tokens, + model.name, + model.format.value, + model.path, + model.role.value, + model.context_window, + model.description, + model.registered_at, + int(model.active), + model.default_temperature, + model.max_tokens, ), ) conn.commit() @@ -160,9 +169,7 @@ class ModelRegistry: return False conn = _get_conn() conn.execute("DELETE FROM custom_models WHERE name = ?", (name,)) - conn.execute( - "DELETE FROM agent_model_assignments WHERE model_name = ?", (name,) - ) + conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)) conn.commit() conn.close() del self._models[name] diff --git a/src/infrastructure/notifications/push.py b/src/infrastructure/notifications/push.py index 1bd251f..6775143 100644 --- a/src/infrastructure/notifications/push.py +++ b/src/infrastructure/notifications/push.py @@ -9,8 +9,8 @@ No cloud push services — everything stays local. """ import logging -import subprocess import platform +import subprocess from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone @@ -25,9 +25,7 @@ class Notification: title: str message: str category: str # swarm | task | agent | system | payment - timestamp: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat() - ) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) read: bool = False @@ -74,9 +72,11 @@ class PushNotifier: def _native_notify(self, title: str, message: str) -> None: """Send a native macOS notification via osascript.""" try: + safe_message = message.replace("\\", "\\\\").replace('"', '\\"') + safe_title = title.replace("\\", "\\\\").replace('"', '\\"') script = ( - f'display notification "{message}" ' - f'with title "Agent Dashboard" subtitle "{title}"' + f'display notification "{safe_message}" ' + f'with title "Agent Dashboard" subtitle "{safe_title}"' ) subprocess.Popen( ["osascript", "-e", script], @@ -114,7 +114,7 @@ class PushNotifier: def clear(self) -> None: self._notifications.clear() - def add_listener(self, callback) -> None: + def add_listener(self, callback: "Callable[[Notification], None]") -> None: """Register a callback for real-time notification delivery.""" self._listeners.append(callback) @@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None: logger.info("Briefing ready but no pending approvals — skipping native notification") return - message = ( - f"Your morning briefing is ready. " - f"{n_approvals} item(s) await your approval." - ) + message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval." notifier.notify( title="Morning Briefing Ready", message=message, diff --git a/src/infrastructure/openfang/client.py b/src/infrastructure/openfang/client.py index 3c234f0..f1a8ec8 100644 --- a/src/infrastructure/openfang/client.py +++ b/src/infrastructure/openfang/client.py @@ -156,33 +156,23 @@ class OpenFangClient: async def browse(self, url: str, instruction: str = "") -> HandResult: """Web automation via OpenFang's Browser hand.""" - return await self.execute_hand( - "browser", {"url": url, "instruction": instruction} - ) + return await self.execute_hand("browser", {"url": url, "instruction": instruction}) async def collect(self, target: str, depth: str = "shallow") -> HandResult: """OSINT collection via OpenFang's Collector hand.""" - return await self.execute_hand( - "collector", {"target": target, "depth": depth} - ) + return await self.execute_hand("collector", {"target": target, "depth": depth}) async def predict(self, question: str, horizon: str = "1w") -> HandResult: """Superforecasting via OpenFang's Predictor hand.""" - return await self.execute_hand( - "predictor", {"question": question, "horizon": horizon} - ) + return await self.execute_hand("predictor", {"question": question, "horizon": horizon}) async def find_leads(self, icp: str, max_results: int = 10) -> HandResult: """Prospect discovery via OpenFang's Lead hand.""" - return await self.execute_hand( - "lead", {"icp": icp, "max_results": max_results} - ) + return await self.execute_hand("lead", {"icp": icp, "max_results": max_results}) async def research(self, topic: str, depth: str = "standard") -> HandResult: """Deep research via OpenFang's Researcher hand.""" - return await self.execute_hand( - "researcher", {"topic": topic, "depth": depth} - ) + return await self.execute_hand("researcher", {"topic": topic, "depth": depth}) # ── Inventory ──────────────────────────────────────────────────────────── diff --git a/src/infrastructure/openfang/tools.py b/src/infrastructure/openfang/tools.py index a51e8d8..db21c74 100644 --- a/src/infrastructure/openfang/tools.py +++ b/src/infrastructure/openfang/tools.py @@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client try: from mcp.schemas.base import create_tool_schema except ImportError: + def create_tool_schema(**kwargs): return kwargs + logger = logging.getLogger(__name__) # ── Tool schemas ───────────────────────────────────────────────────────────── diff --git a/src/infrastructure/router/__init__.py b/src/infrastructure/router/__init__.py index 730f623..49d97ee 100644 --- a/src/infrastructure/router/__init__.py +++ b/src/infrastructure/router/__init__.py @@ -1,7 +1,7 @@ """Cascade LLM Router — Automatic failover between providers.""" -from .cascade import CascadeRouter, Provider, ProviderStatus, get_router from .api import router +from .cascade import CascadeRouter, Provider, ProviderStatus, get_router __all__ = [ "CascadeRouter", diff --git a/src/infrastructure/router/api.py b/src/infrastructure/router/api.py index 7558c4e..a76c345 100644 --- a/src/infrastructure/router/api.py +++ b/src/infrastructure/router/api.py @@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"]) class CompletionRequest(BaseModel): """Request body for completions.""" + messages: list[dict[str, str]] model: str | None = None temperature: float = 0.7 @@ -23,6 +24,7 @@ class CompletionRequest(BaseModel): class CompletionResponse(BaseModel): """Response from completion endpoint.""" + content: str provider: str model: str @@ -31,6 +33,7 @@ class CompletionResponse(BaseModel): class ProviderControl(BaseModel): """Control a provider's status.""" + action: str # "enable", "disable", "reset_circuit" @@ -45,7 +48,7 @@ async def complete( cascade: Annotated[CascadeRouter, Depends(get_cascade_router)], ) -> dict[str, Any]: """Complete a conversation with automatic failover. - + Routes through providers in priority order until one succeeds. """ try: @@ -108,30 +111,32 @@ async def control_provider( if p.name == provider_name: provider = p break - + if not provider: raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found") - + if control.action == "enable": provider.enabled = True provider.status = provider.status.__class__.HEALTHY return {"message": f"Provider {provider_name} enabled"} - + elif control.action == "disable": provider.enabled = False from .cascade import ProviderStatus + provider.status = ProviderStatus.DISABLED return {"message": f"Provider {provider_name} disabled"} - + elif control.action == "reset_circuit": from .cascade import CircuitState, ProviderStatus + provider.circuit_state = CircuitState.CLOSED provider.circuit_opened_at = None provider.half_open_calls = 0 provider.metrics.consecutive_failures = 0 provider.status = ProviderStatus.HEALTHY return {"message": f"Circuit breaker reset for {provider_name}"} - + else: raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}") @@ -142,28 +147,35 @@ async def run_health_check( ) -> dict[str, Any]: """Run health checks on all providers.""" results = [] - + for provider in cascade.providers: # Quick ping to check availability is_healthy = cascade._check_provider_available(provider) - + from .cascade import ProviderStatus + if is_healthy: if provider.status == ProviderStatus.UNHEALTHY: # Reset circuit if it was open but now healthy provider.circuit_state = provider.circuit_state.__class__.CLOSED provider.circuit_opened_at = None - provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED + provider.status = ( + ProviderStatus.HEALTHY + if provider.metrics.error_rate < 0.1 + else ProviderStatus.DEGRADED + ) else: provider.status = ProviderStatus.UNHEALTHY - - results.append({ - "name": provider.name, - "type": provider.type, - "healthy": is_healthy, - "status": provider.status.value, - }) - + + results.append( + { + "name": provider.name, + "type": provider.type, + "healthy": is_healthy, + "status": provider.status.value, + } + ) + return { "checked_at": asyncio.get_event_loop().time(), "providers": results, @@ -177,7 +189,7 @@ async def get_config( ) -> dict[str, Any]: """Get router configuration (without secrets).""" cfg = cascade.config - + return { "timeout_seconds": cfg.timeout_seconds, "max_retries_per_provider": cfg.max_retries_per_provider, diff --git a/src/infrastructure/router/cascade.py b/src/infrastructure/router/cascade.py index a747530..cec0def 100644 --- a/src/infrastructure/router/cascade.py +++ b/src/infrastructure/router/cascade.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) class ProviderStatus(Enum): """Health status of a provider.""" + HEALTHY = "healthy" DEGRADED = "degraded" # Working but slow or occasional errors UNHEALTHY = "unhealthy" # Circuit breaker open @@ -41,22 +42,25 @@ class ProviderStatus(Enum): class CircuitState(Enum): """Circuit breaker state.""" - CLOSED = "closed" # Normal operation - OPEN = "open" # Failing, rejecting requests + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, rejecting requests HALF_OPEN = "half_open" # Testing if recovered class ContentType(Enum): """Type of content in the request.""" + TEXT = "text" - VISION = "vision" # Contains images - AUDIO = "audio" # Contains audio + VISION = "vision" # Contains images + AUDIO = "audio" # Contains audio MULTIMODAL = "multimodal" # Multiple content types @dataclass class ProviderMetrics: """Metrics for a single provider.""" + total_requests: int = 0 successful_requests: int = 0 failed_requests: int = 0 @@ -64,13 +68,13 @@ class ProviderMetrics: last_request_time: Optional[str] = None last_error_time: Optional[str] = None consecutive_failures: int = 0 - + @property def avg_latency_ms(self) -> float: if self.total_requests == 0: return 0.0 return self.total_latency_ms / self.total_requests - + @property def error_rate(self) -> float: if self.total_requests == 0: @@ -81,6 +85,7 @@ class ProviderMetrics: @dataclass class ModelCapability: """Capabilities a model supports.""" + name: str supports_vision: bool = False supports_audio: bool = False @@ -93,6 +98,7 @@ class ModelCapability: @dataclass class Provider: """LLM provider configuration and state.""" + name: str type: str # ollama, openai, anthropic, airllm enabled: bool @@ -101,14 +107,14 @@ class Provider: api_key: Optional[str] = None base_url: Optional[str] = None models: list[dict] = field(default_factory=list) - + # Runtime state status: ProviderStatus = ProviderStatus.HEALTHY metrics: ProviderMetrics = field(default_factory=ProviderMetrics) circuit_state: CircuitState = CircuitState.CLOSED circuit_opened_at: Optional[float] = None half_open_calls: int = 0 - + def get_default_model(self) -> Optional[str]: """Get the default model for this provider.""" for model in self.models: @@ -117,7 +123,7 @@ class Provider: if self.models: return self.models[0]["name"] return None - + def get_model_with_capability(self, capability: str) -> Optional[str]: """Get a model that supports the given capability.""" for model in self.models: @@ -126,7 +132,7 @@ class Provider: return model["name"] # Fall back to default return self.get_default_model() - + def model_has_capability(self, model_name: str, capability: str) -> bool: """Check if a specific model has a capability.""" for model in self.models: @@ -139,6 +145,7 @@ class Provider: @dataclass class RouterConfig: """Cascade router configuration.""" + timeout_seconds: int = 30 max_retries_per_provider: int = 2 retry_delay_seconds: int = 1 @@ -154,22 +161,22 @@ class RouterConfig: class CascadeRouter: """Routes LLM requests with automatic failover. - + Now with multi-modal support: - Automatically detects content type (text, vision, audio) - Selects appropriate models based on capabilities - Falls back through capability-specific model chains - Supports image URLs and base64 encoding - + Usage: router = CascadeRouter() - + # Text request response = await router.complete( messages=[{"role": "user", "content": "Hello"}], model="llama3.2" ) - + # Vision request (automatically detects and selects vision model) response = await router.complete( messages=[{ @@ -179,68 +186,75 @@ class CascadeRouter: }], model="llava:7b" ) - + # Check metrics metrics = router.get_metrics() """ - + def __init__(self, config_path: Optional[Path] = None) -> None: self.config_path = config_path or Path("config/providers.yaml") self.providers: list[Provider] = [] self.config: RouterConfig = RouterConfig() self._load_config() - + # Initialize multi-modal manager if available self._mm_manager: Optional[Any] = None try: from infrastructure.models.multimodal import get_multimodal_manager + self._mm_manager = get_multimodal_manager() except Exception as exc: logger.debug("Multi-modal manager not available: %s", exc) - + logger.info("CascadeRouter initialized with %d providers", len(self.providers)) - + def _load_config(self) -> None: """Load configuration from YAML.""" if not self.config_path.exists(): logger.warning("Config not found: %s, using defaults", self.config_path) return - + try: if yaml is None: raise RuntimeError("PyYAML not installed") - + content = self.config_path.read_text() # Expand environment variables content = self._expand_env_vars(content) data = yaml.safe_load(content) - + # Load cascade settings cascade = data.get("cascade", {}) - + # Load fallback chains fallback_chains = data.get("fallback_chains", {}) - + # Load multi-modal settings multimodal = data.get("multimodal", {}) - + self.config = RouterConfig( timeout_seconds=cascade.get("timeout_seconds", 30), max_retries_per_provider=cascade.get("max_retries_per_provider", 2), retry_delay_seconds=cascade.get("retry_delay_seconds", 1), - circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5), - circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60), - circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2), + circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get( + "failure_threshold", 5 + ), + circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get( + "recovery_timeout", 60 + ), + circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get( + "half_open_max_calls", 2 + ), auto_pull_models=multimodal.get("auto_pull", True), fallback_chains=fallback_chains, ) - + # Load providers for p_data in data.get("providers", []): # Skip disabled providers if not p_data.get("enabled", False): continue - + provider = Provider( name=p_data["name"], type=p_data["type"], @@ -251,30 +265,34 @@ class CascadeRouter: base_url=p_data.get("base_url"), models=p_data.get("models", []), ) - + # Check if provider is actually available if self._check_provider_available(provider): self.providers.append(provider) else: logger.warning("Provider %s not available, skipping", provider.name) - + # Sort by priority self.providers.sort(key=lambda p: p.priority) - + except Exception as exc: logger.error("Failed to load config: %s", exc) - + def _expand_env_vars(self, content: str) -> str: - """Expand ${VAR} syntax in YAML content.""" + """Expand ${VAR} syntax in YAML content. + + Uses os.environ directly (not settings) because this is a generic + YAML config loader that must expand arbitrary variable references. + """ import os import re - - def replace_var(match): + + def replace_var(match: "re.Match[str]") -> str: var_name = match.group(1) return os.environ.get(var_name, match.group(0)) - + return re.sub(r"\$\{(\w+)\}", replace_var, content) - + def _check_provider_available(self, provider: Provider) -> bool: """Check if a provider is actually available.""" if provider.type == "ollama": @@ -288,48 +306,49 @@ class CascadeRouter: return response.status_code == 200 except Exception: return False - + elif provider.type == "airllm": # Check if airllm is installed try: import airllm + return True except ImportError: return False - + elif provider.type in ("openai", "anthropic", "grok"): # Check if API key is set return provider.api_key is not None and provider.api_key != "" return True - + def _detect_content_type(self, messages: list[dict]) -> ContentType: """Detect the type of content in the messages. - + Checks for images, audio, etc. in the message content. """ has_image = False has_audio = False - + for msg in messages: content = msg.get("content", "") - + # Check for image URLs/paths if msg.get("images"): has_image = True - + # Check for image URLs in content if isinstance(content, str): - image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp') + image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp") if any(ext in content.lower() for ext in image_extensions): has_image = True if content.startswith("data:image/"): has_image = True - + # Check for audio if msg.get("audio"): has_audio = True - + # Check for multimodal content structure if isinstance(content, list): for item in content: @@ -338,7 +357,7 @@ class CascadeRouter: has_image = True elif item.get("type") == "audio": has_audio = True - + if has_image and has_audio: return ContentType.MULTIMODAL elif has_image: @@ -346,12 +365,9 @@ class CascadeRouter: elif has_audio: return ContentType.AUDIO return ContentType.TEXT - + def _get_fallback_model( - self, - provider: Provider, - original_model: str, - content_type: ContentType + self, provider: Provider, original_model: str, content_type: ContentType ) -> Optional[str]: """Get a fallback model for the given content type.""" # Map content type to capability @@ -360,24 +376,24 @@ class CascadeRouter: ContentType.AUDIO: "audio", ContentType.MULTIMODAL: "vision", # Vision models often do both } - + capability = capability_map.get(content_type) if not capability: return None - + # Check provider's models for capability fallback_model = provider.get_model_with_capability(capability) if fallback_model and fallback_model != original_model: return fallback_model - + # Use fallback chains from config fallback_chain = self.config.fallback_chains.get(capability, []) for model_name in fallback_chain: if provider.model_has_capability(model_name, capability): return model_name - + return None - + async def complete( self, messages: list[dict], @@ -386,21 +402,21 @@ class CascadeRouter: max_tokens: Optional[int] = None, ) -> dict: """Complete a chat conversation with automatic failover. - + Multi-modal support: - Automatically detects if messages contain images - Falls back to vision-capable models when needed - Supports image URLs, paths, and base64 encoding - + Args: messages: List of message dicts with role and content model: Preferred model (tries this first, then provider defaults) temperature: Sampling temperature max_tokens: Maximum tokens to generate - + Returns: Dict with content, provider_used, and metrics - + Raises: RuntimeError: If all providers fail """ @@ -408,15 +424,15 @@ class CascadeRouter: content_type = self._detect_content_type(messages) if content_type != ContentType.TEXT: logger.debug("Detected %s content, selecting appropriate model", content_type.value) - + errors = [] - + for provider in self.providers: # Skip disabled providers if not provider.enabled: logger.debug("Skipping %s (disabled)", provider.name) continue - + # Skip unhealthy providers (circuit breaker) if provider.status == ProviderStatus.UNHEALTHY: # Check if circuit breaker can close @@ -427,16 +443,16 @@ class CascadeRouter: else: logger.debug("Skipping %s (circuit open)", provider.name) continue - + # Determine which model to use selected_model = model or provider.get_default_model() is_fallback_model = False - + # For non-text content, check if model supports it if content_type != ContentType.TEXT and selected_model: if provider.type == "ollama" and self._mm_manager: from infrastructure.models.multimodal import ModelCapability - + # Check if selected model supports the required capability if content_type == ContentType.VISION: supports = self._mm_manager.model_supports( @@ -450,16 +466,17 @@ class CascadeRouter: if fallback: logger.info( "Model %s doesn't support vision, falling back to %s", - selected_model, fallback + selected_model, + fallback, ) selected_model = fallback is_fallback_model = True else: logger.warning( "No vision-capable model found on %s, trying anyway", - provider.name + provider.name, ) - + # Try this provider for attempt in range(self.config.max_retries_per_provider): try: @@ -471,34 +488,35 @@ class CascadeRouter: max_tokens=max_tokens, content_type=content_type, ) - + # Success! Update metrics and return self._record_success(provider, result.get("latency_ms", 0)) return { "content": result["content"], "provider": provider.name, - "model": result.get("model", selected_model or provider.get_default_model()), + "model": result.get( + "model", selected_model or provider.get_default_model() + ), "latency_ms": result.get("latency_ms", 0), "is_fallback_model": is_fallback_model, } - + except Exception as exc: error_msg = str(exc) logger.warning( - "Provider %s attempt %d failed: %s", - provider.name, attempt + 1, error_msg + "Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg ) errors.append(f"{provider.name}: {error_msg}") - + if attempt < self.config.max_retries_per_provider - 1: await asyncio.sleep(self.config.retry_delay_seconds) - + # All retries failed for this provider self._record_failure(provider) - + # All providers failed raise RuntimeError(f"All providers failed: {'; '.join(errors)}") - + async def _try_provider( self, provider: Provider, @@ -510,7 +528,7 @@ class CascadeRouter: ) -> dict: """Try a single provider request.""" start_time = time.time() - + if provider.type == "ollama": result = await self._call_ollama( provider=provider, @@ -545,12 +563,12 @@ class CascadeRouter: ) else: raise ValueError(f"Unknown provider type: {provider.type}") - + latency_ms = (time.time() - start_time) * 1000 result["latency_ms"] = latency_ms - + return result - + async def _call_ollama( self, provider: Provider, @@ -561,12 +579,12 @@ class CascadeRouter: ) -> dict: """Call Ollama API with multi-modal support.""" import aiohttp - + url = f"{provider.url}/api/chat" - + # Transform messages for Ollama format (including images) transformed_messages = self._transform_messages_for_ollama(messages) - + payload = { "model": model, "messages": transformed_messages, @@ -575,31 +593,31 @@ class CascadeRouter: "temperature": temperature, }, } - + timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) - + async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, json=payload) as response: if response.status != 200: text = await response.text() raise RuntimeError(f"Ollama error {response.status}: {text}") - + data = await response.json() return { "content": data["message"]["content"], "model": model, } - + def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: """Transform messages to Ollama format, handling images.""" transformed = [] - + for msg in messages: new_msg = { "role": msg.get("role", "user"), "content": msg.get("content", ""), } - + # Handle images images = msg.get("images", []) if images: @@ -620,11 +638,11 @@ class CascadeRouter: new_msg["images"].append(img_data) except Exception as exc: logger.error("Failed to read image %s: %s", img, exc) - + transformed.append(new_msg) - + return transformed - + async def _call_openai( self, provider: Provider, @@ -635,13 +653,13 @@ class CascadeRouter: ) -> dict: """Call OpenAI API.""" import openai - + client = openai.AsyncOpenAI( api_key=provider.api_key, base_url=provider.base_url, timeout=self.config.timeout_seconds, ) - + kwargs = { "model": model, "messages": messages, @@ -649,14 +667,14 @@ class CascadeRouter: } if max_tokens: kwargs["max_tokens"] = max_tokens - + response = await client.chat.completions.create(**kwargs) - + return { "content": response.choices[0].message.content, "model": response.model, } - + async def _call_anthropic( self, provider: Provider, @@ -667,12 +685,12 @@ class CascadeRouter: ) -> dict: """Call Anthropic API.""" import anthropic - + client = anthropic.AsyncAnthropic( api_key=provider.api_key, timeout=self.config.timeout_seconds, ) - + # Convert messages to Anthropic format system_msg = None conversation = [] @@ -680,11 +698,13 @@ class CascadeRouter: if msg["role"] == "system": system_msg = msg["content"] else: - conversation.append({ - "role": msg["role"], - "content": msg["content"], - }) - + conversation.append( + { + "role": msg["role"], + "content": msg["content"], + } + ) + kwargs = { "model": model, "messages": conversation, @@ -693,9 +713,9 @@ class CascadeRouter: } if system_msg: kwargs["system"] = system_msg - + response = await client.messages.create(**kwargs) - + return { "content": response.content[0].text, "model": response.model, @@ -733,7 +753,7 @@ class CascadeRouter: "content": response.choices[0].message.content, "model": response.model, } - + def _record_success(self, provider: Provider, latency_ms: float) -> None: """Record a successful request.""" provider.metrics.total_requests += 1 @@ -741,50 +761,50 @@ class CascadeRouter: provider.metrics.total_latency_ms += latency_ms provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat() provider.metrics.consecutive_failures = 0 - + # Close circuit breaker if half-open if provider.circuit_state == CircuitState.HALF_OPEN: provider.half_open_calls += 1 if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls: self._close_circuit(provider) - + # Update status based on error rate if provider.metrics.error_rate < 0.1: provider.status = ProviderStatus.HEALTHY elif provider.metrics.error_rate < 0.3: provider.status = ProviderStatus.DEGRADED - + def _record_failure(self, provider: Provider) -> None: """Record a failed request.""" provider.metrics.total_requests += 1 provider.metrics.failed_requests += 1 provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat() provider.metrics.consecutive_failures += 1 - + # Check if we should open circuit breaker if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold: self._open_circuit(provider) - + # Update status if provider.metrics.error_rate > 0.3: provider.status = ProviderStatus.DEGRADED if provider.metrics.error_rate > 0.5: provider.status = ProviderStatus.UNHEALTHY - + def _open_circuit(self, provider: Provider) -> None: """Open the circuit breaker for a provider.""" provider.circuit_state = CircuitState.OPEN provider.circuit_opened_at = time.time() provider.status = ProviderStatus.UNHEALTHY logger.warning("Circuit breaker OPEN for %s", provider.name) - + def _can_close_circuit(self, provider: Provider) -> bool: """Check if circuit breaker can transition to half-open.""" if provider.circuit_opened_at is None: return False elapsed = time.time() - provider.circuit_opened_at return elapsed >= self.config.circuit_breaker_recovery_timeout - + def _close_circuit(self, provider: Provider) -> None: """Close the circuit breaker (provider healthy again).""" provider.circuit_state = CircuitState.CLOSED @@ -793,7 +813,7 @@ class CascadeRouter: provider.metrics.consecutive_failures = 0 provider.status = ProviderStatus.HEALTHY logger.info("Circuit breaker CLOSED for %s", provider.name) - + def get_metrics(self) -> dict: """Get metrics for all providers.""" return { @@ -814,16 +834,20 @@ class CascadeRouter: for p in self.providers ] } - + def get_status(self) -> dict: """Get current router status.""" healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY) - + return { "total_providers": len(self.providers), "healthy_providers": healthy, - "degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED), - "unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY), + "degraded_providers": sum( + 1 for p in self.providers if p.status == ProviderStatus.DEGRADED + ), + "unhealthy_providers": sum( + 1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY + ), "providers": [ { "name": p.name, @@ -835,7 +859,7 @@ class CascadeRouter: for p in self.providers ], } - + async def generate_with_image( self, prompt: str, @@ -844,21 +868,23 @@ class CascadeRouter: temperature: float = 0.7, ) -> dict: """Convenience method for vision requests. - + Args: prompt: Text prompt about the image image_path: Path to image file model: Vision-capable model (auto-selected if not provided) temperature: Sampling temperature - + Returns: Response dict with content and metadata """ - messages = [{ - "role": "user", - "content": prompt, - "images": [image_path], - }] + messages = [ + { + "role": "user", + "content": prompt, + "images": [image_path], + } + ] return await self.complete( messages=messages, model=model, diff --git a/src/infrastructure/ws_manager/handler.py b/src/infrastructure/ws_manager/handler.py index fff894a..205dd85 100644 --- a/src/infrastructure/ws_manager/handler.py +++ b/src/infrastructure/ws_manager/handler.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) @dataclass class WSEvent: """A WebSocket event to broadcast to connected clients.""" + event: str data: dict timestamp: str @@ -93,28 +94,42 @@ class WebSocketManager: await self.broadcast("agent_left", {"agent_id": agent_id, "name": name}) async def broadcast_task_posted(self, task_id: str, description: str) -> None: - await self.broadcast("task_posted", { - "task_id": task_id, "description": description, - }) + await self.broadcast( + "task_posted", + { + "task_id": task_id, + "description": description, + }, + ) - async def broadcast_bid_submitted( - self, task_id: str, agent_id: str, bid_sats: int - ) -> None: - await self.broadcast("bid_submitted", { - "task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats, - }) + async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None: + await self.broadcast( + "bid_submitted", + { + "task_id": task_id, + "agent_id": agent_id, + "bid_sats": bid_sats, + }, + ) async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None: - await self.broadcast("task_assigned", { - "task_id": task_id, "agent_id": agent_id, - }) + await self.broadcast( + "task_assigned", + { + "task_id": task_id, + "agent_id": agent_id, + }, + ) - async def broadcast_task_completed( - self, task_id: str, agent_id: str, result: str - ) -> None: - await self.broadcast("task_completed", { - "task_id": task_id, "agent_id": agent_id, "result": result[:200], - }) + async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None: + await self.broadcast( + "task_completed", + { + "task_id": task_id, + "agent_id": agent_id, + "result": result[:200], + }, + ) @property def connection_count(self) -> int: @@ -122,28 +137,28 @@ class WebSocketManager: async def broadcast_json(self, data: dict) -> int: """Broadcast raw JSON data to all connected clients. - + Args: data: Dictionary to send as JSON - + Returns: Number of clients notified """ message = json.dumps(data) disconnected = [] count = 0 - + for ws in self._connections: try: await ws.send_text(message) count += 1 except Exception: disconnected.append(ws) - + # Clean up dead connections for ws in disconnected: self.disconnect(ws) - + return count @property diff --git a/src/integrations/chat_bridge/base.py b/src/integrations/chat_bridge/base.py index 6af6607..17be5fd 100644 --- a/src/integrations/chat_bridge/base.py +++ b/src/integrations/chat_bridge/base.py @@ -21,6 +21,7 @@ from typing import Any, Optional class PlatformState(Enum): """Lifecycle state of a chat platform connection.""" + DISCONNECTED = auto() CONNECTING = auto() CONNECTED = auto() @@ -30,13 +31,12 @@ class PlatformState(Enum): @dataclass class ChatMessage: """Vendor-agnostic representation of a chat message.""" + content: str author: str channel_id: str platform: str - timestamp: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat() - ) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) message_id: Optional[str] = None thread_id: Optional[str] = None attachments: list[str] = field(default_factory=list) @@ -46,13 +46,12 @@ class ChatMessage: @dataclass class ChatThread: """Vendor-agnostic representation of a conversation thread.""" + thread_id: str title: str channel_id: str platform: str - created_at: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat() - ) + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) archived: bool = False message_count: int = 0 metadata: dict[str, Any] = field(default_factory=dict) @@ -61,6 +60,7 @@ class ChatThread: @dataclass class InviteInfo: """Parsed invite extracted from an image or text.""" + url: str code: str platform: str @@ -71,6 +71,7 @@ class InviteInfo: @dataclass class PlatformStatus: """Current status of a chat platform connection.""" + platform: str state: PlatformState token_set: bool diff --git a/src/integrations/chat_bridge/invite_parser.py b/src/integrations/chat_bridge/invite_parser.py index df64c44..0e6f315 100644 --- a/src/integrations/chat_bridge/invite_parser.py +++ b/src/integrations/chat_bridge/invite_parser.py @@ -115,7 +115,9 @@ class InviteParser: """Strategy 2: Use Ollama vision model for local OCR.""" try: import base64 + import httpx + from config import settings except ImportError: logger.debug("httpx not available for Ollama vision.") diff --git a/src/integrations/chat_bridge/vendors/discord.py b/src/integrations/chat_bridge/vendors/discord.py index a77b4cd..deeac5a 100644 --- a/src/integrations/chat_bridge/vendors/discord.py +++ b/src/integrations/chat_bridge/vendors/discord.py @@ -90,10 +90,7 @@ class DiscordVendor(ChatPlatform): try: import discord except ImportError: - logger.error( - "discord.py is not installed. " - 'Run: pip install ".[discord]"' - ) + logger.error("discord.py is not installed. " 'Run: pip install ".[discord]"') return False try: @@ -267,6 +264,7 @@ class DiscordVendor(ChatPlatform): try: from config import settings + return settings.discord_token or None except Exception: return None @@ -363,9 +361,7 @@ class DiscordVendor(ChatPlatform): # Show typing indicator while the agent processes async with target.typing(): run = await asyncio.wait_for( - asyncio.to_thread( - agent.run, content, stream=False, session_id=session_id - ), + asyncio.to_thread(agent.run, content, stream=False, session_id=session_id), timeout=300, ) response = run.content if hasattr(run, "content") else str(run) @@ -374,7 +370,9 @@ class DiscordVendor(ChatPlatform): response = "Sorry, that took too long. Please try a simpler request." except Exception as exc: logger.error("Discord: agent.run() failed: %s", exc) - response = "I'm having trouble reaching my language model right now. Please try again shortly." + response = ( + "I'm having trouble reaching my language model right now. Please try again shortly." + ) # Strip hallucinated tool-call JSON and chain-of-thought narration from timmy.session import _clean_response @@ -408,6 +406,7 @@ class DiscordVendor(ChatPlatform): # Create a thread from this message from config import settings + thread_name = f"{settings.agent_name} | {message.author.display_name}" thread = await message.create_thread( name=thread_name[:100], diff --git a/src/integrations/paperclip/models.py b/src/integrations/paperclip/models.py index c0f0c45..c8faaf0 100644 --- a/src/integrations/paperclip/models.py +++ b/src/integrations/paperclip/models.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field - # ── Inbound: Paperclip → Timmy ────────────────────────────────────────────── diff --git a/src/integrations/paperclip/task_runner.py b/src/integrations/paperclip/task_runner.py index c4a9d3d..d30b0e9 100644 --- a/src/integrations/paperclip/task_runner.py +++ b/src/integrations/paperclip/task_runner.py @@ -20,7 +20,8 @@ import logging from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable from config import settings -from integrations.paperclip.bridge import PaperclipBridge, bridge as default_bridge +from integrations.paperclip.bridge import PaperclipBridge +from integrations.paperclip.bridge import bridge as default_bridge from integrations.paperclip.models import PaperclipIssue logger = logging.getLogger(__name__) @@ -30,9 +31,8 @@ logger = logging.getLogger(__name__) class Orchestrator(Protocol): """Anything with an ``execute_task`` matching Timmy's orchestrator.""" - async def execute_task( - self, task_id: str, description: str, context: dict - ) -> Any: ... + async def execute_task(self, task_id: str, description: str, context: dict) -> Any: + ... def _wrap_orchestrator(orch: Orchestrator) -> Callable: @@ -125,7 +125,9 @@ class TaskRunner: # Mark the issue as done return await self.bridge.close_issue(issue.id, comment=None) - async def create_follow_up(self, original: PaperclipIssue, result: str) -> Optional[PaperclipIssue]: + async def create_follow_up( + self, original: PaperclipIssue, result: str + ) -> Optional[PaperclipIssue]: """Create a recursive follow-up task for Timmy. Timmy muses about task automation and writes a follow-up issue diff --git a/src/integrations/shortcuts/siri.py b/src/integrations/shortcuts/siri.py index ff7d444..5c10ce1 100644 --- a/src/integrations/shortcuts/siri.py +++ b/src/integrations/shortcuts/siri.py @@ -22,6 +22,7 @@ logger = logging.getLogger(__name__) @dataclass class ShortcutAction: """Describes a Siri Shortcut action for the setup guide.""" + name: str endpoint: str method: str diff --git a/src/integrations/telegram_bot/bot.py b/src/integrations/telegram_bot/bot.py index 7412531..b1f9b5e 100644 --- a/src/integrations/telegram_bot/bot.py +++ b/src/integrations/telegram_bot/bot.py @@ -54,6 +54,7 @@ class TelegramBot: return from_file try: from config import settings + return settings.telegram_token or None except Exception: return None @@ -94,10 +95,7 @@ class TelegramBot: filters, ) except ImportError: - logger.error( - "python-telegram-bot is not installed. " - 'Run: pip install ".[telegram]"' - ) + logger.error("python-telegram-bot is not installed. " 'Run: pip install ".[telegram]"') return False try: @@ -149,6 +147,7 @@ class TelegramBot: user_text = update.message.text try: from timmy.agent import create_timmy + agent = create_timmy() run = await asyncio.to_thread(agent.run, user_text, stream=False) response = run.content if hasattr(run, "content") else str(run) diff --git a/src/integrations/voice/nlu.py b/src/integrations/voice/nlu.py index 2e9b535..4d3c9da 100644 --- a/src/integrations/voice/nlu.py +++ b/src/integrations/voice/nlu.py @@ -15,8 +15,8 @@ Intents: - unknown: Unrecognized intent """ -import re import logging +import re from dataclasses import dataclass from typing import Optional @@ -35,47 +35,68 @@ class Intent: _PATTERNS: list[tuple[str, re.Pattern, float]] = [ # Status queries - ("status", re.compile( - r"\b(status|health|how are you|are you (running|online|alive)|check)\b", - re.IGNORECASE, - ), 0.9), - + ( + "status", + re.compile( + r"\b(status|health|how are you|are you (running|online|alive)|check)\b", + re.IGNORECASE, + ), + 0.9, + ), # Swarm commands - ("swarm", re.compile( - r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b", - re.IGNORECASE, - ), 0.85), - + ( + "swarm", + re.compile( + r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b", + re.IGNORECASE, + ), + 0.85, + ), # Task commands - ("task", re.compile( - r"\b(task|assign|create task|new task|post task|bid)\b", - re.IGNORECASE, - ), 0.85), - + ( + "task", + re.compile( + r"\b(task|assign|create task|new task|post task|bid)\b", + re.IGNORECASE, + ), + 0.85, + ), # Help - ("help", re.compile( - r"\b(help|commands?|what can you do|capabilities)\b", - re.IGNORECASE, - ), 0.9), - + ( + "help", + re.compile( + r"\b(help|commands?|what can you do|capabilities)\b", + re.IGNORECASE, + ), + 0.9, + ), # Voice settings - ("voice", re.compile( - r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b", - re.IGNORECASE, - ), 0.85), - + ( + "voice", + re.compile( + r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b", + re.IGNORECASE, + ), + 0.85, + ), # Code modification / self-modify - ("code", re.compile( - r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b" - r"|\bself[- ]?modify\b" - r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b", - re.IGNORECASE, - ), 0.9), + ( + "code", + re.compile( + r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b" + r"|\bself[- ]?modify\b" + r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b", + re.IGNORECASE, + ), + 0.9, + ), ] # Keywords for entity extraction _ENTITY_PATTERNS = { - "agent_name": re.compile(r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE), + "agent_name": re.compile( + r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE + ), "task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE), "number": re.compile(r"\b(\d+)\b"), "target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE), diff --git a/src/spark/advisor.py b/src/spark/advisor.py index a0bc465..d2723b7 100644 --- a/src/spark/advisor.py +++ b/src/spark/advisor.py @@ -17,8 +17,8 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional -from spark import memory as spark_memory from spark import eidos as spark_eidos +from spark import memory as spark_memory logger = logging.getLogger(__name__) @@ -29,10 +29,11 @@ _MIN_EVENTS = 3 @dataclass class Advisory: """A single ranked recommendation.""" - category: str # agent_performance, bid_optimization, etc. - priority: float # 0.0–1.0 (higher = more urgent) - title: str # Short headline - detail: str # Longer explanation + + category: str # agent_performance, bid_optimization, etc. + priority: float # 0.0–1.0 (higher = more urgent) + title: str # Short headline + detail: str # Longer explanation suggested_action: str # What to do about it subject: Optional[str] = None # agent_id or None for system-level evidence_count: int = 0 # Number of supporting events @@ -47,15 +48,17 @@ def generate_advisories() -> list[Advisory]: event_count = spark_memory.count_events() if event_count < _MIN_EVENTS: - advisories.append(Advisory( - category="system_health", - priority=0.3, - title="Insufficient data", - detail=f"Only {event_count} events captured. " - f"Spark needs at least {_MIN_EVENTS} events to generate insights.", - suggested_action="Run more swarm tasks to build intelligence.", - evidence_count=event_count, - )) + advisories.append( + Advisory( + category="system_health", + priority=0.3, + title="Insufficient data", + detail=f"Only {event_count} events captured. " + f"Spark needs at least {_MIN_EVENTS} events to generate insights.", + suggested_action="Run more swarm tasks to build intelligence.", + evidence_count=event_count, + ) + ) return advisories advisories.extend(_check_failure_patterns()) @@ -82,18 +85,20 @@ def _check_failure_patterns() -> list[Advisory]: for aid, count in agent_failures.items(): if count >= 2: - results.append(Advisory( - category="failure_prevention", - priority=min(1.0, 0.5 + count * 0.15), - title=f"Agent {aid[:8]} has {count} failures", - detail=f"Agent {aid[:8]}... has failed {count} recent tasks. " - f"This pattern may indicate a capability mismatch or " - f"configuration issue.", - suggested_action=f"Review task types assigned to {aid[:8]}... " - f"and consider adjusting routing preferences.", - subject=aid, - evidence_count=count, - )) + results.append( + Advisory( + category="failure_prevention", + priority=min(1.0, 0.5 + count * 0.15), + title=f"Agent {aid[:8]} has {count} failures", + detail=f"Agent {aid[:8]}... has failed {count} recent tasks. " + f"This pattern may indicate a capability mismatch or " + f"configuration issue.", + suggested_action=f"Review task types assigned to {aid[:8]}... " + f"and consider adjusting routing preferences.", + subject=aid, + evidence_count=count, + ) + ) return results @@ -128,27 +133,31 @@ def _check_agent_performance() -> list[Advisory]: rate = wins / total if rate >= 0.8 and total >= 3: - results.append(Advisory( - category="agent_performance", - priority=0.6, - title=f"Agent {aid[:8]} excels ({rate:.0%} success)", - detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks " - f"successfully. Consider routing more tasks to this agent.", - suggested_action="Increase task routing weight for this agent.", - subject=aid, - evidence_count=total, - )) + results.append( + Advisory( + category="agent_performance", + priority=0.6, + title=f"Agent {aid[:8]} excels ({rate:.0%} success)", + detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks " + f"successfully. Consider routing more tasks to this agent.", + suggested_action="Increase task routing weight for this agent.", + subject=aid, + evidence_count=total, + ) + ) elif rate <= 0.3 and total >= 3: - results.append(Advisory( - category="agent_performance", - priority=0.75, - title=f"Agent {aid[:8]} struggling ({rate:.0%} success)", - detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. " - f"May need different task types or capability updates.", - suggested_action="Review this agent's capabilities and assigned task types.", - subject=aid, - evidence_count=total, - )) + results.append( + Advisory( + category="agent_performance", + priority=0.75, + title=f"Agent {aid[:8]} struggling ({rate:.0%} success)", + detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. " + f"May need different task types or capability updates.", + suggested_action="Review this agent's capabilities and assigned task types.", + subject=aid, + evidence_count=total, + ) + ) return results @@ -181,27 +190,31 @@ def _check_bid_patterns() -> list[Advisory]: spread = max_bid - min_bid if spread > avg_bid * 1.5: - results.append(Advisory( - category="bid_optimization", - priority=0.5, - title=f"Wide bid spread ({min_bid}–{max_bid} sats)", - detail=f"Bids range from {min_bid} to {max_bid} sats " - f"(avg {avg_bid:.0f}). Large spread may indicate " - f"inefficient auction dynamics.", - suggested_action="Review agent bid strategies for consistency.", - evidence_count=len(bid_amounts), - )) + results.append( + Advisory( + category="bid_optimization", + priority=0.5, + title=f"Wide bid spread ({min_bid}–{max_bid} sats)", + detail=f"Bids range from {min_bid} to {max_bid} sats " + f"(avg {avg_bid:.0f}). Large spread may indicate " + f"inefficient auction dynamics.", + suggested_action="Review agent bid strategies for consistency.", + evidence_count=len(bid_amounts), + ) + ) if avg_bid > 70: - results.append(Advisory( - category="bid_optimization", - priority=0.45, - title=f"High average bid ({avg_bid:.0f} sats)", - detail=f"The swarm average bid is {avg_bid:.0f} sats across " - f"{len(bid_amounts)} bids. This may be above optimal.", - suggested_action="Consider adjusting base bid rates for persona agents.", - evidence_count=len(bid_amounts), - )) + results.append( + Advisory( + category="bid_optimization", + priority=0.45, + title=f"High average bid ({avg_bid:.0f} sats)", + detail=f"The swarm average bid is {avg_bid:.0f} sats across " + f"{len(bid_amounts)} bids. This may be above optimal.", + suggested_action="Consider adjusting base bid rates for persona agents.", + evidence_count=len(bid_amounts), + ) + ) return results @@ -216,27 +229,31 @@ def _check_prediction_accuracy() -> list[Advisory]: avg = stats["avg_accuracy"] if avg < 0.4: - results.append(Advisory( - category="system_health", - priority=0.65, - title=f"Low prediction accuracy ({avg:.0%})", - detail=f"EIDOS predictions have averaged {avg:.0%} accuracy " - f"over {stats['evaluated']} evaluations. The learning " - f"model needs more data or the swarm behaviour is changing.", - suggested_action="Continue running tasks; accuracy should improve " - "as the model accumulates more training data.", - evidence_count=stats["evaluated"], - )) + results.append( + Advisory( + category="system_health", + priority=0.65, + title=f"Low prediction accuracy ({avg:.0%})", + detail=f"EIDOS predictions have averaged {avg:.0%} accuracy " + f"over {stats['evaluated']} evaluations. The learning " + f"model needs more data or the swarm behaviour is changing.", + suggested_action="Continue running tasks; accuracy should improve " + "as the model accumulates more training data.", + evidence_count=stats["evaluated"], + ) + ) elif avg >= 0.75: - results.append(Advisory( - category="system_health", - priority=0.3, - title=f"Strong prediction accuracy ({avg:.0%})", - detail=f"EIDOS predictions are performing well at {avg:.0%} " - f"average accuracy over {stats['evaluated']} evaluations.", - suggested_action="No action needed. Spark intelligence is learning effectively.", - evidence_count=stats["evaluated"], - )) + results.append( + Advisory( + category="system_health", + priority=0.3, + title=f"Strong prediction accuracy ({avg:.0%})", + detail=f"EIDOS predictions are performing well at {avg:.0%} " + f"average accuracy over {stats['evaluated']} evaluations.", + suggested_action="No action needed. Spark intelligence is learning effectively.", + evidence_count=stats["evaluated"], + ) + ) return results @@ -247,14 +264,16 @@ def _check_system_activity() -> list[Advisory]: recent = spark_memory.get_events(limit=5) if not recent: - results.append(Advisory( - category="system_health", - priority=0.4, - title="No swarm activity detected", - detail="Spark has not captured any events. " - "The swarm may be idle or Spark event capture is not active.", - suggested_action="Post a task to the swarm to activate the pipeline.", - )) + results.append( + Advisory( + category="system_health", + priority=0.4, + title="No swarm activity detected", + detail="Spark has not captured any events. " + "The swarm may be idle or Spark event capture is not active.", + suggested_action="Post a task to the swarm to activate the pipeline.", + ) + ) return results # Check event type distribution @@ -265,14 +284,16 @@ def _check_system_activity() -> list[Advisory]: if "task_completed" not in type_counts and "task_failed" not in type_counts: if type_counts.get("task_posted", 0) > 3: - results.append(Advisory( - category="system_health", - priority=0.6, - title="Tasks posted but none completing", - detail=f"{type_counts.get('task_posted', 0)} tasks posted " - f"but no completions or failures recorded.", - suggested_action="Check agent availability and auction configuration.", - evidence_count=type_counts.get("task_posted", 0), - )) + results.append( + Advisory( + category="system_health", + priority=0.6, + title="Tasks posted but none completing", + detail=f"{type_counts.get('task_posted', 0)} tasks posted " + f"but no completions or failures recorded.", + suggested_action="Check agent availability and auction configuration.", + evidence_count=type_counts.get("task_posted", 0), + ) + ) return results diff --git a/src/spark/eidos.py b/src/spark/eidos.py index 0377d40..8367658 100644 --- a/src/spark/eidos.py +++ b/src/spark/eidos.py @@ -29,12 +29,13 @@ DB_PATH = Path("data/spark.db") @dataclass class Prediction: """A prediction made by the EIDOS loop.""" + id: str task_id: str - prediction_type: str # outcome, best_agent, bid_range - predicted_value: str # JSON-encoded prediction - actual_value: Optional[str] # JSON-encoded actual (filled on evaluation) - accuracy: Optional[float] # 0.0–1.0 (filled on evaluation) + prediction_type: str # outcome, best_agent, bid_range + predicted_value: str # JSON-encoded prediction + actual_value: Optional[str] # JSON-encoded actual (filled on evaluation) + accuracy: Optional[float] # 0.0–1.0 (filled on evaluation) created_at: str evaluated_at: Optional[str] @@ -57,18 +58,15 @@ def _get_conn() -> sqlite3.Connection: ) """ ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)") conn.commit() return conn # ── Prediction phase ──────────────────────────────────────────────────────── + def predict_task_outcome( task_id: str, task_description: str, @@ -104,12 +102,8 @@ def predict_task_outcome( if best_agent: prediction["likely_winner"] = best_agent - prediction["success_probability"] = round( - min(1.0, 0.5 + best_rate * 0.4), 2 - ) - prediction["reasoning"] = ( - f"agent {best_agent[:8]} has {best_rate:.0%} success rate" - ) + prediction["success_probability"] = round(min(1.0, 0.5 + best_rate * 0.4), 2) + prediction["reasoning"] = f"agent {best_agent[:8]} has {best_rate:.0%} success rate" # Adjust bid range from history all_bids = [] @@ -144,6 +138,7 @@ def predict_task_outcome( # ── Evaluation phase ──────────────────────────────────────────────────────── + def evaluate_prediction( task_id: str, actual_winner: Optional[str], @@ -242,6 +237,7 @@ def _compute_accuracy(predicted: dict, actual: dict) -> float: # ── Query helpers ────────────────────────────────────────────────────────── + def get_predictions( task_id: Optional[str] = None, evaluated_only: bool = False, diff --git a/src/spark/engine.py b/src/spark/engine.py index 83d314e..695ff0c 100644 --- a/src/spark/engine.py +++ b/src/spark/engine.py @@ -76,7 +76,10 @@ class SparkEngine: return event_id def on_bid_submitted( - self, task_id: str, agent_id: str, bid_sats: int, + self, + task_id: str, + agent_id: str, + bid_sats: int, ) -> Optional[str]: """Capture a bid event.""" if not self._enabled: @@ -90,12 +93,13 @@ class SparkEngine: data=json.dumps({"bid_sats": bid_sats}), ) - logger.debug("Spark: captured bid %s→%s (%d sats)", - agent_id[:8], task_id[:8], bid_sats) + logger.debug("Spark: captured bid %s→%s (%d sats)", agent_id[:8], task_id[:8], bid_sats) return event_id def on_task_assigned( - self, task_id: str, agent_id: str, + self, + task_id: str, + agent_id: str, ) -> Optional[str]: """Capture a task-assigned event.""" if not self._enabled: @@ -108,8 +112,7 @@ class SparkEngine: task_id=task_id, ) - logger.debug("Spark: captured assignment %s→%s", - task_id[:8], agent_id[:8]) + logger.debug("Spark: captured assignment %s→%s", task_id[:8], agent_id[:8]) return event_id def on_task_completed( @@ -128,10 +131,12 @@ class SparkEngine: description=f"Task completed by {agent_id[:8]}", agent_id=agent_id, task_id=task_id, - data=json.dumps({ - "result_length": len(result), - "winning_bid": winning_bid, - }), + data=json.dumps( + { + "result_length": len(result), + "winning_bid": winning_bid, + } + ), ) # Evaluate EIDOS prediction @@ -154,8 +159,7 @@ class SparkEngine: # Consolidate memory if enough events for this agent self._maybe_consolidate(agent_id) - logger.debug("Spark: captured completion %s by %s", - task_id[:8], agent_id[:8]) + logger.debug("Spark: captured completion %s by %s", task_id[:8], agent_id[:8]) return event_id def on_task_failed( @@ -186,8 +190,7 @@ class SparkEngine: # Failures always worth consolidating self._maybe_consolidate(agent_id) - logger.debug("Spark: captured failure %s by %s", - task_id[:8], agent_id[:8]) + logger.debug("Spark: captured failure %s by %s", task_id[:8], agent_id[:8]) return event_id def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]: @@ -288,7 +291,7 @@ class SparkEngine: memory_type="pattern", subject=agent_id, content=f"Agent {agent_id[:8]} has a strong track record: " - f"{len(completions)}/{total} tasks completed successfully.", + f"{len(completions)}/{total} tasks completed successfully.", confidence=min(0.95, 0.6 + total * 0.05), source_events=total, ) @@ -297,7 +300,7 @@ class SparkEngine: memory_type="anomaly", subject=agent_id, content=f"Agent {agent_id[:8]} is struggling: only " - f"{len(completions)}/{total} tasks completed.", + f"{len(completions)}/{total} tasks completed.", confidence=min(0.95, 0.6 + total * 0.05), source_events=total, ) @@ -347,6 +350,7 @@ class SparkEngine: def _create_engine() -> SparkEngine: try: from config import settings + return SparkEngine(enabled=settings.spark_enabled) except Exception: return SparkEngine(enabled=True) diff --git a/src/spark/memory.py b/src/spark/memory.py index 238d4f3..b09c3c6 100644 --- a/src/spark/memory.py +++ b/src/spark/memory.py @@ -28,25 +28,27 @@ IMPORTANCE_HIGH = 0.8 @dataclass class SparkEvent: """A single captured swarm event.""" + id: str - event_type: str # task_posted, bid, assignment, completion, failure + event_type: str # task_posted, bid, assignment, completion, failure agent_id: Optional[str] task_id: Optional[str] description: str - data: str # JSON payload - importance: float # 0.0–1.0 + data: str # JSON payload + importance: float # 0.0–1.0 created_at: str @dataclass class SparkMemory: """A consolidated memory distilled from event patterns.""" + id: str - memory_type: str # pattern, insight, anomaly - subject: str # agent_id or "system" - content: str # Human-readable insight - confidence: float # 0.0–1.0 - source_events: int # How many events contributed + memory_type: str # pattern, insight, anomaly + subject: str # agent_id or "system" + content: str # Human-readable insight + confidence: float # 0.0–1.0 + source_events: int # How many events contributed created_at: str expires_at: Optional[str] @@ -83,24 +85,17 @@ def _get_conn() -> sqlite3.Connection: ) """ ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)") conn.commit() return conn # ── Importance scoring ────────────────────────────────────────────────────── + def score_importance(event_type: str, data: dict) -> float: """Compute importance score for an event (0.0–1.0). @@ -132,6 +127,7 @@ def score_importance(event_type: str, data: dict) -> float: # ── Event recording ───────────────────────────────────────────────────────── + def record_event( event_type: str, description: str, @@ -142,6 +138,7 @@ def record_event( ) -> str: """Record a swarm event. Returns the event id.""" import json + event_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() @@ -224,6 +221,7 @@ def count_events(event_type: Optional[str] = None) -> int: # ── Memory consolidation ─────────────────────────────────────────────────── + def store_memory( memory_type: str, subject: str, diff --git a/src/swarm/event_log.py b/src/swarm/event_log.py index 23a60d6..0e92f59 100644 --- a/src/swarm/event_log.py +++ b/src/swarm/event_log.py @@ -73,7 +73,8 @@ def _ensure_db() -> sqlite3.Connection: DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS events ( id TEXT PRIMARY KEY, event_type TEXT NOT NULL, @@ -83,7 +84,8 @@ def _ensure_db() -> sqlite3.Connection: data TEXT DEFAULT '{}', timestamp TEXT NOT NULL ) - """) + """ + ) conn.commit() return conn @@ -119,8 +121,15 @@ def log_event( db.execute( "INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) " "VALUES (?, ?, ?, ?, ?, ?, ?)", - (entry.id, event_type.value, source, task_id, agent_id, - json.dumps(data or {}), entry.timestamp), + ( + entry.id, + event_type.value, + source, + task_id, + agent_id, + json.dumps(data or {}), + entry.timestamp, + ), ) db.commit() finally: @@ -131,6 +140,7 @@ def log_event( # Broadcast to WebSocket clients (non-blocking) try: from infrastructure.events.broadcaster import event_broadcaster + event_broadcaster.broadcast_sync(entry) except Exception: pass @@ -157,13 +167,15 @@ def get_task_events(task_id: str, limit: int = 50) -> list[EventLogEntry]: et = EventType(r["event_type"]) except ValueError: et = EventType.SYSTEM_INFO - entries.append(EventLogEntry( - id=r["id"], - event_type=et, - source=r["source"], - timestamp=r["timestamp"], - data=json.loads(r["data"]) if r["data"] else {}, - task_id=r["task_id"], - agent_id=r["agent_id"], - )) + entries.append( + EventLogEntry( + id=r["id"], + event_type=et, + source=r["source"], + timestamp=r["timestamp"], + data=json.loads(r["data"]) if r["data"] else {}, + task_id=r["task_id"], + agent_id=r["agent_id"], + ) + ) return entries diff --git a/src/swarm/task_queue/models.py b/src/swarm/task_queue/models.py index 34d4134..0bb8706 100644 --- a/src/swarm/task_queue/models.py +++ b/src/swarm/task_queue/models.py @@ -29,7 +29,8 @@ def _ensure_db() -> sqlite3.Connection: DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, title TEXT NOT NULL, @@ -42,7 +43,8 @@ def _ensure_db() -> sqlite3.Connection: created_at TEXT DEFAULT (datetime('now')), completed_at TEXT ) - """) + """ + ) conn.commit() return conn @@ -103,9 +105,7 @@ def get_task_summary_for_briefing() -> dict: """Return a summary of task counts by status for the morning briefing.""" db = _ensure_db() try: - rows = db.execute( - "SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status" - ).fetchall() + rows = db.execute("SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status").fetchall() finally: db.close() diff --git a/src/timmy/agent.py b/src/timmy/agent.py index 6d16e06..61e4cf7 100644 --- a/src/timmy/agent.py +++ b/src/timmy/agent.py @@ -69,16 +69,16 @@ def _check_model_available(model_name: str) -> bool: def _pull_model(model_name: str) -> bool: """Attempt to pull a model from Ollama. - + Returns: True if successful or model already exists """ try: - import urllib.request import json - + import urllib.request + logger.info("Pulling model: %s", model_name) - + url = settings.ollama_url.replace("localhost", "127.0.0.1") req = urllib.request.Request( f"{url}/api/pull", @@ -86,7 +86,7 @@ def _pull_model(model_name: str) -> bool: headers={"Content-Type": "application/json"}, data=json.dumps({"name": model_name, "stream": False}).encode(), ) - + with urllib.request.urlopen(req, timeout=300) as response: if response.status == 200: logger.info("Successfully pulled model: %s", model_name) @@ -94,7 +94,7 @@ def _pull_model(model_name: str) -> bool: else: logger.error("Failed to pull %s: HTTP %s", model_name, response.status) return False - + except Exception as exc: logger.error("Error pulling model %s: %s", model_name, exc) return False @@ -106,53 +106,44 @@ def _resolve_model_with_fallback( auto_pull: bool = True, ) -> tuple[str, bool]: """Resolve model with automatic pulling and fallback. - + Args: requested_model: Preferred model to use require_vision: Whether the model needs vision capabilities auto_pull: Whether to attempt pulling missing models - + Returns: Tuple of (model_name, is_fallback) """ model = requested_model or settings.ollama_model - + # Check if requested model is available if _check_model_available(model): logger.debug("Using available model: %s", model) return model, False - + # Try to pull the requested model if auto_pull: logger.info("Model %s not available locally, attempting to pull...", model) if _pull_model(model): return model, False logger.warning("Failed to pull %s, checking fallbacks...", model) - + # Use appropriate fallback chain fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS - + for fallback_model in fallback_chain: if _check_model_available(fallback_model): - logger.warning( - "Using fallback model %s (requested: %s)", - fallback_model, model - ) + logger.warning("Using fallback model %s (requested: %s)", fallback_model, model) return fallback_model, True - + # Try to pull the fallback if auto_pull and _pull_model(fallback_model): - logger.info( - "Pulled and using fallback model %s (requested: %s)", - fallback_model, model - ) + logger.info("Pulled and using fallback model %s (requested: %s)", fallback_model, model) return fallback_model, True - + # Absolute last resort - return the requested model and hope for the best - logger.error( - "No models available in fallback chain. Requested: %s", - model - ) + logger.error("No models available in fallback chain. Requested: %s", model) return model, False @@ -190,6 +181,7 @@ def _resolve_backend(requested: str | None) -> str: # "auto" path — lazy import to keep startup fast and tests clean. from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon + if is_apple_silicon() and airllm_available(): return "airllm" return "ollama" @@ -215,14 +207,17 @@ def create_timmy( if resolved == "claude": from timmy.backends import ClaudeBackend + return ClaudeBackend() if resolved == "grok": from timmy.backends import GrokBackend + return GrokBackend() if resolved == "airllm": from timmy.backends import TimmyAirLLMAgent + return TimmyAirLLMAgent(model_size=size) # Default: Ollama via Agno. @@ -236,16 +231,16 @@ def create_timmy( # If Ollama is completely unreachable, fall back to Claude if available if not _check_model_available(model_name): from timmy.backends import claude_available + if claude_available(): - logger.warning( - "Ollama unreachable — falling back to Claude backend" - ) + logger.warning("Ollama unreachable — falling back to Claude backend") from timmy.backends import ClaudeBackend + return ClaudeBackend() if is_fallback: logger.info("Using fallback model %s (requested was unavailable)", model_name) - + use_tools = _model_supports_tools(model_name) # Conditionally include tools — small models get none @@ -259,6 +254,7 @@ def create_timmy( # Try to load memory context try: from timmy.memory_system import memory_system + memory_context = memory_system.get_system_context() if memory_context: # Truncate if too long — smaller budget for small models @@ -290,32 +286,32 @@ def create_timmy( class TimmyWithMemory: """Agent wrapper with explicit three-tier memory management.""" - + def __init__(self, db_file: str = "timmy.db") -> None: from timmy.memory_system import memory_system - + self.agent = create_timmy(db_file=db_file) self.memory = memory_system self.session_active = True - + # Store initial context for reference self.initial_context = self.memory.get_system_context() - + def chat(self, message: str) -> str: """Simple chat interface that tracks in memory.""" # Check for user facts to extract self._extract_and_store_facts(message) - + # Run agent result = self.agent.run(message, stream=False) response_text = result.content if hasattr(result, "content") else str(result) - + return response_text - + def _extract_and_store_facts(self, message: str) -> None: """Extract user facts from message and store in memory.""" message_lower = message.lower() - + # Extract name name_patterns = [ ("my name is ", 11), @@ -323,7 +319,7 @@ class TimmyWithMemory: ("i am ", 5), ("call me ", 8), ] - + for pattern, offset in name_patterns: if pattern in message_lower: idx = message_lower.find(pattern) + offset @@ -332,7 +328,7 @@ class TimmyWithMemory: self.memory.update_user_fact("Name", name) self.memory.record_decision(f"Learned user's name: {name}") break - + # Extract preferences pref_patterns = [ ("i like ", "Likes"), @@ -341,7 +337,7 @@ class TimmyWithMemory: ("i don't like ", "Dislikes"), ("i hate ", "Dislikes"), ] - + for pattern, category in pref_patterns: if pattern in message_lower: idx = message_lower.find(pattern) + len(pattern) @@ -349,16 +345,16 @@ class TimmyWithMemory: if pref and len(pref) > 3: self.memory.record_open_item(f"User {category.lower()}: {pref}") break - + def end_session(self, summary: str = "Session completed") -> None: """End session and write handoff.""" if self.session_active: self.memory.end_session(summary) self.session_active = False - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.end_session() return False diff --git a/src/timmy/agent_core/interface.py b/src/timmy/agent_core/interface.py index d1595b7..5ce9cb2 100644 --- a/src/timmy/agent_core/interface.py +++ b/src/timmy/agent_core/interface.py @@ -16,38 +16,41 @@ Architecture: All methods return effects that can be logged, audited, and replayed. """ +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum, auto from typing import Any, Optional -import uuid class PerceptionType(Enum): """Types of sensory input an agent can receive.""" - TEXT = auto() # Natural language - IMAGE = auto() # Visual input - AUDIO = auto() # Sound/speech - SENSOR = auto() # Temperature, distance, etc. - MOTION = auto() # Accelerometer, gyroscope - NETWORK = auto() # API calls, messages - INTERNAL = auto() # Self-monitoring (battery, temp) + + TEXT = auto() # Natural language + IMAGE = auto() # Visual input + AUDIO = auto() # Sound/speech + SENSOR = auto() # Temperature, distance, etc. + MOTION = auto() # Accelerometer, gyroscope + NETWORK = auto() # API calls, messages + INTERNAL = auto() # Self-monitoring (battery, temp) class ActionType(Enum): """Types of actions an agent can perform.""" - TEXT = auto() # Generate text response - SPEAK = auto() # Text-to-speech - MOVE = auto() # Physical movement - GRIP = auto() # Manipulate objects - CALL = auto() # API/network call - EMIT = auto() # Signal/light/sound - SLEEP = auto() # Power management + + TEXT = auto() # Generate text response + SPEAK = auto() # Text-to-speech + MOVE = auto() # Physical movement + GRIP = auto() # Manipulate objects + CALL = auto() # API/network call + EMIT = auto() # Signal/light/sound + SLEEP = auto() # Power management class AgentCapability(Enum): """High-level capabilities a TimAgent may possess.""" + REASONING = "reasoning" CODING = "coding" WRITING = "writing" @@ -63,15 +66,16 @@ class AgentCapability(Enum): @dataclass(frozen=True) class AgentIdentity: """Immutable identity for an agent instance. - + This persists across sessions and substrates. If Timmy moves from cloud to robot, the identity follows. """ + id: str name: str version: str created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) - + @classmethod def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity": """Generate a new unique identity.""" @@ -85,16 +89,17 @@ class AgentIdentity: @dataclass class Perception: """A sensory input to the agent. - - Substrate-agnostic representation. A camera image and a + + Substrate-agnostic representation. A camera image and a LiDAR point cloud are both Perception instances. """ + type: PerceptionType data: Any # Content depends on type timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) source: str = "unknown" # e.g., "camera_1", "microphone", "user_input" metadata: dict = field(default_factory=dict) - + @classmethod def text(cls, content: str, source: str = "user") -> "Perception": """Factory for text perception.""" @@ -103,7 +108,7 @@ class Perception: data=content, source=source, ) - + @classmethod def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception": """Factory for sensor readings.""" @@ -117,16 +122,17 @@ class Perception: @dataclass class Action: """An action the agent intends to perform. - + Actions are effects — they describe what should happen, not how. The substrate implements the "how." """ + type: ActionType payload: Any # Action-specific data timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) confidence: float = 1.0 # 0-1, agent's certainty deadline: Optional[str] = None # When action must complete - + @classmethod def respond(cls, text: str, confidence: float = 1.0) -> "Action": """Factory for text response action.""" @@ -135,7 +141,7 @@ class Action: payload=text, confidence=confidence, ) - + @classmethod def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action": """Factory for movement action (x, y, z meters).""" @@ -148,10 +154,11 @@ class Action: @dataclass class Memory: """A stored experience or fact. - + Memories are substrate-agnostic. A conversation history and a video recording are both Memory instances. """ + id: str content: Any created_at: str @@ -159,7 +166,7 @@ class Memory: last_accessed: Optional[str] = None importance: float = 0.5 # 0-1, for pruning decisions tags: list[str] = field(default_factory=list) - + def touch(self) -> None: """Mark memory as accessed.""" self.access_count += 1 @@ -169,6 +176,7 @@ class Memory: @dataclass class Communication: """A message to/from another agent or human.""" + sender: str recipient: str content: Any @@ -179,132 +187,132 @@ class Communication: class TimAgent(ABC): """Abstract base class for all Timmy agent implementations. - + This is the substrate-agnostic interface. Implementations: - OllamaAgent: LLM-based reasoning (today) - RobotAgent: Physical embodiment (future) - SimulationAgent: Virtual environment (future) - + Usage: agent = OllamaAgent(identity) # Today's implementation - + perception = Perception.text("Hello Timmy") memory = agent.perceive(perception) - + action = agent.reason("How should I respond?") result = agent.act(action) - + agent.remember(memory) # Store for future """ - + def __init__(self, identity: AgentIdentity) -> None: self._identity = identity self._capabilities: set[AgentCapability] = set() self._state: dict[str, Any] = {} - + @property def identity(self) -> AgentIdentity: """Return this agent's immutable identity.""" return self._identity - + @property def capabilities(self) -> set[AgentCapability]: """Return set of supported capabilities.""" return self._capabilities.copy() - + def has_capability(self, capability: AgentCapability) -> bool: """Check if agent supports a capability.""" return capability in self._capabilities - + @abstractmethod def perceive(self, perception: Perception) -> Memory: """Process sensory input and create a memory. - + This is the entry point for all agent interaction. A text message, camera frame, or temperature reading all enter through perceive(). - + Args: perception: Sensory input - + Returns: Memory: Stored representation of the perception """ pass - + @abstractmethod def reason(self, query: str, context: list[Memory]) -> Action: """Reason about a situation and decide on action. - + This is where "thinking" happens. The agent uses its substrate-appropriate reasoning (LLM, neural net, rules) to decide what to do. - + Args: query: What to reason about context: Relevant memories for context - + Returns: Action: What the agent decides to do """ pass - + @abstractmethod def act(self, action: Action) -> Any: """Execute an action in the substrate. - + This is where the abstract action becomes concrete: - TEXT → Generate LLM response - MOVE → Send motor commands - SPEAK → Call TTS engine - + Args: action: The action to execute - + Returns: Result of the action (substrate-specific) """ pass - + @abstractmethod def remember(self, memory: Memory) -> None: """Store a memory for future retrieval. - + The storage mechanism depends on substrate: - Cloud: SQLite, vector DB - Robot: Local flash storage - Hybrid: Synced with conflict resolution - + Args: memory: Experience to store """ pass - + @abstractmethod def recall(self, query: str, limit: int = 5) -> list[Memory]: """Retrieve relevant memories. - + Args: query: What to search for limit: Maximum memories to return - + Returns: List of relevant memories, sorted by relevance """ pass - + @abstractmethod def communicate(self, message: Communication) -> bool: """Send/receive communication with another agent. - + Args: message: Message to send - + Returns: True if communication succeeded """ pass - + def get_state(self) -> dict[str, Any]: """Get current agent state for monitoring/debugging.""" return { @@ -312,7 +320,7 @@ class TimAgent(ABC): "capabilities": list(self._capabilities), "state": self._state.copy(), } - + def shutdown(self) -> None: """Graceful shutdown. Persist state, close connections.""" # Override in subclass for cleanup @@ -321,7 +329,7 @@ class TimAgent(ABC): class AgentEffect: """Log entry for agent actions — for audit and replay. - + The complete history of an agent's life can be captured as a sequence of AgentEffects. This enables: - Debugging: What did the agent see and do? @@ -329,40 +337,46 @@ class AgentEffect: - Replay: Reconstruct agent state from log - Training: Learn from agent experiences """ - + def __init__(self, log_path: Optional[str] = None) -> None: self._effects: list[dict] = [] self._log_path = log_path - + def log_perceive(self, perception: Perception, memory_id: str) -> None: """Log a perception event.""" - self._effects.append({ - "type": "perceive", - "perception_type": perception.type.name, - "source": perception.source, - "memory_id": memory_id, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - + self._effects.append( + { + "type": "perceive", + "perception_type": perception.type.name, + "source": perception.source, + "memory_id": memory_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + ) + def log_reason(self, query: str, action_type: ActionType) -> None: """Log a reasoning event.""" - self._effects.append({ - "type": "reason", - "query": query, - "action_type": action_type.name, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - + self._effects.append( + { + "type": "reason", + "query": query, + "action_type": action_type.name, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + ) + def log_act(self, action: Action, result: Any) -> None: """Log an action event.""" - self._effects.append({ - "type": "act", - "action_type": action.type.name, - "confidence": action.confidence, - "result_type": type(result).__name__, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - + self._effects.append( + { + "type": "act", + "action_type": action.type.name, + "confidence": action.confidence, + "result_type": type(result).__name__, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + ) + def export(self) -> list[dict]: """Export effect log for analysis.""" return self._effects.copy() diff --git a/src/timmy/agent_core/ollama_adapter.py b/src/timmy/agent_core/ollama_adapter.py index e81b136..43a53c1 100644 --- a/src/timmy/agent_core/ollama_adapter.py +++ b/src/timmy/agent_core/ollama_adapter.py @@ -7,10 +7,10 @@ between the old codebase and the new embodiment-ready architecture. Usage: from timmy.agent_core import AgentIdentity, Perception from timmy.agent_core.ollama_adapter import OllamaAgent - + identity = AgentIdentity.generate("Timmy") agent = OllamaAgent(identity) - + perception = Perception.text("Hello!") memory = agent.perceive(perception) action = agent.reason("How should I respond?", [memory]) @@ -19,27 +19,27 @@ Usage: from typing import Any, Optional +from timmy.agent import _resolve_model_with_fallback, create_timmy from timmy.agent_core.interface import ( - AgentCapability, - AgentIdentity, - Perception, - PerceptionType, Action, ActionType, - Memory, - Communication, - TimAgent, + AgentCapability, AgentEffect, + AgentIdentity, + Communication, + Memory, + Perception, + PerceptionType, + TimAgent, ) -from timmy.agent import create_timmy, _resolve_model_with_fallback class OllamaAgent(TimAgent): """TimAgent implementation using local Ollama LLM. - + This is the production agent for Timmy Time v2. It uses Ollama for reasoning and SQLite for memory persistence. - + Capabilities: - REASONING: LLM-based inference - CODING: Code generation and analysis @@ -47,7 +47,7 @@ class OllamaAgent(TimAgent): - ANALYSIS: Data processing and insights - COMMUNICATION: Multi-agent messaging """ - + def __init__( self, identity: AgentIdentity, @@ -56,7 +56,7 @@ class OllamaAgent(TimAgent): require_vision: bool = False, ) -> None: """Initialize Ollama-based agent. - + Args: identity: Agent identity (persistent across sessions) model: Ollama model to use (auto-resolves with fallback) @@ -64,23 +64,24 @@ class OllamaAgent(TimAgent): require_vision: Whether to select a vision-capable model """ super().__init__(identity) - + # Resolve model with automatic pulling and fallback resolved_model, is_fallback = _resolve_model_with_fallback( requested_model=model, require_vision=require_vision, auto_pull=True, ) - + if is_fallback: import logging + logging.getLogger(__name__).info( "OllamaAdapter using fallback model %s", resolved_model ) - + # Initialize underlying Ollama agent self._timmy = create_timmy(model=resolved_model) - + # Set capabilities based on what Ollama can do self._capabilities = { AgentCapability.REASONING, @@ -89,17 +90,17 @@ class OllamaAgent(TimAgent): AgentCapability.ANALYSIS, AgentCapability.COMMUNICATION, } - + # Effect logging for audit/replay self._effect_log = AgentEffect(effect_log) if effect_log else None - + # Simple in-memory working memory (short term) self._working_memory: list[Memory] = [] self._max_working_memory = 10 - + def perceive(self, perception: Perception) -> Memory: """Process perception and store in memory. - + For text perceptions, we might do light preprocessing (summarization, keyword extraction) before storage. """ @@ -114,28 +115,28 @@ class OllamaAgent(TimAgent): created_at=perception.timestamp, tags=self._extract_tags(perception), ) - + # Add to working memory self._working_memory.append(memory) if len(self._working_memory) > self._max_working_memory: self._working_memory.pop(0) # FIFO eviction - + # Log effect if self._effect_log: self._effect_log.log_perceive(perception, memory.id) - + return memory - + def reason(self, query: str, context: list[Memory]) -> Action: """Use LLM to reason and decide on action. - + This is where the Ollama agent does its work. We construct a prompt from the query and context, then interpret the response as an action. """ # Build context string from memories context_str = self._format_context(context) - + # Construct prompt prompt = f"""You are {self._identity.name}, an AI assistant. @@ -145,30 +146,30 @@ Context from previous interactions: Current query: {query} Respond naturally and helpfully.""" - + # Run LLM inference result = self._timmy.run(prompt, stream=False) response_text = result.content if hasattr(result, "content") else str(result) - + # Create text response action action = Action.respond(response_text, confidence=0.9) - + # Log effect if self._effect_log: self._effect_log.log_reason(query, action.type) - + return action - + def act(self, action: Action) -> Any: """Execute action in the Ollama substrate. - + For text actions, the "execution" is just returning the text (already generated during reasoning). For future action types (MOVE, SPEAK), this would trigger the appropriate Ollama tool calls. """ result = None - + if action.type == ActionType.TEXT: result = action.payload elif action.type == ActionType.SPEAK: @@ -179,13 +180,13 @@ Respond naturally and helpfully.""" result = {"status": "not_implemented", "payload": action.payload} else: result = {"error": f"Action type {action.type} not supported by OllamaAgent"} - + # Log effect if self._effect_log: self._effect_log.log_act(action, result) - + return result - + def remember(self, memory: Memory) -> None: """Store memory in working memory. @@ -200,48 +201,48 @@ Respond naturally and helpfully.""" # Evict oldest if over capacity if len(self._working_memory) > self._max_working_memory: self._working_memory.pop(0) - + def recall(self, query: str, limit: int = 5) -> list[Memory]: """Retrieve relevant memories. - + Simple keyword matching for now. Future: vector similarity. """ query_lower = query.lower() scored = [] - + for memory in self._working_memory: score = 0 content_str = str(memory.content).lower() - + # Simple keyword overlap query_words = set(query_lower.split()) content_words = set(content_str.split()) overlap = len(query_words & content_words) score += overlap - + # Boost recent memories score += memory.importance - + scored.append((score, memory)) - + # Sort by score descending scored.sort(key=lambda x: x[0], reverse=True) - + # Return top N return [m for _, m in scored[:limit]] - + def communicate(self, message: Communication) -> bool: """Send message to another agent. - + Swarm comms removed — inter-agent communication will be handled by the unified brain memory layer. """ return False - + def _extract_tags(self, perception: Perception) -> list[str]: """Extract searchable tags from perception.""" tags = [perception.type.name, perception.source] - + if perception.type == PerceptionType.TEXT: # Simple keyword extraction text = str(perception.data).lower() @@ -249,14 +250,14 @@ Respond naturally and helpfully.""" for kw in keywords: if kw in text: tags.append(kw) - + return tags - + def _format_context(self, memories: list[Memory]) -> str: """Format memories into context string for prompt.""" if not memories: return "No previous context." - + parts = [] for mem in memories[-5:]: # Last 5 memories if isinstance(mem.content, dict): @@ -264,9 +265,9 @@ Respond naturally and helpfully.""" parts.append(f"- {data}") else: parts.append(f"- {mem.content}") - + return "\n".join(parts) - + def get_effect_log(self) -> Optional[list[dict]]: """Export effect log if logging is enabled.""" if self._effect_log: diff --git a/src/timmy/agentic_loop.py b/src/timmy/agentic_loop.py index 826884d..6931b17 100644 --- a/src/timmy/agentic_loop.py +++ b/src/timmy/agentic_loop.py @@ -30,9 +30,11 @@ logger = logging.getLogger(__name__) # Data structures # --------------------------------------------------------------------------- + @dataclass class AgenticStep: """Result of a single step in the agentic loop.""" + step_num: int description: str result: str @@ -43,6 +45,7 @@ class AgenticStep: @dataclass class AgenticResult: """Final result of the entire agentic loop.""" + task_id: str task: str summary: str @@ -55,6 +58,7 @@ class AgenticResult: # Agent factory # --------------------------------------------------------------------------- + def _get_loop_agent(): """Create a fresh agent for the agentic loop. @@ -62,6 +66,7 @@ def _get_loop_agent(): dedicated session so it doesn't pollute the main chat history. """ from timmy.agent import create_timmy + return create_timmy() @@ -85,6 +90,7 @@ def _parse_steps(plan_text: str) -> list[str]: # Core loop # --------------------------------------------------------------------------- + async def run_agentic_loop( task: str, *, @@ -146,12 +152,15 @@ async def run_agentic_loop( was_truncated = planned_steps > total_steps # Broadcast plan - await _broadcast_progress("agentic.plan_ready", { - "task_id": task_id, - "task": task, - "steps": steps, - "total": total_steps, - }) + await _broadcast_progress( + "agentic.plan_ready", + { + "task_id": task_id, + "task": task, + "steps": steps, + "total": total_steps, + }, + ) # ── Phase 2: Execution ───────────────────────────────────────────────── completed_results: list[str] = [] @@ -175,6 +184,7 @@ async def run_agentic_loop( # Clean the response from timmy.session import _clean_response + step_result = _clean_response(step_result) step = AgenticStep( @@ -188,13 +198,16 @@ async def run_agentic_loop( completed_results.append(f"Step {i}: {step_result[:200]}") # Broadcast progress - await _broadcast_progress("agentic.step_complete", { - "task_id": task_id, - "step": i, - "total": total_steps, - "description": step_desc, - "result": step_result[:200], - }) + await _broadcast_progress( + "agentic.step_complete", + { + "task_id": task_id, + "step": i, + "total": total_steps, + "description": step_desc, + "result": step_result[:200], + }, + ) if on_progress: await on_progress(step_desc, i, total_steps) @@ -210,11 +223,16 @@ async def run_agentic_loop( ) try: adapt_run = await asyncio.to_thread( - agent.run, adapt_prompt, stream=False, + agent.run, + adapt_prompt, + stream=False, session_id=f"{session_id}_adapt{i}", ) - adapt_result = adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run) + adapt_result = ( + adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run) + ) from timmy.session import _clean_response + adapt_result = _clean_response(adapt_result) step = AgenticStep( @@ -227,14 +245,17 @@ async def run_agentic_loop( result.steps.append(step) completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}") - await _broadcast_progress("agentic.step_adapted", { - "task_id": task_id, - "step": i, - "total": total_steps, - "description": step_desc, - "error": str(exc), - "adaptation": adapt_result[:200], - }) + await _broadcast_progress( + "agentic.step_adapted", + { + "task_id": task_id, + "step": i, + "total": total_steps, + "description": step_desc, + "error": str(exc), + "adaptation": adapt_result[:200], + }, + ) if on_progress: await on_progress(f"[Adapted] {step_desc}", i, total_steps) @@ -259,11 +280,16 @@ async def run_agentic_loop( ) try: summary_run = await asyncio.to_thread( - agent.run, summary_prompt, stream=False, + agent.run, + summary_prompt, + stream=False, session_id=f"{session_id}_summary", ) - result.summary = summary_run.content if hasattr(summary_run, "content") else str(summary_run) + result.summary = ( + summary_run.content if hasattr(summary_run, "content") else str(summary_run) + ) from timmy.session import _clean_response + result.summary = _clean_response(result.summary) except Exception as exc: logger.error("Agentic loop summary failed: %s", exc) @@ -281,13 +307,16 @@ async def run_agentic_loop( result.total_duration_ms = int((time.monotonic() - start_time) * 1000) - await _broadcast_progress("agentic.task_complete", { - "task_id": task_id, - "status": result.status, - "steps_completed": len(result.steps), - "summary": result.summary[:300], - "duration_ms": result.total_duration_ms, - }) + await _broadcast_progress( + "agentic.task_complete", + { + "task_id": task_id, + "status": result.status, + "steps_completed": len(result.steps), + "summary": result.summary[:300], + "duration_ms": result.total_duration_ms, + }, + ) return result @@ -296,10 +325,12 @@ async def run_agentic_loop( # WebSocket broadcast helper # --------------------------------------------------------------------------- + async def _broadcast_progress(event: str, data: dict) -> None: """Broadcast agentic loop progress via WebSocket (best-effort).""" try: from infrastructure.ws_manager.handler import ws_manager + await ws_manager.broadcast(event, data) except Exception: logger.debug("Agentic loop: WS broadcast failed for %s", event) diff --git a/src/timmy/agents/base.py b/src/timmy/agents/base.py index ec7abbd..8a13e43 100644 --- a/src/timmy/agents/base.py +++ b/src/timmy/agents/base.py @@ -18,7 +18,7 @@ from agno.agent import Agent from agno.models.ollama import Ollama from config import settings -from infrastructure.events.bus import EventBus, Event +from infrastructure.events.bus import Event, EventBus try: from mcp.registry import tool_registry @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class BaseAgent(ABC): """Base class for all sub-agents.""" - + def __init__( self, agent_id: str, @@ -43,15 +43,15 @@ class BaseAgent(ABC): self.name = name self.role = role self.tools = tools or [] - + # Create Agno agent self.agent = self._create_agent(system_prompt) - + # Event bus for communication self.event_bus: Optional[EventBus] = None - + logger.info("%s agent initialized (id: %s)", name, agent_id) - + def _create_agent(self, system_prompt: str) -> Agent: """Create the underlying Agno agent.""" # Get tools from registry @@ -60,7 +60,7 @@ class BaseAgent(ABC): handler = tool_registry.get_handler(tool_name) if handler: tool_instances.append(handler) - + return Agent( name=self.name, model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300), @@ -71,19 +71,19 @@ class BaseAgent(ABC): markdown=True, telemetry=settings.telemetry_enabled, ) - + def connect_event_bus(self, bus: EventBus) -> None: """Connect to the event bus for inter-agent communication.""" self.event_bus = bus - + # Subscribe to relevant events bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message) bus.subscribe("agent.task.assigned")(self._handle_task_assignment) - + async def _handle_direct_message(self, event: Event) -> None: """Handle direct messages to this agent.""" logger.debug("%s received message: %s", self.name, event.type) - + async def _handle_task_assignment(self, event: Event) -> None: """Handle task assignment events.""" assigned_agent = event.data.get("agent_id") @@ -91,41 +91,43 @@ class BaseAgent(ABC): task_id = event.data.get("task_id") description = event.data.get("description", "") logger.info("%s assigned task %s: %s", self.name, task_id, description[:50]) - + # Execute the task await self.execute_task(task_id, description, event.data) - + @abstractmethod async def execute_task(self, task_id: str, description: str, context: dict) -> Any: """Execute a task assigned to this agent. - + Must be implemented by subclasses. """ pass - + async def run(self, message: str) -> str: """Run the agent with a message. - + Returns: Agent response """ result = self.agent.run(message, stream=False) response = result.content if hasattr(result, "content") else str(result) - + # Emit completion event if self.event_bus: - await self.event_bus.publish(Event( - type=f"agent.{self.agent_id}.response", - source=self.agent_id, - data={"input": message, "output": response}, - )) - + await self.event_bus.publish( + Event( + type=f"agent.{self.agent_id}.response", + source=self.agent_id, + data={"input": message, "output": response}, + ) + ) + return response - + def get_capabilities(self) -> list[str]: """Get list of capabilities this agent provides.""" return self.tools - + def get_status(self) -> dict: """Get current agent status.""" return { diff --git a/src/timmy/agents/timmy.py b/src/timmy/agents/timmy.py index 2189a78..9e27773 100644 --- a/src/timmy/agents/timmy.py +++ b/src/timmy/agents/timmy.py @@ -12,9 +12,9 @@ from typing import Any, Optional from agno.agent import Agent from agno.models.ollama import Ollama -from timmy.agents.base import BaseAgent, SubAgent from config import settings from infrastructure.events.bus import EventBus, event_bus +from timmy.agents.base import BaseAgent, SubAgent logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ _timmy_context: dict[str, Any] = { async def _load_hands_async() -> list[dict]: """Async helper to load hands. - + Hands registry removed — hand definitions live in TOML files under hands/. This will be rewired to read from brain memory. """ @@ -42,7 +42,7 @@ def build_timmy_context_sync() -> dict[str, Any]: Gathers git commits, active sub-agents, and hot memory. """ global _timmy_context - + ctx: dict[str, Any] = { "timestamp": datetime.now(timezone.utc).isoformat(), "repo_root": settings.repo_root, @@ -51,45 +51,52 @@ def build_timmy_context_sync() -> dict[str, Any]: "hands": [], "memory": "", } - + # 1. Get recent git commits try: from tools.git_tools import git_log + result = git_log(max_count=20) if result.get("success"): commits = result.get("commits", []) - ctx["git_log"] = "\n".join([ - f"{c['short_sha']} {c['message'].split(chr(10))[0]}" - for c in commits[:20] - ]) + ctx["git_log"] = "\n".join( + [f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits[:20]] + ) except Exception as exc: logger.warning("Could not load git log for context: %s", exc) ctx["git_log"] = "(Git log unavailable)" - + # 2. Get active sub-agents try: from swarm import registry as swarm_registry + conn = swarm_registry._get_conn() rows = conn.execute( "SELECT id, name, status, capabilities FROM agents ORDER BY name" ).fetchall() ctx["agents"] = [ - {"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]} + { + "id": r["id"], + "name": r["name"], + "status": r["status"], + "capabilities": r["capabilities"], + } for r in rows ] conn.close() except Exception as exc: logger.warning("Could not load agents for context: %s", exc) ctx["agents"] = [] - + # 3. Read hot memory (via HotMemory to auto-create if missing) try: from timmy.memory_system import memory_system + ctx["memory"] = memory_system.hot.read()[:2000] except Exception as exc: logger.warning("Could not load memory for context: %s", exc) ctx["memory"] = "(Memory unavailable)" - + _timmy_context.update(ctx) logger.info("Context built (sync): %d agents", len(ctx["agents"])) return ctx @@ -110,21 +117,31 @@ build_timmy_context = build_timmy_context_sync def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str: """Format the system prompt with dynamic context.""" - + # Format agents list - agents_list = "\n".join([ - f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |" - for a in context.get("agents", []) - ]) or "(No agents registered yet)" - + agents_list = ( + "\n".join( + [ + f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |" + for a in context.get("agents", []) + ] + ) + or "(No agents registered yet)" + ) + # Format hands list - hands_list = "\n".join([ - f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |" - for h in context.get("hands", []) - ]) or "(No hands configured)" - - repo_root = context.get('repo_root', settings.repo_root) - + hands_list = ( + "\n".join( + [ + f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |" + for h in context.get("hands", []) + ] + ) + or "(No hands configured)" + ) + + repo_root = context.get("repo_root", settings.repo_root) + context_block = f""" ## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())}) @@ -149,10 +166,10 @@ def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str: ### Hot Memory: {context.get('memory', '(unavailable)')[:1000]} """ - + # Replace {REPO_ROOT} placeholder with actual path base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root) - + # Insert context after the first line lines = base_prompt.split("\n") if lines: @@ -227,63 +244,71 @@ class TimmyOrchestrator(BaseAgent): name="Orchestrator", role="orchestrator", system_prompt=formatted_prompt, - tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write", "system_status"], + tools=[ + "web_search", + "read_file", + "write_file", + "python", + "memory_search", + "memory_write", + "system_status", + ], ) - + # Sub-agent registry self.sub_agents: dict[str, BaseAgent] = {} - + # Session tracking for init behavior self._session_initialized = False self._session_context: dict[str, Any] = {} self._context_fully_loaded = False - + # Connect to event bus self.connect_event_bus(event_bus) - + logger.info("Orchestrator initialized with context-aware prompt") - + def register_sub_agent(self, agent: BaseAgent) -> None: """Register a sub-agent with the orchestrator.""" self.sub_agents[agent.agent_id] = agent agent.connect_event_bus(event_bus) logger.info("Registered sub-agent: %s", agent.name) - + async def _session_init(self) -> None: """Initialize session context on first user message. - + Silently reads git log and AGENTS.md to ground the orchestrator in real data. This runs once per session before the first response. """ if self._session_initialized: return - + logger.debug("Running session init...") - + # Load full context including hands if not already done if not self._context_fully_loaded: await build_timmy_context_async() self._context_fully_loaded = True - + # Read recent git log --oneline -15 from repo root try: from tools.git_tools import git_log + git_result = git_log(max_count=15) if git_result.get("success"): commits = git_result.get("commits", []) self._session_context["git_log_commits"] = commits # Format as oneline for easy reading - self._session_context["git_log_oneline"] = "\n".join([ - f"{c['short_sha']} {c['message'].split(chr(10))[0]}" - for c in commits - ]) + self._session_context["git_log_oneline"] = "\n".join( + [f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits] + ) logger.debug(f"Session init: loaded {len(commits)} commits from git log") else: self._session_context["git_log_oneline"] = "Git log unavailable" except Exception as exc: logger.warning("Session init: could not read git log: %s", exc) self._session_context["git_log_oneline"] = "Git log unavailable" - + # Read AGENTS.md for self-awareness try: agents_md_path = Path(settings.repo_root) / "AGENTS.md" @@ -291,7 +316,7 @@ class TimmyOrchestrator(BaseAgent): self._session_context["agents_md"] = agents_md_path.read_text()[:3000] except Exception as exc: logger.warning("Session init: could not read AGENTS.md: %s", exc) - + # Read CHANGELOG for recent changes try: changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md" @@ -299,11 +324,13 @@ class TimmyOrchestrator(BaseAgent): self._session_context["changelog"] = changelog_path.read_text()[:2000] except Exception: pass # Changelog is optional - + # Build session-specific context block for the prompt recent_changes = self._session_context.get("git_log_oneline", "") if recent_changes and recent_changes != "Git log unavailable": - self._session_context["recent_changes_block"] = f""" + self._session_context[ + "recent_changes_block" + ] = f""" ## Recent Changes to Your Codebase (last 15 commits): ``` {recent_changes} @@ -312,17 +339,17 @@ When asked "what's new?" or similar, refer to these commits for actual changes. """ else: self._session_context["recent_changes_block"] = "" - + self._session_initialized = True logger.debug("Session init complete") - + def _get_enhanced_system_prompt(self) -> str: """Get system prompt enhanced with session-specific context. - + Prepends the recent git log to the system prompt for grounding. """ base = self.system_prompt - + # Add recent changes block if available recent_changes = self._session_context.get("recent_changes_block", "") if recent_changes: @@ -330,36 +357,45 @@ When asked "what's new?" or similar, refer to these commits for actual changes. lines = base.split("\n") if lines: return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:]) - + return base - + async def orchestrate(self, user_request: str) -> str: """Main entry point for user requests. - + Analyzes the request and either handles directly or delegates. """ # Run session init on first message (loads git log, etc.) await self._session_init() - + # Quick classification request_lower = user_request.lower() - + # Direct response patterns (no delegation needed) direct_patterns = [ - "your name", "who are you", "what are you", - "hello", "hi", "how are you", - "help", "what can you do", + "your name", + "who are you", + "what are you", + "hello", + "hi", + "how are you", + "help", + "what can you do", ] - + for pattern in direct_patterns: if pattern in request_lower: return await self.run(user_request) - + # Check for memory references — delegate to Echo memory_patterns = [ - "we talked about", "we discussed", "remember", - "what did i say", "what did we decide", - "remind me", "have we", + "we talked about", + "we discussed", + "remember", + "what did i say", + "what did we decide", + "remind me", + "have we", ] for pattern in memory_patterns: @@ -395,19 +431,16 @@ When asked "what's new?" or similar, refer to these commits for actual changes. if agent in text_lower: return agent return "orchestrator" - + async def execute_task(self, task_id: str, description: str, context: dict) -> Any: """Execute a task (usually delegates to appropriate agent).""" return await self.orchestrate(description) - + def get_swarm_status(self) -> dict: """Get status of all agents in the swarm.""" return { "orchestrator": self.get_status(), - "sub_agents": { - aid: agent.get_status() - for aid, agent in self.sub_agents.items() - }, + "sub_agents": {aid: agent.get_status() for aid, agent in self.sub_agents.items()}, "total_agents": 1 + len(self.sub_agents), } @@ -468,10 +501,29 @@ _PERSONAS: list[dict[str, Any]] = [ "system_prompt": ( "You are Helm, a routing and orchestration specialist.\n" "Analyze tasks and decide how to route them to other agents.\n" - "Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory).\n" + "Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory), Lab (experiments).\n" "Respond with: Primary Agent: [agent name]" ), }, + { + "agent_id": "lab", + "name": "Lab", + "role": "experiment", + "tools": [ + "run_experiment", + "prepare_experiment", + "shell", + "python", + "read_file", + "write_file", + ], + "system_prompt": ( + "You are Lab, an autonomous ML experimentation specialist.\n" + "You run time-boxed training experiments, evaluate metrics,\n" + "modify training code to improve results, and iterate.\n" + "Always report the metric delta. Never exceed the time budget." + ), + }, ] diff --git a/src/timmy/approvals.py b/src/timmy/approvals.py index 52888c8..2df441c 100644 --- a/src/timmy/approvals.py +++ b/src/timmy/approvals.py @@ -38,10 +38,10 @@ class ApprovalItem: id: str title: str description: str - proposed_action: str # what Timmy wants to do - impact: str # "low" | "medium" | "high" + proposed_action: str # what Timmy wants to do + impact: str # "low" | "medium" | "high" created_at: datetime - status: str # "pending" | "approved" | "rejected" + status: str # "pending" | "approved" | "rejected" def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: @@ -81,6 +81,7 @@ def _row_to_item(row: sqlite3.Row) -> ApprovalItem: # Public API # --------------------------------------------------------------------------- + def create_item( title: str, description: str, @@ -133,18 +134,14 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: """Return all approval items regardless of status, newest first.""" conn = _get_conn(db_path) - rows = conn.execute( - "SELECT * FROM approval_items ORDER BY created_at DESC" - ).fetchall() + rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall() conn.close() return [_row_to_item(r) for r in rows] def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: conn = _get_conn(db_path) - row = conn.execute( - "SELECT * FROM approval_items WHERE id = ?", (item_id,) - ).fetchone() + row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone() conn.close() return _row_to_item(row) if row else None @@ -152,9 +149,7 @@ def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: """Mark an approval item as approved.""" conn = _get_conn(db_path) - conn.execute( - "UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,) - ) + conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,)) conn.commit() conn.close() return get_item(item_id, db_path) @@ -163,9 +158,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem] def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: """Mark an approval item as rejected.""" conn = _get_conn(db_path) - conn.execute( - "UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,) - ) + conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,)) conn.commit() conn.close() return get_item(item_id, db_path) diff --git a/src/timmy/autoresearch.py b/src/timmy/autoresearch.py new file mode 100644 index 0000000..396f858 --- /dev/null +++ b/src/timmy/autoresearch.py @@ -0,0 +1,214 @@ +"""Autoresearch — autonomous ML experiment loops. + +Integrates Karpathy's autoresearch pattern: an agent modifies training +code, runs time-boxed GPU experiments, evaluates a target metric +(val_bpb by default), and iterates to find improvements. + +Flow: + 1. prepare_experiment — clone repo + run data prep + 2. run_experiment — execute train.py with wall-clock timeout + 3. evaluate_result — compare metric against baseline + 4. experiment_loop — orchestrate the full cycle + +All subprocess calls are guarded with timeouts for graceful degradation. +""" + +from __future__ import annotations + +import json +import logging +import re +import subprocess +import time +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + +DEFAULT_REPO = "https://github.com/karpathy/autoresearch.git" +_METRIC_RE = re.compile(r"val_bpb[:\s]+([0-9]+\.?[0-9]*)") + + +def prepare_experiment( + workspace: Path, + repo_url: str = DEFAULT_REPO, +) -> str: + """Clone autoresearch repo and run data preparation. + + Args: + workspace: Directory to set up the experiment in. + repo_url: Git URL for the autoresearch repository. + + Returns: + Status message describing what was prepared. + """ + workspace = Path(workspace) + workspace.mkdir(parents=True, exist_ok=True) + + repo_dir = workspace / "autoresearch" + if not repo_dir.exists(): + logger.info("Cloning autoresearch into %s", repo_dir) + result = subprocess.run( + ["git", "clone", "--depth", "1", repo_url, str(repo_dir)], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode != 0: + return f"Clone failed: {result.stderr.strip()}" + else: + logger.info("Autoresearch repo already present at %s", repo_dir) + + # Run prepare.py (data download + tokeniser training) + prepare_script = repo_dir / "prepare.py" + if prepare_script.exists(): + logger.info("Running prepare.py …") + result = subprocess.run( + ["python", str(prepare_script)], + capture_output=True, + text=True, + cwd=str(repo_dir), + timeout=300, + ) + if result.returncode != 0: + return f"Preparation failed: {result.stderr.strip()[:500]}" + return "Preparation complete — data downloaded and tokeniser trained." + + return "Preparation skipped — no prepare.py found." + + +def run_experiment( + workspace: Path, + timeout: int = 300, + metric_name: str = "val_bpb", +) -> dict[str, Any]: + """Run a single training experiment with a wall-clock timeout. + + Args: + workspace: Experiment workspace (contains autoresearch/ subdir). + timeout: Maximum wall-clock seconds for the run. + metric_name: Name of the metric to extract from stdout. + + Returns: + Dict with keys: metric (float|None), log (str), duration_s (int), + success (bool), error (str|None). + """ + repo_dir = Path(workspace) / "autoresearch" + train_script = repo_dir / "train.py" + + if not train_script.exists(): + return { + "metric": None, + "log": "", + "duration_s": 0, + "success": False, + "error": f"train.py not found in {repo_dir}", + } + + start = time.monotonic() + try: + result = subprocess.run( + ["python", str(train_script)], + capture_output=True, + text=True, + cwd=str(repo_dir), + timeout=timeout, + ) + duration = int(time.monotonic() - start) + output = result.stdout + result.stderr + + # Extract metric from output + metric_val = _extract_metric(output, metric_name) + + return { + "metric": metric_val, + "log": output[-2000:], # Keep last 2k chars + "duration_s": duration, + "success": result.returncode == 0, + "error": None if result.returncode == 0 else f"Exit code {result.returncode}", + } + except subprocess.TimeoutExpired: + duration = int(time.monotonic() - start) + return { + "metric": None, + "log": f"Experiment timed out after {timeout}s", + "duration_s": duration, + "success": False, + "error": f"Timed out after {timeout}s", + } + except OSError as exc: + return { + "metric": None, + "log": "", + "duration_s": 0, + "success": False, + "error": str(exc), + } + + +def _extract_metric(output: str, metric_name: str = "val_bpb") -> Optional[float]: + """Extract the last occurrence of a metric value from training output.""" + pattern = re.compile(rf"{re.escape(metric_name)}[:\s]+([0-9]+\.?[0-9]*)") + matches = pattern.findall(output) + if matches: + try: + return float(matches[-1]) + except ValueError: + pass + return None + + +def evaluate_result( + current: float, + baseline: float, + metric_name: str = "val_bpb", +) -> str: + """Compare a metric against baseline and return an assessment. + + For val_bpb, lower is better. + + Args: + current: Current experiment's metric value. + baseline: Baseline metric to compare against. + metric_name: Name of the metric (for display). + + Returns: + Human-readable assessment string. + """ + delta = current - baseline + pct = (delta / baseline) * 100 if baseline != 0 else 0.0 + + if delta < 0: + return f"Improvement: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)" + elif delta > 0: + return f"Regression: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)" + else: + return f"No change: {metric_name} = {current:.4f}" + + +def get_experiment_history(workspace: Path) -> list[dict[str, Any]]: + """Read experiment history from the workspace results file. + + Returns: + List of experiment result dicts, most recent first. + """ + results_file = Path(workspace) / "results.jsonl" + if not results_file.exists(): + return [] + + history: list[dict[str, Any]] = [] + for line in results_file.read_text().strip().splitlines(): + try: + history.append(json.loads(line)) + except json.JSONDecodeError: + continue + + return list(reversed(history)) + + +def _append_result(workspace: Path, result: dict[str, Any]) -> None: + """Append a result to the workspace JSONL log.""" + results_file = Path(workspace) / "results.jsonl" + results_file.parent.mkdir(parents=True, exist_ok=True) + with results_file.open("a") as f: + f.write(json.dumps(result) + "\n") diff --git a/src/timmy/backends.py b/src/timmy/backends.py index 0e6642d..91c6f36 100644 --- a/src/timmy/backends.py +++ b/src/timmy/backends.py @@ -24,8 +24,8 @@ logger = logging.getLogger(__name__) # HuggingFace model IDs for each supported size. _AIRLLM_MODELS: dict[str, str] = { - "8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", - "70b": "meta-llama/Meta-Llama-3.1-70B-Instruct", + "8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "70b": "meta-llama/Meta-Llama-3.1-70B-Instruct", "405b": "meta-llama/Meta-Llama-3.1-405B-Instruct", } @@ -35,6 +35,7 @@ ModelSize = Literal["8b", "70b", "405b"] @dataclass class RunResult: """Minimal Agno-compatible run result — carries the model's response text.""" + content: str @@ -47,6 +48,7 @@ def airllm_available() -> bool: """Return True when the airllm package is importable.""" try: import airllm # noqa: F401 + return True except ImportError: return False @@ -67,15 +69,16 @@ class TimmyAirLLMAgent: model_id = _AIRLLM_MODELS.get(model_size) if model_id is None: raise ValueError( - f"Unknown model size {model_size!r}. " - f"Choose from: {list(_AIRLLM_MODELS)}" + f"Unknown model size {model_size!r}. " f"Choose from: {list(_AIRLLM_MODELS)}" ) if is_apple_silicon(): from airllm import AirLLMMLX # type: ignore[import] + self._model = AirLLMMLX(model_id) else: from airllm import AutoModel # type: ignore[import] + self._model = AutoModel.from_pretrained(model_id) self._history: list[str] = [] @@ -137,6 +140,7 @@ class TimmyAirLLMAgent: try: from rich.console import Console from rich.markdown import Markdown + Console().print(Markdown(text)) except ImportError: print(text) @@ -157,6 +161,7 @@ GROK_MODELS: dict[str, str] = { @dataclass class GrokUsageStats: """Tracks Grok API usage for cost monitoring and Spark logging.""" + total_requests: int = 0 total_prompt_tokens: int = 0 total_completion_tokens: int = 0 @@ -240,9 +245,7 @@ class GrokBackend: RunResult with response content """ if not self._api_key: - return RunResult( - content="Grok is not configured. Set XAI_API_KEY to enable." - ) + return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.") start = time.time() messages = self._build_messages(message) @@ -285,16 +288,12 @@ class GrokBackend: except Exception as exc: self.stats.errors += 1 logger.error("Grok API error: %s", exc) - return RunResult( - content=f"Grok temporarily unavailable: {exc}" - ) + return RunResult(content=f"Grok temporarily unavailable: {exc}") async def arun(self, message: str) -> RunResult: """Async inference via Grok API — used by cascade router and tools.""" if not self._api_key: - return RunResult( - content="Grok is not configured. Set XAI_API_KEY to enable." - ) + return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.") start = time.time() messages = self._build_messages(message) @@ -336,9 +335,7 @@ class GrokBackend: except Exception as exc: self.stats.errors += 1 logger.error("Grok async API error: %s", exc) - return RunResult( - content=f"Grok temporarily unavailable: {exc}" - ) + return RunResult(content=f"Grok temporarily unavailable: {exc}") def print_response(self, message: str, *, stream: bool = True) -> None: """Run inference and render the response to stdout (CLI interface).""" @@ -346,6 +343,7 @@ class GrokBackend: try: from rich.console import Console from rich.markdown import Markdown + Console().print(Markdown(result.content)) except ImportError: print(result.content) @@ -415,6 +413,7 @@ def grok_available() -> bool: """Return True when Grok is enabled and API key is configured.""" try: from config import settings + return settings.grok_enabled and bool(settings.xai_api_key) except Exception: return False @@ -472,9 +471,7 @@ class ClaudeBackend: def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult: """Synchronous inference via Claude API.""" if not self._api_key: - return RunResult( - content="Claude is not configured. Set ANTHROPIC_API_KEY to enable." - ) + return RunResult(content="Claude is not configured. Set ANTHROPIC_API_KEY to enable.") start = time.time() messages = self._build_messages(message) @@ -508,9 +505,7 @@ class ClaudeBackend: except Exception as exc: logger.error("Claude API error: %s", exc) - return RunResult( - content=f"Claude temporarily unavailable: {exc}" - ) + return RunResult(content=f"Claude temporarily unavailable: {exc}") def print_response(self, message: str, *, stream: bool = True) -> None: """Run inference and render the response to stdout (CLI interface).""" @@ -518,6 +513,7 @@ class ClaudeBackend: try: from rich.console import Console from rich.markdown import Markdown + Console().print(Markdown(result.content)) except ImportError: print(result.content) @@ -569,6 +565,7 @@ def claude_available() -> bool: """Return True when Anthropic API key is configured.""" try: from config import settings + return bool(settings.anthropic_api_key) except Exception: return False diff --git a/src/timmy/briefing.py b/src/timmy/briefing.py index 11c8645..4acf4bf 100644 --- a/src/timmy/briefing.py +++ b/src/timmy/briefing.py @@ -25,6 +25,7 @@ _CACHE_MINUTES = 30 # Data structures # --------------------------------------------------------------------------- + @dataclass class ApprovalItem: """Lightweight representation used inside a Briefing. @@ -32,6 +33,7 @@ class ApprovalItem: The canonical mutable version (with persistence) lives in timmy.approvals. This one travels with the Briefing dataclass as a read-only snapshot. """ + id: str title: str description: str @@ -44,20 +46,19 @@ class ApprovalItem: @dataclass class Briefing: generated_at: datetime - summary: str # 150-300 words + summary: str # 150-300 words approval_items: list[ApprovalItem] = field(default_factory=list) period_start: datetime = field( default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6) ) - period_end: datetime = field( - default_factory=lambda: datetime.now(timezone.utc) - ) + period_end: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) # --------------------------------------------------------------------------- # SQLite cache # --------------------------------------------------------------------------- + def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(db_path)) @@ -98,9 +99,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None: def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]: """Load the most-recently cached briefing, or None if there is none.""" conn = _get_cache_conn(db_path) - row = conn.execute( - "SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1" - ).fetchone() + row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone() conn.close() if row is None: return None @@ -115,7 +114,11 @@ def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]: def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool: """Return True if the briefing was generated within max_age_minutes.""" now = datetime.now(timezone.utc) - age = now - briefing.generated_at.replace(tzinfo=timezone.utc) if briefing.generated_at.tzinfo is None else now - briefing.generated_at + age = ( + now - briefing.generated_at.replace(tzinfo=timezone.utc) + if briefing.generated_at.tzinfo is None + else now - briefing.generated_at + ) return age.total_seconds() < max_age_minutes * 60 @@ -123,6 +126,7 @@ def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool: # Activity gathering helpers # --------------------------------------------------------------------------- + def _gather_swarm_summary(since: datetime) -> str: """Pull recent task/agent stats from swarm.db. Graceful if DB missing.""" swarm_db = Path("data/swarm.db") @@ -170,6 +174,7 @@ def _gather_task_queue_summary() -> str: """Pull task queue stats for the briefing. Graceful if unavailable.""" try: from swarm.task_queue.models import get_task_summary_for_briefing + stats = get_task_summary_for_briefing() parts = [] if stats["pending_approval"]: @@ -194,6 +199,7 @@ def _gather_chat_summary(since: datetime) -> str: """Pull recent chat messages from the in-memory log.""" try: from dashboard.store import message_log + messages = message_log.all() # Filter to messages in the briefing window (best-effort: no timestamps) recent = messages[-10:] if len(messages) > 10 else messages @@ -213,6 +219,7 @@ def _gather_chat_summary(since: datetime) -> str: # BriefingEngine # --------------------------------------------------------------------------- + class BriefingEngine: """Generates morning briefings by querying activity and asking Timmy.""" @@ -297,6 +304,7 @@ class BriefingEngine: """Call Timmy's Agno agent and return the response text.""" try: from timmy.agent import create_timmy + agent = create_timmy() run = agent.run(prompt, stream=False) result = run.content if hasattr(run, "content") else str(run) @@ -317,6 +325,7 @@ class BriefingEngine: """Return pending ApprovalItems from the approvals DB.""" try: from timmy import approvals as _approvals + raw_items = _approvals.list_pending() return [ ApprovalItem( diff --git a/src/timmy/cascade_adapter.py b/src/timmy/cascade_adapter.py index 52d2622..2c2c3c4 100644 --- a/src/timmy/cascade_adapter.py +++ b/src/timmy/cascade_adapter.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) @dataclass class TimmyResponse: """Response from Timmy via Cascade Router.""" + content: str provider_used: str latency_ms: float @@ -27,31 +28,30 @@ class TimmyResponse: class TimmyCascadeAdapter: """Adapter that routes Timmy requests through Cascade Router. - + Usage: adapter = TimmyCascadeAdapter() response = await adapter.chat("Hello") print(f"Response: {response.content}") print(f"Provider: {response.provider_used}") """ - + def __init__(self, router: Optional[CascadeRouter] = None) -> None: """Initialize adapter with Cascade Router. - + Args: router: CascadeRouter instance. If None, creates default. """ self.router = router or CascadeRouter() - logger.info("TimmyCascadeAdapter initialized with %d providers", - len(self.router.providers)) - + logger.info("TimmyCascadeAdapter initialized with %d providers", len(self.router.providers)) + async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse: """Send message through cascade router with automatic failover. - + Args: message: User message context: Optional conversation context - + Returns: TimmyResponse with content and metadata """ @@ -60,37 +60,38 @@ class TimmyCascadeAdapter: if context: messages.append({"role": "system", "content": context}) messages.append({"role": "user", "content": message}) - + # Route through cascade import time + start = time.time() - + try: result = await self.router.complete( messages=messages, system_prompt=SYSTEM_PROMPT, ) - + latency = (time.time() - start) * 1000 - + # Determine if fallback was used primary = self.router.providers[0] if self.router.providers else None fallback_used = primary and primary.status.value != "healthy" - + return TimmyResponse( content=result.content, provider_used=result.provider_name, latency_ms=latency, fallback_used=fallback_used, ) - + except Exception as exc: logger.error("All providers failed: %s", exc) raise - + def get_provider_status(self) -> list[dict]: """Get status of all providers. - + Returns: List of provider status dicts """ @@ -112,10 +113,10 @@ class TimmyCascadeAdapter: } for p in self.router.providers ] - + def get_preferred_provider(self) -> Optional[str]: """Get name of highest-priority healthy provider. - + Returns: Provider name or None if all unhealthy """ diff --git a/src/timmy/conversation.py b/src/timmy/conversation.py index dbca651..8da6ed2 100644 --- a/src/timmy/conversation.py +++ b/src/timmy/conversation.py @@ -17,22 +17,23 @@ logger = logging.getLogger(__name__) @dataclass class ConversationContext: """Tracks the current conversation state.""" + user_name: Optional[str] = None current_topic: Optional[str] = None last_intent: Optional[str] = None turn_count: int = 0 started_at: datetime = field(default_factory=datetime.now) - + def update_topic(self, topic: str) -> None: """Update the current conversation topic.""" self.current_topic = topic self.turn_count += 1 - + def set_user_name(self, name: str) -> None: """Remember the user's name.""" self.user_name = name logger.info("User name set to: %s", name) - + def get_context_summary(self) -> str: """Generate a context summary for the prompt.""" parts = [] @@ -47,35 +48,88 @@ class ConversationContext: class ConversationManager: """Manages conversation context across sessions.""" - + def __init__(self) -> None: self._contexts: dict[str, ConversationContext] = {} - + def get_context(self, session_id: str) -> ConversationContext: """Get or create context for a session.""" if session_id not in self._contexts: self._contexts[session_id] = ConversationContext() return self._contexts[session_id] - + def clear_context(self, session_id: str) -> None: """Clear context for a session.""" if session_id in self._contexts: del self._contexts[session_id] - + # Words that look like names but are actually verbs/UI states - _NAME_BLOCKLIST = frozenset({ - "sending", "loading", "pending", "processing", "typing", - "working", "going", "trying", "looking", "getting", "doing", - "waiting", "running", "checking", "coming", "leaving", - "thinking", "reading", "writing", "watching", "listening", - "playing", "eating", "sleeping", "sitting", "standing", - "walking", "talking", "asking", "telling", "feeling", - "hoping", "wondering", "glad", "happy", "sorry", "sure", - "fine", "good", "great", "okay", "here", "there", "back", - "done", "ready", "busy", "free", "available", "interested", - "confused", "lost", "stuck", "curious", "excited", "tired", - "not", "also", "just", "still", "already", "currently", - }) + _NAME_BLOCKLIST = frozenset( + { + "sending", + "loading", + "pending", + "processing", + "typing", + "working", + "going", + "trying", + "looking", + "getting", + "doing", + "waiting", + "running", + "checking", + "coming", + "leaving", + "thinking", + "reading", + "writing", + "watching", + "listening", + "playing", + "eating", + "sleeping", + "sitting", + "standing", + "walking", + "talking", + "asking", + "telling", + "feeling", + "hoping", + "wondering", + "glad", + "happy", + "sorry", + "sure", + "fine", + "good", + "great", + "okay", + "here", + "there", + "back", + "done", + "ready", + "busy", + "free", + "available", + "interested", + "confused", + "lost", + "stuck", + "curious", + "excited", + "tired", + "not", + "also", + "just", + "still", + "already", + "currently", + } + ) def extract_user_name(self, message: str) -> Optional[str]: """Try to extract user's name from message.""" @@ -106,40 +160,66 @@ class ConversationManager: return name.capitalize() return None - + def should_use_tools(self, message: str, context: ConversationContext) -> bool: """Determine if this message likely requires tools. - + Returns True if tools are likely needed, False for simple chat. """ message_lower = message.lower().strip() - + # Tool keywords that suggest tool usage is needed tool_keywords = [ - "search", "look up", "find", "google", "current price", - "latest", "today's", "news", "weather", "stock price", - "read file", "write file", "save", "calculate", "compute", - "run ", "execute", "shell", "command", "install", + "search", + "look up", + "find", + "google", + "current price", + "latest", + "today's", + "news", + "weather", + "stock price", + "read file", + "write file", + "save", + "calculate", + "compute", + "run ", + "execute", + "shell", + "command", + "install", ] - + # Chat-only keywords that definitely don't need tools chat_only = [ - "hello", "hi ", "hey", "how are you", "what's up", - "your name", "who are you", "what are you", - "thanks", "thank you", "bye", "goodbye", - "tell me about yourself", "what can you do", + "hello", + "hi ", + "hey", + "how are you", + "what's up", + "your name", + "who are you", + "what are you", + "thanks", + "thank you", + "bye", + "goodbye", + "tell me about yourself", + "what can you do", ] - + # Check for chat-only patterns first for pattern in chat_only: if pattern in message_lower: return False - + # Check for tool keywords for keyword in tool_keywords: if keyword in message_lower: return True - + # Simple questions (starting with what, who, how, why, when, where) # usually don't need tools unless about current/real-time info simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"] @@ -150,7 +230,7 @@ class ConversationManager: if any(t in message_lower for t in time_words): return True return False - + # Default: don't use tools for unclear cases return False diff --git a/src/timmy/memory/vector_store.py b/src/timmy/memory/vector_store.py index 43562ad..61db5b9 100644 --- a/src/timmy/memory/vector_store.py +++ b/src/timmy/memory/vector_store.py @@ -25,11 +25,12 @@ def _get_model(): global _model, _has_embeddings if _has_embeddings is False: return None - + if _model is not None: return _model - + from config import settings + # In test mode or low-memory environments, skip embedding model load if settings.timmy_skip_embeddings: _has_embeddings = False @@ -37,7 +38,8 @@ def _get_model(): try: from sentence_transformers import SentenceTransformer - _model = SentenceTransformer('all-MiniLM-L6-v2') + + _model = SentenceTransformer("all-MiniLM-L6-v2") _has_embeddings = True return _model except (ImportError, RuntimeError, Exception): @@ -56,7 +58,7 @@ def _get_embedding_dimension() -> int: def _compute_embedding(text: str) -> list[float]: """Compute embedding vector for text. - + Uses sentence-transformers if available, otherwise returns a simple hash-based vector for basic similarity. """ @@ -66,30 +68,31 @@ def _compute_embedding(text: str) -> list[float]: return model.encode(text).tolist() except Exception: pass - + # Fallback: simple character n-gram hash embedding # Not as good but allows the system to work without heavy deps dim = 384 vec = [0.0] * dim text = text.lower() - + # Generate character trigram features for i in range(len(text) - 2): - trigram = text[i:i+3] + trigram = text[i : i + 3] hash_val = hash(trigram) % dim vec[hash_val] += 1.0 - + # Normalize - norm = sum(x*x for x in vec) ** 0.5 + norm = sum(x * x for x in vec) ** 0.5 if norm > 0: - vec = [x/norm for x in vec] - + vec = [x / norm for x in vec] + return vec @dataclass class MemoryEntry: """A memory entry with vector embedding.""" + id: str = field(default_factory=lambda: str(uuid.uuid4())) content: str = "" # The actual text content source: str = "" # Where it came from (agent, user, system) @@ -99,9 +102,7 @@ class MemoryEntry: session_id: Optional[str] = None metadata: Optional[dict] = None embedding: Optional[list[float]] = None - timestamp: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat() - ) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) relevance_score: Optional[float] = None # Set during search @@ -110,7 +111,7 @@ def _get_conn() -> sqlite3.Connection: DB_PATH.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row - + # Try to load sqlite-vss extension try: conn.enable_load_extension(True) @@ -119,7 +120,7 @@ def _get_conn() -> sqlite3.Connection: _has_vss = True except Exception: _has_vss = False - + # Create tables conn.execute( """ @@ -137,24 +138,14 @@ def _get_conn() -> sqlite3.Connection: ) """ ) - + # Create indexes - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)" - ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)" - ) - + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)") + conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)") + conn.commit() return conn @@ -170,7 +161,7 @@ def store_memory( compute_embedding: bool = True, ) -> MemoryEntry: """Store a memory entry with optional embedding. - + Args: content: The text content to store source: Source of the memory (agent name, user, system) @@ -180,14 +171,14 @@ def store_memory( session_id: Session identifier metadata: Additional structured data compute_embedding: Whether to compute vector embedding - + Returns: The stored MemoryEntry """ embedding = None if compute_embedding: embedding = _compute_embedding(content) - + entry = MemoryEntry( content=content, source=source, @@ -198,7 +189,7 @@ def store_memory( metadata=metadata, embedding=embedding, ) - + conn = _get_conn() conn.execute( """ @@ -222,7 +213,7 @@ def store_memory( ) conn.commit() conn.close() - + return entry @@ -235,7 +226,7 @@ def search_memories( min_relevance: float = 0.0, ) -> list[MemoryEntry]: """Search for memories by semantic similarity. - + Args: query: Search query text limit: Maximum results @@ -243,18 +234,18 @@ def search_memories( agent_id: Filter by agent session_id: Filter by session min_relevance: Minimum similarity score (0-1) - + Returns: List of MemoryEntry objects sorted by relevance """ query_embedding = _compute_embedding(query) - + conn = _get_conn() - + # Build query with filters conditions = [] params = [] - + if context_type: conditions.append("context_type = ?") params.append(context_type) @@ -264,9 +255,9 @@ def search_memories( if session_id: conditions.append("session_id = ?") params.append(session_id) - + where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" - + # Fetch candidates (we'll do in-memory similarity for now) # For production with sqlite-vss, this would use vector similarity index query_sql = f""" @@ -276,10 +267,10 @@ def search_memories( LIMIT ? """ params.append(limit * 3) # Get more candidates for ranking - + rows = conn.execute(query_sql, params).fetchall() conn.close() - + # Compute similarity scores results = [] for row in rows: @@ -295,7 +286,7 @@ def search_memories( embedding=json.loads(row["embedding"]) if row["embedding"] else None, timestamp=row["timestamp"], ) - + if entry.embedding: # Cosine similarity score = _cosine_similarity(query_embedding, entry.embedding) @@ -308,7 +299,7 @@ def search_memories( entry.relevance_score = score if score >= min_relevance: results.append(entry) - + # Sort by relevance and return top results results.sort(key=lambda x: x.relevance_score or 0, reverse=True) return results[:limit] @@ -316,9 +307,9 @@ def search_memories( def _cosine_similarity(a: list[float], b: list[float]) -> float: """Compute cosine similarity between two vectors.""" - dot = sum(x*y for x, y in zip(a, b)) - norm_a = sum(x*x for x in a) ** 0.5 - norm_b = sum(x*x for x in b) ** 0.5 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) @@ -334,51 +325,47 @@ def _keyword_overlap(query: str, content: str) -> float: return overlap / len(query_words) -def get_memory_context( - query: str, - max_tokens: int = 2000, - **filters -) -> str: +def get_memory_context(query: str, max_tokens: int = 2000, **filters) -> str: """Get relevant memory context as formatted text for LLM prompts. - + Args: query: Search query max_tokens: Approximate maximum tokens to return **filters: Additional filters (agent_id, session_id, etc.) - + Returns: Formatted context string for inclusion in prompts """ memories = search_memories(query, limit=20, **filters) - + context_parts = [] total_chars = 0 max_chars = max_tokens * 4 # Rough approximation - + for mem in memories: formatted = f"[{mem.source}]: {mem.content}" if total_chars + len(formatted) > max_chars: break context_parts.append(formatted) total_chars += len(formatted) - + if not context_parts: return "" - + return "Relevant context from memory:\n" + "\n\n".join(context_parts) def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]: """Recall personal facts about the user or system. - + Args: agent_id: Optional agent filter - + Returns: List of fact strings """ conn = _get_conn() - + if agent_id: rows = conn.execute( """ @@ -398,7 +385,7 @@ def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]: LIMIT 100 """, ).fetchall() - + conn.close() return [r["content"] for r in rows] @@ -434,11 +421,11 @@ def update_personal_fact(memory_id: str, new_content: str) -> bool: def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry: """Store a personal fact about the user or system. - + Args: fact: The fact to store agent_id: Associated agent - + Returns: The stored MemoryEntry """ @@ -453,7 +440,7 @@ def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntr def delete_memory(memory_id: str) -> bool: """Delete a memory entry by ID. - + Returns: True if deleted, False if not found """ @@ -470,29 +457,27 @@ def delete_memory(memory_id: str) -> bool: def get_memory_stats() -> dict: """Get statistics about the memory store. - + Returns: Dict with counts by type, total entries, etc. """ conn = _get_conn() - - total = conn.execute( - "SELECT COUNT(*) as count FROM memory_entries" - ).fetchone()["count"] - + + total = conn.execute("SELECT COUNT(*) as count FROM memory_entries").fetchone()["count"] + by_type = {} rows = conn.execute( "SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type" ).fetchall() for row in rows: by_type[row["context_type"]] = row["count"] - + with_embeddings = conn.execute( "SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL" ).fetchone()["count"] - + conn.close() - + return { "total_entries": total, "by_type": by_type, @@ -503,20 +488,20 @@ def get_memory_stats() -> dict: def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int: """Delete old memories to manage storage. - + Args: older_than_days: Delete memories older than this keep_facts: Whether to preserve fact-type memories - + Returns: Number of entries deleted """ from datetime import timedelta - + cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat() - + conn = _get_conn() - + if keep_facts: cursor = conn.execute( """ @@ -530,9 +515,9 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int: "DELETE FROM memory_entries WHERE timestamp < ?", (cutoff,), ) - + deleted = cursor.rowcount conn.commit() conn.close() - + return deleted diff --git a/src/timmy/memory_system.py b/src/timmy/memory_system.py index bbc7bff..e0d6680 100644 --- a/src/timmy/memory_system.py +++ b/src/timmy/memory_system.py @@ -28,50 +28,52 @@ HANDOFF_PATH = VAULT_PATH / "notes" / "last-session-handoff.md" class HotMemory: """Tier 1: Hot memory (MEMORY.md) — always loaded.""" - + def __init__(self) -> None: self.path = HOT_MEMORY_PATH self._content: Optional[str] = None self._last_modified: Optional[float] = None - + def read(self, force_refresh: bool = False) -> str: """Read hot memory, with caching.""" if not self.path.exists(): self._create_default() - + # Check if file changed current_mtime = self.path.stat().st_mtime if not force_refresh and self._content and self._last_modified == current_mtime: return self._content - + self._content = self.path.read_text() self._last_modified = current_mtime logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path) return self._content - + def update_section(self, section: str, content: str) -> None: """Update a specific section in MEMORY.md.""" full_content = self.read() - + # Find section pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)" match = re.search(pattern, full_content, re.DOTALL) - + if match: # Replace section new_section = f"## {section}\n\n{content}\n\n" - full_content = full_content[:match.start()] + new_section + full_content[match.end():] + full_content = full_content[: match.start()] + new_section + full_content[match.end() :] else: # Append section before last updated line insert_point = full_content.rfind("*Prune date:") new_section = f"## {section}\n\n{content}\n\n" - full_content = full_content[:insert_point] + new_section + "\n" + full_content[insert_point:] - + full_content = ( + full_content[:insert_point] + new_section + "\n" + full_content[insert_point:] + ) + self.path.write_text(full_content) self._content = full_content self._last_modified = self.path.stat().st_mtime logger.info("HotMemory: Updated section '%s'", section) - + def _create_default(self) -> None: """Create default MEMORY.md if missing.""" default_content = """# Timmy Hot Memory @@ -130,33 +132,33 @@ class HotMemory: *Prune date: {prune_date}* """.format( date=datetime.now(timezone.utc).strftime("%Y-%m-%d"), - prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d") + prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d"), ) - + self.path.write_text(default_content) logger.info("HotMemory: Created default MEMORY.md") class VaultMemory: """Tier 2: Structured vault (memory/) — append-only markdown.""" - + def __init__(self) -> None: self.path = VAULT_PATH self._ensure_structure() - + def _ensure_structure(self) -> None: """Ensure vault directory structure exists.""" (self.path / "self").mkdir(parents=True, exist_ok=True) (self.path / "notes").mkdir(parents=True, exist_ok=True) (self.path / "aar").mkdir(parents=True, exist_ok=True) - + def write_note(self, name: str, content: str, namespace: str = "notes") -> Path: """Write a note to the vault.""" # Add timestamp to filename timestamp = datetime.now(timezone.utc).strftime("%Y%m%d") filename = f"{timestamp}_{name}.md" filepath = self.path / namespace / filename - + # Add header full_content = f"""# {name.replace('_', ' ').title()} @@ -171,39 +173,39 @@ class VaultMemory: *Auto-generated by Timmy Memory System* """ - + filepath.write_text(full_content) logger.info("VaultMemory: Wrote %s", filepath) return filepath - + def read_file(self, filepath: Path) -> str: """Read a file from the vault.""" if not filepath.exists(): return "" return filepath.read_text() - + def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]: """List files in a namespace.""" dir_path = self.path / namespace if not dir_path.exists(): return [] return sorted(dir_path.glob(pattern)) - + def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]: """Get most recent file in namespace.""" files = self.list_files(namespace, pattern) return files[-1] if files else None - + def update_user_profile(self, key: str, value: str) -> None: """Update a field in user_profile.md.""" profile_path = self.path / "self" / "user_profile.md" - + if not profile_path.exists(): # Create default profile self._create_default_profile() - + content = profile_path.read_text() - + # Simple pattern replacement pattern = rf"(\*\*{re.escape(key)}:\*\*).*" if re.search(pattern, content): @@ -214,17 +216,17 @@ class VaultMemory: if facts_section in content: insert_point = content.find(facts_section) + len(facts_section) content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:] - + # Update last_updated content = re.sub( r"\*Last updated:.*\*", f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*", - content + content, ) - + profile_path.write_text(content) logger.info("VaultMemory: Updated user profile: %s = %s", key, value) - + def _create_default_profile(self) -> None: """Create default user profile.""" profile_path = self.path / "self" / "user_profile.md" @@ -254,24 +256,26 @@ class VaultMemory: --- *Last updated: {date}* -""".format(date=datetime.now(timezone.utc).strftime("%Y-%m-%d")) - +""".format( + date=datetime.now(timezone.utc).strftime("%Y-%m-%d") + ) + profile_path.write_text(default) class HandoffProtocol: """Session handoff protocol for continuity.""" - + def __init__(self) -> None: self.path = HANDOFF_PATH self.vault = VaultMemory() - + def write_handoff( self, session_summary: str, key_decisions: list[str], open_items: list[str], - next_steps: list[str] + next_steps: list[str], ) -> None: """Write handoff at session end.""" content = f"""# Last Session Handoff @@ -303,25 +307,24 @@ The user was last working on: {session_summary[:200]}... *This handoff will be auto-loaded at next session start* """ - + self.path.write_text(content) - + # Also archive to notes - self.vault.write_note( - "session_handoff", - content, - namespace="notes" + self.vault.write_note("session_handoff", content, namespace="notes") + + logger.info( + "HandoffProtocol: Wrote handoff with %d decisions, %d open items", + len(key_decisions), + len(open_items), ) - - logger.info("HandoffProtocol: Wrote handoff with %d decisions, %d open items", - len(key_decisions), len(open_items)) - + def read_handoff(self) -> Optional[str]: """Read handoff if exists.""" if not self.path.exists(): return None return self.path.read_text() - + def clear_handoff(self) -> None: """Clear handoff after loading.""" if self.path.exists(): @@ -331,7 +334,7 @@ The user was last working on: {session_summary[:200]}... class MemorySystem: """Central memory system coordinating all tiers.""" - + def __init__(self) -> None: self.hot = HotMemory() self.vault = VaultMemory() @@ -339,52 +342,52 @@ class MemorySystem: self.session_start_time: Optional[datetime] = None self.session_decisions: list[str] = [] self.session_open_items: list[str] = [] - + def start_session(self) -> str: """Start a new session, loading context from memory.""" self.session_start_time = datetime.now(timezone.utc) - + # Build context context_parts = [] - + # 1. Hot memory hot_content = self.hot.read() context_parts.append("## Hot Memory\n" + hot_content) - + # 2. Last session handoff handoff_content = self.handoff.read_handoff() if handoff_content: context_parts.append("## Previous Session\n" + handoff_content) self.handoff.clear_handoff() - + # 3. User profile (key fields only) profile = self._load_user_profile_summary() if profile: context_parts.append("## User Context\n" + profile) - + full_context = "\n\n---\n\n".join(context_parts) logger.info("MemorySystem: Session started with %d chars context", len(full_context)) - + return full_context - + def end_session(self, summary: str) -> None: """End session, write handoff.""" self.handoff.write_handoff( session_summary=summary, key_decisions=self.session_decisions, open_items=self.session_open_items, - next_steps=[] + next_steps=[], ) - + # Update hot memory self.hot.update_section( "Current Session", - f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n" + - f"**Summary:** {summary[:100]}..." + f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n" + + f"**Summary:** {summary[:100]}...", ) - + logger.info("MemorySystem: Session ended, handoff written") - + def record_decision(self, decision: str) -> None: """Record a key decision during session.""" self.session_decisions.append(decision) @@ -393,43 +396,47 @@ class MemorySystem: if "## Key Decisions" in current: # Append to section pass # Handled at session end - + def record_open_item(self, item: str) -> None: """Record an open item for follow-up.""" self.session_open_items.append(item) - + def update_user_fact(self, key: str, value: str) -> None: """Update user profile in vault.""" self.vault.update_user_profile(key, value) # Also update hot memory if key.lower() == "name": self.hot.update_section("User Profile", f"**Name:** {value}") - + def _load_user_profile_summary(self) -> str: """Load condensed user profile.""" profile_path = self.vault.path / "self" / "user_profile.md" if not profile_path.exists(): return "" - + content = profile_path.read_text() - + # Extract key fields summary_parts = [] - + # Name name_match = re.search(r"\*\*Name:\*\* (.+)", content) if name_match and "unknown" not in name_match.group(1).lower(): summary_parts.append(f"Name: {name_match.group(1).strip()}") - + # Interests interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL) if interests_section: - interests = [i.strip() for i in interests_section.group(1).split("\n-") if i.strip() and "to be" not in i] + interests = [ + i.strip() + for i in interests_section.group(1).split("\n-") + if i.strip() and "to be" not in i + ] if interests: summary_parts.append(f"Interests: {', '.join(interests[:3])}") - + return "\n".join(summary_parts) if summary_parts else "" - + def get_system_context(self) -> str: """Get full context for system prompt injection. diff --git a/src/timmy/semantic_memory.py b/src/timmy/semantic_memory.py index d36d5fa..e996b19 100644 --- a/src/timmy/semantic_memory.py +++ b/src/timmy/semantic_memory.py @@ -38,12 +38,14 @@ def _get_embedding_model(): global EMBEDDING_MODEL if EMBEDDING_MODEL is None: from config import settings + if settings.timmy_skip_embeddings: EMBEDDING_MODEL = False return EMBEDDING_MODEL try: from sentence_transformers import SentenceTransformer - EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2') + + EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2") logger.info("SemanticMemory: Loaded embedding model") except ImportError: logger.warning("SemanticMemory: sentence-transformers not installed, using fallback") @@ -60,11 +62,12 @@ def _simple_hash_embedding(text: str) -> list[float]: h = hashlib.md5(word.encode()).hexdigest() for j in range(8): idx = (i * 8 + j) % 128 - vec[idx] += int(h[j*2:j*2+2], 16) / 255.0 + vec[idx] += int(h[j * 2 : j * 2 + 2], 16) / 255.0 # Normalize import math - mag = math.sqrt(sum(x*x for x in vec)) or 1.0 - return [x/mag for x in vec] + + mag = math.sqrt(sum(x * x for x in vec)) or 1.0 + return [x / mag for x in vec] def embed_text(text: str) -> list[float]: @@ -80,9 +83,10 @@ def embed_text(text: str) -> list[float]: def cosine_similarity(a: list[float], b: list[float]) -> float: """Calculate cosine similarity between two vectors.""" import math - dot = sum(x*y for x, y in zip(a, b)) - mag_a = math.sqrt(sum(x*x for x in a)) - mag_b = math.sqrt(sum(x*x for x in b)) + + dot = sum(x * y for x, y in zip(a, b)) + mag_a = math.sqrt(sum(x * x for x in a)) + mag_b = math.sqrt(sum(x * x for x in b)) if mag_a == 0 or mag_b == 0: return 0.0 return dot / (mag_a * mag_b) @@ -91,6 +95,7 @@ def cosine_similarity(a: list[float], b: list[float]) -> float: @dataclass class MemoryChunk: """A searchable chunk of memory.""" + id: str source: str # filepath content: str @@ -100,17 +105,18 @@ class MemoryChunk: class SemanticMemory: """Vector-based semantic search over vault content.""" - + def __init__(self) -> None: self.db_path = SEMANTIC_DB_PATH self.vault_path = VAULT_PATH self._init_db() - + def _init_db(self) -> None: """Initialize SQLite with vector storage.""" self.db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(str(self.db_path)) - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS chunks ( id TEXT PRIMARY KEY, source TEXT NOT NULL, @@ -119,76 +125,76 @@ class SemanticMemory: created_at TEXT NOT NULL, source_hash TEXT NOT NULL ) - """) + """ + ) conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)") conn.commit() conn.close() - + def index_file(self, filepath: Path) -> int: """Index a single file into semantic memory.""" if not filepath.exists(): return 0 - + content = filepath.read_text() file_hash = hashlib.md5(content.encode()).hexdigest() - + # Check if already indexed with same hash conn = sqlite3.connect(str(self.db_path)) cursor = conn.execute( - "SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", - (str(filepath),) + "SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", (str(filepath),) ) existing = cursor.fetchone() if existing and existing[0] == file_hash: conn.close() return 0 # Already indexed - + # Delete old chunks for this file conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),)) - + # Split into chunks (paragraphs) chunks = self._split_into_chunks(content) - + # Index each chunk now = datetime.now(timezone.utc).isoformat() for i, chunk_text in enumerate(chunks): if len(chunk_text.strip()) < 20: # Skip tiny chunks continue - + chunk_id = f"{filepath.stem}_{i}" embedding = embed_text(chunk_text) - + conn.execute( """INSERT INTO chunks (id, source, content, embedding, created_at, source_hash) VALUES (?, ?, ?, ?, ?, ?)""", - (chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash) + (chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash), ) - + conn.commit() conn.close() - + logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks)) return len(chunks) - + def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]: """Split text into semantic chunks.""" # Split by paragraphs first - paragraphs = text.split('\n\n') + paragraphs = text.split("\n\n") chunks = [] - + for para in paragraphs: para = para.strip() if not para: continue - + # If paragraph is small enough, keep as one chunk if len(para) <= max_chunk_size: chunks.append(para) else: # Split long paragraphs by sentences - sentences = para.replace('. ', '.\n').split('\n') + sentences = para.replace(". ", ".\n").split("\n") current_chunk = "" - + for sent in sentences: if len(current_chunk) + len(sent) < max_chunk_size: current_chunk += " " + sent if current_chunk else sent @@ -196,82 +202,80 @@ class SemanticMemory: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sent - + if current_chunk: chunks.append(current_chunk.strip()) - + return chunks - + def index_vault(self) -> int: """Index entire vault directory.""" total_chunks = 0 - + for md_file in self.vault_path.rglob("*.md"): # Skip handoff file (handled separately) if "last-session-handoff" in md_file.name: continue total_chunks += self.index_file(md_file) - + logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks) return total_chunks - + def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]: """Search for relevant memory chunks.""" query_embedding = embed_text(query) - + conn = sqlite3.connect(str(self.db_path)) conn.row_factory = sqlite3.Row - + # Get all chunks (in production, use vector index) - rows = conn.execute( - "SELECT source, content, embedding FROM chunks" - ).fetchall() - + rows = conn.execute("SELECT source, content, embedding FROM chunks").fetchall() + conn.close() - + # Calculate similarities scored = [] for row in rows: embedding = json.loads(row["embedding"]) score = cosine_similarity(query_embedding, embedding) scored.append((row["source"], row["content"], score)) - + # Sort by score descending scored.sort(key=lambda x: x[2], reverse=True) - + # Return top_k return [(content, score) for _, content, score in scored[:top_k]] - + def get_relevant_context(self, query: str, max_chars: int = 2000) -> str: """Get formatted context string for a query.""" results = self.search(query, top_k=3) - + if not results: return "" - + parts = [] total_chars = 0 - + for content, score in results: if score < 0.3: # Similarity threshold continue - + chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..." if total_chars + len(chunk) > max_chars: break - + parts.append(chunk) total_chars += len(chunk) - + return "\n\n".join(parts) if parts else "" - + def stats(self) -> dict: """Get indexing statistics.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks") total_chunks, total_files = cursor.fetchone() conn.close() - + return { "total_chunks": total_chunks, "total_files": total_files, @@ -281,40 +285,39 @@ class SemanticMemory: class MemorySearcher: """High-level interface for memory search.""" - + def __init__(self) -> None: self.semantic = SemanticMemory() - + def search(self, query: str, tiers: list[str] = None) -> dict: """Search across memory tiers. - + Args: query: Search query tiers: List of tiers to search ["hot", "vault", "semantic"] - + Returns: Dict with results from each tier """ tiers = tiers or ["semantic"] # Default to semantic only results = {} - + if "semantic" in tiers: semantic_results = self.semantic.search(query, top_k=5) results["semantic"] = [ - {"content": content, "score": score} - for content, score in semantic_results + {"content": content, "score": score} for content, score in semantic_results ] - + return results - + def get_context_for_query(self, query: str) -> str: """Get comprehensive context for a user query.""" # Get semantic context semantic_context = self.semantic.get_relevant_context(query) - + if semantic_context: return f"## Relevant Past Context\n\n{semantic_context}" - + return "" @@ -353,6 +356,7 @@ def memory_search(query: str, top_k: int = 5) -> str: # 2. Search runtime vector store (stored facts/conversations) try: from timmy.memory.vector_store import search_memories + runtime_results = search_memories(query, limit=top_k, min_relevance=0.2) for entry in runtime_results: label = entry.context_type or "memory" @@ -387,6 +391,7 @@ def memory_read(query: str = "", top_k: int = 5) -> str: # Always include personal facts first try: from timmy.memory.vector_store import search_memories + facts = search_memories(query or "", limit=top_k, min_relevance=0.0) fact_entries = [e for e in facts if (e.context_type or "") == "fact"] if fact_entries: @@ -433,6 +438,7 @@ def memory_write(content: str, context_type: str = "fact") -> str: try: from timmy.memory.vector_store import store_memory + entry = store_memory( content=content.strip(), source="agent", diff --git a/src/timmy/session.py b/src/timmy/session.py index fea1bb8..ab642a8 100644 --- a/src/timmy/session.py +++ b/src/timmy/session.py @@ -32,13 +32,15 @@ _TOOL_CALL_JSON = re.compile( # Matches function-call-style text: memory_search(query="...") etc. _FUNC_CALL_TEXT = re.compile( - r'\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)' - r'\s*\([^)]*\)', + r"\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)" + r"\s*\([^)]*\)", ) # Matches chain-of-thought narration lines the model should keep internal _COT_PATTERNS = [ - re.compile(r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE), + re.compile( + r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE + ), re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE), ] @@ -48,6 +50,7 @@ def _get_agent(): global _agent if _agent is None: from timmy.agent import create_timmy + try: _agent = create_timmy() logger.info("Session: Timmy agent initialized (singleton)") @@ -99,6 +102,7 @@ def reset_session(session_id: Optional[str] = None) -> None: sid = session_id or _DEFAULT_SESSION_ID try: from timmy.conversation import conversation_manager + conversation_manager.clear_context(sid) except Exception as exc: logger.debug("Session: context clear failed for %s: %s", sid, exc) @@ -112,10 +116,12 @@ def _extract_facts(message: str) -> None: """ try: from timmy.conversation import conversation_manager + name = conversation_manager.extract_user_name(message) if name: try: from timmy.memory_system import memory_system + memory_system.update_user_fact("Name", name) logger.info("Session: Learned user name: %s", name) except Exception as exc: diff --git a/src/timmy/session_logger.py b/src/timmy/session_logger.py index 1fb4408..8a52088 100644 --- a/src/timmy/session_logger.py +++ b/src/timmy/session_logger.py @@ -6,7 +6,7 @@ including any mistakes or errors that occur during the session." import json import logging -from datetime import datetime, date +from datetime import date, datetime from pathlib import Path from typing import Any diff --git a/src/timmy/thinking.py b/src/timmy/thinking.py index d2160f4..abe72bb 100644 --- a/src/timmy/thinking.py +++ b/src/timmy/thinking.py @@ -75,6 +75,7 @@ Continue your train of thought.""" @dataclass class Thought: """A single thought in Timmy's inner stream.""" + id: str content: str seed_type: str @@ -98,9 +99,7 @@ def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: ) """ ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)") conn.commit() return conn @@ -190,9 +189,7 @@ class ThinkingEngine: def get_thought(self, thought_id: str) -> Optional[Thought]: """Retrieve a single thought by ID.""" conn = _get_conn(self._db_path) - row = conn.execute( - "SELECT * FROM thoughts WHERE id = ?", (thought_id,) - ).fetchone() + row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone() conn.close() return _row_to_thought(row) if row else None @@ -208,9 +205,7 @@ class ThinkingEngine: for _ in range(max_depth): if not current_id: break - row = conn.execute( - "SELECT * FROM thoughts WHERE id = ?", (current_id,) - ).fetchone() + row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (current_id,)).fetchone() if not row: break chain.append(_row_to_thought(row)) @@ -254,8 +249,10 @@ class ThinkingEngine: def _seed_from_swarm(self) -> str: """Gather recent swarm activity as thought seed.""" try: - from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary from datetime import timedelta + + from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary + since = datetime.now(timezone.utc) - timedelta(hours=1) swarm = _gather_swarm_summary(since) tasks = _gather_task_queue_summary() @@ -272,6 +269,7 @@ class ThinkingEngine: """Gather memory context as thought seed.""" try: from timmy.memory_system import memory_system + context = memory_system.get_system_context() if context: # Truncate to a reasonable size for a thought seed @@ -299,10 +297,12 @@ class ThinkingEngine: """ try: from timmy.session import chat + return chat(prompt, session_id="thinking") except Exception: # Fallback: create a fresh agent from timmy.agent import create_timmy + agent = create_timmy() run = agent.run(prompt, stream=False) return run.content if hasattr(run, "content") else str(run) @@ -323,8 +323,7 @@ class ThinkingEngine: INSERT INTO thoughts (id, content, seed_type, parent_id, created_at) VALUES (?, ?, ?, ?, ?) """, - (thought.id, thought.content, thought.seed_type, - thought.parent_id, thought.created_at), + (thought.id, thought.content, thought.seed_type, thought.parent_id, thought.created_at), ) conn.commit() conn.close() @@ -333,7 +332,8 @@ class ThinkingEngine: def _log_event(self, thought: Thought) -> None: """Log the thought as a swarm event.""" try: - from swarm.event_log import log_event, EventType + from swarm.event_log import EventType, log_event + log_event( EventType.TIMMY_THOUGHT, source="thinking-engine", @@ -351,12 +351,16 @@ class ThinkingEngine: """Broadcast the thought to WebSocket clients.""" try: from infrastructure.ws_manager.handler import ws_manager - await ws_manager.broadcast("timmy_thought", { - "thought_id": thought.id, - "content": thought.content, - "seed_type": thought.seed_type, - "created_at": thought.created_at, - }) + + await ws_manager.broadcast( + "timmy_thought", + { + "thought_id": thought.id, + "content": thought.content, + "seed_type": thought.seed_type, + "created_at": thought.created_at, + }, + ) except Exception as exc: logger.debug("Failed to broadcast thought: %s", exc) diff --git a/src/timmy/tools.py b/src/timmy/tools.py index b7222c2..6da81f7 100644 --- a/src/timmy/tools.py +++ b/src/timmy/tools.py @@ -227,11 +227,7 @@ def create_aider_tool(base_path: Path): ) if result.returncode == 0: - return ( - result.stdout - if result.stdout - else "Code changes applied successfully" - ) + return result.stdout if result.stdout else "Code changes applied successfully" else: return f"Aider error: {result.stderr}" except FileNotFoundError: @@ -354,7 +350,7 @@ def consult_grok(query: str) -> str: Grok's response text, or an error/status message. """ from config import settings - from timmy.backends import grok_available, get_grok_backend + from timmy.backends import get_grok_backend, grok_available if not grok_available(): return ( @@ -385,9 +381,7 @@ def consult_grok(query: str) -> str: ln = get_ln_backend() sats = min(settings.grok_max_sats_per_query, 100) inv = ln.create_invoice(sats, f"Grok query: {query[:50]}") - invoice_info = ( - f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]" - ) + invoice_info = f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]" except Exception: pass @@ -447,7 +441,7 @@ def create_full_toolkit(base_dir: str | Path | None = None): # Memory search and write — persistent recall across all channels try: - from timmy.semantic_memory import memory_search, memory_write, memory_read + from timmy.semantic_memory import memory_read, memory_search, memory_write toolkit.register(memory_search, name="memory_search") toolkit.register(memory_write, name="memory_write") @@ -473,6 +467,7 @@ def create_full_toolkit(base_dir: str | Path | None = None): Task ID and confirmation that background execution has started. """ import asyncio + task_id = None async def _launch(): @@ -502,11 +497,7 @@ def create_full_toolkit(base_dir: str | Path | None = None): # System introspection - query runtime environment (sovereign self-knowledge) try: - from timmy.tools_intro import ( - get_system_info, - check_ollama_health, - get_memory_status, - ) + from timmy.tools_intro import check_ollama_health, get_memory_status, get_system_info toolkit.register(get_system_info, name="get_system_info") toolkit.register(check_ollama_health, name="check_ollama_health") @@ -526,6 +517,60 @@ def create_full_toolkit(base_dir: str | Path | None = None): return toolkit +def create_experiment_tools(base_dir: str | Path | None = None): + """Create tools for the experiment agent (Lab). + + Includes: prepare_experiment, run_experiment, evaluate_result, + plus shell + file ops for editing training code. + """ + if not _AGNO_TOOLS_AVAILABLE: + raise ImportError(f"Agno tools not available: {_ImportError}") + + from config import settings + + toolkit = Toolkit(name="experiment") + + from timmy.autoresearch import evaluate_result, prepare_experiment, run_experiment + + workspace = ( + Path(base_dir) if base_dir else Path(settings.repo_root) / settings.autoresearch_workspace + ) + + def _prepare(repo_url: str = "https://github.com/karpathy/autoresearch.git") -> str: + """Clone and prepare an autoresearch experiment workspace.""" + return prepare_experiment(workspace, repo_url) + + def _run(timeout: int = 0) -> str: + """Run a single training experiment with wall-clock timeout.""" + t = timeout or settings.autoresearch_time_budget + result = run_experiment(workspace, timeout=t, metric_name=settings.autoresearch_metric) + if result["success"] and result["metric"] is not None: + return ( + f"{settings.autoresearch_metric}: {result['metric']:.4f} ({result['duration_s']}s)" + ) + return result.get("error") or "Experiment failed" + + def _evaluate(current: float, baseline: float) -> str: + """Compare current metric against baseline.""" + return evaluate_result(current, baseline, metric_name=settings.autoresearch_metric) + + toolkit.register(_prepare, name="prepare_experiment") + toolkit.register(_run, name="run_experiment") + toolkit.register(_evaluate, name="evaluate_result") + + # Also give Lab access to file + shell tools for editing train.py + shell_tools = ShellTools() + toolkit.register(shell_tools.run_shell_command, name="shell") + + base_path = Path(base_dir) if base_dir else Path(settings.repo_root) + file_tools = FileTools(base_dir=base_path) + toolkit.register(file_tools.read_file, name="read_file") + toolkit.register(file_tools.save_file, name="write_file") + toolkit.register(file_tools.list_files, name="list_files") + + return toolkit + + # Mapping of agent IDs to their toolkits AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = { "echo": create_research_tools, @@ -534,6 +579,7 @@ AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = { "seer": create_data_tools, "forge": create_code_tools, "quill": create_writing_tools, + "lab": create_experiment_tools, "pixel": lambda base_dir=None: _create_stub_toolkit("pixel"), "lyra": lambda base_dir=None: _create_stub_toolkit("lyra"), "reel": lambda base_dir=None: _create_stub_toolkit("reel"), @@ -553,9 +599,7 @@ def _create_stub_toolkit(name: str): return toolkit -def get_tools_for_agent( - agent_id: str, base_dir: str | Path | None = None -) -> Toolkit | None: +def get_tools_for_agent(agent_id: str, base_dir: str | Path | None = None) -> Toolkit | None: """Get the appropriate toolkit for an agent. Args: @@ -643,6 +687,21 @@ def get_all_available_tools() -> dict[str, dict]: "description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)", "available_in": ["forge", "orchestrator"], }, + "prepare_experiment": { + "name": "Prepare Experiment", + "description": "Clone autoresearch repo and run data preparation for ML experiments", + "available_in": ["lab", "orchestrator"], + }, + "run_experiment": { + "name": "Run Experiment", + "description": "Execute a time-boxed ML training experiment and capture metrics", + "available_in": ["lab", "orchestrator"], + }, + "evaluate_result": { + "name": "Evaluate Result", + "description": "Compare experiment metric against baseline to assess improvement", + "available_in": ["lab", "orchestrator"], + }, } # ── Git tools ───────────────────────────────────────────────────────────── diff --git a/src/timmy/tools_delegation/__init__.py b/src/timmy/tools_delegation/__init__.py index 80d652c..b711911 100644 --- a/src/timmy/tools_delegation/__init__.py +++ b/src/timmy/tools_delegation/__init__.py @@ -20,7 +20,9 @@ _VALID_AGENTS: dict[str, str] = { } -def delegate_task(agent_name: str, task_description: str, priority: str = "normal") -> dict[str, Any]: +def delegate_task( + agent_name: str, task_description: str, priority: str = "normal" +) -> dict[str, Any]: """Record a delegation intent to another agent. Args: @@ -44,7 +46,9 @@ def delegate_task(agent_name: str, task_description: str, priority: str = "norma if priority not in valid_priorities: priority = "normal" - logger.info("Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority) + logger.info( + "Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority + ) return { "success": True, diff --git a/src/timmy/tools_intro/__init__.py b/src/timmy/tools_intro/__init__.py index 5f4cd7e..7dbf972 100644 --- a/src/timmy/tools_intro/__init__.py +++ b/src/timmy/tools_intro/__init__.py @@ -65,9 +65,7 @@ def _get_ollama_model() -> str: models = response.json().get("models", []) # Check if configured model is available for model in models: - if model.get("name", "").startswith( - settings.ollama_model.split(":")[0] - ): + if model.get("name", "").startswith(settings.ollama_model.split(":")[0]): return settings.ollama_model # Fallback: return configured model @@ -139,9 +137,7 @@ def get_memory_status() -> dict[str, Any]: if tier1_exists: lines = memory_md.read_text().splitlines() tier1_info["line_count"] = len(lines) - tier1_info["sections"] = [ - ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ") - ] + tier1_info["sections"] = [ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")] # Vault — scan all subdirs under memory/ vault_root = repo_root / "memory" @@ -233,13 +229,15 @@ def get_agent_roster() -> dict[str, Any]: roster = [] for persona in _PERSONAS: - roster.append({ - "id": persona["agent_id"], - "name": persona["name"], - "status": "available", - "capabilities": ", ".join(persona.get("tools", [])), - "role": persona.get("role", ""), - }) + roster.append( + { + "id": persona["agent_id"], + "name": persona["name"], + "status": "available", + "capabilities": ", ".join(persona.get("tools", [])), + "role": persona.get("role", ""), + } + ) return { "agents": roster, diff --git a/src/timmy_serve/app.py b/src/timmy_serve/app.py index c21ae44..f5f3ed8 100644 --- a/src/timmy_serve/app.py +++ b/src/timmy_serve/app.py @@ -41,7 +41,7 @@ class StatusResponse(BaseModel): class RateLimitMiddleware(BaseHTTPMiddleware): """Simple in-memory rate limiting middleware.""" - + def __init__(self, app, limit: int = 10, window: int = 60): super().__init__(app) self.limit = limit @@ -53,22 +53,20 @@ class RateLimitMiddleware(BaseHTTPMiddleware): if request.url.path == "/serve/chat" and request.method == "POST": client_ip = request.client.host if request.client else "unknown" now = time.time() - + # Clean up old requests self.requests[client_ip] = [ - t for t in self.requests[client_ip] - if now - t < self.window + t for t in self.requests[client_ip] if now - t < self.window ] - + if len(self.requests[client_ip]) >= self.limit: logger.warning("Rate limit exceeded for %s", client_ip) return JSONResponse( - status_code=429, - content={"error": "Rate limit exceeded. Try again later."} + status_code=429, content={"error": "Rate limit exceeded. Try again later."} ) - + self.requests[client_ip].append(now) - + return await call_next(request) diff --git a/src/timmy_serve/cli.py b/src/timmy_serve/cli.py index ca3738f..b4d95ea 100644 --- a/src/timmy_serve/cli.py +++ b/src/timmy_serve/cli.py @@ -33,6 +33,7 @@ def start( return import uvicorn + from timmy_serve.app import create_timmy_serve_app serve_app = create_timmy_serve_app() diff --git a/src/timmy_serve/inter_agent.py b/src/timmy_serve/inter_agent.py index 6df6e6d..e113efe 100644 --- a/src/timmy_serve/inter_agent.py +++ b/src/timmy_serve/inter_agent.py @@ -23,9 +23,7 @@ class AgentMessage: to_agent: str = "" content: str = "" message_type: str = "text" # text | command | response | error - timestamp: str = field( - default_factory=lambda: datetime.now(timezone.utc).isoformat() - ) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) replied: bool = False @@ -56,7 +54,10 @@ class InterAgentMessenger: self._all_messages.append(msg) logger.info( "Message %s → %s: %s (%s)", - from_agent, to_agent, content[:50], message_type, + from_agent, + to_agent, + content[:50], + message_type, ) return msg diff --git a/src/timmy_serve/voice_tts.py b/src/timmy_serve/voice_tts.py index 7163c81..38ed430 100644 --- a/src/timmy_serve/voice_tts.py +++ b/src/timmy_serve/voice_tts.py @@ -26,6 +26,7 @@ class VoiceTTS: def _init_engine(self) -> None: try: import pyttsx3 + self._engine = pyttsx3.init() self._engine.setProperty("rate", self._rate) self._engine.setProperty("volume", self._volume) diff --git a/tests/brain/test_brain_client.py b/tests/brain/test_brain_client.py index 0cd94f6..bf01504 100644 --- a/tests/brain/test_brain_client.py +++ b/tests/brain/test_brain_client.py @@ -1,10 +1,11 @@ """Tests for brain.client — BrainClient memory + task operations.""" import json -import pytest from unittest.mock import AsyncMock, MagicMock, patch -from brain.client import BrainClient, DEFAULT_RQLITE_URL +import pytest + +from brain.client import DEFAULT_RQLITE_URL, BrainClient class TestBrainClientInit: @@ -40,9 +41,7 @@ class TestBrainClientMemory: async def test_remember_success(self): client = self._make_client() mock_response = MagicMock() - mock_response.json.return_value = { - "results": [{"last_insert_id": 42}] - } + mock_response.json.return_value = {"results": [{"last_insert_id": 42}]} mock_response.raise_for_status = MagicMock() client._client = MagicMock() client._client.post = AsyncMock(return_value=mock_response) @@ -74,9 +73,13 @@ class TestBrainClientMemory: client = self._make_client() mock_response = MagicMock() mock_response.json.return_value = { - "results": [{"rows": [ - ["memory content", "test", '{"key": "val"}', 0.1], - ]}] + "results": [ + { + "rows": [ + ["memory content", "test", '{"key": "val"}', 0.1], + ] + } + ] } mock_response.raise_for_status = MagicMock() client._client = MagicMock() @@ -129,9 +132,13 @@ class TestBrainClientMemory: client = self._make_client() mock_response = MagicMock() mock_response.json.return_value = { - "results": [{"rows": [ - [1, "recent memory", "test", '["tag1"]', '{}', "2026-03-06T00:00:00"], - ]}] + "results": [ + { + "rows": [ + [1, "recent memory", "test", '["tag1"]', "{}", "2026-03-06T00:00:00"], + ] + } + ] } mock_response.raise_for_status = MagicMock() client._client = MagicMock() @@ -152,13 +159,17 @@ class TestBrainClientMemory: async def test_get_context(self): client = self._make_client() - client.get_recent = AsyncMock(return_value=[ - {"content": "Recent item 1"}, - {"content": "Recent item 2"}, - ]) - client.recall = AsyncMock(return_value=[ - {"content": "Relevant item 1"}, - ]) + client.get_recent = AsyncMock( + return_value=[ + {"content": "Recent item 1"}, + {"content": "Recent item 2"}, + ] + ) + client.recall = AsyncMock( + return_value=[ + {"content": "Relevant item 1"}, + ] + ) ctx = await client.get_context("test query") assert "Recent activity:" in ctx @@ -176,9 +187,7 @@ class TestBrainClientTasks: async def test_submit_task(self): client = self._make_client() mock_response = MagicMock() - mock_response.json.return_value = { - "results": [{"last_insert_id": 7}] - } + mock_response.json.return_value = {"results": [{"last_insert_id": 7}]} mock_response.raise_for_status = MagicMock() client._client = MagicMock() client._client.post = AsyncMock(return_value=mock_response) @@ -199,9 +208,7 @@ class TestBrainClientTasks: client = self._make_client() mock_response = MagicMock() mock_response.json.return_value = { - "results": [{"rows": [ - [1, "task content", "shell", 5, '{"key": "val"}'] - ]}] + "results": [{"rows": [[1, "task content", "shell", 5, '{"key": "val"}']]}] } mock_response.raise_for_status = MagicMock() client._client = MagicMock() @@ -253,10 +260,14 @@ class TestBrainClientTasks: client = self._make_client() mock_response = MagicMock() mock_response.json.return_value = { - "results": [{"rows": [ - [1, "task 1", "general", 0, '{}', "2026-03-06"], - [2, "task 2", "shell", 5, '{}', "2026-03-06"], - ]}] + "results": [ + { + "rows": [ + [1, "task 1", "general", 0, "{}", "2026-03-06"], + [2, "task 2", "shell", 5, "{}", "2026-03-06"], + ] + } + ] } mock_response.raise_for_status = MagicMock() client._client = MagicMock() diff --git a/tests/brain/test_brain_worker.py b/tests/brain/test_brain_worker.py index 48c9291..3ce84c6 100644 --- a/tests/brain/test_brain_worker.py +++ b/tests/brain/test_brain_worker.py @@ -1,7 +1,8 @@ """Tests for brain.worker — DistributedWorker capability detection + task execution.""" +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock, AsyncMock from brain.worker import DistributedWorker @@ -156,11 +157,13 @@ class TestExecuteTask: worker._handlers["test_type"] = fake_handler - result = await worker.execute_task({ - "id": 1, - "type": "test_type", - "content": "do it", - }) + result = await worker.execute_task( + { + "id": 1, + "type": "test_type", + "content": "do it", + } + ) assert result["success"] is True assert result["result"] == "result" worker.brain.complete_task.assert_awaited_once_with(1, success=True, result="result") @@ -173,11 +176,13 @@ class TestExecuteTask: worker._handlers["fail_type"] = failing_handler - result = await worker.execute_task({ - "id": 2, - "type": "fail_type", - "content": "fail", - }) + result = await worker.execute_task( + { + "id": 2, + "type": "fail_type", + "content": "fail", + } + ) assert result["success"] is False assert "oops" in result["error"] worker.brain.complete_task.assert_awaited_once() @@ -190,11 +195,13 @@ class TestExecuteTask: worker._handlers["general"] = general_handler - result = await worker.execute_task({ - "id": 3, - "type": "unknown_type", - "content": "something", - }) + result = await worker.execute_task( + { + "id": 3, + "type": "unknown_type", + "content": "something", + } + ) assert result["success"] is True assert result["result"] == "general result" @@ -219,9 +226,7 @@ class TestRunOnce: async def test_run_once_with_task(self): worker = self._make_worker() - worker.brain.claim_task.return_value = { - "id": 1, "type": "code", "content": "write code" - } + worker.brain.claim_task.return_value = {"id": 1, "type": "code", "content": "write code"} had_work = await worker.run_once() assert had_work is True diff --git a/tests/brain/test_unified_memory.py b/tests/brain/test_unified_memory.py index 7f7f50c..9488e09 100644 --- a/tests/brain/test_unified_memory.py +++ b/tests/brain/test_unified_memory.py @@ -112,7 +112,9 @@ class TestRememberSync: memory.remember_sync("Dark mode enabled", tags=["preference", "ui"]) conn = memory._get_conn() try: - row = conn.execute("SELECT tags FROM memories WHERE content = ?", ("Dark mode enabled",)).fetchone() + row = conn.execute( + "SELECT tags FROM memories WHERE content = ?", ("Dark mode enabled",) + ).fetchone() tags = json.loads(row["tags"]) assert "preference" in tags assert "ui" in tags @@ -191,9 +193,9 @@ class TestRecallSync: results = memory.recall_sync("underwater basket weaving") if results: # If semantic search returned something, score should be low - assert results[0]["score"] < 0.7, ( - f"Expected low score for irrelevant query, got {results[0]['score']}" - ) + assert ( + results[0]["score"] < 0.7 + ), f"Expected low score for irrelevant query, got {results[0]['score']}" def test_recall_respects_limit(self, memory): """Recall should respect the limit parameter.""" @@ -254,9 +256,9 @@ class TestFacts: # Second access — count should be higher facts = memory.get_facts_sync(category="test_cat") second_count = facts[0]["access_count"] - assert second_count > first_count, ( - f"Access count should increment: {first_count} -> {second_count}" - ) + assert ( + second_count > first_count + ), f"Access count should increment: {first_count} -> {second_count}" def test_fact_confidence_ordering(self, memory): """Facts should be ordered by confidence (highest first).""" diff --git a/tests/conftest.py b/tests/conftest.py index 6fa3850..dad3742 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,10 @@ except ImportError: # ── Stub heavy optional dependencies so tests run without them installed ────── # Uses setdefault: real module is used if already installed, mock otherwise. +# Stub heavy optional dependencies so tests run without them installed. +# Uses setdefault: real module is used if already installed, mock otherwise. +# Note: only stub packages that are truly optional and may not be installed. +# Packages like typer, httpx, fastapi are required deps — never stub those. for _mod in [ "agno", "agno.agent", @@ -31,7 +35,6 @@ for _mod in [ "discord.ext.commands", "pyzbar", "pyzbar.pyzbar", - "requests", "pyttsx3", "sentence_transformers", ]: @@ -47,6 +50,7 @@ os.environ["TIMMY_SKIP_EMBEDDINGS"] = "1" def reset_message_log(): """Clear the in-memory chat log before and after every test.""" from dashboard.store import message_log + message_log.clear() yield message_log.clear() @@ -127,6 +131,7 @@ def cleanup_event_loops(): """Clean up any leftover event loops after each test.""" import asyncio import warnings + yield try: try: @@ -147,7 +152,9 @@ def cleanup_event_loops(): def client(): """FastAPI test client with fresh app instance.""" from fastapi.testclient import TestClient + from dashboard.app import app + with TestClient(app) as c: yield c @@ -157,7 +164,8 @@ def db_connection(): """Provide a fresh in-memory SQLite connection for tests.""" conn = sqlite3.connect(":memory:") conn.row_factory = sqlite3.Row - conn.executescript(""" + conn.executescript( + """ CREATE TABLE IF NOT EXISTS agents ( id TEXT PRIMARY KEY, name TEXT NOT NULL, @@ -175,14 +183,13 @@ def db_connection(): created_at TEXT NOT NULL, completed_at TEXT ); - """) + """ + ) conn.commit() yield conn conn.close() - - @pytest.fixture def mock_ollama_client(): """Provide a mock Ollama client for unit tests.""" @@ -201,5 +208,3 @@ def mock_timmy_agent(): agent.run = MagicMock(return_value="Test response from Timmy") agent.chat = MagicMock(return_value="Test chat response") return agent - - diff --git a/tests/conftest_markers.py b/tests/conftest_markers.py index 30d12eb..cd64688 100644 --- a/tests/conftest_markers.py +++ b/tests/conftest_markers.py @@ -30,7 +30,7 @@ def pytest_collection_modifyitems(config, items): """Automatically assign markers to tests based on file location.""" for item in items: test_path = str(item.fspath) - + # Categorize based on directory if "e2e" in test_path: item.add_marker(pytest.mark.e2e) @@ -41,19 +41,19 @@ def pytest_collection_modifyitems(config, items): item.add_marker(pytest.mark.integration) else: item.add_marker(pytest.mark.unit) - + # Add additional markers based on test name/path if "selenium" in test_path or "ui_" in item.name: item.add_marker(pytest.mark.selenium) item.add_marker(pytest.mark.skip_ci) - + if "docker" in test_path: item.add_marker(pytest.mark.docker) item.add_marker(pytest.mark.skip_ci) - + if "ollama" in test_path or "test_ollama" in item.name: item.add_marker(pytest.mark.ollama) - + # Mark slow tests if "slow" in item.name: item.add_marker(pytest.mark.slow) diff --git a/tests/dashboard/middleware/test_csrf.py b/tests/dashboard/middleware/test_csrf.py index 60507a0..4a26487 100644 --- a/tests/dashboard/middleware/test_csrf.py +++ b/tests/dashboard/middleware/test_csrf.py @@ -13,6 +13,7 @@ class TestCSRFMiddleware: def enable_csrf(self): """Re-enable CSRF for these tests.""" from config import settings + original = settings.timmy_disable_csrf settings.timmy_disable_csrf = False yield @@ -21,29 +22,29 @@ class TestCSRFMiddleware: def test_csrf_token_generation(self): """CSRF token should be generated and stored in session/state.""" from dashboard.middleware.csrf import generate_csrf_token - + token1 = generate_csrf_token() token2 = generate_csrf_token() - + # Tokens should be non-empty strings assert isinstance(token1, str) assert len(token1) > 0 - + # Each token should be unique assert token1 != token2 def test_csrf_token_validation(self): """Valid CSRF tokens should pass validation.""" from dashboard.middleware.csrf import generate_csrf_token, validate_csrf_token - + token = generate_csrf_token() - + # Same token should validate assert validate_csrf_token(token, token) is True - + # Different tokens should not validate assert validate_csrf_token(token, "different-token") is False - + # Empty tokens should not validate assert validate_csrf_token(token, "") is False assert validate_csrf_token("", token) is False @@ -51,16 +52,16 @@ class TestCSRFMiddleware: def test_csrf_middleware_allows_safe_methods(self): """GET, HEAD, OPTIONS requests should not require CSRF token.""" from dashboard.middleware.csrf import CSRFMiddleware - + app = FastAPI() app.add_middleware(CSRFMiddleware, secret="test-secret") - + @app.get("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # GET should work without CSRF token response = client.get("/test") assert response.status_code == 200 @@ -69,16 +70,16 @@ class TestCSRFMiddleware: def test_csrf_middleware_blocks_unsafe_methods_without_token(self): """POST, PUT, DELETE should require CSRF token.""" from dashboard.middleware.csrf import CSRFMiddleware - + app = FastAPI() app.add_middleware(CSRFMiddleware, secret="test-secret") - + @app.post("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # POST without CSRF token should fail response = client.post("/test") assert response.status_code == 403 @@ -87,24 +88,22 @@ class TestCSRFMiddleware: def test_csrf_middleware_allows_with_valid_token(self): """POST with valid CSRF token should succeed.""" from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token - + app = FastAPI() app.add_middleware(CSRFMiddleware, secret="test-secret") - + @app.post("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # Get CSRF token from cookie or header token = generate_csrf_token() - + # POST with valid CSRF token response = client.post( - "/test", - headers={"X-CSRF-Token": token}, - cookies={"csrf_token": token} + "/test", headers={"X-CSRF-Token": token}, cookies={"csrf_token": token} ) assert response.status_code == 200 assert response.json() == {"message": "success"} @@ -112,16 +111,16 @@ class TestCSRFMiddleware: def test_csrf_middleware_exempt_routes(self): """Routes with webhook patterns should bypass CSRF validation.""" from dashboard.middleware.csrf import CSRFMiddleware - + app = FastAPI() app.add_middleware(CSRFMiddleware, secret="test-secret") - + @app.post("/webhook") def webhook_endpoint(): return {"message": "webhook received"} - + client = TestClient(app) - + # POST to exempt route without CSRF token should work response = client.post("/webhook") assert response.status_code == 200 @@ -130,16 +129,16 @@ class TestCSRFMiddleware: def test_csrf_token_in_cookie(self): """CSRF token should be set in cookie for frontend to read.""" from dashboard.middleware.csrf import CSRFMiddleware - + app = FastAPI() app.add_middleware(CSRFMiddleware, secret="test-secret") - + @app.get("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # GET should set CSRF cookie response = client.get("/test") assert response.status_code == 200 @@ -148,22 +147,20 @@ class TestCSRFMiddleware: def test_csrf_middleware_allows_with_form_field(self): """POST with valid CSRF token in form field should succeed.""" from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token - + app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test-form") async def test_endpoint(request: Request): return {"message": "success"} - + client = TestClient(app) token = generate_csrf_token() - + # POST with valid CSRF token in form field response = client.post( - "/test-form", - data={"csrf_token": token, "other": "data"}, - cookies={"csrf_token": token} + "/test-form", data={"csrf_token": token, "other": "data"}, cookies={"csrf_token": token} ) assert response.status_code == 200 assert response.json() == {"message": "success"} @@ -171,23 +168,21 @@ class TestCSRFMiddleware: def test_csrf_middleware_blocks_mismatched_token(self): """POST with mismatched token should fail.""" from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token - + app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test") async def test_endpoint(): return {"message": "success"} - + client = TestClient(app) token1 = generate_csrf_token() token2 = generate_csrf_token() - + # POST with token from one session and cookie from another response = client.post( - "/test", - headers={"X-CSRF-Token": token1}, - cookies={"csrf_token": token2} + "/test", headers={"X-CSRF-Token": token1}, cookies={"csrf_token": token2} ) assert response.status_code == 403 assert "CSRF" in response.json().get("error", "") @@ -195,20 +190,17 @@ class TestCSRFMiddleware: def test_csrf_middleware_blocks_missing_cookie(self): """POST with header token but missing cookie should fail.""" from dashboard.middleware.csrf import CSRFMiddleware, generate_csrf_token - + app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test") async def test_endpoint(): return {"message": "success"} - + client = TestClient(app) token = generate_csrf_token() - + # POST with header token but no cookie - response = client.post( - "/test", - headers={"X-CSRF-Token": token} - ) + response = client.post("/test", headers={"X-CSRF-Token": token}) assert response.status_code == 403 diff --git a/tests/dashboard/middleware/test_csrf_bypass.py b/tests/dashboard/middleware/test_csrf_bypass.py index afc2efe..85849e8 100644 --- a/tests/dashboard/middleware/test_csrf_bypass.py +++ b/tests/dashboard/middleware/test_csrf_bypass.py @@ -3,8 +3,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient + from dashboard.middleware.csrf import CSRFMiddleware + class TestCSRFBypass: """Test potential CSRF bypasses.""" @@ -12,6 +14,7 @@ class TestCSRFBypass: def enable_csrf(self): """Re-enable CSRF for these tests.""" from config import settings + original = settings.timmy_disable_csrf settings.timmy_disable_csrf = False yield @@ -21,19 +24,16 @@ class TestCSRFBypass: """POST should require CSRF token even with AJAX headers (if not explicitly allowed).""" app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # POST with X-Requested-With should STILL fail if it's not a valid CSRF token # Some older middlewares used to trust this header blindly. - response = client.post( - "/test", - headers={"X-Requested-With": "XMLHttpRequest"} - ) + response = client.post("/test", headers={"X-Requested-With": "XMLHttpRequest"}) # This should fail with 403 because no CSRF token is provided assert response.status_code == 403 @@ -41,32 +41,32 @@ class TestCSRFBypass: """Test if path traversal can bypass CSRF exempt patterns.""" app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - - # If the middleware checks path starts with /webhook, + + # If the middleware checks path starts with /webhook, # can we use /webhook/../test to bypass? # Note: TestClient/FastAPI might normalize this, but we should check the logic. response = client.post("/webhook/../test") - + # If it bypassed, it would return 200 (if normalized to /test) or 404 (if not). # But it should definitely not return 200 success without CSRF. if response.status_code == 200: assert response.json() != {"message": "success"} - + def test_csrf_middleware_null_byte_bypass(self): """Test if null byte in path can bypass CSRF exempt patterns.""" app = FastAPI() middleware = CSRFMiddleware(app) - + # Test directly since TestClient blocks null bytes path = "/webhook\0/test" is_exempt = middleware._is_likely_exempt(path) - + # It should either be not exempt or the null byte should be handled # In our current implementation, it might still be exempt if normalized to /webhook\0/test # But it's better than /webhook/../test diff --git a/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py b/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py index d957660..48c9443 100644 --- a/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py +++ b/tests/dashboard/middleware/test_csrf_bypass_vulnerability.py @@ -3,8 +3,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient + from dashboard.middleware.csrf import CSRFMiddleware + class TestCSRFBypassVulnerability: """Test CSRF bypass via path normalization and suffix matching.""" @@ -12,6 +14,7 @@ class TestCSRFBypassVulnerability: def enable_csrf(self): """Re-enable CSRF for these tests.""" from config import settings + original = settings.timmy_disable_csrf settings.timmy_disable_csrf = False yield @@ -19,25 +22,25 @@ class TestCSRFBypassVulnerability: def test_csrf_bypass_via_traversal_to_exempt_pattern(self): """Test if a non-exempt route can be accessed by traversing to an exempt pattern. - - The middleware uses os.path.normpath() on the request path and then checks + + The middleware uses os.path.normpath() on the request path and then checks if it starts with an exempt pattern. If the request is to '/webhook/../api/chat', normpath makes it '/api/chat', which DOES NOT start with '/webhook'. - + Wait, the vulnerability is actually the OTHER way around: If I want to access '/api/chat' (protected) but I use '/webhook/../api/chat', normpath makes it '/api/chat', which is NOT exempt. - + HOWEVER, if the middleware DOES NOT use normpath, then '/webhook/../api/chat' WOULD start with '/webhook' and be exempt. - + The current code DOES use normpath: ```python normalized_path = os.path.normpath(path) if not normalized_path.startswith("/"): normalized_path = "/" + normalized_path ``` - + Let's look at the exempt patterns again: exempt_patterns = [ "/webhook", @@ -45,23 +48,23 @@ class TestCSRFBypassVulnerability: "/lightning/webhook", "/_internal/", ] - - If I have a route '/webhook_attacker' that is NOT exempt, + + If I have a route '/webhook_attacker' that is NOT exempt, but it starts with '/webhook', it WILL be exempt. """ app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/webhook_attacker") def sensitive_endpoint(): return {"message": "sensitive data accessed"} - + client = TestClient(app) - + # This route should NOT be exempt, but it starts with '/webhook' # CSRF validation should fail (403) because we provide no token. response = client.post("/webhook_attacker") - + # If it's 200, it's a bypass! assert response.status_code == 403, "Route /webhook_attacker should be protected by CSRF" @@ -76,13 +79,13 @@ class TestCSRFBypassVulnerability: """Test if /webhook_secret is exempt because it starts with /webhook.""" app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/webhook_secret") def sensitive_endpoint(): return {"message": "sensitive data accessed"} - + client = TestClient(app) - + # Should be 403 response = client.post("/webhook_secret") assert response.status_code == 403, "Route /webhook_secret should be protected by CSRF" @@ -91,21 +94,21 @@ class TestCSRFBypassVulnerability: """Test that legitimate exempt paths still work.""" app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/webhook") def webhook(): return {"message": "webhook received"} - + @app.post("/api/v1/chat") def api_chat(): return {"message": "api chat"} - + client = TestClient(app) - + # Legitimate /webhook (exact match) response = client.post("/webhook") assert response.status_code == 200 - + # Legitimate /api/v1/chat (prefix match) response = client.post("/api/v1/chat") assert response.status_code == 200 diff --git a/tests/dashboard/middleware/test_csrf_traversal.py b/tests/dashboard/middleware/test_csrf_traversal.py index fc8f950..83d2212 100644 --- a/tests/dashboard/middleware/test_csrf_traversal.py +++ b/tests/dashboard/middleware/test_csrf_traversal.py @@ -3,8 +3,10 @@ import pytest from fastapi import FastAPI from fastapi.testclient import TestClient + from dashboard.middleware.csrf import CSRFMiddleware + class TestCSRFTraversal: """Test path traversal CSRF bypass.""" @@ -12,6 +14,7 @@ class TestCSRFTraversal: def enable_csrf(self): """Re-enable CSRF for these tests.""" from config import settings + original = settings.timmy_disable_csrf settings.timmy_disable_csrf = False yield @@ -21,21 +24,21 @@ class TestCSRFTraversal: """Test if path traversal can bypass CSRF exempt patterns.""" app = FastAPI() app.add_middleware(CSRFMiddleware) - + @app.post("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) - + # We want to check if the middleware logic is flawed. # Since TestClient might normalize, we can test the _is_likely_exempt method directly. middleware = CSRFMiddleware(app) - + # This path starts with /webhook, but resolves to /test traversal_path = "/webhook/../test" - + # If this returns True, it's a vulnerability because /test is not supposed to be exempt. is_exempt = middleware._is_likely_exempt(traversal_path) - + assert is_exempt is False, f"Path {traversal_path} should not be exempt" diff --git a/tests/dashboard/middleware/test_request_logging.py b/tests/dashboard/middleware/test_request_logging.py index 4bc6f4b..974b76f 100644 --- a/tests/dashboard/middleware/test_request_logging.py +++ b/tests/dashboard/middleware/test_request_logging.py @@ -1,8 +1,9 @@ """Tests for request logging middleware.""" -import pytest import time from unittest.mock import Mock, patch + +import pytest from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.testclient import TestClient @@ -15,23 +16,23 @@ class TestRequestLoggingMiddleware: def app_with_logging(self): """Create app with request logging middleware.""" from dashboard.middleware.request_logging import RequestLoggingMiddleware - + app = FastAPI() app.add_middleware(RequestLoggingMiddleware) - + @app.get("/test") def test_endpoint(): return {"message": "success"} - + @app.get("/slow") def slow_endpoint(): time.sleep(0.1) return {"message": "slow response"} - + @app.get("/error") def error_endpoint(): raise ValueError("Test error") - + return app def test_logs_request_method_and_path(self, app_with_logging, caplog): @@ -40,17 +41,18 @@ class TestRequestLoggingMiddleware: client = TestClient(app_with_logging) response = client.get("/test") assert response.status_code == 200 - + # Check log contains method and path - assert any("GET" in record.message and "/test" in record.message - for record in caplog.records) + assert any( + "GET" in record.message and "/test" in record.message for record in caplog.records + ) def test_logs_response_status_code(self, app_with_logging, caplog): """Log should include response status code.""" with caplog.at_level("INFO"): client = TestClient(app_with_logging) response = client.get("/test") - + # Check log contains status code assert any("200" in record.message for record in caplog.records) @@ -59,27 +61,30 @@ class TestRequestLoggingMiddleware: with caplog.at_level("INFO"): client = TestClient(app_with_logging) response = client.get("/slow") - + # Check log contains duration (e.g., "0.1" or "100ms") - assert any(record.message for record in caplog.records - if any(c.isdigit() for c in record.message)) + assert any( + record.message for record in caplog.records if any(c.isdigit() for c in record.message) + ) def test_logs_client_ip(self, app_with_logging, caplog): """Log should include client IP address.""" with caplog.at_level("INFO"): client = TestClient(app_with_logging) response = client.get("/test", headers={"X-Forwarded-For": "192.168.1.1"}) - + # Check log contains IP - assert any("192.168.1.1" in record.message or "127.0.0.1" in record.message - for record in caplog.records) + assert any( + "192.168.1.1" in record.message or "127.0.0.1" in record.message + for record in caplog.records + ) def test_logs_user_agent(self, app_with_logging, caplog): """Log should include User-Agent header.""" with caplog.at_level("INFO"): client = TestClient(app_with_logging) response = client.get("/test", headers={"User-Agent": "TestAgent/1.0"}) - + # Check log contains user agent assert any("TestAgent" in record.message for record in caplog.records) @@ -88,7 +93,7 @@ class TestRequestLoggingMiddleware: with caplog.at_level("ERROR"): client = TestClient(app_with_logging, raise_server_exceptions=False) response = client.get("/error") - + assert response.status_code == 500 # Should have error log assert any(record.levelname == "ERROR" for record in caplog.records) @@ -96,18 +101,18 @@ class TestRequestLoggingMiddleware: def test_skips_health_check_logging(self, caplog): """Health check endpoints should not be logged (to reduce noise).""" from dashboard.middleware.request_logging import RequestLoggingMiddleware - + app = FastAPI() app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"]) - + @app.get("/health") def health_endpoint(): return {"status": "ok"} - + with caplog.at_level("INFO", logger="timmy.requests"): client = TestClient(app) response = client.get("/health") - + # Should not log health check (only check our logger's records) timmy_records = [r for r in caplog.records if r.name == "timmy.requests"] assert not any("/health" in record.message for record in timmy_records) @@ -117,7 +122,7 @@ class TestRequestLoggingMiddleware: with caplog.at_level("INFO"): client = TestClient(app_with_logging) response = client.get("/test") - + # Check for correlation ID format (UUID or similar) log_messages = [record.message for record in caplog.records] assert any(len(record.message) > 20 for record in caplog.records) # Rough check for ID diff --git a/tests/dashboard/middleware/test_security_headers.py b/tests/dashboard/middleware/test_security_headers.py index 1921fe4..9409130 100644 --- a/tests/dashboard/middleware/test_security_headers.py +++ b/tests/dashboard/middleware/test_security_headers.py @@ -2,7 +2,7 @@ import pytest from fastapi import FastAPI -from fastapi.responses import JSONResponse, HTMLResponse +from fastapi.responses import HTMLResponse, JSONResponse from fastapi.testclient import TestClient @@ -13,18 +13,18 @@ class TestSecurityHeadersMiddleware: def client_with_headers(self): """Create a test client with security headers middleware.""" from dashboard.middleware.security_headers import SecurityHeadersMiddleware - + app = FastAPI() app.add_middleware(SecurityHeadersMiddleware) - + @app.get("/test") def test_endpoint(): return {"message": "success"} - + @app.get("/html") def html_endpoint(): return HTMLResponse(content="Test") - + return TestClient(app) def test_x_content_type_options_header(self, client_with_headers): @@ -66,17 +66,17 @@ class TestSecurityHeadersMiddleware: def test_strict_transport_security_in_production(self): """HSTS header should be set in production mode.""" from dashboard.middleware.security_headers import SecurityHeadersMiddleware - + app = FastAPI() app.add_middleware(SecurityHeadersMiddleware, production=True) - + @app.get("/test") def test_endpoint(): return {"message": "success"} - + client = TestClient(app) response = client.get("/test") - + hsts = response.headers.get("strict-transport-security") assert hsts is not None assert "max-age=" in hsts @@ -89,18 +89,18 @@ class TestSecurityHeadersMiddleware: def test_headers_on_error_response(self): """Security headers should be set even on error responses.""" from dashboard.middleware.security_headers import SecurityHeadersMiddleware - + app = FastAPI() app.add_middleware(SecurityHeadersMiddleware) - + @app.get("/error") def error_endpoint(): raise ValueError("Test error") - + # Use raise_server_exceptions=False to get the error response client = TestClient(app, raise_server_exceptions=False) response = client.get("/error") - + # Even on 500 error, security headers should be present assert response.status_code == 500 assert response.headers.get("x-content-type-options") == "nosniff" diff --git a/tests/dashboard/test_briefing.py b/tests/dashboard/test_briefing.py index db1de24..36a6372 100644 --- a/tests/dashboard/test_briefing.py +++ b/tests/dashboard/test_briefing.py @@ -6,19 +6,13 @@ from unittest.mock import MagicMock, patch import pytest -from timmy.briefing import ( - Briefing, - BriefingEngine, - _load_latest, - _save_briefing, - is_fresh, -) - +from timmy.briefing import Briefing, BriefingEngine, _load_latest, _save_briefing, is_fresh # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture() def tmp_db(tmp_path): return tmp_path / "test_briefings.db" @@ -45,6 +39,7 @@ def _make_briefing(offset_minutes: int = 0) -> Briefing: # Briefing dataclass # --------------------------------------------------------------------------- + def test_briefing_fields(): b = _make_briefing() assert isinstance(b.generated_at, datetime) @@ -64,6 +59,7 @@ def test_briefing_default_period_is_6_hours(): # is_fresh # --------------------------------------------------------------------------- + def test_is_fresh_recent_briefing(): b = _make_briefing(offset_minutes=5) assert is_fresh(b) is True @@ -84,6 +80,7 @@ def test_is_fresh_custom_max_age(): # SQLite cache (save/load round-trip) # --------------------------------------------------------------------------- + def test_save_and_load_briefing(tmp_db): b = _make_briefing() _save_briefing(b, db_path=tmp_db) @@ -104,13 +101,17 @@ def test_load_latest_returns_most_recent(tmp_db): loaded = _load_latest(db_path=tmp_db) assert loaded is not None # Should return the newer one (generated_at closest to now) - assert abs((loaded.generated_at.replace(tzinfo=timezone.utc) - new.generated_at).total_seconds()) < 5 + assert ( + abs((loaded.generated_at.replace(tzinfo=timezone.utc) - new.generated_at).total_seconds()) + < 5 + ) # --------------------------------------------------------------------------- # BriefingEngine.needs_refresh # --------------------------------------------------------------------------- + def test_needs_refresh_when_no_cache(engine, tmp_db): assert engine.needs_refresh() is True @@ -131,6 +132,7 @@ def test_needs_refresh_true_when_stale(engine, tmp_db): # BriefingEngine.get_cached # --------------------------------------------------------------------------- + def test_get_cached_returns_none_when_empty(engine): assert engine.get_cached() is None @@ -147,6 +149,7 @@ def test_get_cached_returns_briefing(engine, tmp_db): # BriefingEngine.generate (agent mocked) # --------------------------------------------------------------------------- + def test_generate_returns_briefing(engine): with patch.object(engine, "_call_agent", return_value="All is well."): with patch.object(engine, "_load_pending_items", return_value=[]): @@ -177,6 +180,7 @@ def test_generate_handles_agent_failure(engine): # BriefingEngine.get_or_generate # --------------------------------------------------------------------------- + def test_get_or_generate_uses_cache_when_fresh(engine, tmp_db): fresh = _make_briefing(offset_minutes=5) _save_briefing(fresh, db_path=tmp_db) @@ -209,6 +213,7 @@ def test_get_or_generate_generates_when_no_cache(engine): # BriefingEngine._call_agent (unit — mocked agent) # --------------------------------------------------------------------------- + def test_call_agent_returns_content(engine): mock_run = MagicMock() mock_run.content = "Agent said hello." @@ -232,6 +237,7 @@ def test_call_agent_falls_back_on_exception(engine): # Push notification hook # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_notify_briefing_ready_skips_when_no_approvals(caplog): """notify_briefing_ready should NOT fire native notification with 0 approvals.""" diff --git a/tests/dashboard/test_calm.py b/tests/dashboard/test_calm.py index 5958420..d353b1e 100644 --- a/tests/dashboard/test_calm.py +++ b/tests/dashboard/test_calm.py @@ -1,22 +1,21 @@ -import pytest from datetime import date + +import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool from dashboard.app import app +from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState from dashboard.models.database import Base, get_db -from dashboard.models.calm import Task, JournalEntry, TaskState, TaskCertainty @pytest.fixture(name="test_db_engine") def test_db_engine_fixture(): # Use StaticPool to keep the in-memory database alive across multiple connections engine = create_engine( - "sqlite:///:memory:", - connect_args={"check_same_thread": False}, - poolclass=StaticPool + "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool ) Base.metadata.create_all(bind=engine) # Create tables yield engine @@ -134,7 +133,9 @@ def test_defer_now_task_promotes_next_and_later(client: TestClient, db_session: assert db_session.query(Task).filter(Task.id == task_later2.id).first().state == TaskState.LATER -def test_start_task_demotes_current_now_and_promotes_to_now(client: TestClient, db_session: Session): +def test_start_task_demotes_current_now_and_promotes_to_now( + client: TestClient, db_session: Session +): task_now = Task(title="Task NOW", state=TaskState.NOW, is_mit=True, sort_order=0) task_next = Task(title="Task NEXT", state=TaskState.NEXT, is_mit=False, sort_order=0) task_later1 = Task(title="Task LATER 1", state=TaskState.LATER, is_mit=True, sort_order=0) @@ -178,11 +179,17 @@ def test_evening_ritual_archives_active_tasks(client: TestClient, db_session: Se assert "Evening Ritual Complete" in response.text assert db_session.query(Task).filter(Task.id == task_now.id).first().state == TaskState.DEFERRED - assert db_session.query(Task).filter(Task.id == task_next.id).first().state == TaskState.DEFERRED - assert db_session.query(Task).filter(Task.id == task_later.id).first().state == TaskState.DEFERRED + assert ( + db_session.query(Task).filter(Task.id == task_next.id).first().state == TaskState.DEFERRED + ) + assert ( + db_session.query(Task).filter(Task.id == task_later.id).first().state == TaskState.DEFERRED + ) assert db_session.query(Task).filter(Task.id == task_done.id).first().state == TaskState.DONE - updated_journal = db_session.query(JournalEntry).filter(JournalEntry.id == journal_entry.id).first() + updated_journal = ( + db_session.query(JournalEntry).filter(JournalEntry.id == journal_entry.id).first() + ) assert updated_journal.evening_reflection == "Reflected well" assert updated_journal.gratitude == "Grateful for everything" assert updated_journal.energy_level == 8 @@ -200,9 +207,7 @@ def test_reorder_later_tasks(client: TestClient, db_session: Session): response = client.post( "/calm/tasks/reorder", - data={ - "later_task_ids": f"{task_later3.id},{task_later1.id},{task_later2.id}" - }, + data={"later_task_ids": f"{task_later3.id},{task_later1.id},{task_later2.id}"}, ) assert response.status_code == 200 @@ -223,9 +228,7 @@ def test_reorder_promote_later_to_next(client: TestClient, db_session: Session): response = client.post( "/calm/tasks/reorder", - data={ - "next_task_id": task_later1.id - }, + data={"next_task_id": task_later1.id}, ) assert response.status_code == 200 diff --git a/tests/dashboard/test_chat_api.py b/tests/dashboard/test_chat_api.py index c416025..2f5ed1e 100644 --- a/tests/dashboard/test_chat_api.py +++ b/tests/dashboard/test_chat_api.py @@ -3,7 +3,6 @@ import io from unittest.mock import patch - # ── POST /api/chat ──────────────────────────────────────────────────────────── diff --git a/tests/dashboard/test_dashboard.py b/tests/dashboard/test_dashboard.py index ad2152c..b6c01c4 100644 --- a/tests/dashboard/test_dashboard.py +++ b/tests/dashboard/test_dashboard.py @@ -1,6 +1,5 @@ from unittest.mock import AsyncMock, patch - # ── Index ───────────────────────────────────────────────────────────────────── diff --git a/tests/dashboard/test_experiments_route.py b/tests/dashboard/test_experiments_route.py new file mode 100644 index 0000000..2c25e50 --- /dev/null +++ b/tests/dashboard/test_experiments_route.py @@ -0,0 +1,41 @@ +"""Tests for the experiments dashboard route.""" + +from unittest.mock import patch + +import pytest + + +class TestExperimentsRoute: + """Tests for /experiments endpoints.""" + + def test_experiments_page_returns_200(self, client): + response = client.get("/experiments") + assert response.status_code == 200 + assert "Autoresearch" in response.text + + def test_experiments_page_shows_disabled_when_off(self, client): + response = client.get("/experiments") + assert response.status_code == 200 + assert "disabled" in response.text.lower() or "Disabled" in response.text + + @patch("dashboard.routes.experiments.settings") + def test_start_experiment_when_disabled(self, mock_settings, client): + mock_settings.autoresearch_enabled = False + response = client.post("/experiments/start") + assert response.status_code == 403 + + def test_experiment_detail_not_found(self, client): + response = client.get("/experiments/nonexistent-id") + assert response.status_code == 404 + + @patch("dashboard.routes.experiments.settings") + @patch("timmy.autoresearch.subprocess.run") + def test_start_experiment_when_enabled(self, mock_run, mock_settings, client): + mock_settings.autoresearch_enabled = True + mock_settings.repo_root = "/tmp" + mock_settings.autoresearch_workspace = "test-experiments" + mock_run.return_value = type("R", (), {"returncode": 0, "stdout": "", "stderr": ""})() + response = client.post("/experiments/start") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "started" diff --git a/tests/dashboard/test_input_validation.py b/tests/dashboard/test_input_validation.py index 49fc760..6a43397 100644 --- a/tests/dashboard/test_input_validation.py +++ b/tests/dashboard/test_input_validation.py @@ -1,93 +1,101 @@ import pytest from fastapi.testclient import TestClient + from dashboard.app import app + @pytest.fixture def client(): return TestClient(app) + def test_agents_chat_empty_message_validation(client): """Verify that empty messages are rejected.""" # First get a CSRF token get_resp = client.get("/agents/default/panel") csrf_token = get_resp.cookies.get("csrf_token") - + response = client.post( "/agents/default/chat", data={"message": ""}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) # Empty message should either be rejected or handled gracefully # For now, we'll accept it but it should be logged assert response.status_code in [200, 422] + def test_agents_chat_oversized_message_validation(client): """Verify that oversized messages are rejected.""" # First get a CSRF token get_resp = client.get("/agents/default/panel") csrf_token = get_resp.cookies.get("csrf_token") - + # Create a message that's too large (e.g., 100KB) large_message = "x" * (100 * 1024) response = client.post( "/agents/default/chat", data={"message": large_message}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) # Should reject or handle gracefully assert response.status_code in [200, 413, 422] + def test_memory_search_empty_query_validation(client): """Verify that empty search queries are handled.""" # First get a CSRF token get_resp = client.get("/memory") csrf_token = get_resp.cookies.get("csrf_token") - + response = client.post( "/memory/search", data={"query": ""}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) assert response.status_code in [200, 422, 500] # 500 for missing template + def test_memory_search_oversized_query_validation(client): """Verify that oversized search queries are rejected.""" # First get a CSRF token get_resp = client.get("/memory") csrf_token = get_resp.cookies.get("csrf_token") - + large_query = "x" * (50 * 1024) response = client.post( "/memory/search", data={"query": large_query}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) assert response.status_code in [200, 413, 422, 500] # 500 for missing template + def test_memory_fact_empty_fact_validation(client): """Verify that empty facts are rejected.""" # First get a CSRF token get_resp = client.get("/memory") csrf_token = get_resp.cookies.get("csrf_token") - + response = client.post( "/memory/fact", data={"fact": ""}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) # Empty fact should be rejected assert response.status_code in [400, 422, 500] # 500 for missing template + def test_memory_fact_oversized_fact_validation(client): """Verify that oversized facts are rejected.""" # First get a CSRF token get_resp = client.get("/memory") csrf_token = get_resp.cookies.get("csrf_token") - + large_fact = "x" * (100 * 1024) response = client.post( "/memory/fact", data={"fact": large_fact}, - headers={"X-CSRF-Token": csrf_token} if csrf_token else {} + headers={"X-CSRF-Token": csrf_token} if csrf_token else {}, ) assert response.status_code in [200, 413, 422, 500] # 500 for missing template diff --git a/tests/dashboard/test_local_models.py b/tests/dashboard/test_local_models.py index 41f924d..859b498 100644 --- a/tests/dashboard/test_local_models.py +++ b/tests/dashboard/test_local_models.py @@ -11,9 +11,9 @@ Categories: import re from pathlib import Path - # ── helpers ────────────────────────────────────────────────────────────────── + def _local_html(client) -> str: return client.get("/mobile/local").text @@ -25,6 +25,7 @@ def _local_llm_js() -> str: # ── L1xx — Route & API responses ───────────────────────────────────────────── + def test_L101_mobile_local_route_returns_200(client): """The /mobile/local endpoint should return 200 OK.""" response = client.get("/mobile/local") @@ -61,9 +62,11 @@ def test_L104_local_models_config_default_values(client): # ── L2xx — Config settings ─────────────────────────────────────────────────── + def test_L201_config_has_browser_model_enabled(): """config.py should define browser_model_enabled.""" from config import settings + assert hasattr(settings, "browser_model_enabled") assert isinstance(settings.browser_model_enabled, bool) @@ -71,6 +74,7 @@ def test_L201_config_has_browser_model_enabled(): def test_L202_config_has_browser_model_id(): """config.py should define browser_model_id.""" from config import settings + assert hasattr(settings, "browser_model_id") assert isinstance(settings.browser_model_id, str) assert len(settings.browser_model_id) > 0 @@ -79,12 +83,14 @@ def test_L202_config_has_browser_model_id(): def test_L203_config_has_browser_model_fallback(): """config.py should define browser_model_fallback.""" from config import settings + assert hasattr(settings, "browser_model_fallback") assert isinstance(settings.browser_model_fallback, bool) # ── L3xx — Template content & UX ──────────────────────────────────────────── + def test_L301_template_includes_local_llm_script(client): """mobile_local.html must include the local_llm.js script.""" html = _local_html(client) @@ -157,6 +163,7 @@ def test_L311_template_has_backend_badge(client): # ── L4xx — JavaScript asset ────────────────────────────────────────────────── + def test_L401_local_llm_js_exists(): """static/local_llm.js must exist.""" js_path = Path(__file__).parent.parent.parent / "static" / "local_llm.js" @@ -221,15 +228,16 @@ def test_L410_local_llm_js_has_isSupported_static(): # ── L5xx — Security ───────────────────────────────────────────────────────── + def test_L501_no_innerhtml_with_user_input(client): """Template must not use innerHTML with user-controlled data.""" html = _local_html(client) # Check for dangerous patterns: innerHTML += `${message}` etc. blocks = re.findall(r"innerHTML\s*\+=?\s*`([^`]*)`", html, re.DOTALL) for block in blocks: - assert "${message}" not in block, ( - "innerHTML template literal contains ${message} — XSS vulnerability" - ) + assert ( + "${message}" not in block + ), "innerHTML template literal contains ${message} — XSS vulnerability" def test_L502_uses_textcontent_for_messages(client): diff --git a/tests/dashboard/test_memory_api.py b/tests/dashboard/test_memory_api.py index e2eca4e..f577392 100644 --- a/tests/dashboard/test_memory_api.py +++ b/tests/dashboard/test_memory_api.py @@ -43,6 +43,7 @@ def test_edit_fact(client): # Extract a fact ID from the page (look for fact- pattern) import re + match = re.search(r'id="fact-([^"]+)"', page.text) if match: fact_id = match.group(1) @@ -61,6 +62,7 @@ def test_delete_fact(client): page = client.get("/memory") import re + match = re.search(r'id="fact-([^"]+)"', page.text) if match: fact_id = match.group(1) diff --git a/tests/dashboard/test_middleware_migration.py b/tests/dashboard/test_middleware_migration.py index 3974ebc..6f66a15 100644 --- a/tests/dashboard/test_middleware_migration.py +++ b/tests/dashboard/test_middleware_migration.py @@ -1,11 +1,14 @@ import pytest from fastapi.testclient import TestClient + from dashboard.app import app + @pytest.fixture def client(): return TestClient(app) + def test_security_headers_middleware_is_used(client): """Verify that SecurityHeadersMiddleware is used instead of the inline function.""" response = client.get("/") @@ -14,15 +17,18 @@ def test_security_headers_middleware_is_used(client): # SecurityHeadersMiddleware also sets Permissions-Policy assert "Permissions-Policy" in response.headers + def test_request_logging_middleware_is_used(client): """Verify that RequestLoggingMiddleware is used.""" response = client.get("/") # RequestLoggingMiddleware adds X-Correlation-ID to the response assert "X-Correlation-ID" in response.headers + def test_csrf_middleware_is_used(client): """Verify that CSRFMiddleware is used.""" from config import settings + original = settings.timmy_disable_csrf settings.timmy_disable_csrf = False try: diff --git a/tests/dashboard/test_mission_control.py b/tests/dashboard/test_mission_control.py index 50b8024..8ccf17c 100644 --- a/tests/dashboard/test_mission_control.py +++ b/tests/dashboard/test_mission_control.py @@ -1,8 +1,9 @@ """Tests for health and sovereignty endpoints.""" -import pytest from unittest.mock import patch +import pytest + class TestSovereigntyEndpoint: """Tests for /health/sovereignty endpoint.""" diff --git a/tests/dashboard/test_mobile_scenarios.py b/tests/dashboard/test_mobile_scenarios.py index b023cc1..f1dcdf9 100644 --- a/tests/dashboard/test_mobile_scenarios.py +++ b/tests/dashboard/test_mobile_scenarios.py @@ -17,7 +17,6 @@ import re from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch - # ── helpers ─────────────────────────────────────────────────────────────────── @@ -119,9 +118,7 @@ def test_M301_input_font_size_16px_in_mobile_query(): """iOS Safari zooms in when input font-size < 16px. Must be exactly 16px.""" css = _css() # The mobile media-query block must override to 16px - mobile_block_match = re.search( - r"@media\s*\(max-width:\s*768px\)(.*)", css, re.DOTALL - ) + mobile_block_match = re.search(r"@media\s*\(max-width:\s*768px\)(.*)", css, re.DOTALL) assert mobile_block_match, "Mobile media query not found" mobile_block = mobile_block_match.group(1) assert "font-size: 16px" in mobile_block @@ -224,9 +221,9 @@ def test_M601_airllm_agent_has_run_method(): """TimmyAirLLMAgent must expose run() so the dashboard route can call it.""" from timmy.backends import TimmyAirLLMAgent - assert hasattr(TimmyAirLLMAgent, "run"), ( - "TimmyAirLLMAgent is missing run() — dashboard will fail with AirLLM backend" - ) + assert hasattr( + TimmyAirLLMAgent, "run" + ), "TimmyAirLLMAgent is missing run() — dashboard will fail with AirLLM backend" def test_M602_airllm_run_returns_content_attribute(): @@ -277,7 +274,7 @@ def test_M603_airllm_run_updates_history(): def test_M604_airllm_print_response_delegates_to_run(): """print_response must use run() so both interfaces share one inference path.""" with patch("timmy.backends.is_apple_silicon", return_value=False): - from timmy.backends import TimmyAirLLMAgent, RunResult + from timmy.backends import RunResult, TimmyAirLLMAgent agent = TimmyAirLLMAgent(model_size="8b") @@ -300,9 +297,7 @@ def test_M605_health_status_passes_model_to_template(client): response = client.get("/health/status") # The default model is qwen2.5:14b — it should appear from settings assert response.status_code == 200 - assert ( - "qwen2.5" in response.text - ) # rendered via template variable, not hardcoded literal + assert "qwen2.5" in response.text # rendered via template variable, not hardcoded literal # ── M7xx — XSS prevention ───────────────────────────────────────────────────── @@ -310,24 +305,14 @@ def test_M605_health_status_passes_model_to_template(client): def _mobile_html() -> str: """Read the mobile template source.""" - path = ( - Path(__file__).parent.parent.parent - / "src" - / "dashboard" - / "templates" - / "mobile.html" - ) + path = Path(__file__).parent.parent.parent / "src" / "dashboard" / "templates" / "mobile.html" return path.read_text() def _swarm_live_html() -> str: """Read the swarm live template source.""" path = ( - Path(__file__).parent.parent.parent - / "src" - / "dashboard" - / "templates" - / "swarm_live.html" + Path(__file__).parent.parent.parent / "src" / "dashboard" / "templates" / "swarm_live.html" ) return path.read_text() @@ -337,9 +322,9 @@ def test_M701_mobile_chat_no_raw_message_interpolation(): html = _mobile_html() # The vulnerable pattern is `${message}` inside a template literal assigned to innerHTML # After the fix, message must only appear via textContent assignment - assert "textContent = message" in html or "textContent=message" in html, ( - "mobile.html still uses innerHTML + ${message} interpolation — XSS vulnerability" - ) + assert ( + "textContent = message" in html or "textContent=message" in html + ), "mobile.html still uses innerHTML + ${message} interpolation — XSS vulnerability" def test_M702_mobile_chat_user_input_not_in_innerhtml_template_literal(): @@ -348,25 +333,23 @@ def test_M702_mobile_chat_user_input_not_in_innerhtml_template_literal(): # Find all innerHTML += `...` blocks and verify none contain ${message} blocks = re.findall(r"innerHTML\s*\+=?\s*`([^`]*)`", html, re.DOTALL) for block in blocks: - assert "${message}" not in block, ( - "innerHTML template literal still contains ${message} — XSS vulnerability" - ) + assert ( + "${message}" not in block + ), "innerHTML template literal still contains ${message} — XSS vulnerability" def test_M703_swarm_live_agent_name_not_interpolated_in_innerhtml(): """swarm_live.html must not put ${agent.name} inside innerHTML template literals.""" html = _swarm_live_html() - blocks = re.findall( - r"innerHTML\s*=\s*agents\.map\([^;]+\)\.join\([^)]*\)", html, re.DOTALL - ) - assert len(blocks) == 0, ( - "swarm_live.html still uses innerHTML=agents.map(…) with interpolated agent data — XSS vulnerability" - ) + blocks = re.findall(r"innerHTML\s*=\s*agents\.map\([^;]+\)\.join\([^)]*\)", html, re.DOTALL) + assert ( + len(blocks) == 0 + ), "swarm_live.html still uses innerHTML=agents.map(…) with interpolated agent data — XSS vulnerability" def test_M704_swarm_live_uses_textcontent_for_agent_data(): """swarm_live.html must use textContent (not innerHTML) to set agent name/description.""" html = _swarm_live_html() - assert "textContent" in html, ( - "swarm_live.html does not use textContent — agent data may be raw-interpolated into DOM" - ) + assert ( + "textContent" in html + ), "swarm_live.html does not use textContent — agent data may be raw-interpolated into DOM" diff --git a/tests/dashboard/test_paperclip_routes.py b/tests/dashboard/test_paperclip_routes.py index 469c0f3..b165926 100644 --- a/tests/dashboard/test_paperclip_routes.py +++ b/tests/dashboard/test_paperclip_routes.py @@ -1,9 +1,8 @@ """Tests for the Paperclip API routes.""" -from unittest.mock import AsyncMock, patch, MagicMock - -from integrations.paperclip.models import PaperclipIssue, PaperclipAgent, PaperclipGoal +from unittest.mock import AsyncMock, MagicMock, patch +from integrations.paperclip.models import PaperclipAgent, PaperclipGoal, PaperclipIssue # ── GET /api/paperclip/status ──────────────────────────────────────────────── diff --git a/tests/dashboard/test_round4_fixes.py b/tests/dashboard/test_round4_fixes.py index dbe7deb..af4d7fe 100644 --- a/tests/dashboard/test_round4_fixes.py +++ b/tests/dashboard/test_round4_fixes.py @@ -5,13 +5,13 @@ agent tools on /tools, notification bell /api/notifications, and Ollama timeout parameter. """ -from unittest.mock import patch, MagicMock - +from unittest.mock import MagicMock, patch # --------------------------------------------------------------------------- # Fix 1: /calm no longer returns 500 # --------------------------------------------------------------------------- + def test_calm_page_returns_200(client): """GET /calm should render without error now that tables are created.""" response = client.get("/calm") @@ -29,6 +29,7 @@ def test_calm_morning_ritual_form_returns_200(client): # Fix 2: /api/queue/status endpoint exists # --------------------------------------------------------------------------- + def test_queue_status_returns_json(client): """GET /api/queue/status returns valid JSON instead of 404.""" response = client.get("/api/queue/status?assigned_to=default") @@ -50,10 +51,13 @@ def test_queue_status_default_idle(client): def test_queue_status_reflects_running_task(client): """Queue status shows working when a task is running.""" # Create a task and set it to running - create = client.post("/api/tasks", json={ - "title": "Running task", - "assigned_to": "default", - }) + create = client.post( + "/api/tasks", + json={ + "title": "Running task", + "assigned_to": "default", + }, + ) task_id = create.json()["id"] client.patch(f"/api/tasks/{task_id}/status", json={"status": "approved"}) client.patch(f"/api/tasks/{task_id}/status", json={"status": "running"}) @@ -68,6 +72,7 @@ def test_queue_status_reflects_running_task(client): # Fix 3: Bootstrap JS present in base.html (creative tabs) # --------------------------------------------------------------------------- + def test_base_html_has_bootstrap_js(client): """base.html should include bootstrap.bundle.min.js for tab switching.""" response = client.get("/") @@ -86,6 +91,7 @@ def test_creative_page_returns_200(client): # Fix 4: Swarm Live WebSocket sends initial state # --------------------------------------------------------------------------- + def test_swarm_live_page_returns_200(client): """GET /swarm/live renders the live dashboard page.""" response = client.get("/swarm/live") @@ -95,6 +101,7 @@ def test_swarm_live_page_returns_200(client): def test_swarm_live_websocket_sends_initial_state(client): """WebSocket at /swarm/live sends initial_state on connect.""" import json + with client.websocket_connect("/swarm/live") as ws: data = ws.receive_json() assert data["type"] == "initial_state" @@ -108,6 +115,7 @@ def test_swarm_live_websocket_sends_initial_state(client): # Fix 5: Agent tools populated on /tools page # --------------------------------------------------------------------------- + def test_tools_page_returns_200(client): """GET /tools loads successfully.""" response = client.get("/tools") @@ -120,6 +128,7 @@ def test_tools_page_shows_agent_capabilities(client): # The tools registry always has at least the built-in tools # If tools are registered, we should NOT see the empty message from timmy.tools import get_all_available_tools + if get_all_available_tools(): assert "No agents registered yet" not in response.text assert "Timmy" in response.text @@ -137,6 +146,7 @@ def test_tools_api_stats_returns_json(client): # Fix 6: Notification bell dropdown + /api/notifications # --------------------------------------------------------------------------- + def test_notifications_api_returns_json(client): """GET /api/notifications returns a JSON array.""" response = client.get("/api/notifications") @@ -156,14 +166,16 @@ def test_notifications_bell_dropdown_in_html(client): # Fix 0b: Ollama timeout parameter # --------------------------------------------------------------------------- + def test_create_timmy_uses_timeout_not_request_timeout(): """create_timmy() should pass timeout=300, not request_timeout.""" - with patch("timmy.agent.Ollama") as mock_ollama, \ - patch("timmy.agent.SqliteDb"), \ - patch("timmy.agent.Agent"): + with patch("timmy.agent.Ollama") as mock_ollama, patch("timmy.agent.SqliteDb"), patch( + "timmy.agent.Agent" + ): mock_ollama.return_value = MagicMock() from timmy.agent import create_timmy + try: create_timmy() except Exception: @@ -171,8 +183,7 @@ def test_create_timmy_uses_timeout_not_request_timeout(): if mock_ollama.called: _, kwargs = mock_ollama.call_args - assert "request_timeout" not in kwargs, \ - "Should use 'timeout', not 'request_timeout'" + assert "request_timeout" not in kwargs, "Should use 'timeout', not 'request_timeout'" assert kwargs.get("timeout") == 300 @@ -180,14 +191,18 @@ def test_create_timmy_uses_timeout_not_request_timeout(): # Task lifecycle e2e: create → approve → run → complete # --------------------------------------------------------------------------- + def test_task_full_lifecycle(client): """Test full task lifecycle: create → approve → running → completed.""" # Create - create = client.post("/api/tasks", json={ - "title": "Lifecycle test", - "priority": "high", - "assigned_to": "default", - }) + create = client.post( + "/api/tasks", + json={ + "title": "Lifecycle test", + "priority": "high", + "assigned_to": "default", + }, + ) assert create.status_code == 201 task_id = create.json()["id"] @@ -234,6 +249,7 @@ def test_task_full_lifecycle(client): # Pages that were broken — verify they return 200 # --------------------------------------------------------------------------- + def test_all_dashboard_pages_return_200(client): """Smoke test: all main dashboard routes return 200.""" pages = [ diff --git a/tests/dashboard/test_security_headers.py b/tests/dashboard/test_security_headers.py index 6d26666..7773d01 100644 --- a/tests/dashboard/test_security_headers.py +++ b/tests/dashboard/test_security_headers.py @@ -7,20 +7,20 @@ from fastapi.testclient import TestClient def test_security_headers_present(client: TestClient): """Test that security headers are present in all responses.""" response = client.get("/") - + # Check for security headers assert "X-Frame-Options" in response.headers assert response.headers["X-Frame-Options"] == "SAMEORIGIN" - + assert "X-Content-Type-Options" in response.headers assert response.headers["X-Content-Type-Options"] == "nosniff" - + assert "X-XSS-Protection" in response.headers assert response.headers["X-XSS-Protection"] == "1; mode=block" - + assert "Referrer-Policy" in response.headers assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" - + assert "Content-Security-Policy" in response.headers @@ -28,16 +28,16 @@ def test_csp_header_content(client: TestClient): """Test that Content Security Policy is properly configured.""" response = client.get("/") csp = response.headers.get("Content-Security-Policy", "") - + # Should restrict default-src to self assert "default-src 'self'" in csp - + # Should allow scripts from self and CDN assert "script-src 'self' 'unsafe-inline' 'unsafe-eval' cdn.jsdelivr.net" in csp - + # Should allow styles from self, CDN, and Google Fonts assert "style-src 'self' 'unsafe-inline' fonts.googleapis.com cdn.jsdelivr.net" in csp - + # Should restrict frame ancestors to self assert "frame-ancestors 'self'" in csp @@ -45,7 +45,7 @@ def test_csp_header_content(client: TestClient): def test_cors_headers_restricted(client: TestClient): """Test that CORS is properly restricted (not allow-origins: *).""" response = client.get("/") - + # Should not have overly permissive CORS # (The actual CORS headers depend on the origin of the request, # so we just verify the app doesn't crash with permissive settings) @@ -55,7 +55,7 @@ def test_cors_headers_restricted(client: TestClient): def test_health_endpoint_has_security_headers(client: TestClient): """Test that security headers are present on all endpoints.""" response = client.get("/health") - + assert "X-Frame-Options" in response.headers assert "X-Content-Type-Options" in response.headers assert "Content-Security-Policy" in response.headers diff --git a/tests/dashboard/test_tasks_api.py b/tests/dashboard/test_tasks_api.py index 314073f..8afc5a6 100644 --- a/tests/dashboard/test_tasks_api.py +++ b/tests/dashboard/test_tasks_api.py @@ -12,10 +12,13 @@ def test_tasks_page_returns_200(client): def test_create_task(client): """POST /api/tasks returns 201 with task JSON.""" - response = client.post("/api/tasks", json={ - "title": "Fix the memory bug", - "priority": "high", - }) + response = client.post( + "/api/tasks", + json={ + "title": "Fix the memory bug", + "priority": "high", + }, + ) assert response.status_code == 201 data = response.json() assert data["title"] == "Fix the memory bug" @@ -73,12 +76,15 @@ def test_create_task_missing_title_422(client): def test_create_task_via_form(client): """POST /tasks/create via form creates and returns task card HTML.""" - response = client.post("/tasks/create", data={ - "title": "Form task", - "description": "Created via form", - "priority": "normal", - "assigned_to": "", - }) + response = client.post( + "/tasks/create", + data={ + "title": "Form task", + "description": "Created via form", + "priority": "normal", + "assigned_to": "", + }, + ) assert response.status_code == 200 assert "Form task" in response.text diff --git a/tests/dashboard/test_work_orders_api.py b/tests/dashboard/test_work_orders_api.py index bde6441..45650a4 100644 --- a/tests/dashboard/test_work_orders_api.py +++ b/tests/dashboard/test_work_orders_api.py @@ -9,14 +9,17 @@ def test_work_orders_page_returns_200(client): def test_submit_work_order(client): """POST /work-orders/submit creates a work order.""" - response = client.post("/work-orders/submit", data={ - "title": "Fix the dashboard", - "description": "Details here", - "priority": "high", - "category": "bug", - "submitter": "dashboard", - "related_files": "src/app.py", - }) + response = client.post( + "/work-orders/submit", + data={ + "title": "Fix the dashboard", + "description": "Details here", + "priority": "high", + "category": "bug", + "submitter": "dashboard", + "related_files": "src/app.py", + }, + ) assert response.status_code == 200 @@ -34,12 +37,15 @@ def test_active_partial_returns_200(client): def test_submit_and_list_roundtrip(client): """Submitting a work order makes it appear in the pending section.""" - client.post("/work-orders/submit", data={ - "title": "Roundtrip WO", - "priority": "medium", - "category": "suggestion", - "submitter": "test", - }) + client.post( + "/work-orders/submit", + data={ + "title": "Roundtrip WO", + "priority": "medium", + "category": "suggestion", + "submitter": "test", + }, + ) response = client.get("/work-orders/queue/pending") assert "Roundtrip WO" in response.text @@ -47,15 +53,19 @@ def test_submit_and_list_roundtrip(client): def test_approve_work_order(client): """POST /work-orders/{id}/approve changes status.""" # Submit one first - client.post("/work-orders/submit", data={ - "title": "To approve", - "priority": "medium", - "category": "suggestion", - "submitter": "test", - }) + client.post( + "/work-orders/submit", + data={ + "title": "To approve", + "priority": "medium", + "category": "suggestion", + "submitter": "test", + }, + ) # Get ID from pending pending = client.get("/work-orders/queue/pending") import re + match = re.search(r'id="wo-([^"]+)"', pending.text) if match: wo_id = match.group(1) diff --git a/tests/e2e/test_agentic_chain.py b/tests/e2e/test_agentic_chain.py index 9057cc9..bd25191 100644 --- a/tests/e2e/test_agentic_chain.py +++ b/tests/e2e/test_agentic_chain.py @@ -4,8 +4,10 @@ These tests validate the full agentic loop pipeline: planning, execution, adaptation, and progress tracking. """ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import MagicMock, patch, AsyncMock + from timmy.agentic_loop import run_agentic_loop @@ -20,16 +22,19 @@ def _mock_run(content: str): async def test_multistep_chain_completes_all_steps(): """GREEN PATH: multi-step prompt executes all steps.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Search AI news\n2. Write to file\n3. Verify"), - _mock_run("Found 5 articles about AI in March 2026."), - _mock_run("Wrote summary to /tmp/ai_news.md"), - _mock_run("File exists, 15 lines."), - _mock_run("Searched, wrote, verified."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Search AI news\n2. Write to file\n3. Verify"), + _mock_run("Found 5 articles about AI in March 2026."), + _mock_run("Wrote summary to /tmp/ai_news.md"), + _mock_run("File exists, 15 lines."), + _mock_run("Searched, wrote, verified."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Search AI news and write summary to file") assert result.status == "completed" @@ -41,17 +46,20 @@ async def test_multistep_chain_completes_all_steps(): async def test_multistep_chain_adapts_on_failure(): """Step failure -> model adapts -> continues.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Read config\n2. Update setting\n3. Verify"), - _mock_run("Config: timeout=30"), - Exception("Permission denied"), - _mock_run("Adapted: wrote to ~/config.yaml instead"), - _mock_run("Verified: timeout=60"), - _mock_run("Updated config. Used ~/config.yaml due to permissions."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Read config\n2. Update setting\n3. Verify"), + _mock_run("Config: timeout=30"), + Exception("Permission denied"), + _mock_run("Adapted: wrote to ~/config.yaml instead"), + _mock_run("Verified: timeout=60"), + _mock_run("Updated config. Used ~/config.yaml due to permissions."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Update config timeout to 60") assert result.status == "completed" @@ -62,15 +70,18 @@ async def test_multistep_chain_adapts_on_failure(): async def test_max_steps_enforced(): """Loop stops at max_steps.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. A\n2. B\n3. C\n4. D\n5. E"), - _mock_run("A done"), - _mock_run("B done"), - _mock_run("Completed 2 of 5 steps."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. A\n2. B\n3. C\n4. D\n5. E"), + _mock_run("A done"), + _mock_run("B done"), + _mock_run("Completed 2 of 5 steps."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Do 5 things", max_steps=2) assert len(result.steps) == 2 @@ -86,15 +97,18 @@ async def test_progress_events_fire(): events.append((step, total)) mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Do A\n2. Do B"), - _mock_run("A done"), - _mock_run("B done"), - _mock_run("All done"), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Do A\n2. Do B"), + _mock_run("A done"), + _mock_run("B done"), + _mock_run("All done"), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): await run_agentic_loop("Do A and B", on_progress=on_progress) assert len(events) == 2 diff --git a/tests/e2e/test_ollama_integration.py b/tests/e2e/test_ollama_integration.py index a4ef0ae..b11aecc 100644 --- a/tests/e2e/test_ollama_integration.py +++ b/tests/e2e/test_ollama_integration.py @@ -4,17 +4,19 @@ These tests verify that Ollama models are correctly loaded, Timmy can interact with them, and fallback mechanisms work as expected. """ +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock + from config import settings @pytest.mark.asyncio async def test_ollama_connection(): """Test that we can connect to Ollama and retrieve available models.""" - import urllib.request import json - + import urllib.request + try: url = settings.ollama_url.replace("localhost", "127.0.0.1") req = urllib.request.Request( @@ -33,15 +35,15 @@ async def test_ollama_connection(): @pytest.mark.asyncio async def test_model_fallback_chain(): """Test that the model fallback chain works correctly.""" - from timmy.agent import _resolve_model_with_fallback, DEFAULT_MODEL_FALLBACKS - + from timmy.agent import DEFAULT_MODEL_FALLBACKS, _resolve_model_with_fallback + # Test with a non-existent model model, is_fallback = _resolve_model_with_fallback( requested_model="nonexistent-model", require_vision=False, auto_pull=False, ) - + # When a model doesn't exist and auto_pull=False, the system falls back to an available model # or the last resort (the requested model itself if nothing else is available). # In tests, if no models are available in the mock environment, it might return the requested model. @@ -56,7 +58,7 @@ async def test_model_fallback_chain(): async def test_timmy_agent_with_available_model(): """Test that Timmy agent can be created with an available model.""" from timmy.agent import create_timmy - + try: agent = create_timmy(db_file=":memory:") assert agent is not None, "Agent should be created" @@ -70,13 +72,15 @@ async def test_timmy_agent_with_available_model(): async def test_timmy_chat_with_simple_query(): """Test that Timmy can respond to a simple chat query.""" from timmy.session import chat - + try: response = chat("Hello, who are you?") assert response is not None, "Response should not be None" assert isinstance(response, str), "Response should be a string" assert len(response) > 0, "Response should not be empty" - assert "Timmy" in response or "agent" in response.lower(), "Response should mention Timmy or agent" + assert ( + "Timmy" in response or "agent" in response.lower() + ), "Response should mention Timmy or agent" except Exception as e: pytest.skip(f"Chat failed: {e}") @@ -85,15 +89,17 @@ async def test_timmy_chat_with_simple_query(): async def test_model_supports_tools(): """Test the model tool support detection.""" from timmy.agent import _model_supports_tools - + # Small models should not support tools assert _model_supports_tools("llama3.2") == False, "llama3.2 should not support tools" assert _model_supports_tools("llama3.2:3b") == False, "llama3.2:3b should not support tools" - + # Larger models should support tools assert _model_supports_tools("llama3.1") == True, "llama3.1 should support tools" - assert _model_supports_tools("llama3.1:8b-instruct") == True, "llama3.1:8b-instruct should support tools" - + assert ( + _model_supports_tools("llama3.1:8b-instruct") == True + ), "llama3.1:8b-instruct should support tools" + # Unknown models default to True assert _model_supports_tools("unknown-model") == True, "Unknown models should default to True" @@ -102,10 +108,10 @@ async def test_model_supports_tools(): async def test_system_prompt_selection(): """Test that the correct system prompt is selected based on tool capability.""" from timmy.prompts import get_system_prompt - + prompt_with_tools = get_system_prompt(tools_enabled=True) prompt_without_tools = get_system_prompt(tools_enabled=False) - + assert prompt_with_tools is not None, "Prompt with tools should not be None" assert prompt_without_tools is not None, "Prompt without tools should not be None" @@ -121,7 +127,7 @@ async def test_system_prompt_selection(): async def test_ollama_model_availability_check(): """Test the Ollama model availability check function.""" from timmy.agent import _check_model_available - + try: # llama3.2 should be available (we pulled it earlier) result = _check_model_available("llama3.2") @@ -135,7 +141,7 @@ async def test_ollama_model_availability_check(): async def test_memory_system_initialization(): """Test that the memory system initializes correctly.""" from timmy.memory_system import memory_system - + context = memory_system.get_system_context() assert context is not None, "Memory context should not be None" assert isinstance(context, str), "Memory context should be a string" diff --git a/tests/fixtures/media.py b/tests/fixtures/media.py index da8508f..0218c68 100644 --- a/tests/fixtures/media.py +++ b/tests/fixtures/media.py @@ -12,16 +12,15 @@ from pathlib import Path import numpy as np from PIL import Image, ImageDraw - # ── Color palettes for visual variety ───────────────────────────────────────── SCENE_COLORS = [ - (30, 60, 120), # dark blue — "night sky" - (200, 100, 30), # warm orange — "sunrise" - (50, 150, 50), # forest green — "mountain forest" - (20, 120, 180), # teal blue — "river" - (180, 60, 60), # crimson — "sunset" - (40, 40, 80), # deep purple — "twilight" + (30, 60, 120), # dark blue — "night sky" + (200, 100, 30), # warm orange — "sunrise" + (50, 150, 50), # forest green — "mountain forest" + (20, 120, 180), # teal blue — "river" + (180, 60, 60), # crimson — "sunset" + (40, 40, 80), # deep purple — "twilight" ] @@ -170,9 +169,14 @@ def make_scene_clips( c2 = SCENE_COLORS[(i + 1) % len(SCENE_COLORS)] path = output_dir / f"clip_{i:03d}.mp4" make_video_clip( - path, duration_seconds=duration_per_clip, fps=fps, - width=width, height=height, - color_start=c1, color_end=c2, label=label, + path, + duration_seconds=duration_per_clip, + fps=fps, + width=width, + height=height, + color_start=c1, + color_end=c2, + label=label, ) clips.append(path) return clips diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index ac5a7b3..da470a5 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -24,7 +24,7 @@ def is_server_running(): @pytest.fixture(scope="session") def live_server(): """Start the real Timmy server for E2E tests. - + Yields the base URL (http://localhost:8000). Kills the server after tests complete. """ @@ -33,27 +33,36 @@ def live_server(): print(f"\n📡 Using existing server at {DASHBOARD_URL}") yield DASHBOARD_URL return - + # Start server in subprocess print(f"\n🚀 Starting server on {DASHBOARD_URL}...") - + env = os.environ.copy() env["PYTHONPATH"] = "src" env["TIMMY_ENV"] = "test" # Use test config if available - + # Determine project root project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - + proc = subprocess.Popen( - [sys.executable, "-m", "uvicorn", "dashboard.app:app", - "--host", "127.0.0.1", "--port", "8000", - "--log-level", "warning"], + [ + sys.executable, + "-m", + "uvicorn", + "dashboard.app:app", + "--host", + "127.0.0.1", + "--port", + "8000", + "--log-level", + "warning", + ], cwd=project_root, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - + # Wait for server to start max_retries = 30 for i in range(max_retries): @@ -66,9 +75,9 @@ def live_server(): proc.terminate() proc.wait() raise RuntimeError("Server failed to start") - + yield DASHBOARD_URL - + # Cleanup print("\n🛑 Stopping server...") proc.terminate() @@ -83,11 +92,13 @@ def live_server(): @pytest.fixture def app_client(): """FastAPI test client for functional tests. - + Same as the 'client' fixture in root conftest but available here. """ from fastapi.testclient import TestClient + from dashboard.app import app + with TestClient(app) as c: yield c @@ -96,7 +107,9 @@ def app_client(): def timmy_runner(): """Typer CLI runner for timmy CLI tests.""" from typer.testing import CliRunner + from timmy.cli import app + yield CliRunner(), app @@ -104,17 +117,20 @@ def timmy_runner(): def serve_runner(): """Typer CLI runner for timmy-serve CLI tests.""" from typer.testing import CliRunner + from timmy_serve.cli import app + yield CliRunner(), app @pytest.fixture def docker_stack(): """Docker stack URL for container-level tests. - + Skips if FUNCTIONAL_DOCKER env var is not set to "1". """ import os + if os.environ.get("FUNCTIONAL_DOCKER") != "1": pytest.skip("Set FUNCTIONAL_DOCKER=1 to run Docker tests") yield "http://localhost:18000" @@ -124,8 +140,10 @@ def docker_stack(): def serve_client(): """FastAPI test client for timmy-serve app.""" pytest.importorskip("timmy_serve.app", reason="timmy_serve not available") - from timmy_serve.app import create_timmy_serve_app from fastapi.testclient import TestClient + + from timmy_serve.app import create_timmy_serve_app + app = create_timmy_serve_app() with TestClient(app) as c: yield c @@ -145,5 +163,3 @@ def pytest_addoption(parser): def headed_mode(request): """Check if --headed flag was passed.""" return request.config.getoption("--headed") - - diff --git a/tests/functional/test_cli.py b/tests/functional/test_cli.py index 17f3152..8f7527c 100644 --- a/tests/functional/test_cli.py +++ b/tests/functional/test_cli.py @@ -7,7 +7,6 @@ user scenario we want to verify. import pytest - # ── timmy CLI ───────────────────────────────────────────────────────────────── diff --git a/tests/functional/test_docker_swarm.py b/tests/functional/test_docker_swarm.py index 27934b4..b2ba995 100644 --- a/tests/functional/test_docker_swarm.py +++ b/tests/functional/test_docker_swarm.py @@ -25,7 +25,7 @@ COMPOSE_TEST = PROJECT_ROOT / "docker-compose.test.yml" pytestmark = pytest.mark.skipif( subprocess.run(["which", "docker"], capture_output=True).returncode != 0 or subprocess.run(["which", "docker-compose"], capture_output=True).returncode != 0, - reason="Docker or docker-compose not installed" + reason="Docker or docker-compose not installed", ) @@ -188,9 +188,7 @@ class TestDockerAgentSwarm: resp = httpx.get(f"{docker_stack}/swarm/agents", timeout=10) agents = resp.json()["agents"] # Should have at least the 3 agents we started (plus possibly Timmy and auto-spawned ones) - worker_count = sum( - 1 for a in agents if "Worker" in a["name"] or "TestWorker" in a["name"] - ) + worker_count = sum(1 for a in agents if "Worker" in a["name"] or "TestWorker" in a["name"]) assert worker_count >= 1 # At least some registered _compose("--profile", "agents", "down", timeout=30) diff --git a/tests/functional/test_fast_e2e.py b/tests/functional/test_fast_e2e.py index 8e5c02e..e1582e8 100644 --- a/tests/functional/test_fast_e2e.py +++ b/tests/functional/test_fast_e2e.py @@ -5,8 +5,8 @@ RUN: SELENIUM_UI=1 pytest tests/functional/test_fast_e2e.py -v import os -import pytest import httpx +import pytest try: from selenium import webdriver @@ -14,6 +14,7 @@ try: from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait + HAS_SELENIUM = True except ImportError: HAS_SELENIUM = False @@ -77,9 +78,10 @@ class TestAllPagesLoad: WebDriverWait(driver, 5).until( EC.presence_of_element_located((By.TAG_NAME, "body")) ) - + # Give a small extra buffer for animations (fadeUp in style.css) import time + time.sleep(0.5) # Verify page has expected content @@ -102,10 +104,9 @@ class TestAllFeaturesWork: # 1. Event Log - verify events display driver.get(f"{dashboard_url}/swarm/events") - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 5).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) import time + time.sleep(0.5) # Should have header and either events or empty state @@ -124,21 +125,15 @@ class TestAllFeaturesWork: # 2. Memory - verify search works driver.get(f"{dashboard_url}/memory?query=test") - WebDriverWait(driver, 3).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 3).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) # Should have search input - search = driver.find_elements( - By.CSS_SELECTOR, "input[type='search'], input[name='query']" - ) + search = driver.find_elements(By.CSS_SELECTOR, "input[type='search'], input[name='query']") assert search, "Memory page missing search input" # 3. Ledger - verify balance display driver.get(f"{dashboard_url}/lightning/ledger") - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 5).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) time.sleep(0.5) body = driver.find_element(By.TAG_NAME, "body").text @@ -155,10 +150,9 @@ class TestCascadeRouter: # Check router status page driver.get(f"{dashboard_url}/router/status") - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 5).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) import time + time.sleep(0.5) body = driver.find_element(By.TAG_NAME, "body").text @@ -172,9 +166,7 @@ class TestCascadeRouter: # Check nav has router link driver.get(dashboard_url) - WebDriverWait(driver, 3).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 3).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) nav_links = driver.find_elements(By.XPATH, "//a[contains(@href, '/router')]") assert nav_links, "Navigation missing router link" @@ -187,30 +179,23 @@ class TestUpgradeQueue: """Verify upgrade queue page loads with expected elements.""" driver.get(f"{dashboard_url}/self-modify/queue") - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 5).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) import time + time.sleep(0.5) body = driver.find_element(By.TAG_NAME, "body").text # Should have queue header - assert "upgrade" in body.lower() or "queue" in body.lower(), ( - "Missing queue header" - ) + assert "upgrade" in body.lower() or "queue" in body.lower(), "Missing queue header" # Should have pending section or empty state has_pending = "pending" in body.lower() or "no pending" in body.lower() assert has_pending, "Missing pending upgrades section" # Check for approve/reject buttons if upgrades exist - approve_btns = driver.find_elements( - By.XPATH, "//button[contains(text(), 'Approve')]" - ) - reject_btns = driver.find_elements( - By.XPATH, "//button[contains(text(), 'Reject')]" - ) + approve_btns = driver.find_elements(By.XPATH, "//button[contains(text(), 'Approve')]") + reject_btns = driver.find_elements(By.XPATH, "//button[contains(text(), 'Reject')]") # Either no upgrades (no buttons) or buttons exist # This is a soft check - page structure is valid either way @@ -223,18 +208,15 @@ class TestActivityFeed: """Verify swarm live page has activity feed elements.""" driver.get(f"{dashboard_url}/swarm/live") - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.TAG_NAME, "body")) - ) + WebDriverWait(driver, 5).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) import time + time.sleep(0.5) body = driver.find_element(By.TAG_NAME, "body").text # Should have live indicator or activity section - has_live = any( - x in body.lower() for x in ["live", "activity", "swarm", "agents", "tasks"] - ) + has_live = any(x in body.lower() for x in ["live", "activity", "swarm", "agents", "tasks"]) assert has_live, "Swarm live page missing content" # Check for WebSocket connection indicator (if implemented) @@ -261,9 +243,7 @@ class TestFastSmoke: for route in routes: try: - r = httpx.get( - f"{dashboard_url}{route}", timeout=3, follow_redirects=True - ) + r = httpx.get(f"{dashboard_url}{route}", timeout=3, follow_redirects=True) if r.status_code != 200: failures.append(f"{route}: {r.status_code}") except Exception as exc: diff --git a/tests/functional/test_ollama_chat.py b/tests/functional/test_ollama_chat.py index 0399085..09da23a 100644 --- a/tests/functional/test_ollama_chat.py +++ b/tests/functional/test_ollama_chat.py @@ -38,14 +38,20 @@ TEST_MODEL = os.environ.get("OLLAMA_TEST_MODEL", "qwen2.5:0.5b") def _compose(*args, timeout=120): """Run a docker compose command against the test compose file.""" cmd = [ - "docker", "compose", - "-f", str(COMPOSE_TEST), - "-p", "timmy-test", + "docker", + "compose", + "-f", + str(COMPOSE_TEST), + "-p", + "timmy-test", *args, ] return subprocess.run( - cmd, capture_output=True, text=True, - timeout=timeout, cwd=str(PROJECT_ROOT), + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=str(PROJECT_ROOT), ) @@ -67,10 +73,16 @@ def _pull_model(model: str, retries=3): for attempt in range(retries): result = subprocess.run( [ - "docker", "exec", "timmy-test-ollama", - "ollama", "pull", model, + "docker", + "exec", + "timmy-test-ollama", + "ollama", + "pull", + model, ], - capture_output=True, text=True, timeout=600, + capture_output=True, + text=True, + timeout=600, ) if result.returncode == 0: return True @@ -94,7 +106,10 @@ def ollama_stack(): # Verify Docker daemon docker_check = subprocess.run( - ["docker", "info"], capture_output=True, text=True, timeout=10, + ["docker", "info"], + capture_output=True, + text=True, + timeout=10, ) if docker_check.returncode != 0: pytest.skip(f"Docker daemon not available: {docker_check.stderr.strip()}") @@ -108,14 +123,24 @@ def ollama_stack(): } result = subprocess.run( [ - "docker", "compose", - "-f", str(COMPOSE_TEST), - "-p", "timmy-test", - "--profile", "ollama", - "up", "-d", "--build", "--wait", + "docker", + "compose", + "-f", + str(COMPOSE_TEST), + "-p", + "timmy-test", + "--profile", + "ollama", + "up", + "-d", + "--build", + "--wait", ], - capture_output=True, text=True, timeout=300, - cwd=str(PROJECT_ROOT), env=env, + capture_output=True, + text=True, + timeout=300, + cwd=str(PROJECT_ROOT), + env=env, ) if result.returncode != 0: pytest.fail(f"docker compose up failed:\n{result.stderr}") @@ -138,13 +163,20 @@ def ollama_stack(): # Teardown subprocess.run( [ - "docker", "compose", - "-f", str(COMPOSE_TEST), - "-p", "timmy-test", - "--profile", "ollama", - "down", "-v", + "docker", + "compose", + "-f", + str(COMPOSE_TEST), + "-p", + "timmy-test", + "--profile", + "ollama", + "down", + "-v", ], - capture_output=True, text=True, timeout=60, + capture_output=True, + text=True, + timeout=60, cwd=str(PROJECT_ROOT), ) @@ -161,9 +193,7 @@ class TestOllamaHealth: assert resp.status_code == 200 data = resp.json() services = data.get("services", {}) - assert services.get("ollama") == "up", ( - f"Expected ollama=up, got: {services}" - ) + assert services.get("ollama") == "up", f"Expected ollama=up, got: {services}" class TestOllamaChat: @@ -179,9 +209,9 @@ class TestOllamaChat: assert resp.status_code == 200 body = resp.text.lower() # The response should contain actual content, not an error fallback - assert "error" not in body or "hello" in body, ( - f"Expected LLM response, got error:\n{resp.text[:500]}" - ) + assert ( + "error" not in body or "hello" in body + ), f"Expected LLM response, got error:\n{resp.text[:500]}" def test_chat_history_contains_response(self, ollama_stack): """After chatting, history should include both user and agent messages.""" @@ -226,10 +256,16 @@ class TestOllamaDirectAPI: # Ollama isn't port-mapped, so we exec into the container result = subprocess.run( [ - "docker", "exec", "timmy-test-ollama", - "curl", "-sf", "http://localhost:11434/api/tags", + "docker", + "exec", + "timmy-test-ollama", + "curl", + "-sf", + "http://localhost:11434/api/tags", ], - capture_output=True, text=True, timeout=10, + capture_output=True, + text=True, + timeout=10, ) assert result.returncode == 0 assert TEST_MODEL.split(":")[0] in result.stdout diff --git a/tests/functional/test_setup_prod.py b/tests/functional/test_setup_prod.py index 591e851..fdcb5a0 100644 --- a/tests/functional/test_setup_prod.py +++ b/tests/functional/test_setup_prod.py @@ -1,9 +1,10 @@ import os -import subprocess import shutil -import pytest -from pathlib import Path +import subprocess import time +from pathlib import Path + +import pytest # Production-like paths for functional testing PROD_PROJECT_DIR = Path("/home/ubuntu/prod-sovereign-stack") @@ -15,29 +16,28 @@ pytestmark = pytest.mark.skipif( reason=f"Setup script not found at {SETUP_SCRIPT_PATH}", ) + @pytest.fixture(scope="module", autouse=True) def setup_prod_env(): """Ensure a clean environment and run the full installation.""" if PROD_PROJECT_DIR.exists(): shutil.rmtree(PROD_PROJECT_DIR) - + # Run the actual install command env = os.environ.copy() env["PROJECT_DIR"] = str(PROD_PROJECT_DIR) env["VAULT_DIR"] = str(PROD_VAULT_DIR) - + result = subprocess.run( - [str(SETUP_SCRIPT_PATH), "install"], - capture_output=True, - text=True, - env=env + [str(SETUP_SCRIPT_PATH), "install"], capture_output=True, text=True, env=env ) - + assert result.returncode == 0, f"Install failed: {result.stderr}" yield # Cleanup after all tests in module # shutil.rmtree(PROD_PROJECT_DIR) + def test_prod_directory_structure(): """Verify the directory structure matches production expectations.""" assert PROD_PROJECT_DIR.exists() @@ -47,12 +47,19 @@ def test_prod_directory_structure(): assert (PROD_PROJECT_DIR / "logs").exists() assert (PROD_PROJECT_DIR / "pids").exists() + def test_prod_paperclip_dependencies(): """Verify that Paperclip dependencies were actually installed (node_modules exists).""" node_modules = PROD_PROJECT_DIR / "paperclip/node_modules" assert node_modules.exists(), "Paperclip node_modules should exist after installation" # Check for a common package to ensure it's not just an empty dir - assert (node_modules / "typescript").exists() or (node_modules / "vite").exists() or (node_modules / "next").exists() or any(node_modules.iterdir()) + assert ( + (node_modules / "typescript").exists() + or (node_modules / "vite").exists() + or (node_modules / "next").exists() + or any(node_modules.iterdir()) + ) + def test_prod_openfang_config(): """Verify OpenFang agent configuration.""" @@ -63,14 +70,15 @@ def test_prod_openfang_config(): assert 'name = "hello-timmy"' in content assert 'model = "default"' in content + def test_prod_obsidian_vault_content(): """Verify the initial content of the Obsidian vault.""" hello_note = PROD_VAULT_DIR / "Hello World.md" soul_note = PROD_VAULT_DIR / "SOUL.md" - + assert hello_note.exists() assert soul_note.exists() - + with open(hello_note, "r") as f: content = f.read() assert "# Hello World" in content @@ -82,6 +90,7 @@ def test_prod_obsidian_vault_content(): assert "I am Timmy" in content assert "sovereign AI agent" in content + def test_prod_service_lifecycle(): """Verify that services can be started, checked, and stopped.""" env = os.environ.copy() @@ -90,39 +99,27 @@ def test_prod_service_lifecycle(): # Start services start_result = subprocess.run( - [str(SETUP_SCRIPT_PATH), "start"], - capture_output=True, - text=True, - env=env + [str(SETUP_SCRIPT_PATH), "start"], capture_output=True, text=True, env=env ) assert start_result.returncode == 0 - + # Wait a moment for processes to initialize time.sleep(2) - + # Check status status_result = subprocess.run( - [str(SETUP_SCRIPT_PATH), "status"], - capture_output=True, - text=True, - env=env + [str(SETUP_SCRIPT_PATH), "status"], capture_output=True, text=True, env=env ) assert "running" in status_result.stdout - + # Stop services stop_result = subprocess.run( - [str(SETUP_SCRIPT_PATH), "stop"], - capture_output=True, - text=True, - env=env + [str(SETUP_SCRIPT_PATH), "stop"], capture_output=True, text=True, env=env ) assert stop_result.returncode == 0 - + # Final status check final_status = subprocess.run( - [str(SETUP_SCRIPT_PATH), "status"], - capture_output=True, - text=True, - env=env + [str(SETUP_SCRIPT_PATH), "status"], capture_output=True, text=True, env=env ) assert "stopped" in final_status.stdout diff --git a/tests/functional/test_ui_selenium.py b/tests/functional/test_ui_selenium.py index 461ae4b..a38e56e 100644 --- a/tests/functional/test_ui_selenium.py +++ b/tests/functional/test_ui_selenium.py @@ -20,6 +20,7 @@ try: from selenium.webdriver.common.keys import Keys from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait + HAS_SELENIUM = True except ImportError: HAS_SELENIUM = False @@ -66,18 +67,14 @@ def _load_dashboard(driver): """Navigate to dashboard and wait for Timmy panel to load.""" driver.get(DASHBOARD_URL) WebDriverWait(driver, 15).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'TIMMY INTERFACE')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'TIMMY INTERFACE')]")) ) def _wait_for_sidebar(driver): """Wait for the agent sidebar to finish its HTMX load.""" WebDriverWait(driver, 15).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'SWARM AGENTS')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'SWARM AGENTS')]")) ) @@ -100,8 +97,7 @@ def _send_chat_and_wait(driver, message): # Wait for a NEW agent response (not one from a prior test) WebDriverWait(driver, 30).until( - lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) - > existing + lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) > existing ) return existing @@ -144,9 +140,7 @@ class TestPageLoad: def test_health_panel_loads(self, driver): _load_dashboard(driver) WebDriverWait(driver, 10).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'SYSTEM HEALTH')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'SYSTEM HEALTH')]")) ) @@ -168,9 +162,7 @@ class TestChatInteraction: lambda d: d.execute_script("return document.readyState") == "complete" ) - existing_agents = len( - driver.find_elements(By.CSS_SELECTOR, ".chat-message.agent") - ) + existing_agents = len(driver.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) inp = driver.find_element(By.CSS_SELECTOR, "input[name='message']") inp.send_keys("hello from selenium") @@ -183,8 +175,7 @@ class TestChatInteraction: # 2. Agent response arrives WebDriverWait(driver, 30).until( - lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) - > existing_agents + lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) > existing_agents ) # 3. Input cleared (regression test) @@ -195,12 +186,8 @@ class TestChatInteraction: # 4. Chat scrolled to bottom (regression test) chat_log = driver.find_element(By.ID, "chat-log") scroll_top = driver.execute_script("return arguments[0].scrollTop", chat_log) - scroll_height = driver.execute_script( - "return arguments[0].scrollHeight", chat_log - ) - client_height = driver.execute_script( - "return arguments[0].clientHeight", chat_log - ) + scroll_height = driver.execute_script("return arguments[0].scrollHeight", chat_log) + client_height = driver.execute_script("return arguments[0].clientHeight", chat_log) if scroll_height > client_height: gap = scroll_height - scroll_top - client_height @@ -217,18 +204,14 @@ class TestTaskPanel: """Task panel loads correctly when navigated to directly.""" driver.get(f"{DASHBOARD_URL}/swarm/tasks/panel") WebDriverWait(driver, 10).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'CREATE TASK')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'CREATE TASK')]")) ) def test_task_panel_has_form(self, driver): """Task creation panel has description and agent fields.""" driver.get(f"{DASHBOARD_URL}/swarm/tasks/panel") WebDriverWait(driver, 10).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'CREATE TASK')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'CREATE TASK')]")) ) driver.find_element(By.CSS_SELECTOR, "textarea[name='description']") @@ -249,9 +232,7 @@ class TestTaskPanel: task_btn.click() WebDriverWait(driver, 10).until( - EC.presence_of_element_located( - (By.XPATH, "//*[contains(text(), 'CREATE TASK')]") - ) + EC.presence_of_element_located((By.XPATH, "//*[contains(text(), 'CREATE TASK')]")) ) diff --git a/tests/infrastructure/test_error_capture.py b/tests/infrastructure/test_error_capture.py index 98c701e..3236c5a 100644 --- a/tests/infrastructure/test_error_capture.py +++ b/tests/infrastructure/test_error_capture.py @@ -1,15 +1,16 @@ """Tests for infrastructure.error_capture module.""" -import pytest -from unittest.mock import patch, MagicMock from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest from infrastructure.error_capture import ( - _stack_hash, - _is_duplicate, - _get_git_context, - capture_error, _dedup_cache, + _get_git_context, + _is_duplicate, + _stack_hash, + capture_error, ) diff --git a/tests/infrastructure/test_event_broadcaster.py b/tests/infrastructure/test_event_broadcaster.py index c03631d..de5cf85 100644 --- a/tests/infrastructure/test_event_broadcaster.py +++ b/tests/infrastructure/test_event_broadcaster.py @@ -1,23 +1,24 @@ """Tests for the event broadcaster (infrastructure.events.broadcaster).""" -import pytest -from unittest.mock import AsyncMock, MagicMock, patch from dataclasses import dataclass from enum import Enum +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from infrastructure.events.broadcaster import ( - EventBroadcaster, - event_broadcaster, - get_event_icon, - get_event_label, - format_event_for_display, EVENT_ICONS, EVENT_LABELS, + EventBroadcaster, + event_broadcaster, + format_event_for_display, + get_event_icon, + get_event_label, ) - # ── Fake EventLogEntry for testing ────────────────────────────────────────── + class FakeEventType(Enum): TASK_CREATED = "task.created" TASK_ASSIGNED = "task.assigned" diff --git a/tests/infrastructure/test_event_bus.py b/tests/infrastructure/test_event_bus.py index 24296c3..6b88b37 100644 --- a/tests/infrastructure/test_event_bus.py +++ b/tests/infrastructure/test_event_bus.py @@ -1,8 +1,10 @@ """Tests for the async event bus (infrastructure.events.bus).""" import asyncio + import pytest -from infrastructure.events.bus import EventBus, Event, emit, on, event_bus + +from infrastructure.events.bus import Event, EventBus, emit, event_bus, on class TestEvent: diff --git a/tests/infrastructure/test_functional_router.py b/tests/infrastructure/test_functional_router.py index 4b0199e..99babcb 100644 --- a/tests/infrastructure/test_functional_router.py +++ b/tests/infrastructure/test_functional_router.py @@ -10,17 +10,17 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from infrastructure.router.cascade import CascadeRouter, Provider, ProviderStatus, CircuitState +from infrastructure.router.cascade import CascadeRouter, CircuitState, Provider, ProviderStatus class TestCascadeRouterFunctional: """Functional tests for Cascade Router with mocked providers.""" - + @pytest.fixture def router(self): """Create a router with no config file.""" return CascadeRouter(config_path=Path("/nonexistent")) - + @pytest.fixture def mock_healthy_provider(self): """Create a mock healthy provider.""" @@ -32,7 +32,7 @@ class TestCascadeRouterFunctional: models=[{"name": "test-model", "default": True}], ) return provider - + @pytest.fixture def mock_failing_provider(self): """Create a mock failing provider.""" @@ -44,12 +44,12 @@ class TestCascadeRouterFunctional: models=[{"name": "test-model", "default": True}], ) return provider - + @pytest.mark.asyncio async def test_successful_completion_single_provider(self, router, mock_healthy_provider): """Test successful completion with a single working provider.""" router.providers = [mock_healthy_provider] - + # Mock the provider's call method with patch.object(router, "_try_provider") as mock_try: mock_try.return_value = { @@ -57,16 +57,16 @@ class TestCascadeRouterFunctional: "model": "test-model", "latency_ms": 100.0, } - + result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) - + assert result["content"] == "Hello, world!" assert result["provider"] == "test-healthy" assert result["model"] == "test-model" assert result["latency_ms"] == 100.0 - + @pytest.mark.asyncio async def test_failover_to_second_provider(self, router): """Test failover when first provider fails.""" @@ -85,23 +85,23 @@ class TestCascadeRouterFunctional: models=[{"name": "model", "default": True}], ) router.providers = [provider1, provider2] - + call_count = [0] - + async def side_effect(*args, **kwargs): call_count[0] += 1 if call_count[0] <= router.config.max_retries_per_provider: raise RuntimeError("Connection failed") return {"content": "Backup works!", "model": "model"} - + with patch.object(router, "_try_provider", side_effect=side_effect): result = await router.complete( messages=[{"role": "user", "content": "Hi"}], ) - + assert result["content"] == "Backup works!" assert result["provider"] == "backup" - + @pytest.mark.asyncio async def test_all_providers_fail_raises_error(self, router): """Test that RuntimeError is raised when all providers fail.""" @@ -113,15 +113,15 @@ class TestCascadeRouterFunctional: models=[{"name": "model", "default": True}], ) router.providers = [provider] - + with patch.object(router, "_try_provider") as mock_try: mock_try.side_effect = RuntimeError("Always fails") - + with pytest.raises(RuntimeError) as exc_info: await router.complete(messages=[{"role": "user", "content": "Hi"}]) - + assert "All providers failed" in str(exc_info.value) - + @pytest.mark.asyncio async def test_circuit_breaker_opens_after_failures(self, router): """Test circuit breaker opens after threshold failures.""" @@ -134,14 +134,14 @@ class TestCascadeRouterFunctional: ) router.providers = [provider] router.config.circuit_breaker_failure_threshold = 3 - + # Record 3 failures for _ in range(3): router._record_failure(provider) - + assert provider.circuit_state == CircuitState.OPEN assert provider.status == ProviderStatus.UNHEALTHY - + def test_metrics_tracking(self, router): """Test that metrics are tracked correctly.""" provider = Provider( @@ -151,14 +151,14 @@ class TestCascadeRouterFunctional: priority=1, ) router.providers = [provider] - + # Record some successes and failures router._record_success(provider, 100.0) router._record_success(provider, 200.0) router._record_failure(provider) - + metrics = router.get_metrics() - + assert len(metrics["providers"]) == 1 p_metrics = metrics["providers"][0] assert p_metrics["metrics"]["total_requests"] == 3 @@ -166,7 +166,7 @@ class TestCascadeRouterFunctional: assert p_metrics["metrics"]["failed"] == 1 # Average latency is over ALL requests (including failures with 0 latency) assert p_metrics["metrics"]["avg_latency_ms"] == 100.0 # (100+200+0)/3 - + @pytest.mark.asyncio async def test_skips_disabled_providers(self, router): """Test that disabled providers are skipped.""" @@ -185,23 +185,23 @@ class TestCascadeRouterFunctional: models=[{"name": "model", "default": True}], ) router.providers = [disabled, enabled] - + # The router should try enabled provider with patch.object(router, "_try_provider") as mock_try: mock_try.return_value = {"content": "Success", "model": "model"} - + result = await router.complete(messages=[{"role": "user", "content": "Hi"}]) - + assert result["provider"] == "enabled" class TestProviderAvailability: """Test provider availability checking.""" - + @pytest.fixture def router(self): return CascadeRouter(config_path=Path("/nonexistent")) - + def test_openai_available_with_key(self, router): """Test OpenAI provider is available when API key is set.""" provider = Provider( @@ -211,9 +211,9 @@ class TestProviderAvailability: priority=1, api_key="sk-test123", ) - + assert router._check_provider_available(provider) is True - + def test_openai_unavailable_without_key(self, router): """Test OpenAI provider is unavailable without API key.""" provider = Provider( @@ -223,9 +223,9 @@ class TestProviderAvailability: priority=1, api_key=None, ) - + assert router._check_provider_available(provider) is False - + def test_anthropic_available_with_key(self, router): """Test Anthropic provider is available when API key is set.""" provider = Provider( @@ -235,17 +235,17 @@ class TestProviderAvailability: priority=1, api_key="sk-test123", ) - + assert router._check_provider_available(provider) is True class TestRouterConfigLoading: """Test router configuration loading.""" - + def test_loads_timeout_from_config(self, tmp_path): """Test that timeout is loaded from config.""" import yaml - + config = { "cascade": { "timeout_seconds": 60, @@ -253,18 +253,18 @@ class TestRouterConfigLoading: }, "providers": [], } - + config_path = tmp_path / "providers.yaml" config_path.write_text(yaml.dump(config)) - + router = CascadeRouter(config_path=config_path) - + assert router.config.timeout_seconds == 60 assert router.config.max_retries_per_provider == 3 - + def test_uses_defaults_without_config(self): """Test that defaults are used when config file doesn't exist.""" router = CascadeRouter(config_path=Path("/nonexistent")) - + assert router.config.timeout_seconds == 30 assert router.config.max_retries_per_provider == 2 diff --git a/tests/infrastructure/test_model_registry.py b/tests/infrastructure/test_model_registry.py index d8e70d4..0235582 100644 --- a/tests/infrastructure/test_model_registry.py +++ b/tests/infrastructure/test_model_registry.py @@ -6,12 +6,7 @@ from unittest.mock import patch import pytest -from infrastructure.models.registry import ( - CustomModel, - ModelFormat, - ModelRegistry, - ModelRole, -) +from infrastructure.models.registry import CustomModel, ModelFormat, ModelRegistry, ModelRole @pytest.fixture @@ -199,9 +194,7 @@ class TestCustomModelDataclass: """Test CustomModel construction.""" def test_default_registered_at(self): - model = CustomModel( - name="test", format=ModelFormat.OLLAMA, path="test" - ) + model = CustomModel(name="test", format=ModelFormat.OLLAMA, path="test") assert model.registered_at != "" def test_model_roles(self): diff --git a/tests/infrastructure/test_models_api.py b/tests/infrastructure/test_models_api.py index 212c513..3b2e654 100644 --- a/tests/infrastructure/test_models_api.py +++ b/tests/infrastructure/test_models_api.py @@ -1,15 +1,10 @@ """Tests for the custom models API routes.""" -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest -from infrastructure.models.registry import ( - CustomModel, - ModelFormat, - ModelRegistry, - ModelRole, -) +from infrastructure.models.registry import CustomModel, ModelFormat, ModelRegistry, ModelRole @pytest.fixture @@ -27,9 +22,7 @@ class TestModelsAPIList: def test_list_models_empty(self, client, tmp_path): db = tmp_path / "api.db" with patch("infrastructure.models.registry.DB_PATH", db): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.list_models.return_value = [] resp = client.get("/api/v1/models") assert resp.status_code == 200 @@ -44,9 +37,7 @@ class TestModelsAPIList: path="llama3.2", role=ModelRole.GENERAL, ) - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.list_models.return_value = [model] resp = client.get("/api/v1/models") assert resp.status_code == 200 @@ -59,9 +50,7 @@ class TestModelsAPIRegister: """Test model registration via the API.""" def test_register_ollama_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.register.return_value = CustomModel( name="my-model", format=ModelFormat.OLLAMA, @@ -111,17 +100,13 @@ class TestModelsAPIDelete: """Test model deletion via the API.""" def test_delete_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.unregister.return_value = True resp = client.delete("/api/v1/models/my-model") assert resp.status_code == 200 def test_delete_nonexistent(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.unregister.return_value = False resp = client.delete("/api/v1/models/nonexistent") assert resp.status_code == 404 @@ -137,18 +122,14 @@ class TestModelsAPIGet: path="llama3.2", role=ModelRole.GENERAL, ) - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get.return_value = model resp = client.get("/api/v1/models/my-model") assert resp.status_code == 200 assert resp.json()["name"] == "my-model" def test_get_nonexistent(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get.return_value = None resp = client.get("/api/v1/models/nonexistent") assert resp.status_code == 404 @@ -158,9 +139,7 @@ class TestModelsAPIAssignments: """Test agent model assignment endpoints.""" def test_assign_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.assign_model.return_value = True resp = client.post( "/api/v1/models/assignments", @@ -169,9 +148,7 @@ class TestModelsAPIAssignments: assert resp.status_code == 200 def test_assign_nonexistent_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.assign_model.return_value = False resp = client.post( "/api/v1/models/assignments", @@ -180,25 +157,19 @@ class TestModelsAPIAssignments: assert resp.status_code == 404 def test_unassign_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.unassign_model.return_value = True resp = client.delete("/api/v1/models/assignments/agent-1") assert resp.status_code == 200 def test_unassign_nonexistent(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.unassign_model.return_value = False resp = client.delete("/api/v1/models/assignments/nonexistent") assert resp.status_code == 404 def test_list_assignments(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get_agent_assignments.return_value = { "agent-1": "model-a", "agent-2": "model-b", @@ -219,9 +190,7 @@ class TestModelsAPIRoles: path="deepseek-r1:1.5b", role=ModelRole.REWARD, ) - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get_reward_model.return_value = model resp = client.get("/api/v1/models/roles/reward") assert resp.status_code == 200 @@ -229,18 +198,14 @@ class TestModelsAPIRoles: assert data["reward_model"]["name"] == "reward-m" def test_get_reward_model_none(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get_reward_model.return_value = None resp = client.get("/api/v1/models/roles/reward") assert resp.status_code == 200 assert resp.json()["reward_model"] is None def test_get_teacher_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.get_teacher_model.return_value = None resp = client.get("/api/v1/models/roles/teacher") assert resp.status_code == 200 @@ -251,9 +216,7 @@ class TestModelsAPISetActive: """Test enable/disable model endpoint.""" def test_enable_model(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.set_active.return_value = True resp = client.patch( "/api/v1/models/my-model/active", @@ -262,9 +225,7 @@ class TestModelsAPISetActive: assert resp.status_code == 200 def test_disable_nonexistent(self, client): - with patch( - "dashboard.routes.models.model_registry" - ) as mock_reg: + with patch("dashboard.routes.models.model_registry") as mock_reg: mock_reg.set_active.return_value = False resp = client.patch( "/api/v1/models/nonexistent/active", diff --git a/tests/infrastructure/test_router_api.py b/tests/infrastructure/test_router_api.py index d9c9083..a7a5ec4 100644 --- a/tests/infrastructure/test_router_api.py +++ b/tests/infrastructure/test_router_api.py @@ -5,14 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient +from infrastructure.router.api import get_cascade_router, router from infrastructure.router.cascade import CircuitState, Provider, ProviderStatus -from infrastructure.router.api import router, get_cascade_router def make_mock_router(): """Create a mock CascadeRouter.""" router = MagicMock() - + # Create test providers provider1 = Provider( name="ollama-local", @@ -24,7 +24,7 @@ def make_mock_router(): ) provider1.status = ProviderStatus.HEALTHY provider1.circuit_state = CircuitState.CLOSED - + provider2 = Provider( name="openai-backup", type="openai", @@ -35,12 +35,12 @@ def make_mock_router(): ) provider2.status = ProviderStatus.DEGRADED provider2.circuit_state = CircuitState.CLOSED - + router.providers = [provider1, provider2] router.config.timeout_seconds = 30 router.config.max_retries_per_provider = 2 router.config.circuit_breaker_failure_threshold = 5 - + return router @@ -48,74 +48,87 @@ def make_mock_router(): def mock_router(): """Create test client with mocked router.""" from fastapi import FastAPI - + app = FastAPI() app.include_router(router) - + # Create mock router mock = make_mock_router() - + # Override dependency async def mock_get_router(): return mock - + app.dependency_overrides[get_cascade_router] = mock_get_router - + client = TestClient(app) return client, mock class TestCompleteEndpoint: """Test /complete endpoint.""" - + def test_complete_success(self, mock_router): """Test successful completion.""" client, mock = mock_router - mock.complete = AsyncMock(return_value={ - "content": "Hello! How can I help?", - "provider": "ollama-local", - "model": "llama3.2", - "latency_ms": 250.5, - }) - - response = client.post("/api/v1/router/complete", json={ - "messages": [{"role": "user", "content": "Hi"}], - "model": "llama3.2", - "temperature": 0.7, - }) - + mock.complete = AsyncMock( + return_value={ + "content": "Hello! How can I help?", + "provider": "ollama-local", + "model": "llama3.2", + "latency_ms": 250.5, + } + ) + + response = client.post( + "/api/v1/router/complete", + json={ + "messages": [{"role": "user", "content": "Hi"}], + "model": "llama3.2", + "temperature": 0.7, + }, + ) + assert response.status_code == 200 data = response.json() assert data["content"] == "Hello! How can I help?" assert data["provider"] == "ollama-local" assert data["latency_ms"] == 250.5 - + def test_complete_all_providers_fail(self, mock_router): """Test 503 when all providers fail.""" client, mock = mock_router mock.complete = AsyncMock(side_effect=RuntimeError("All providers failed")) - - response = client.post("/api/v1/router/complete", json={ - "messages": [{"role": "user", "content": "Hi"}], - }) - + + response = client.post( + "/api/v1/router/complete", + json={ + "messages": [{"role": "user", "content": "Hi"}], + }, + ) + assert response.status_code == 503 assert "All providers failed" in response.json()["detail"] - + def test_complete_default_temperature(self, mock_router): """Test completion with default temperature.""" client, mock = mock_router - mock.complete = AsyncMock(return_value={ - "content": "Response", - "provider": "ollama-local", - "model": "llama3.2", - "latency_ms": 100.0, - }) - - response = client.post("/api/v1/router/complete", json={ - "messages": [{"role": "user", "content": "Hi"}], - }) - + mock.complete = AsyncMock( + return_value={ + "content": "Response", + "provider": "ollama-local", + "model": "llama3.2", + "latency_ms": 100.0, + } + ) + + response = client.post( + "/api/v1/router/complete", + json={ + "messages": [{"role": "user", "content": "Hi"}], + }, + ) + assert response.status_code == 200 # Check that complete was called with correct temperature call_args = mock.complete.call_args @@ -124,35 +137,37 @@ class TestCompleteEndpoint: class TestStatusEndpoint: """Test /status endpoint.""" - + def test_get_status(self, mock_router): """Test getting router status.""" client, mock = mock_router - mock.get_status = MagicMock(return_value={ - "total_providers": 2, - "healthy_providers": 1, - "degraded_providers": 1, - "unhealthy_providers": 0, - "providers": [ - { - "name": "ollama-local", - "type": "ollama", - "status": "healthy", - "priority": 1, - "default_model": "llama3.2", - }, - { - "name": "openai-backup", - "type": "openai", - "status": "degraded", - "priority": 2, - "default_model": "gpt-4o-mini", - }, - ], - }) - + mock.get_status = MagicMock( + return_value={ + "total_providers": 2, + "healthy_providers": 1, + "degraded_providers": 1, + "unhealthy_providers": 0, + "providers": [ + { + "name": "ollama-local", + "type": "ollama", + "status": "healthy", + "priority": 1, + "default_model": "llama3.2", + }, + { + "name": "openai-backup", + "type": "openai", + "status": "degraded", + "priority": 2, + "default_model": "gpt-4o-mini", + }, + ], + } + ) + response = client.get("/api/v1/router/status") - + assert response.status_code == 200 data = response.json() assert data["total_providers"] == 2 @@ -163,31 +178,33 @@ class TestStatusEndpoint: class TestMetricsEndpoint: """Test /metrics endpoint.""" - + def test_get_metrics(self, mock_router): """Test getting detailed metrics.""" client, mock = mock_router # Setup the mock return value on the mock_router object - mock.get_metrics = MagicMock(return_value={ - "providers": [ - { - "name": "ollama-local", - "type": "ollama", - "status": "healthy", - "circuit_state": "closed", - "metrics": { - "total_requests": 100, - "successful": 98, - "failed": 2, - "error_rate": 0.02, - "avg_latency_ms": 150.5, + mock.get_metrics = MagicMock( + return_value={ + "providers": [ + { + "name": "ollama-local", + "type": "ollama", + "status": "healthy", + "circuit_state": "closed", + "metrics": { + "total_requests": 100, + "successful": 98, + "failed": 2, + "error_rate": 0.02, + "avg_latency_ms": 150.5, + }, }, - }, - ], - }) - + ], + } + ) + response = client.get("/api/v1/router/metrics") - + assert response.status_code == 200 data = response.json() assert len(data["providers"]) == 1 @@ -199,17 +216,17 @@ class TestMetricsEndpoint: class TestListProvidersEndpoint: """Test /providers endpoint.""" - + def test_list_providers(self, mock_router): """Test listing all providers.""" client, mock = mock_router - + response = client.get("/api/v1/router/providers") - + assert response.status_code == 200 data = response.json() assert len(data) == 2 - + # Check first provider assert data[0]["name"] == "ollama-local" assert data[0]["type"] == "ollama" @@ -221,40 +238,38 @@ class TestListProvidersEndpoint: class TestControlProviderEndpoint: """Test /providers/{name}/control endpoint.""" - + def test_disable_provider(self, mock_router): """Test disabling a provider.""" client, mock = mock_router - + response = client.post( - "/api/v1/router/providers/ollama-local/control", - json={"action": "disable"} + "/api/v1/router/providers/ollama-local/control", json={"action": "disable"} ) - + assert response.status_code == 200 assert "disabled" in response.json()["message"] - + # Check that the provider was disabled provider = mock.providers[0] assert provider.enabled is False assert provider.status == ProviderStatus.DISABLED - + def test_enable_provider(self, mock_router): """Test enabling a provider.""" client, mock = mock_router # First disable it mock.providers[0].enabled = False mock.providers[0].status = ProviderStatus.DISABLED - + response = client.post( - "/api/v1/router/providers/ollama-local/control", - json={"action": "enable"} + "/api/v1/router/providers/ollama-local/control", json={"action": "enable"} ) - + assert response.status_code == 200 assert "enabled" in response.json()["message"] assert mock.providers[0].enabled is True - + def test_reset_circuit(self, mock_router): """Test resetting circuit breaker.""" client, mock = mock_router @@ -262,73 +277,70 @@ class TestControlProviderEndpoint: mock.providers[0].circuit_state = CircuitState.OPEN mock.providers[0].status = ProviderStatus.UNHEALTHY mock.providers[0].metrics.consecutive_failures = 10 - + response = client.post( - "/api/v1/router/providers/ollama-local/control", - json={"action": "reset_circuit"} + "/api/v1/router/providers/ollama-local/control", json={"action": "reset_circuit"} ) - + assert response.status_code == 200 assert "reset" in response.json()["message"] - + provider = mock.providers[0] assert provider.circuit_state == CircuitState.CLOSED assert provider.status == ProviderStatus.HEALTHY assert provider.metrics.consecutive_failures == 0 - + def test_control_unknown_provider(self, mock_router): """Test controlling unknown provider returns 404.""" client, mock = mock_router response = client.post( - "/api/v1/router/providers/unknown/control", - json={"action": "disable"} + "/api/v1/router/providers/unknown/control", json={"action": "disable"} ) - + assert response.status_code == 404 assert "not found" in response.json()["detail"] - + def test_control_unknown_action(self, mock_router): """Test unknown action returns 400.""" client, mock = mock_router response = client.post( - "/api/v1/router/providers/ollama-local/control", - json={"action": "invalid_action"} + "/api/v1/router/providers/ollama-local/control", json={"action": "invalid_action"} ) - + assert response.status_code == 400 assert "Unknown action" in response.json()["detail"] class TestHealthCheckEndpoint: """Test /health-check endpoint.""" - + def test_health_check_all_healthy(self, mock_router): """Test health check when all providers are healthy.""" client, mock = mock_router - + with patch.object(mock, "_check_provider_available") as mock_check: mock_check.return_value = True - + response = client.post("/api/v1/router/health-check") - + assert response.status_code == 200 data = response.json() assert data["healthy_count"] == 2 assert len(data["providers"]) == 2 - + for p in data["providers"]: assert p["healthy"] is True - + def test_health_check_with_failure(self, mock_router): """Test health check when some providers fail.""" client, mock = mock_router - + with patch.object(mock, "_check_provider_available") as mock_check: # First provider fails, second succeeds mock_check.side_effect = [False, True] - + response = client.post("/api/v1/router/health-check") - + assert response.status_code == 200 data = response.json() assert data["healthy_count"] == 1 @@ -338,21 +350,21 @@ class TestHealthCheckEndpoint: class TestGetConfigEndpoint: """Test /config endpoint.""" - + def test_get_config(self, mock_router): """Test getting router configuration.""" client, mock = mock_router - + response = client.get("/api/v1/router/config") - + assert response.status_code == 200 data = response.json() - + assert data["timeout_seconds"] == 30 assert data["max_retries_per_provider"] == 2 assert "circuit_breaker" in data assert data["circuit_breaker"]["failure_threshold"] == 5 - + # Check providers list (without secrets) assert len(data["providers"]) == 2 assert "api_key" not in data["providers"][0] diff --git a/tests/integrations/test_chat_bridge.py b/tests/integrations/test_chat_bridge.py index ef9a8d4..cdd0d54 100644 --- a/tests/integrations/test_chat_bridge.py +++ b/tests/integrations/test_chat_bridge.py @@ -1,8 +1,9 @@ """Tests for the chat_bridge base classes, registry, and invite parser.""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from integrations.chat_bridge.base import ( ChatMessage, ChatPlatform, @@ -13,7 +14,6 @@ from integrations.chat_bridge.base import ( ) from integrations.chat_bridge.registry import PlatformRegistry - # ── Base dataclass tests ─────────────────────────────────────────────────────── @@ -122,9 +122,7 @@ class _FakePlatform(ChatPlatform): ) async def create_thread(self, channel_id, title, initial_message=None): - return ChatThread( - thread_id="t1", title=title, channel_id=channel_id, platform=self._name - ) + return ChatThread(thread_id="t1", title=title, channel_id=channel_id, platform=self._name) async def join_from_invite(self, invite_code) -> bool: return True @@ -217,18 +215,14 @@ class TestInviteParser: def test_parse_text_discord_com_invite(self): from integrations.chat_bridge.invite_parser import invite_parser - result = invite_parser.parse_text( - "Link: https://discord.com/invite/myServer2024" - ) + result = invite_parser.parse_text("Link: https://discord.com/invite/myServer2024") assert result is not None assert result.code == "myServer2024" def test_parse_text_discordapp(self): from integrations.chat_bridge.invite_parser import invite_parser - result = invite_parser.parse_text( - "https://discordapp.com/invite/test-code" - ) + result = invite_parser.parse_text("https://discordapp.com/invite/test-code") assert result is not None assert result.code == "test-code" diff --git a/tests/integrations/test_discord_vendor.py b/tests/integrations/test_discord_vendor.py index 770d5d7..a81a1c7 100644 --- a/tests/integrations/test_discord_vendor.py +++ b/tests/integrations/test_discord_vendor.py @@ -1,12 +1,12 @@ """Tests for the Discord vendor and dashboard routes.""" import json -import pytest from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch -from integrations.chat_bridge.base import PlatformState +import pytest +from integrations.chat_bridge.base import PlatformState # ── DiscordVendor unit tests ────────────────────────────────────────────────── @@ -52,9 +52,9 @@ class TestDiscordVendor: assert loaded == "test-token-abc" def test_load_token_missing_file(self, tmp_path, monkeypatch): + from config import settings from integrations.chat_bridge.vendors import discord as discord_mod from integrations.chat_bridge.vendors.discord import DiscordVendor - from config import settings state_file = tmp_path / "nonexistent.json" monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file) @@ -186,9 +186,7 @@ class TestDiscordRoutes: new_callable=AsyncMock, return_value=False, ): - resp = client.post( - "/discord/setup", json={"token": "fake-token-123"} - ) + resp = client.post("/discord/setup", json={"token": "fake-token-123"}) assert resp.status_code == 200 data = resp.json() # Will fail because discord.py is mocked, but route handles it diff --git a/tests/integrations/test_paperclip_bridge.py b/tests/integrations/test_paperclip_bridge.py index e1ca36c..c64d891 100644 --- a/tests/integrations/test_paperclip_bridge.py +++ b/tests/integrations/test_paperclip_bridge.py @@ -1,6 +1,6 @@ """Tests for the Paperclip bridge (CEO orchestration logic).""" -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/integrations/test_paperclip_task_runner.py b/tests/integrations/test_paperclip_task_runner.py index a2032f5..f298db2 100644 --- a/tests/integrations/test_paperclip_task_runner.py +++ b/tests/integrations/test_paperclip_task_runner.py @@ -27,12 +27,9 @@ import pytest from integrations.paperclip.bridge import PaperclipBridge from integrations.paperclip.client import PaperclipClient -from integrations.paperclip.models import ( - PaperclipIssue, -) +from integrations.paperclip.models import PaperclipIssue from integrations.paperclip.task_runner import TaskRunner - # ── Constants ───────────────────────────────────────────────────────────────── TIMMY_AGENT_ID = "agent-timmy" @@ -53,9 +50,7 @@ class StubOrchestrator: def __init__(self) -> None: self.calls: list[dict] = [] - async def execute_task( - self, task_id: str, description: str, context: dict - ) -> dict: + async def execute_task(self, task_id: str, description: str, context: dict) -> dict: call_record = { "task_id": task_id, "description": description, @@ -121,8 +116,9 @@ def bridge(mock_client): @pytest.fixture def settings_patch(): """Patch settings for all task runner tests.""" - with patch("integrations.paperclip.task_runner.settings") as ts, \ - patch("integrations.paperclip.bridge.settings") as bs: + with patch("integrations.paperclip.task_runner.settings") as ts, patch( + "integrations.paperclip.bridge.settings" + ) as bs: for s in (ts, bs): s.paperclip_enabled = True s.paperclip_agent_id = TIMMY_AGENT_ID @@ -179,7 +175,11 @@ class TestOrchestratorWiring: """Verify the orchestrator parameter actually connects to the pipe.""" async def test_orchestrator_execute_task_is_called( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """When orchestrator is wired, process_task calls execute_task.""" issue = _make_issue() @@ -193,7 +193,11 @@ class TestOrchestratorWiring: assert call["context"]["title"] == "Muse about task automation" async def test_orchestrator_receives_full_context( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """Context dict passed to execute_task includes all issue metadata.""" issue = _make_issue( @@ -213,7 +217,11 @@ class TestOrchestratorWiring: assert ctx["labels"] == ["automation", "meta"] async def test_orchestrator_dict_result_unwrapped( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """When execute_task returns a dict, the 'result' key is extracted.""" issue = _make_issue() @@ -226,7 +234,10 @@ class TestOrchestratorWiring: assert "issue-1" in result async def test_orchestrator_string_result_passthrough( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): """When execute_task returns a plain string, it passes through.""" @@ -240,7 +251,11 @@ class TestOrchestratorWiring: assert result == "Plain string result for issue-1" async def test_process_fn_overrides_orchestrator( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """Explicit process_fn takes priority over orchestrator.""" @@ -248,7 +263,9 @@ class TestOrchestratorWiring: return "override wins" runner = TaskRunner( - bridge=bridge, orchestrator=stub_orchestrator, process_fn=override, + bridge=bridge, + orchestrator=stub_orchestrator, + process_fn=override, ) result = await runner.process_task(_make_issue()) @@ -314,7 +331,11 @@ class TestProcessTask: """Verify checkout + orchestrator invocation + result flow.""" async def test_checkout_before_orchestrator( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """Issue must be checked out before orchestrator runs.""" issue = _make_issue() @@ -323,9 +344,7 @@ class TestProcessTask: original_execute = stub_orchestrator.execute_task async def tracking_execute(task_id, desc, ctx): - checkout_happened["before_execute"] = ( - mock_client.checkout_issue.await_count > 0 - ) + checkout_happened["before_execute"] = mock_client.checkout_issue.await_count > 0 return await original_execute(task_id, desc, ctx) stub_orchestrator.execute_task = tracking_execute @@ -336,7 +355,11 @@ class TestProcessTask: assert checkout_happened["before_execute"], "checkout must happen before execute_task" async def test_orchestrator_output_flows_to_result( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """The string returned by process_task comes from the orchestrator.""" issue = _make_issue(id="flow-1", title="Flow verification", priority="high") @@ -350,7 +373,10 @@ class TestProcessTask: assert "high" in result async def test_default_fallback_without_orchestrator( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): """Without orchestrator or process_fn, a default message is returned.""" issue = _make_issue(title="Fallback test") @@ -368,7 +394,11 @@ class TestCompleteTask: """Verify orchestrator output flows into the completion comment.""" async def test_orchestrator_output_in_comment( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """The comment posted to Paperclip contains the orchestrator's output.""" issue = _make_issue(id="cmt-1", title="Comment pipe test") @@ -386,7 +416,10 @@ class TestCompleteTask: assert "Comment pipe test" in comment_content async def test_marks_issue_done( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): issue = _make_issue() mock_client.update_issue.return_value = _make_done() @@ -399,7 +432,10 @@ class TestCompleteTask: assert update_req.status == "done" async def test_returns_false_on_close_failure( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): mock_client.update_issue.return_value = None runner = TaskRunner(bridge=bridge) @@ -415,7 +451,11 @@ class TestCreateFollowUp: """Verify orchestrator output flows into the follow-up description.""" async def test_follow_up_contains_orchestrator_output( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """The follow-up description includes the orchestrator's result text.""" issue = _make_issue(id="fu-1", title="Follow-up pipe test") @@ -431,7 +471,10 @@ class TestCreateFollowUp: assert "fu-1" in create_req.description async def test_follow_up_assigned_to_self( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): mock_client.create_issue.return_value = _make_follow_up() runner = TaskRunner(bridge=bridge) @@ -441,7 +484,10 @@ class TestCreateFollowUp: assert req.assignee_id == TIMMY_AGENT_ID async def test_follow_up_preserves_priority( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): mock_client.create_issue.return_value = _make_follow_up() runner = TaskRunner(bridge=bridge) @@ -475,7 +521,11 @@ class TestGreenPathWithOrchestrator: """ async def test_full_cycle_orchestrator_output_everywhere( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """Orchestrator result appears in comment, follow-up, and summary.""" original = _make_issue( @@ -527,7 +577,11 @@ class TestGreenPathWithOrchestrator: assert mock_client.create_issue.await_count == 1 async def test_no_tasks_returns_none( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): mock_client.list_issues.return_value = [] runner = TaskRunner(bridge=bridge, orchestrator=stub_orchestrator) @@ -535,7 +589,11 @@ class TestGreenPathWithOrchestrator: assert len(stub_orchestrator.calls) == 0 async def test_close_failure_still_creates_follow_up( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): mock_client.list_issues.return_value = [_make_issue()] mock_client.update_issue.return_value = None # close fails @@ -558,7 +616,11 @@ class TestExternalTaskInjection: """External system creates a task → Timmy's orchestrator processes it.""" async def test_external_task_flows_through_orchestrator( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): external = _make_issue( id="ext-1", @@ -581,7 +643,11 @@ class TestExternalTaskInjection: assert "Review quarterly metrics" in summary["result"] async def test_skips_tasks_for_other_agents( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): other = _make_issue(id="other-1", assignee_id="agent-codex") mine = _make_issue(id="mine-1", title="My task") @@ -605,17 +671,26 @@ class TestRecursiveChain: """Multi-cycle chains where each follow-up becomes the next task.""" async def test_two_cycle_chain( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): task_a = _make_issue(id="A", title="Initial musing") fu_b = PaperclipIssue( - id="B", title="Follow-up: Initial musing", - description="Continue", status="open", - assignee_id=TIMMY_AGENT_ID, priority="normal", + id="B", + title="Follow-up: Initial musing", + description="Continue", + status="open", + assignee_id=TIMMY_AGENT_ID, + priority="normal", ) fu_c = PaperclipIssue( - id="C", title="Follow-up: Follow-up", - status="open", assignee_id=TIMMY_AGENT_ID, + id="C", + title="Follow-up: Follow-up", + status="open", + assignee_id=TIMMY_AGENT_ID, ) # Cycle 1 @@ -643,14 +718,20 @@ class TestRecursiveChain: assert stub_orchestrator.calls[1]["task_id"] == "B" async def test_three_cycle_chain_all_through_orchestrator( - self, mock_client, bridge, stub_orchestrator, settings_patch, + self, + mock_client, + bridge, + stub_orchestrator, + settings_patch, ): """Three cycles — every task goes through the orchestrator pipe.""" tasks = [_make_issue(id=f"c-{i}", title=f"Chain {i}") for i in range(3)] follow_ups = [ PaperclipIssue( - id=f"c-{i + 1}", title=f"Follow-up: Chain {i}", - status="open", assignee_id=TIMMY_AGENT_ID, + id=f"c-{i + 1}", + title=f"Follow-up: Chain {i}", + status="open", + assignee_id=TIMMY_AGENT_ID, ) for i in range(3) ] @@ -676,7 +757,6 @@ class TestRecursiveChain: class TestLifecycle: - async def test_stop_halts_loop(self, mock_client, bridge, settings_patch): runner = TaskRunner(bridge=bridge) runner._running = True @@ -684,7 +764,10 @@ class TestLifecycle: assert runner._running is False async def test_start_disabled_when_interval_zero( - self, mock_client, bridge, settings_patch, + self, + mock_client, + bridge, + settings_patch, ): settings_patch.paperclip_poll_interval = 0 runner = TaskRunner(bridge=bridge) @@ -701,6 +784,7 @@ def _ollama_reachable() -> tuple[bool, list[str]]: """Return (reachable, model_names).""" try: import httpx + resp = httpx.get("http://localhost:11434/api/tags", timeout=3) resp.raise_for_status() names = [m["name"] for m in resp.json().get("models", [])] @@ -726,9 +810,7 @@ class LiveOllamaOrchestrator: self.model_name = model_name self.calls: list[dict] = [] - async def execute_task( - self, task_id: str, description: str, context: dict - ) -> str: + async def execute_task(self, task_id: str, description: str, context: dict) -> str: import httpx as hx self.calls.append({"task_id": task_id, "description": description}) @@ -814,13 +896,18 @@ class TestLiveOllamaGreenPath: task_a = _make_issue(id="live-A", title="Initial reflection") fu_b = PaperclipIssue( - id="live-B", title="Follow-up: Initial reflection", - description="Continue reflecting", status="open", - assignee_id=TIMMY_AGENT_ID, priority="normal", + id="live-B", + title="Follow-up: Initial reflection", + description="Continue reflecting", + status="open", + assignee_id=TIMMY_AGENT_ID, + priority="normal", ) fu_c = PaperclipIssue( - id="live-C", title="Follow-up: Follow-up", - status="open", assignee_id=TIMMY_AGENT_ID, + id="live-C", + title="Follow-up: Follow-up", + status="open", + assignee_id=TIMMY_AGENT_ID, ) live_orch = LiveOllamaOrchestrator(chosen) diff --git a/tests/integrations/test_shortcuts.py b/tests/integrations/test_shortcuts.py index 9991d82..82e1692 100644 --- a/tests/integrations/test_shortcuts.py +++ b/tests/integrations/test_shortcuts.py @@ -1,6 +1,6 @@ """Tests for shortcuts/siri.py — Siri Shortcuts integration.""" -from integrations.shortcuts.siri import get_setup_guide, SHORTCUT_ACTIONS +from integrations.shortcuts.siri import SHORTCUT_ACTIONS, get_setup_guide def test_setup_guide_has_title(): diff --git a/tests/integrations/test_telegram_bot.py b/tests/integrations/test_telegram_bot.py index 06d5030..4cc55d5 100644 --- a/tests/integrations/test_telegram_bot.py +++ b/tests/integrations/test_telegram_bot.py @@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - # ── TelegramBot unit tests ──────────────────────────────────────────────────── @@ -17,6 +16,7 @@ class TestTelegramBotTokenHelpers: monkeypatch.setattr("integrations.telegram_bot.bot._STATE_FILE", state_file) from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() bot.save_token("test-token-123") @@ -38,6 +38,7 @@ class TestTelegramBotTokenHelpers: with patch("integrations.telegram_bot.bot._load_token_from_file", return_value=None): with patch("config.settings", mock_settings): from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() result = bot.load_token() assert result is None @@ -45,6 +46,7 @@ class TestTelegramBotTokenHelpers: def test_token_set_property(self): """token_set reflects whether a token has been applied.""" from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() assert not bot.token_set bot._token = "tok" @@ -52,6 +54,7 @@ class TestTelegramBotTokenHelpers: def test_is_running_property(self): from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() assert not bot.is_running bot._running = True @@ -66,6 +69,7 @@ class TestTelegramBotLifecycle: monkeypatch.setattr("integrations.telegram_bot.bot._STATE_FILE", state_file) from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() with patch.object(bot, "load_token", return_value=None): result = await bot.start() @@ -75,6 +79,7 @@ class TestTelegramBotLifecycle: @pytest.mark.asyncio async def test_start_already_running_returns_true(self): from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() bot._running = True result = await bot.start(token="any") @@ -84,10 +89,12 @@ class TestTelegramBotLifecycle: async def test_start_import_error_returns_false(self): """start() returns False gracefully when python-telegram-bot absent.""" from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() - with patch.object(bot, "load_token", return_value="tok"), \ - patch.dict("sys.modules", {"telegram": None, "telegram.ext": None}): + with patch.object(bot, "load_token", return_value="tok"), patch.dict( + "sys.modules", {"telegram": None, "telegram.ext": None} + ): result = await bot.start(token="tok") assert result is False assert not bot.is_running @@ -95,6 +102,7 @@ class TestTelegramBotLifecycle: @pytest.mark.asyncio async def test_stop_when_not_running_is_noop(self): from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() # Should not raise await bot.stop() @@ -103,6 +111,7 @@ class TestTelegramBotLifecycle: async def test_stop_calls_shutdown(self): """stop() invokes the Application shutdown sequence.""" from integrations.telegram_bot.bot import TelegramBot + bot = TelegramBot() bot._running = True @@ -126,6 +135,7 @@ class TestTelegramRoutes: def test_status_not_running(self, client): """GET /telegram/status returns running=False when bot is idle.""" from integrations.telegram_bot.bot import telegram_bot + telegram_bot._running = False telegram_bot._token = None @@ -138,6 +148,7 @@ class TestTelegramRoutes: def test_status_running(self, client): """GET /telegram/status returns running=True when bot is active.""" from integrations.telegram_bot.bot import telegram_bot + telegram_bot._running = True telegram_bot._token = "tok" @@ -164,8 +175,9 @@ class TestTelegramRoutes: from integrations.telegram_bot.bot import telegram_bot telegram_bot._running = False - with patch.object(telegram_bot, "save_token") as mock_save, \ - patch.object(telegram_bot, "start", new_callable=AsyncMock, return_value=True): + with patch.object(telegram_bot, "save_token") as mock_save, patch.object( + telegram_bot, "start", new_callable=AsyncMock, return_value=True + ): resp = client.post("/telegram/setup", json={"token": "bot123:abc"}) assert resp.status_code == 200 @@ -178,8 +190,9 @@ class TestTelegramRoutes: from integrations.telegram_bot.bot import telegram_bot telegram_bot._running = False - with patch.object(telegram_bot, "save_token"), \ - patch.object(telegram_bot, "start", new_callable=AsyncMock, return_value=False): + with patch.object(telegram_bot, "save_token"), patch.object( + telegram_bot, "start", new_callable=AsyncMock, return_value=False + ): resp = client.post("/telegram/setup", json={"token": "bad-token"}) assert resp.status_code == 200 @@ -190,11 +203,14 @@ class TestTelegramRoutes: def test_setup_stops_running_bot_first(self, client): """POST /telegram/setup stops any running bot before starting new one.""" from integrations.telegram_bot.bot import telegram_bot + telegram_bot._running = True - with patch.object(telegram_bot, "save_token"), \ - patch.object(telegram_bot, "stop", new_callable=AsyncMock) as mock_stop, \ - patch.object(telegram_bot, "start", new_callable=AsyncMock, return_value=True): + with patch.object(telegram_bot, "save_token"), patch.object( + telegram_bot, "stop", new_callable=AsyncMock + ) as mock_stop, patch.object( + telegram_bot, "start", new_callable=AsyncMock, return_value=True + ): resp = client.post("/telegram/setup", json={"token": "new-token"}) mock_stop.assert_awaited_once() @@ -207,5 +223,6 @@ class TestTelegramRoutes: def test_module_singleton_exists(): """telegram_bot module exposes a singleton TelegramBot instance.""" - from integrations.telegram_bot.bot import telegram_bot, TelegramBot + from integrations.telegram_bot.bot import TelegramBot, telegram_bot + assert isinstance(telegram_bot, TelegramBot) diff --git a/tests/integrations/test_voice_nlu.py b/tests/integrations/test_voice_nlu.py index 69770bf..d1b45bc 100644 --- a/tests/integrations/test_voice_nlu.py +++ b/tests/integrations/test_voice_nlu.py @@ -2,9 +2,9 @@ from integrations.voice.nlu import detect_intent, extract_command - # ── Intent detection ───────────────────────────────────────────────────────── + def test_status_intent(): intent = detect_intent("What is your status?") assert intent.name == "status" @@ -55,6 +55,7 @@ def test_intent_has_raw_text(): # ── Entity extraction ──────────────────────────────────────────────────────── + def test_entity_agent_name(): intent = detect_intent("spawn agent Echo") assert "agent_name" in intent.entities @@ -69,6 +70,7 @@ def test_entity_number(): # ── Command extraction ────────────────────────────────────────────────────── + def test_slash_command(): cmd = extract_command("/status") assert cmd == "status" diff --git a/tests/integrations/test_voice_tts_functional.py b/tests/integrations/test_voice_tts_functional.py index fc69925..c60746a 100644 --- a/tests/integrations/test_voice_tts_functional.py +++ b/tests/integrations/test_voice_tts_functional.py @@ -4,7 +4,7 @@ pyttsx3 is not available in CI, so all tests mock the engine. """ import threading -from unittest.mock import patch, MagicMock, PropertyMock +from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -29,6 +29,7 @@ class TestVoiceTTS: """When pyttsx3 import fails, VoiceTTS degrades gracefully.""" with patch.dict("sys.modules", {"pyttsx3": None}): from importlib import reload + import timmy_serve.voice_tts as mod tts = mod.VoiceTTS.__new__(mod.VoiceTTS) diff --git a/tests/integrations/test_websocket.py b/tests/integrations/test_websocket.py index ed6428c..de45736 100644 --- a/tests/integrations/test_websocket.py +++ b/tests/integrations/test_websocket.py @@ -24,6 +24,7 @@ def test_ws_manager_initial_state(): async def test_ws_manager_event_history_limit(): """History is trimmed to maxlen after broadcasts.""" import collections + mgr = WebSocketManager() mgr._event_history = collections.deque(maxlen=5) for i in range(10): diff --git a/tests/integrations/test_websocket_extended.py b/tests/integrations/test_websocket_extended.py index 37d4f37..557e7c8 100644 --- a/tests/integrations/test_websocket_extended.py +++ b/tests/integrations/test_websocket_extended.py @@ -68,6 +68,7 @@ class TestWebSocketManagerBroadcast: @pytest.mark.asyncio async def test_broadcast_trims_history(self): import collections + mgr = WebSocketManager() mgr._event_history = collections.deque(maxlen=3) for i in range(5): @@ -90,9 +91,7 @@ class TestWebSocketManagerConnect: mgr = WebSocketManager() # Pre-populate history for i in range(3): - mgr._event_history.append( - WSEvent(event=f"e{i}", data={}, timestamp="t") - ) + mgr._event_history.append(WSEvent(event=f"e{i}", data={}, timestamp="t")) ws = AsyncMock() await mgr.connect(ws) # Should have sent 3 history events diff --git a/tests/security/test_security_fixes_xss.py b/tests/security/test_security_fixes_xss.py index ba07448..c8a9aa3 100644 --- a/tests/security/test_security_fixes_xss.py +++ b/tests/security/test_security_fixes_xss.py @@ -1,58 +1,65 @@ import pytest from fastapi.templating import Jinja2Templates + def test_agent_chat_msg_xss_prevention(): """Verify XSS prevention in agent_chat_msg.html.""" templates = Jinja2Templates(directory="src/dashboard/templates") payload = "" + class MockAgent: def __init__(self): self.name = "TestAgent" self.id = "test-agent" - - response = templates.get_template("partials/agent_chat_msg.html").render({ - "message": payload, - "response": payload, - "error": payload, - "agent": MockAgent(), - "timestamp": "12:00:00" - }) - + + response = templates.get_template("partials/agent_chat_msg.html").render( + { + "message": payload, + "response": payload, + "error": payload, + "agent": MockAgent(), + "timestamp": "12:00:00", + } + ) + # Check that payload is escaped assert "<script>alert('xss')</script>" in response assert payload not in response + def test_agent_panel_xss_prevention(): """Verify XSS prevention in agent_panel.html.""" templates = Jinja2Templates(directory="src/dashboard/templates") payload = "" + class MockAgent: def __init__(self): self.name = payload self.id = "test-agent" self.status = "idle" self.capabilities = payload - + class MockTask: def __init__(self): self.id = "task-1" - self.status = type('obj', (object,), {'value': 'completed'}) + self.status = type("obj", (object,), {"value": "completed"}) self.created_at = "2026-02-26T12:00:00" self.description = payload self.result = payload - response = templates.get_template("partials/agent_panel.html").render({ - "agent": MockAgent(), - "tasks": [MockTask()] - }) - + response = templates.get_template("partials/agent_panel.html").render( + {"agent": MockAgent(), "tasks": [MockTask()]} + ) + assert "<script>alert('xss')</script>" in response assert payload not in response + def test_swarm_sidebar_xss_prevention(): """Verify XSS prevention in swarm_agents_sidebar.html.""" templates = Jinja2Templates(directory="src/dashboard/templates") payload = "" + class MockAgent: def __init__(self): self.name = payload @@ -61,9 +68,9 @@ def test_swarm_sidebar_xss_prevention(): self.capabilities = payload self.last_seen = "2026-02-26T12:00:00" - response = templates.get_template("partials/swarm_agents_sidebar.html").render({ - "agents": [MockAgent()] - }) - + response = templates.get_template("partials/swarm_agents_sidebar.html").render( + {"agents": [MockAgent()]} + ) + assert "<script>alert('xss')</script>" in response assert payload not in response diff --git a/tests/security/test_security_regression.py b/tests/security/test_security_regression.py index cd2424b..24804af 100644 --- a/tests/security/test_security_regression.py +++ b/tests/security/test_security_regression.py @@ -1,5 +1,6 @@ import pytest + def test_xss_protection_in_templates(): """Verify that templates now use the escape filter for user-controlled content.""" templates_to_check = [ @@ -9,9 +10,8 @@ def test_xss_protection_in_templates(): ("src/dashboard/templates/partials/approval_card_single.html", "{{ item.title | e }}"), ("src/dashboard/templates/marketplace.html", "{{ agent.name | e }}"), ] - + for path, expected_snippet in templates_to_check: with open(path, "r") as f: content = f.read() assert expected_snippet in content, f"XSS fix missing in {path}" - diff --git a/tests/security/test_xss_vulnerabilities.py b/tests/security/test_xss_vulnerabilities.py index fc24234..a205399 100644 --- a/tests/security/test_xss_vulnerabilities.py +++ b/tests/security/test_xss_vulnerabilities.py @@ -1,12 +1,16 @@ +import html + import pytest from fastapi.testclient import TestClient + from dashboard.app import app -import html + @pytest.fixture def client(): return TestClient(app) + def test_health_status_xss_vulnerability(client, monkeypatch): """Verify that the health status page escapes the model name.""" malicious_model = '">' @@ -19,6 +23,7 @@ def test_health_status_xss_vulnerability(client, monkeypatch): assert escaped_model in response.text assert malicious_model not in response.text + def test_grok_toggle_xss_vulnerability(client, monkeypatch): """Verify that the grok toggle card escapes the model name.""" malicious_model = '">' diff --git a/tests/spark/test_spark.py b/tests/spark/test_spark.py index ce046af..ce015d2 100644 --- a/tests/spark/test_spark.py +++ b/tests/spark/test_spark.py @@ -13,9 +13,9 @@ from pathlib import Path import pytest - # ── Fixtures ──────────────────────────────────────────────────────────────── + @pytest.fixture(autouse=True) def tmp_spark_db(tmp_path, monkeypatch): """Redirect all Spark SQLite writes to a temp directory.""" @@ -31,29 +31,34 @@ def tmp_spark_db(tmp_path, monkeypatch): class TestImportanceScoring: def test_failure_scores_high(self): from spark.memory import score_importance + score = score_importance("task_failed", {}) assert score >= 0.9 def test_bid_scores_low(self): from spark.memory import score_importance + score = score_importance("bid_submitted", {}) assert score <= 0.3 def test_high_bid_boosts_score(self): from spark.memory import score_importance + low = score_importance("bid_submitted", {"bid_sats": 10}) high = score_importance("bid_submitted", {"bid_sats": 100}) assert high > low def test_unknown_event_default(self): from spark.memory import score_importance + score = score_importance("unknown_type", {}) assert score == 0.5 class TestEventRecording: def test_record_and_query(self): - from spark.memory import record_event, get_events + from spark.memory import get_events, record_event + eid = record_event("task_posted", "Test task", task_id="t1") assert eid events = get_events(task_id="t1") @@ -62,22 +67,26 @@ class TestEventRecording: assert events[0].description == "Test task" def test_record_with_agent(self): - from spark.memory import record_event, get_events - record_event("bid_submitted", "Agent bid", agent_id="a1", task_id="t2", - data='{"bid_sats": 50}') + from spark.memory import get_events, record_event + + record_event( + "bid_submitted", "Agent bid", agent_id="a1", task_id="t2", data='{"bid_sats": 50}' + ) events = get_events(agent_id="a1") assert len(events) == 1 assert events[0].agent_id == "a1" def test_filter_by_event_type(self): - from spark.memory import record_event, get_events + from spark.memory import get_events, record_event + record_event("task_posted", "posted", task_id="t3") record_event("task_completed", "completed", task_id="t3") posted = get_events(event_type="task_posted") assert len(posted) == 1 def test_filter_by_min_importance(self): - from spark.memory import record_event, get_events + from spark.memory import get_events, record_event + record_event("bid_submitted", "low", importance=0.1) record_event("task_failed", "high", importance=0.9) high_events = get_events(min_importance=0.5) @@ -85,7 +94,8 @@ class TestEventRecording: assert high_events[0].event_type == "task_failed" def test_count_events(self): - from spark.memory import record_event, count_events + from spark.memory import count_events, record_event + record_event("task_posted", "a") record_event("task_posted", "b") record_event("task_completed", "c") @@ -93,7 +103,8 @@ class TestEventRecording: assert count_events("task_posted") == 2 def test_limit_results(self): - from spark.memory import record_event, get_events + from spark.memory import get_events, record_event + for i in range(10): record_event("bid_submitted", f"bid {i}") events = get_events(limit=3) @@ -102,7 +113,8 @@ class TestEventRecording: class TestMemoryConsolidation: def test_store_and_query_memory(self): - from spark.memory import store_memory, get_memories + from spark.memory import get_memories, store_memory + mid = store_memory("pattern", "agent-x", "Strong performer", confidence=0.8) assert mid memories = get_memories(subject="agent-x") @@ -110,7 +122,8 @@ class TestMemoryConsolidation: assert memories[0].content == "Strong performer" def test_filter_by_type(self): - from spark.memory import store_memory, get_memories + from spark.memory import get_memories, store_memory + store_memory("pattern", "system", "Good pattern") store_memory("anomaly", "system", "Bad anomaly") patterns = get_memories(memory_type="pattern") @@ -118,7 +131,8 @@ class TestMemoryConsolidation: assert patterns[0].memory_type == "pattern" def test_filter_by_confidence(self): - from spark.memory import store_memory, get_memories + from spark.memory import get_memories, store_memory + store_memory("pattern", "a", "Low conf", confidence=0.2) store_memory("pattern", "b", "High conf", confidence=0.9) high = get_memories(min_confidence=0.5) @@ -126,7 +140,8 @@ class TestMemoryConsolidation: assert high[0].content == "High conf" def test_count_memories(self): - from spark.memory import store_memory, count_memories + from spark.memory import count_memories, store_memory + store_memory("pattern", "a", "X") store_memory("anomaly", "b", "Y") assert count_memories() == 2 @@ -138,7 +153,8 @@ class TestMemoryConsolidation: class TestPredictions: def test_predict_stores_prediction(self): - from spark.eidos import predict_task_outcome, get_predictions + from spark.eidos import get_predictions, predict_task_outcome + result = predict_task_outcome("t1", "Fix the bug", ["agent-a", "agent-b"]) assert "prediction_id" in result assert result["likely_winner"] == "agent-a" @@ -147,12 +163,15 @@ class TestPredictions: def test_predict_with_history(self): from spark.eidos import predict_task_outcome + history = { "agent-a": {"success_rate": 0.3, "avg_winning_bid": 40}, "agent-b": {"success_rate": 0.9, "avg_winning_bid": 30}, } result = predict_task_outcome( - "t2", "Research topic", ["agent-a", "agent-b"], + "t2", + "Research topic", + ["agent-a", "agent-b"], agent_history=history, ) assert result["likely_winner"] == "agent-b" @@ -160,20 +179,23 @@ class TestPredictions: def test_predict_empty_candidates(self): from spark.eidos import predict_task_outcome + result = predict_task_outcome("t3", "No agents", []) assert result["likely_winner"] is None class TestEvaluation: def test_evaluate_correct_prediction(self): - from spark.eidos import predict_task_outcome, evaluate_prediction + from spark.eidos import evaluate_prediction, predict_task_outcome + predict_task_outcome("t4", "Task", ["agent-a"]) result = evaluate_prediction("t4", "agent-a", task_succeeded=True, winning_bid=30) assert result is not None assert result["accuracy"] > 0.0 def test_evaluate_wrong_prediction(self): - from spark.eidos import predict_task_outcome, evaluate_prediction + from spark.eidos import evaluate_prediction, predict_task_outcome + predict_task_outcome("t5", "Task", ["agent-a"]) result = evaluate_prediction("t5", "agent-b", task_succeeded=False) assert result is not None @@ -182,11 +204,13 @@ class TestEvaluation: def test_evaluate_no_prediction_returns_none(self): from spark.eidos import evaluate_prediction + result = evaluate_prediction("no-task", "agent-a", task_succeeded=True) assert result is None def test_double_evaluation_returns_none(self): - from spark.eidos import predict_task_outcome, evaluate_prediction + from spark.eidos import evaluate_prediction, predict_task_outcome + predict_task_outcome("t6", "Task", ["agent-a"]) evaluate_prediction("t6", "agent-a", task_succeeded=True) # Second evaluation should return None (already evaluated) @@ -197,13 +221,15 @@ class TestEvaluation: class TestAccuracyStats: def test_empty_stats(self): from spark.eidos import get_accuracy_stats + stats = get_accuracy_stats() assert stats["total_predictions"] == 0 assert stats["evaluated"] == 0 assert stats["avg_accuracy"] == 0.0 def test_stats_after_evaluations(self): - from spark.eidos import predict_task_outcome, evaluate_prediction, get_accuracy_stats + from spark.eidos import evaluate_prediction, get_accuracy_stats, predict_task_outcome + for i in range(3): predict_task_outcome(f"task-{i}", "Description", ["agent-a"]) evaluate_prediction(f"task-{i}", "agent-a", task_succeeded=True, winning_bid=30) @@ -217,6 +243,7 @@ class TestAccuracyStats: class TestComputeAccuracy: def test_perfect_prediction(self): from spark.eidos import _compute_accuracy + predicted = { "likely_winner": "agent-a", "success_probability": 1.0, @@ -228,6 +255,7 @@ class TestComputeAccuracy: def test_all_wrong(self): from spark.eidos import _compute_accuracy + predicted = { "likely_winner": "agent-a", "success_probability": 1.0, @@ -239,6 +267,7 @@ class TestComputeAccuracy: def test_partial_credit(self): from spark.eidos import _compute_accuracy + predicted = { "likely_winner": "agent-a", "success_probability": 0.5, @@ -256,36 +285,38 @@ class TestComputeAccuracy: class TestAdvisor: def test_insufficient_data(self): from spark.advisor import generate_advisories + advisories = generate_advisories() assert len(advisories) >= 1 assert advisories[0].category == "system_health" assert "Insufficient" in advisories[0].title def test_failure_detection(self): - from spark.memory import record_event from spark.advisor import generate_advisories + from spark.memory import record_event + # Record enough events to pass the minimum threshold for i in range(5): - record_event("task_failed", f"Failed task {i}", - agent_id="agent-bad", task_id=f"t-{i}") + record_event("task_failed", f"Failed task {i}", agent_id="agent-bad", task_id=f"t-{i}") advisories = generate_advisories() failure_advisories = [a for a in advisories if a.category == "failure_prevention"] assert len(failure_advisories) >= 1 assert "agent-ba" in failure_advisories[0].title def test_advisories_sorted_by_priority(self): - from spark.memory import record_event from spark.advisor import generate_advisories + from spark.memory import record_event + for i in range(4): record_event("task_posted", f"posted {i}", task_id=f"p-{i}") - record_event("task_completed", f"done {i}", - agent_id="agent-good", task_id=f"p-{i}") + record_event("task_completed", f"done {i}", agent_id="agent-good", task_id=f"p-{i}") advisories = generate_advisories() if len(advisories) >= 2: assert advisories[0].priority >= advisories[-1].priority def test_no_activity_advisory(self): from spark.advisor import _check_system_activity + advisories = _check_system_activity() assert len(advisories) >= 1 assert "No swarm activity" in advisories[0].title @@ -297,11 +328,13 @@ class TestAdvisor: class TestSparkEngine: def test_engine_enabled(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=True) assert engine.enabled def test_engine_disabled(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=False) result = engine.on_task_posted("t1", "Ignored task") assert result is None @@ -309,6 +342,7 @@ class TestSparkEngine: def test_on_task_posted(self): from spark.engine import SparkEngine from spark.memory import get_events + engine = SparkEngine(enabled=True) eid = engine.on_task_posted("t1", "Test task", ["agent-a"]) assert eid is not None @@ -318,6 +352,7 @@ class TestSparkEngine: def test_on_bid_submitted(self): from spark.engine import SparkEngine from spark.memory import get_events + engine = SparkEngine(enabled=True) eid = engine.on_bid_submitted("t1", "agent-a", 50) assert eid is not None @@ -327,6 +362,7 @@ class TestSparkEngine: def test_on_task_assigned(self): from spark.engine import SparkEngine from spark.memory import get_events + engine = SparkEngine(enabled=True) eid = engine.on_task_assigned("t1", "agent-a") assert eid is not None @@ -334,8 +370,9 @@ class TestSparkEngine: assert len(events) == 1 def test_on_task_completed_evaluates_prediction(self): - from spark.engine import SparkEngine from spark.eidos import get_predictions + from spark.engine import SparkEngine + engine = SparkEngine(enabled=True) engine.on_task_posted("t1", "Fix bug", ["agent-a"]) eid = engine.on_task_completed("t1", "agent-a", "Fixed it") @@ -347,6 +384,7 @@ class TestSparkEngine: def test_on_task_failed(self): from spark.engine import SparkEngine from spark.memory import get_events + engine = SparkEngine(enabled=True) engine.on_task_posted("t1", "Deploy server", ["agent-a"]) eid = engine.on_task_failed("t1", "agent-a", "Connection timeout") @@ -357,6 +395,7 @@ class TestSparkEngine: def test_on_agent_joined(self): from spark.engine import SparkEngine from spark.memory import get_events + engine = SparkEngine(enabled=True) eid = engine.on_agent_joined("agent-a", "Echo") assert eid is not None @@ -365,6 +404,7 @@ class TestSparkEngine: def test_status(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=True) engine.on_task_posted("t1", "Test", ["agent-a"]) engine.on_bid_submitted("t1", "agent-a", 30) @@ -376,18 +416,21 @@ class TestSparkEngine: def test_get_advisories(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=True) advisories = engine.get_advisories() assert isinstance(advisories, list) def test_get_advisories_disabled(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=False) advisories = engine.get_advisories() assert advisories == [] def test_get_timeline(self): from spark.engine import SparkEngine + engine = SparkEngine(enabled=True) engine.on_task_posted("t1", "Task 1") engine.on_task_posted("t2", "Task 2") @@ -397,6 +440,7 @@ class TestSparkEngine: def test_memory_consolidation(self): from spark.engine import SparkEngine from spark.memory import get_memories + engine = SparkEngine(enabled=True) # Generate enough completions to trigger consolidation (>=5 events, >=3 outcomes) for i in range(6): diff --git a/tests/spark/test_spark_tools_creative.py b/tests/spark/test_spark_tools_creative.py index 7934bfd..5f857ab 100644 --- a/tests/spark/test_spark_tools_creative.py +++ b/tests/spark/test_spark_tools_creative.py @@ -7,7 +7,7 @@ in Phase 6. import pytest from spark.engine import SparkEngine -from spark.memory import get_events, count_events +from spark.memory import count_events, get_events @pytest.fixture(autouse=True) diff --git a/tests/test_agentic_loop.py b/tests/test_agentic_loop.py index 1d5541b..f856670 100644 --- a/tests/test_agentic_loop.py +++ b/tests/test_agentic_loop.py @@ -4,20 +4,17 @@ Tests cover planning, execution, max_steps enforcement, failure adaptation, progress callbacks, and response cleaning. """ -import pytest -from unittest.mock import MagicMock, patch, AsyncMock -from timmy.agentic_loop import ( - run_agentic_loop, - _parse_steps, - AgenticResult, - AgenticStep, -) +from unittest.mock import AsyncMock, MagicMock, patch +import pytest + +from timmy.agentic_loop import AgenticResult, AgenticStep, _parse_steps, run_agentic_loop # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _mock_run(content: str): """Create a mock return value for agent.run().""" m = MagicMock() @@ -29,6 +26,7 @@ def _mock_run(content: str): # _parse_steps # --------------------------------------------------------------------------- + class TestParseSteps: def test_numbered_with_dot(self): text = "1. Search for data\n2. Write to file\n3. Verify" @@ -50,20 +48,24 @@ class TestParseSteps: # run_agentic_loop # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_planning_phase_produces_steps(): """Planning prompt returns numbered step list.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Search AI news\n2. Write to file\n3. Verify"), - _mock_run("Found 5 articles about AI."), - _mock_run("Wrote summary to /tmp/ai_news.md"), - _mock_run("File verified, 15 lines."), - _mock_run("Searched, wrote, verified."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Search AI news\n2. Write to file\n3. Verify"), + _mock_run("Found 5 articles about AI."), + _mock_run("Wrote summary to /tmp/ai_news.md"), + _mock_run("File verified, 15 lines."), + _mock_run("Searched, wrote, verified."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Search AI news and write summary") assert result.status == "completed" @@ -74,15 +76,18 @@ async def test_planning_phase_produces_steps(): async def test_loop_executes_all_steps(): """Loop calls agent.run() for plan + each step + summary.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Do A\n2. Do B"), - _mock_run("A done"), - _mock_run("B done"), - _mock_run("All done"), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Do A\n2. Do B"), + _mock_run("A done"), + _mock_run("B done"), + _mock_run("All done"), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Do A and B") # plan + 2 steps + summary = 4 calls @@ -94,15 +99,18 @@ async def test_loop_executes_all_steps(): async def test_loop_respects_max_steps(): """Loop stops at max_steps and returns status='partial'.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. A\n2. B\n3. C\n4. D\n5. E"), - _mock_run("A done"), - _mock_run("B done"), - _mock_run("Completed 2 of 5 steps."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. A\n2. B\n3. C\n4. D\n5. E"), + _mock_run("A done"), + _mock_run("B done"), + _mock_run("Completed 2 of 5 steps."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Do 5 things", max_steps=2) assert len(result.steps) == 2 @@ -113,17 +121,20 @@ async def test_loop_respects_max_steps(): async def test_failure_triggers_adaptation(): """Failed step feeds error back to model, step marked as adapted.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Read config\n2. Update setting\n3. Verify"), - _mock_run("Config: timeout=30"), - Exception("Permission denied"), - _mock_run("Adapted: wrote to ~/config.yaml instead"), - _mock_run("Verified: timeout=60"), - _mock_run("Updated config via alternative path."), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Read config\n2. Update setting\n3. Verify"), + _mock_run("Config: timeout=30"), + Exception("Permission denied"), + _mock_run("Adapted: wrote to ~/config.yaml instead"), + _mock_run("Verified: timeout=60"), + _mock_run("Updated config via alternative path."), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Update config timeout to 60") assert result.status == "completed" @@ -139,15 +150,18 @@ async def test_progress_callback_fires(): events.append((step, total)) mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Do A\n2. Do B"), - _mock_run("A done"), - _mock_run("B done"), - _mock_run("All done"), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Do A\n2. Do B"), + _mock_run("A done"), + _mock_run("B done"), + _mock_run("All done"), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): await run_agentic_loop("Do A and B", on_progress=on_progress) assert len(events) == 2 @@ -159,15 +173,18 @@ async def test_progress_callback_fires(): async def test_result_contains_step_metadata(): """AgenticResult.steps has status and duration per step.""" mock_agent = MagicMock() - mock_agent.run = MagicMock(side_effect=[ - _mock_run("1. Search\n2. Write"), - _mock_run("Found results"), - _mock_run("Written to file"), - _mock_run("Done"), - ]) + mock_agent.run = MagicMock( + side_effect=[ + _mock_run("1. Search\n2. Write"), + _mock_run("Found results"), + _mock_run("Written to file"), + _mock_run("Done"), + ] + ) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Search and write") for step in result.steps: @@ -191,8 +208,9 @@ async def test_config_default_used(): mock_agent.run = MagicMock(side_effect=side_effects) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Do 14 things", max_steps=0) # Should be capped at 10 (config default) @@ -205,8 +223,9 @@ async def test_planning_failure_returns_failed(): mock_agent = MagicMock() mock_agent.run = MagicMock(side_effect=Exception("Model offline")) - with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), \ - patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock): + with patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent), patch( + "timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock + ): result = await run_agentic_loop("Do something") assert result.status == "failed" diff --git a/tests/test_hands_git.py b/tests/test_hands_git.py index da743cc..ea78a2e 100644 --- a/tests/test_hands_git.py +++ b/tests/test_hands_git.py @@ -10,11 +10,11 @@ Covers: import pytest - # --------------------------------------------------------------------------- # Destructive operation gating # --------------------------------------------------------------------------- + def test_is_destructive_detects_force_push(): """Force-push should be flagged as destructive.""" from infrastructure.hands.git import GitHand @@ -71,6 +71,7 @@ async def test_run_allows_destructive_with_flag(): # Successful operations # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_run_git_status(): """git status should succeed in a git repo.""" @@ -97,6 +98,7 @@ async def test_run_git_log(): # Convenience wrappers # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_status_wrapper(): """status() convenience wrapper should work.""" @@ -145,6 +147,7 @@ async def test_diff_staged_wrapper(): # GitResult dataclass # --------------------------------------------------------------------------- + def test_git_result_defaults(): """GitResult should have sensible defaults.""" from infrastructure.hands.git import GitResult @@ -161,6 +164,7 @@ def test_git_result_defaults(): # Info summary # --------------------------------------------------------------------------- + def test_info_returns_summary(): """info() should return a dict with repo_dir and timeout.""" from infrastructure.hands.git import GitHand @@ -175,6 +179,7 @@ def test_info_returns_summary(): # Tool registration # --------------------------------------------------------------------------- + def test_persona_hand_map(): """Forge and Helm should have both shell and git access.""" from infrastructure.hands.tools import get_local_hands_for_persona diff --git a/tests/test_hands_shell.py b/tests/test_hands_shell.py index 74c7aa0..27e7ecd 100644 --- a/tests/test_hands_shell.py +++ b/tests/test_hands_shell.py @@ -13,11 +13,11 @@ import asyncio import pytest - # --------------------------------------------------------------------------- # Command validation # --------------------------------------------------------------------------- + def test_validate_allows_safe_commands(): """Commands matching the allow-list should pass validation.""" from infrastructure.hands.shell import ShellHand @@ -72,6 +72,7 @@ def test_validate_strips_path_prefix(): # Execution — success path # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_run_echo(): """A simple echo command should succeed.""" @@ -100,6 +101,7 @@ async def test_run_python_expression(): # Execution — failure path # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_run_blocked_command(): """Running a blocked command returns success=False without executing.""" @@ -126,6 +128,7 @@ async def test_run_nonzero_exit(): # Timeout # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_run_timeout(): """A command exceeding timeout should be killed and return timed_out=True.""" @@ -158,6 +161,7 @@ async def test_run_timeout(): # ShellResult dataclass # --------------------------------------------------------------------------- + def test_shell_result_defaults(): """ShellResult should have sensible defaults.""" from infrastructure.hands.shell import ShellResult @@ -176,6 +180,7 @@ def test_shell_result_defaults(): # Status summary # --------------------------------------------------------------------------- + def test_status_returns_summary(): """status() should return a dict with enabled, working_dir, etc.""" from infrastructure.hands.shell import ShellHand diff --git a/tests/test_openfang_client.py b/tests/test_openfang_client.py index c5c999f..a48f12c 100644 --- a/tests/test_openfang_client.py +++ b/tests/test_openfang_client.py @@ -15,11 +15,11 @@ from unittest.mock import MagicMock, patch import pytest - # --------------------------------------------------------------------------- # Health checks # --------------------------------------------------------------------------- + def test_health_check_false_when_unreachable(): """Client should report unhealthy when OpenFang is not running.""" from infrastructure.openfang.client import OpenFangClient @@ -47,6 +47,7 @@ def test_health_check_caching(): # execute_hand — unknown hand # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_execute_hand_unknown_hand(): """Requesting an unknown hand returns success=False immediately.""" @@ -62,16 +63,19 @@ async def test_execute_hand_unknown_hand(): # execute_hand — success path (mocked HTTP) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_execute_hand_success_mocked(): """When OpenFang returns 200 with output, HandResult.success is True.""" from infrastructure.openfang.client import OpenFangClient - response_body = json.dumps({ - "success": True, - "output": "Page loaded successfully", - "metadata": {"url": "https://example.com"}, - }).encode() + response_body = json.dumps( + { + "success": True, + "output": "Page loaded successfully", + "metadata": {"url": "https://example.com"}, + } + ).encode() mock_resp = MagicMock() mock_resp.status = 200 @@ -93,6 +97,7 @@ async def test_execute_hand_success_mocked(): # execute_hand — graceful degradation on connection error # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_execute_hand_connection_error(): """When OpenFang is unreachable, HandResult.success is False (no crash).""" @@ -110,6 +115,7 @@ async def test_execute_hand_connection_error(): # Convenience wrappers # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_browse_calls_browser_hand(): """browse() should delegate to execute_hand('browser', ...).""" @@ -180,6 +186,7 @@ async def test_predict_calls_predictor_hand(): # HandResult dataclass # --------------------------------------------------------------------------- + def test_hand_result_defaults(): """HandResult should have sensible defaults.""" from infrastructure.openfang.client import HandResult @@ -195,6 +202,7 @@ def test_hand_result_defaults(): # OPENFANG_HANDS constant # --------------------------------------------------------------------------- + def test_openfang_hands_tuple(): """The OPENFANG_HANDS constant should list all 7 hands.""" from infrastructure.openfang.client import OPENFANG_HANDS @@ -213,6 +221,7 @@ def test_openfang_hands_tuple(): # status() summary # --------------------------------------------------------------------------- + def test_status_returns_summary(): """status() should return a dict with url, healthy flag, and hands list.""" from infrastructure.openfang.client import OpenFangClient diff --git a/tests/test_setup_script.py b/tests/test_setup_script.py index d604b89..cc426d6 100644 --- a/tests/test_setup_script.py +++ b/tests/test_setup_script.py @@ -1,9 +1,10 @@ import os -import subprocess import shutil -import pytest +import subprocess from pathlib import Path +import pytest + # Constants for testing TEST_PROJECT_DIR = Path("/home/ubuntu/test-sovereign-stack") TEST_VAULT_DIR = TEST_PROJECT_DIR / "TimmyVault" @@ -14,6 +15,7 @@ pytestmark = pytest.mark.skipif( reason=f"Setup script not found at {SETUP_SCRIPT_PATH}", ) + @pytest.fixture(scope="module", autouse=True) def cleanup_test_env(): """Ensure a clean environment before and after tests.""" @@ -23,26 +25,25 @@ def cleanup_test_env(): # We keep the test env for manual inspection if needed, or cleanup # shutil.rmtree(TEST_PROJECT_DIR) + def run_setup_command(args): """Helper to run the setup script with arguments.""" result = subprocess.run( - [str(SETUP_SCRIPT_PATH)] + args, - capture_output=True, - text=True, - cwd="/home/ubuntu" + [str(SETUP_SCRIPT_PATH)] + args, capture_output=True, text=True, cwd="/home/ubuntu" ) return result + def test_setup_install_creates_directories(): """Test that './setup_timmy.sh install' creates the expected directory structure.""" # Note: We expect the script to be present at SETUP_SCRIPT_PATH assert SETUP_SCRIPT_PATH.exists(), "Setup script must exist before testing" - + result = run_setup_command(["install"]) - + # Check if command succeeded assert result.returncode == 0, f"Setup install failed: {result.stderr}" - + # Check directory structure assert TEST_PROJECT_DIR.exists() assert (TEST_PROJECT_DIR / "paperclip").exists() @@ -50,6 +51,7 @@ def test_setup_install_creates_directories(): assert TEST_VAULT_DIR.exists() assert (TEST_PROJECT_DIR / "logs").exists() + def test_setup_install_creates_files(): """Test that './setup_timmy.sh install' creates the expected configuration and notes.""" # Check Agent config @@ -64,11 +66,12 @@ def test_setup_install_creates_files(): soul_note = TEST_VAULT_DIR / "SOUL.md" assert hello_note.exists() assert soul_note.exists() - + with open(soul_note, "r") as f: content = f.read() assert "I am Timmy" in content + def test_setup_install_dependencies(): """Test that dependencies are correctly handled (OpenFang, Paperclip deps).""" # Check if Paperclip node_modules exists (implies pnpm install ran) @@ -76,11 +79,12 @@ def test_setup_install_dependencies(): node_modules = TEST_PROJECT_DIR / "paperclip/node_modules" assert node_modules.exists() + def test_setup_start_stop_logic(): """Test the start/stop command logic (simulated).""" # This is harder to test fully without actually running the services, # but we can check if the script handles the commands without crashing. - + # Mocking start (it might fail if ports are taken, so we check return code) # For the sake of this test, we just check if the script recognizes the command result = run_setup_command(["status"]) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 1f2fed5..4bba87d 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -13,6 +13,7 @@ from fastapi.testclient import TestClient @pytest.fixture def client(): from dashboard.app import app + with TestClient(app, raise_server_exceptions=False) as c: yield c @@ -21,6 +22,7 @@ def client(): # Core pages — these MUST return 200 # --------------------------------------------------------------------------- + class TestCorePages: """Every core dashboard page loads without error.""" @@ -49,6 +51,7 @@ class TestCorePages: # Feature pages — should return 200 (or 307 redirect, never 500) # --------------------------------------------------------------------------- + class TestFeaturePages: """Feature pages load without 500 errors.""" @@ -109,6 +112,7 @@ class TestFeaturePages: # JSON API endpoints — should return valid JSON, never 500 # --------------------------------------------------------------------------- + class TestAPIEndpoints: """API endpoints return valid JSON without server errors.""" @@ -179,49 +183,53 @@ class TestAPIEndpoints: # No 500s — every GET route should survive without server error # --------------------------------------------------------------------------- + class TestNo500: """Verify that no page returns a 500 Internal Server Error.""" - @pytest.mark.parametrize("path", [ - "/", - "/health", - "/health/status", - "/health/sovereignty", - "/health/components", - "/agents/default/panel", - "/agents/default/history", - "/briefing", - "/thinking", - "/thinking/api", - "/tools", - "/tools/api/stats", - "/memory", - "/calm", - "/tasks", - "/tasks/pending", - "/tasks/active", - "/tasks/completed", - "/work-orders/queue", - "/work-orders/queue/pending", - "/work-orders/queue/active", - "/mobile", - "/mobile/status", - "/spark", - "/models", - "/swarm/live", - "/swarm/events", - "/marketplace", - "/api/queue/status", - "/api/tasks", - "/api/chat/history", - "/api/notifications", - "/router/api/providers", - "/discord/status", - "/telegram/status", - "/grok/status", - "/grok/stats", - "/api/paperclip/status", - ]) + @pytest.mark.parametrize( + "path", + [ + "/", + "/health", + "/health/status", + "/health/sovereignty", + "/health/components", + "/agents/default/panel", + "/agents/default/history", + "/briefing", + "/thinking", + "/thinking/api", + "/tools", + "/tools/api/stats", + "/memory", + "/calm", + "/tasks", + "/tasks/pending", + "/tasks/active", + "/tasks/completed", + "/work-orders/queue", + "/work-orders/queue/pending", + "/work-orders/queue/active", + "/mobile", + "/mobile/status", + "/spark", + "/models", + "/swarm/live", + "/swarm/events", + "/marketplace", + "/api/queue/status", + "/api/tasks", + "/api/chat/history", + "/api/notifications", + "/router/api/providers", + "/discord/status", + "/telegram/status", + "/grok/status", + "/grok/stats", + "/api/paperclip/status", + ], + ) def test_no_500(self, client, path): r = client.get(path) assert r.status_code != 500, f"GET {path} returned 500" diff --git a/tests/timmy/test_agent.py b/tests/timmy/test_agent.py index 40ee278..d336974 100644 --- a/tests/timmy/test_agent.py +++ b/tests/timmy/test_agent.py @@ -3,14 +3,14 @@ from unittest.mock import MagicMock, patch def test_create_timmy_returns_agent(): """create_timmy should delegate to Agno Agent with correct config.""" - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): mock_instance = MagicMock() MockAgent.return_value = mock_instance from timmy.agent import create_timmy + result = create_timmy() assert result is mock_instance @@ -18,11 +18,11 @@ def test_create_timmy_returns_agent(): def test_create_timmy_agent_name(): - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs @@ -30,11 +30,11 @@ def test_create_timmy_agent_name(): def test_create_timmy_history_config(): - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs @@ -44,11 +44,11 @@ def test_create_timmy_history_config(): def test_create_timmy_custom_db_file(): - with patch("timmy.agent.Agent"), \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb") as MockDb: - + with patch("timmy.agent.Agent"), patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ) as MockDb: from timmy.agent import create_timmy + create_timmy(db_file="custom.db") MockDb.assert_called_once_with(db_file="custom.db") @@ -57,11 +57,11 @@ def test_create_timmy_custom_db_file(): def test_create_timmy_embeds_system_prompt(): from timmy.prompts import SYSTEM_PROMPT - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs @@ -72,17 +72,18 @@ def test_create_timmy_embeds_system_prompt(): # ── Ollama host regression (container connectivity) ───────────────────────── + def test_create_timmy_passes_ollama_url_to_model(): """Regression: Ollama model must receive settings.ollama_url as host. Without this, containers default to localhost:11434 which is unreachable when Ollama runs on the Docker host. """ - with patch("timmy.agent.Agent"), \ - patch("timmy.agent.Ollama") as MockOllama, \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent"), patch("timmy.agent.Ollama") as MockOllama, patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockOllama.call_args.kwargs @@ -93,17 +94,16 @@ def test_create_timmy_passes_ollama_url_to_model(): def test_create_timmy_respects_custom_ollama_url(): """Ollama host should follow OLLAMA_URL when overridden in config.""" custom_url = "http://host.docker.internal:11434" - with patch("timmy.agent.Agent"), \ - patch("timmy.agent.Ollama") as MockOllama, \ - patch("timmy.agent.SqliteDb"), \ - patch("timmy.agent.settings") as mock_settings: - + with patch("timmy.agent.Agent"), patch("timmy.agent.Ollama") as MockOllama, patch( + "timmy.agent.SqliteDb" + ), patch("timmy.agent.settings") as mock_settings: mock_settings.ollama_model = "llama3.2" mock_settings.ollama_url = custom_url mock_settings.timmy_model_backend = "ollama" mock_settings.airllm_model_size = "70b" from timmy.agent import create_timmy + create_timmy() kwargs = MockOllama.call_args.kwargs @@ -112,6 +112,7 @@ def test_create_timmy_respects_custom_ollama_url(): # ── AirLLM path ────────────────────────────────────────────────────────────── + def test_create_timmy_airllm_returns_airllm_agent(): """backend='airllm' must return a TimmyAirLLMAgent, not an Agno Agent.""" with patch("timmy.backends.is_apple_silicon", return_value=False): @@ -125,10 +126,11 @@ def test_create_timmy_airllm_returns_airllm_agent(): def test_create_timmy_airllm_does_not_call_agno_agent(): """When using the airllm backend, Agno Agent should never be instantiated.""" - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.backends.is_apple_silicon", return_value=False): - + with patch("timmy.agent.Agent") as MockAgent, patch( + "timmy.backends.is_apple_silicon", return_value=False + ): from timmy.agent import create_timmy + create_timmy(backend="airllm", model_size="8b") MockAgent.assert_not_called() @@ -136,11 +138,11 @@ def test_create_timmy_airllm_does_not_call_agno_agent(): def test_create_timmy_explicit_ollama_ignores_autodetect(): """backend='ollama' must always use Ollama, even on Apple Silicon.""" - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy(backend="ollama") MockAgent.assert_called_once() @@ -148,8 +150,10 @@ def test_create_timmy_explicit_ollama_ignores_autodetect(): # ── _resolve_backend ───────────────────────────────────────────────────────── + def test_resolve_backend_explicit_takes_priority(): from timmy.agent import _resolve_backend + assert _resolve_backend("airllm") == "airllm" assert _resolve_backend("ollama") == "ollama" @@ -157,38 +161,45 @@ def test_resolve_backend_explicit_takes_priority(): def test_resolve_backend_defaults_to_ollama_without_config(): """Default config (timmy_model_backend='ollama') → 'ollama'.""" from timmy.agent import _resolve_backend + assert _resolve_backend(None) == "ollama" def test_resolve_backend_auto_uses_airllm_on_apple_silicon(): """'auto' on Apple Silicon with airllm stubbed → 'airllm'.""" - with patch("timmy.backends.is_apple_silicon", return_value=True), \ - patch("timmy.agent.settings") as mock_settings: + with patch("timmy.backends.is_apple_silicon", return_value=True), patch( + "timmy.agent.settings" + ) as mock_settings: mock_settings.timmy_model_backend = "auto" mock_settings.airllm_model_size = "70b" mock_settings.ollama_model = "llama3.2" from timmy.agent import _resolve_backend + assert _resolve_backend(None) == "airllm" def test_resolve_backend_auto_falls_back_on_non_apple(): """'auto' on non-Apple Silicon → 'ollama'.""" - with patch("timmy.backends.is_apple_silicon", return_value=False), \ - patch("timmy.agent.settings") as mock_settings: + with patch("timmy.backends.is_apple_silicon", return_value=False), patch( + "timmy.agent.settings" + ) as mock_settings: mock_settings.timmy_model_backend = "auto" mock_settings.airllm_model_size = "70b" mock_settings.ollama_model = "llama3.2" from timmy.agent import _resolve_backend + assert _resolve_backend(None) == "ollama" # ── _model_supports_tools ──────────────────────────────────────────────────── + def test_model_supports_tools_llama32_returns_false(): """llama3.2 (3B) is too small for reliable tool calling.""" from timmy.agent import _model_supports_tools + assert _model_supports_tools("llama3.2") is False assert _model_supports_tools("llama3.2:latest") is False @@ -196,6 +207,7 @@ def test_model_supports_tools_llama32_returns_false(): def test_model_supports_tools_llama31_returns_true(): """llama3.1 (8B+) can handle tool calling.""" from timmy.agent import _model_supports_tools + assert _model_supports_tools("llama3.1") is True assert _model_supports_tools("llama3.3") is True @@ -203,6 +215,7 @@ def test_model_supports_tools_llama31_returns_true(): def test_model_supports_tools_other_small_models(): """Other known small models should not get tools.""" from timmy.agent import _model_supports_tools + assert _model_supports_tools("phi-3") is False assert _model_supports_tools("tinyllama") is False @@ -210,19 +223,21 @@ def test_model_supports_tools_other_small_models(): def test_model_supports_tools_unknown_model_gets_tools(): """Unknown models default to tool-capable (optimistic).""" from timmy.agent import _model_supports_tools + assert _model_supports_tools("mistral") is True assert _model_supports_tools("qwen2.5:72b") is True # ── Tool gating in create_timmy ────────────────────────────────────────────── + def test_create_timmy_no_tools_for_small_model(): """llama3.2 should get no tools.""" - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs @@ -234,12 +249,11 @@ def test_create_timmy_includes_tools_for_large_model(): """A tool-capable model (e.g. llama3.1) should attempt to include tools.""" mock_toolkit = MagicMock() - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"), \ - patch("timmy.agent.create_full_toolkit", return_value=mock_toolkit), \ - patch("timmy.agent.settings") as mock_settings: - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ), patch("timmy.agent.create_full_toolkit", return_value=mock_toolkit), patch( + "timmy.agent.settings" + ) as mock_settings: mock_settings.ollama_model = "llama3.1" mock_settings.ollama_url = "http://localhost:11434" mock_settings.timmy_model_backend = "ollama" @@ -247,6 +261,7 @@ def test_create_timmy_includes_tools_for_large_model(): mock_settings.telemetry_enabled = False from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs @@ -255,11 +270,11 @@ def test_create_timmy_includes_tools_for_large_model(): def test_create_timmy_show_tool_calls_matches_tool_capability(): """show_tool_calls should be True when tools are enabled, False otherwise.""" - with patch("timmy.agent.Agent") as MockAgent, \ - patch("timmy.agent.Ollama"), \ - patch("timmy.agent.SqliteDb"): - + with patch("timmy.agent.Agent") as MockAgent, patch("timmy.agent.Ollama"), patch( + "timmy.agent.SqliteDb" + ): from timmy.agent import create_timmy + create_timmy() kwargs = MockAgent.call_args.kwargs diff --git a/tests/timmy/test_agent_core.py b/tests/timmy/test_agent_core.py index 220fc08..a2154e0 100644 --- a/tests/timmy/test_agent_core.py +++ b/tests/timmy/test_agent_core.py @@ -11,11 +11,11 @@ from unittest.mock import MagicMock, patch import pytest from timmy.agent_core.interface import ( + Action, ActionType, AgentCapability, AgentEffect, AgentIdentity, - Action, Communication, Memory, Perception, @@ -23,7 +23,6 @@ from timmy.agent_core.interface import ( TimAgent, ) - # ── AgentIdentity ───────────────────────────────────────────────────────────── @@ -166,12 +165,23 @@ class TestTimAgentABC: def test_concrete_subclass_works(self): class Dummy(TimAgent): - def perceive(self, p): return Memory(id="1", content=p.data, created_at="") - def reason(self, q, c): return Action.respond(q) - def act(self, a): return a.payload - def remember(self, m): pass - def recall(self, q, limit=5): return [] - def communicate(self, m): return True + def perceive(self, p): + return Memory(id="1", content=p.data, created_at="") + + def reason(self, q, c): + return Action.respond(q) + + def act(self, a): + return a.payload + + def remember(self, m): + pass + + def recall(self, q, limit=5): + return [] + + def communicate(self, m): + return True d = Dummy(AgentIdentity.generate("Dummy")) assert d.identity.name == "Dummy" @@ -179,12 +189,23 @@ class TestTimAgentABC: def test_has_capability(self): class Dummy(TimAgent): - def perceive(self, p): pass - def reason(self, q, c): pass - def act(self, a): pass - def remember(self, m): pass - def recall(self, q, limit=5): return [] - def communicate(self, m): return True + def perceive(self, p): + pass + + def reason(self, q, c): + pass + + def act(self, a): + pass + + def remember(self, m): + pass + + def recall(self, q, limit=5): + return [] + + def communicate(self, m): + return True d = Dummy(AgentIdentity.generate("D")) d._capabilities.add(AgentCapability.REASONING) @@ -193,12 +214,23 @@ class TestTimAgentABC: def test_capabilities_returns_copy(self): class Dummy(TimAgent): - def perceive(self, p): pass - def reason(self, q, c): pass - def act(self, a): pass - def remember(self, m): pass - def recall(self, q, limit=5): return [] - def communicate(self, m): return True + def perceive(self, p): + pass + + def reason(self, q, c): + pass + + def act(self, a): + pass + + def remember(self, m): + pass + + def recall(self, q, limit=5): + return [] + + def communicate(self, m): + return True d = Dummy(AgentIdentity.generate("D")) caps = d.capabilities @@ -207,12 +239,23 @@ class TestTimAgentABC: def test_get_state(self): class Dummy(TimAgent): - def perceive(self, p): pass - def reason(self, q, c): pass - def act(self, a): pass - def remember(self, m): pass - def recall(self, q, limit=5): return [] - def communicate(self, m): return True + def perceive(self, p): + pass + + def reason(self, q, c): + pass + + def act(self, a): + pass + + def remember(self, m): + pass + + def recall(self, q, limit=5): + return [] + + def communicate(self, m): + return True d = Dummy(AgentIdentity.generate("D")) state = d.get_state() @@ -222,12 +265,23 @@ class TestTimAgentABC: def test_shutdown_does_not_raise(self): class Dummy(TimAgent): - def perceive(self, p): pass - def reason(self, q, c): pass - def act(self, a): pass - def remember(self, m): pass - def recall(self, q, limit=5): return [] - def communicate(self, m): return True + def perceive(self, p): + pass + + def reason(self, q, c): + pass + + def act(self, a): + pass + + def remember(self, m): + pass + + def recall(self, q, limit=5): + return [] + + def communicate(self, m): + return True d = Dummy(AgentIdentity.generate("D")) d.shutdown() # should not raise @@ -311,6 +365,7 @@ class TestOllamaAgent: mock_ct.return_value = mock_timmy from timmy.agent_core.ollama_adapter import OllamaAgent + identity = AgentIdentity.generate("TestTimmy") return OllamaAgent(identity, effect_log="/tmp/test_effects") @@ -427,6 +482,7 @@ class TestOllamaAgent: mock_timmy = MagicMock() mock_ct.return_value = mock_timmy from timmy.agent_core.ollama_adapter import OllamaAgent + identity = AgentIdentity.generate("NoLog") agent = OllamaAgent(identity) # no effect_log assert agent.get_effect_log() is None diff --git a/tests/timmy/test_agents_timmy.py b/tests/timmy/test_agents_timmy.py index ffae023..aba316e 100644 --- a/tests/timmy/test_agents_timmy.py +++ b/tests/timmy/test_agents_timmy.py @@ -1,9 +1,10 @@ """Tests for timmy.agents.timmy — orchestrator, personas, context building.""" import sys -import pytest -from unittest.mock import patch, MagicMock, AsyncMock from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest # Ensure mcp.registry stub with tool_registry exists before importing agents if "mcp" not in sys.modules: @@ -16,14 +17,14 @@ if "mcp" not in sys.modules: sys.modules["mcp.registry"] = _mock_registry_mod from timmy.agents.timmy import ( - _load_hands_async, - build_timmy_context_sync, - build_timmy_context_async, - format_timmy_prompt, - TimmyOrchestrator, - create_timmy_swarm, _PERSONAS, ORCHESTRATOR_PROMPT_BASE, + TimmyOrchestrator, + _load_hands_async, + build_timmy_context_async, + build_timmy_context_sync, + create_timmy_swarm, + format_timmy_prompt, ) @@ -65,6 +66,7 @@ class TestBuildContext: # Patch HotMemory path so it reads from tmp_path from timmy.memory_system import memory_system + original_path = memory_system.hot.path memory_system.hot.path = memory_file memory_system.hot._content = None # Clear cache @@ -183,6 +185,7 @@ class TestTimmyOrchestrator: orch = TimmyOrchestrator() from timmy.agents.base import SubAgent + agent = SubAgent( agent_id="test-agent", name="Test", @@ -250,8 +253,8 @@ class TestPersonas: ids = [p["agent_id"] for p in _PERSONAS] assert len(ids) == len(set(ids)) - def test_five_personas(self): - assert len(_PERSONAS) == 5 + def test_six_personas(self): + assert len(_PERSONAS) == 6 class TestOrchestratorPrompt: diff --git a/tests/timmy/test_api_rate_limiting.py b/tests/timmy/test_api_rate_limiting.py index 940fe9e..f758018 100644 --- a/tests/timmy/test_api_rate_limiting.py +++ b/tests/timmy/test_api_rate_limiting.py @@ -1,32 +1,44 @@ """Tests for API rate limiting in Timmy Serve.""" -import pytest import time + +import pytest from fastapi.testclient import TestClient + from timmy_serve.app import create_timmy_serve_app + @pytest.fixture def client(): app = create_timmy_serve_app() return TestClient(app) + def test_health_check_no_rate_limit(client): """Health check should not be rate limited (or have a very high limit).""" for _ in range(10): response = client.get("/health") assert response.status_code == 200 + def test_chat_rate_limiting(client, monkeypatch): """Chat endpoint should be rate limited.""" # Mock create_timmy to avoid heavy LLM initialization - monkeypatch.setattr("timmy_serve.app.create_timmy", lambda: type('obj', (object,), {'run': lambda self, m, stream: type('obj', (object,), {'content': 'reply'})()})()) - + monkeypatch.setattr( + "timmy_serve.app.create_timmy", + lambda: type( + "obj", + (object,), + {"run": lambda self, m, stream: type("obj", (object,), {"content": "reply"})()}, + )(), + ) + # Send requests up to the limit (assuming limit is small for tests or we just test it's there) # Since we haven't implemented it yet, this test should fail if we assert 429 responses = [] for _ in range(20): responses.append(client.post("/serve/chat", json={"message": "hi"})) - + # If rate limiting is implemented, some of these should be 429 status_codes = [r.status_code for r in responses] assert 429 in status_codes diff --git a/tests/timmy/test_approvals.py b/tests/timmy/test_approvals.py index 9d118d6..81db768 100644 --- a/tests/timmy/test_approvals.py +++ b/tests/timmy/test_approvals.py @@ -1,20 +1,21 @@ """Tests for timmy.approvals — approval workflow and Golden Timmy rule.""" -import pytest -from pathlib import Path from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest from timmy.approvals import ( GOLDEN_TIMMY, ApprovalItem, - create_item, - list_pending, - list_all, - get_item, - approve, - reject, - expire_old, _get_conn, + approve, + create_item, + expire_old, + get_item, + list_all, + list_pending, + reject, ) diff --git a/tests/timmy/test_autoresearch.py b/tests/timmy/test_autoresearch.py new file mode 100644 index 0000000..bacd247 --- /dev/null +++ b/tests/timmy/test_autoresearch.py @@ -0,0 +1,179 @@ +"""Tests for the autoresearch module — autonomous ML experiment loops.""" + +import json +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +class TestPrepareExperiment: + """Tests for prepare_experiment().""" + + def test_clones_repo_when_not_present(self, tmp_path): + from timmy.autoresearch import prepare_experiment + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + result = prepare_experiment(tmp_path, "https://example.com/repo.git") + + assert mock_run.call_count >= 1 + clone_call = mock_run.call_args_list[0] + assert "git" in clone_call.args[0] + assert "clone" in clone_call.args[0] + + def test_skips_clone_when_present(self, tmp_path): + from timmy.autoresearch import prepare_experiment + + repo_dir = tmp_path / "autoresearch" + repo_dir.mkdir() + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + result = prepare_experiment(tmp_path) + + # Should not call git clone + if mock_run.called: + assert "clone" not in str(mock_run.call_args_list[0]) + + def test_clone_failure_returns_error(self, tmp_path): + from timmy.autoresearch import prepare_experiment + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=1, stdout="", stderr="auth failed") + result = prepare_experiment(tmp_path) + + assert "failed" in result.lower() + + +class TestRunExperiment: + """Tests for run_experiment().""" + + def test_successful_run_extracts_metric(self, tmp_path): + from timmy.autoresearch import run_experiment + + repo_dir = tmp_path / "autoresearch" + repo_dir.mkdir() + (repo_dir / "train.py").write_text("print('training')") + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.return_value = MagicMock( + returncode=0, + stdout="step 1000 val_bpb: 1.2345\nstep 2000 val_bpb: 1.1234", + stderr="", + ) + result = run_experiment(tmp_path, timeout=60) + + assert result["success"] is True + assert result["metric"] == pytest.approx(1.1234) + assert result["error"] is None + + def test_timeout_returns_error(self, tmp_path): + import subprocess + + from timmy.autoresearch import run_experiment + + repo_dir = tmp_path / "autoresearch" + repo_dir.mkdir() + (repo_dir / "train.py").write_text("print('training')") + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.side_effect = subprocess.TimeoutExpired(cmd="python", timeout=5) + result = run_experiment(tmp_path, timeout=5) + + assert result["success"] is False + assert "timed out" in result["error"].lower() + + def test_missing_train_py(self, tmp_path): + from timmy.autoresearch import run_experiment + + repo_dir = tmp_path / "autoresearch" + repo_dir.mkdir() + # No train.py + + result = run_experiment(tmp_path) + assert result["success"] is False + assert "not found" in result["error"].lower() + + def test_no_metric_in_output(self, tmp_path): + from timmy.autoresearch import run_experiment + + repo_dir = tmp_path / "autoresearch" + repo_dir.mkdir() + (repo_dir / "train.py").write_text("print('done')") + + with patch("timmy.autoresearch.subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="no metrics here", stderr="") + result = run_experiment(tmp_path) + + assert result["success"] is True + assert result["metric"] is None + + +class TestEvaluateResult: + """Tests for evaluate_result().""" + + def test_improvement_detected(self): + from timmy.autoresearch import evaluate_result + + result = evaluate_result(1.10, 1.20) + assert "improvement" in result.lower() + + def test_regression_detected(self): + from timmy.autoresearch import evaluate_result + + result = evaluate_result(1.30, 1.20) + assert "regression" in result.lower() + + def test_no_change(self): + from timmy.autoresearch import evaluate_result + + result = evaluate_result(1.20, 1.20) + assert "no change" in result.lower() + + +class TestExperimentHistory: + """Tests for get_experiment_history().""" + + def test_empty_workspace(self, tmp_path): + from timmy.autoresearch import get_experiment_history + + history = get_experiment_history(tmp_path) + assert history == [] + + def test_reads_jsonl(self, tmp_path): + from timmy.autoresearch import get_experiment_history + + results_file = tmp_path / "results.jsonl" + results_file.write_text( + json.dumps({"metric": 1.2, "success": True}) + + "\n" + + json.dumps({"metric": 1.1, "success": True}) + + "\n" + ) + + history = get_experiment_history(tmp_path) + assert len(history) == 2 + # Most recent first + assert history[0]["metric"] == 1.1 + + +class TestExtractMetric: + """Tests for _extract_metric().""" + + def test_extracts_last_value(self): + from timmy.autoresearch import _extract_metric + + output = "val_bpb: 1.5\nval_bpb: 1.3\nval_bpb: 1.1" + assert _extract_metric(output) == pytest.approx(1.1) + + def test_no_match_returns_none(self): + from timmy.autoresearch import _extract_metric + + assert _extract_metric("no metrics here") is None + + def test_custom_metric_name(self): + from timmy.autoresearch import _extract_metric + + output = "loss: 0.45\nloss: 0.32" + assert _extract_metric(output, "loss") == pytest.approx(0.32) diff --git a/tests/timmy/test_backends.py b/tests/timmy/test_backends.py index ace393c..2efc6b1 100644 --- a/tests/timmy/test_backends.py +++ b/tests/timmy/test_backends.py @@ -5,35 +5,43 @@ from unittest.mock import MagicMock, patch import pytest - # ── is_apple_silicon ────────────────────────────────────────────────────────── + def test_is_apple_silicon_true_on_arm_darwin(): - with patch("timmy.backends.platform.system", return_value="Darwin"), \ - patch("timmy.backends.platform.machine", return_value="arm64"): + with patch("timmy.backends.platform.system", return_value="Darwin"), patch( + "timmy.backends.platform.machine", return_value="arm64" + ): from timmy.backends import is_apple_silicon + assert is_apple_silicon() is True def test_is_apple_silicon_false_on_linux(): - with patch("timmy.backends.platform.system", return_value="Linux"), \ - patch("timmy.backends.platform.machine", return_value="x86_64"): + with patch("timmy.backends.platform.system", return_value="Linux"), patch( + "timmy.backends.platform.machine", return_value="x86_64" + ): from timmy.backends import is_apple_silicon + assert is_apple_silicon() is False def test_is_apple_silicon_false_on_intel_mac(): - with patch("timmy.backends.platform.system", return_value="Darwin"), \ - patch("timmy.backends.platform.machine", return_value="x86_64"): + with patch("timmy.backends.platform.system", return_value="Darwin"), patch( + "timmy.backends.platform.machine", return_value="x86_64" + ): from timmy.backends import is_apple_silicon + assert is_apple_silicon() is False # ── airllm_available ───────────────────────────────────────────────────────── + def test_airllm_available_true_when_stub_in_sys_modules(): # conftest already stubs 'airllm' — importable → True. from timmy.backends import airllm_available + assert airllm_available() is True @@ -42,6 +50,7 @@ def test_airllm_available_false_when_not_importable(): saved = sys.modules.pop("airllm", None) try: from timmy.backends import airllm_available + assert airllm_available() is False finally: if saved is not None: @@ -50,8 +59,10 @@ def test_airllm_available_false_when_not_importable(): # ── TimmyAirLLMAgent construction ──────────────────────────────────────────── + def test_airllm_agent_raises_on_unknown_size(): from timmy.backends import TimmyAirLLMAgent + with pytest.raises(ValueError, match="Unknown model size"): TimmyAirLLMAgent(model_size="3b") @@ -60,6 +71,7 @@ def test_airllm_agent_uses_automodel_on_non_apple(): """Non-Apple-Silicon path uses AutoModel.from_pretrained.""" with patch("timmy.backends.is_apple_silicon", return_value=False): from timmy.backends import TimmyAirLLMAgent + agent = TimmyAirLLMAgent(model_size="8b") # sys.modules["airllm"] is a MagicMock; AutoModel.from_pretrained was called. assert sys.modules["airllm"].AutoModel.from_pretrained.called @@ -69,25 +81,27 @@ def test_airllm_agent_uses_mlx_on_apple_silicon(): """Apple Silicon path uses AirLLMMLX, not AutoModel.""" with patch("timmy.backends.is_apple_silicon", return_value=True): from timmy.backends import TimmyAirLLMAgent + agent = TimmyAirLLMAgent(model_size="8b") assert sys.modules["airllm"].AirLLMMLX.called def test_airllm_agent_resolves_correct_model_id_for_70b(): with patch("timmy.backends.is_apple_silicon", return_value=False): - from timmy.backends import TimmyAirLLMAgent, _AIRLLM_MODELS + from timmy.backends import _AIRLLM_MODELS, TimmyAirLLMAgent + TimmyAirLLMAgent(model_size="70b") - sys.modules["airllm"].AutoModel.from_pretrained.assert_called_with( - _AIRLLM_MODELS["70b"] - ) + sys.modules["airllm"].AutoModel.from_pretrained.assert_called_with(_AIRLLM_MODELS["70b"]) # ── TimmyAirLLMAgent.print_response ────────────────────────────────────────── + def _make_agent(model_size: str = "8b") -> "TimmyAirLLMAgent": """Helper: create an agent with a fully mocked underlying model.""" with patch("timmy.backends.is_apple_silicon", return_value=False): from timmy.backends import TimmyAirLLMAgent + agent = TimmyAirLLMAgent(model_size=model_size) # Replace the underlying model with a clean mock that returns predictable output. @@ -151,6 +165,7 @@ def test_claude_available_false_when_no_key(): with patch("config.settings") as mock_settings: mock_settings.anthropic_api_key = "" from timmy.backends import claude_available + assert claude_available() is False @@ -159,12 +174,14 @@ def test_claude_available_true_when_key_set(): with patch("config.settings") as mock_settings: mock_settings.anthropic_api_key = "sk-ant-test-key" from timmy.backends import claude_available + assert claude_available() is True def test_claude_backend_init_with_explicit_params(): """ClaudeBackend can be created with explicit api_key and model.""" from timmy.backends import ClaudeBackend + backend = ClaudeBackend(api_key="sk-ant-test", model="haiku") assert backend._api_key == "sk-ant-test" assert "haiku" in backend._model @@ -172,7 +189,8 @@ def test_claude_backend_init_with_explicit_params(): def test_claude_backend_init_resolves_short_names(): """ClaudeBackend resolves short model names to full IDs.""" - from timmy.backends import ClaudeBackend, CLAUDE_MODELS + from timmy.backends import CLAUDE_MODELS, ClaudeBackend + backend = ClaudeBackend(api_key="sk-test", model="sonnet") assert backend._model == CLAUDE_MODELS["sonnet"] @@ -180,6 +198,7 @@ def test_claude_backend_init_resolves_short_names(): def test_claude_backend_init_passes_through_full_model_id(): """ClaudeBackend passes through full model IDs unchanged.""" from timmy.backends import ClaudeBackend + backend = ClaudeBackend(api_key="sk-test", model="claude-haiku-4-5-20251001") assert backend._model == "claude-haiku-4-5-20251001" @@ -187,6 +206,7 @@ def test_claude_backend_init_passes_through_full_model_id(): def test_claude_backend_run_no_key_returns_error(): """run() gracefully returns error message when no API key.""" from timmy.backends import ClaudeBackend + backend = ClaudeBackend(api_key="", model="haiku") result = backend.run("hello") assert "not configured" in result.content diff --git a/tests/timmy/test_conversation.py b/tests/timmy/test_conversation.py index ad0219c..b91401f 100644 --- a/tests/timmy/test_conversation.py +++ b/tests/timmy/test_conversation.py @@ -1,6 +1,7 @@ """Tests for timmy.conversation — conversation context and tool routing.""" import pytest + from timmy.conversation import ConversationContext, ConversationManager diff --git a/tests/timmy/test_grok_backend.py b/tests/timmy/test_grok_backend.py index 693e049..69787d4 100644 --- a/tests/timmy/test_grok_backend.py +++ b/tests/timmy/test_grok_backend.py @@ -4,15 +4,16 @@ from unittest.mock import MagicMock, patch import pytest - # ── grok_available ─────────────────────────────────────────────────────────── + def test_grok_available_false_when_disabled(): """Grok not available when GROK_ENABLED is false.""" with patch("config.settings") as mock_settings: mock_settings.grok_enabled = False mock_settings.xai_api_key = "xai-test-key" from timmy.backends import grok_available + assert grok_available() is False @@ -22,6 +23,7 @@ def test_grok_available_false_when_no_key(): mock_settings.grok_enabled = True mock_settings.xai_api_key = "" from timmy.backends import grok_available + assert grok_available() is False @@ -31,14 +33,17 @@ def test_grok_available_true_when_enabled_and_key_set(): mock_settings.grok_enabled = True mock_settings.xai_api_key = "xai-test-key" from timmy.backends import grok_available + assert grok_available() is True # ── GrokBackend construction ──────────────────────────────────────────────── + def test_grok_backend_init_with_explicit_params(): """GrokBackend can be created with explicit api_key and model.""" from timmy.backends import GrokBackend + backend = GrokBackend(api_key="xai-test", model="grok-3-fast") assert backend._api_key == "xai-test" assert backend._model == "grok-3-fast" @@ -51,6 +56,7 @@ def test_grok_backend_init_from_settings(): mock_settings.xai_api_key = "xai-from-env" mock_settings.grok_default_model = "grok-3" from timmy.backends import GrokBackend + backend = GrokBackend() assert backend._api_key == "xai-from-env" assert backend._model == "grok-3" @@ -59,6 +65,7 @@ def test_grok_backend_init_from_settings(): def test_grok_backend_run_no_key_returns_error(): """run() gracefully returns error message when no API key.""" from timmy.backends import GrokBackend + backend = GrokBackend(api_key="", model="grok-3-fast") result = backend.run("hello") assert "not configured" in result.content @@ -191,6 +198,7 @@ def test_grok_backend_build_messages(): # ── get_grok_backend singleton ────────────────────────────────────────────── + def test_get_grok_backend_returns_singleton(): """get_grok_backend returns the same instance on repeated calls.""" import timmy.backends as backends_mod @@ -208,18 +216,22 @@ def test_get_grok_backend_returns_singleton(): # ── GROK_MODELS constant ─────────────────────────────────────────────────── + def test_grok_models_dict_has_expected_entries(): from timmy.backends import GROK_MODELS + assert "grok-3-fast" in GROK_MODELS assert "grok-3" in GROK_MODELS # ── consult_grok tool ────────────────────────────────────────────────────── + def test_consult_grok_returns_unavailable_when_disabled(): """consult_grok tool returns error when Grok is not available.""" with patch("timmy.backends.grok_available", return_value=False): from timmy.tools import consult_grok + result = consult_grok("test query") assert "not available" in result @@ -233,13 +245,14 @@ def test_consult_grok_calls_backend_when_available(): mock_backend.stats = MagicMock() mock_backend.stats.total_latency_ms = 100 - with patch("timmy.backends.grok_available", return_value=True), \ - patch("timmy.backends.get_grok_backend", return_value=mock_backend), \ - patch("config.settings") as mock_settings: + with patch("timmy.backends.grok_available", return_value=True), patch( + "timmy.backends.get_grok_backend", return_value=mock_backend + ), patch("config.settings") as mock_settings: mock_settings.grok_free = True mock_settings.grok_enabled = True mock_settings.xai_api_key = "xai-test" from timmy.tools import consult_grok + result = consult_grok("complex question") assert "Grok answer" in result @@ -248,6 +261,7 @@ def test_consult_grok_calls_backend_when_available(): # ── Grok dashboard route tests ───────────────────────────────────────────── + def test_grok_status_endpoint(client): """GET /grok/status returns HTML dashboard page.""" response = client.get("/grok/status") @@ -281,4 +295,8 @@ def test_grok_chat_without_key(client): ) assert response.status_code == 200 # Should contain error since GROK_ENABLED is false in test mode - assert "not available" in response.text.lower() or "error" in response.text.lower() or "grok" in response.text.lower() + assert ( + "not available" in response.text.lower() + or "error" in response.text.lower() + or "grok" in response.text.lower() + ) diff --git a/tests/timmy/test_introspection.py b/tests/timmy/test_introspection.py index f0f2351..0c689f7 100644 --- a/tests/timmy/test_introspection.py +++ b/tests/timmy/test_introspection.py @@ -18,8 +18,8 @@ def test_get_system_info_returns_dict(): def test_get_system_info_contains_model(): """System info should include model name.""" - from timmy.tools_intro import get_system_info from config import settings + from timmy.tools_intro import get_system_info info = get_system_info() @@ -30,8 +30,8 @@ def test_get_system_info_contains_model(): def test_get_system_info_contains_repo_root(): """System info should include repo_root.""" - from timmy.tools_intro import get_system_info from config import settings + from timmy.tools_intro import get_system_info info = get_system_info() diff --git a/tests/timmy/test_ollama_timeout.py b/tests/timmy/test_ollama_timeout.py index 7a045c0..0f7ea90 100644 --- a/tests/timmy/test_ollama_timeout.py +++ b/tests/timmy/test_ollama_timeout.py @@ -5,13 +5,12 @@ This caused socket read errors in production. The agno Ollama class uses ``timeout`` (not ``request_timeout``). """ -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch def test_base_agent_sets_timeout(): """BaseAgent creates Ollama with timeout=300.""" - with patch("timmy.agents.base.Ollama") as mock_ollama, \ - patch("timmy.agents.base.Agent"): + with patch("timmy.agents.base.Ollama") as mock_ollama, patch("timmy.agents.base.Agent"): mock_ollama.return_value = MagicMock() # Import after patching to get the patched version @@ -36,19 +35,20 @@ def test_base_agent_sets_timeout(): # Verify Ollama was called with timeout if mock_ollama.called: _, kwargs = mock_ollama.call_args - assert kwargs.get("timeout") == 300, ( - f"Expected timeout=300, got {kwargs.get('timeout')}" - ) + assert ( + kwargs.get("timeout") == 300 + ), f"Expected timeout=300, got {kwargs.get('timeout')}" def test_main_agent_sets_timeout(): """create_timmy() creates Ollama with timeout=300.""" - with patch("timmy.agent.Ollama") as mock_ollama, \ - patch("timmy.agent.SqliteDb"), \ - patch("timmy.agent.Agent"): + with patch("timmy.agent.Ollama") as mock_ollama, patch("timmy.agent.SqliteDb"), patch( + "timmy.agent.Agent" + ): mock_ollama.return_value = MagicMock() from timmy.agent import create_timmy + try: create_timmy() except Exception: @@ -56,6 +56,6 @@ def test_main_agent_sets_timeout(): if mock_ollama.called: _, kwargs = mock_ollama.call_args - assert kwargs.get("timeout") == 300, ( - f"Expected timeout=300, got {kwargs.get('timeout')}" - ) + assert ( + kwargs.get("timeout") == 300 + ), f"Expected timeout=300, got {kwargs.get('timeout')}" diff --git a/tests/timmy/test_prompts.py b/tests/timmy/test_prompts.py index da6ce65..56db51b 100644 --- a/tests/timmy/test_prompts.py +++ b/tests/timmy/test_prompts.py @@ -1,4 +1,4 @@ -from timmy.prompts import SYSTEM_PROMPT, STATUS_PROMPT, get_system_prompt +from timmy.prompts import STATUS_PROMPT, SYSTEM_PROMPT, get_system_prompt def test_system_prompt_not_empty(): diff --git a/tests/timmy/test_semantic_memory.py b/tests/timmy/test_semantic_memory.py index 5518c47..c6d9205 100644 --- a/tests/timmy/test_semantic_memory.py +++ b/tests/timmy/test_semantic_memory.py @@ -1,19 +1,20 @@ """Tests for timmy.semantic_memory — semantic search, chunking, indexing.""" -import pytest from pathlib import Path from unittest.mock import patch +import pytest + from timmy.semantic_memory import ( - _simple_hash_embedding, - embed_text, - cosine_similarity, - SemanticMemory, - MemorySearcher, MemoryChunk, - memory_search, - memory_read, + MemorySearcher, + SemanticMemory, _get_embedding_model, + _simple_hash_embedding, + cosine_similarity, + embed_text, + memory_read, + memory_search, ) @@ -38,6 +39,7 @@ class TestSimpleHashEmbedding: def test_normalized(self): import math + vec = _simple_hash_embedding("test normalization") magnitude = math.sqrt(sum(x * x for x in vec)) assert abs(magnitude - 1.0) < 0.01 @@ -112,7 +114,9 @@ class TestSemanticMemory: def test_index_file(self, mem): md_file = mem.vault_path / "test.md" - md_file.write_text("# Title\n\nThis is a test document with enough content to index properly.\n\nAnother paragraph with more content here.") + md_file.write_text( + "# Title\n\nThis is a test document with enough content to index properly.\n\nAnother paragraph with more content here." + ) count = mem.index_file(md_file) assert count > 0 @@ -129,8 +133,12 @@ class TestSemanticMemory: assert count2 == 0 # Already indexed, same hash def test_index_vault(self, mem): - (mem.vault_path / "a.md").write_text("# File A\n\nContent of file A with some meaningful text here.") - (mem.vault_path / "b.md").write_text("# File B\n\nContent of file B with different meaningful text.") + (mem.vault_path / "a.md").write_text( + "# File A\n\nContent of file A with some meaningful text here." + ) + (mem.vault_path / "b.md").write_text( + "# File B\n\nContent of file B with different meaningful text." + ) total = mem.index_vault() assert total >= 2 @@ -148,6 +156,7 @@ class TestSemanticMemory: # Wipe and re-test via index_vault import sqlite3 + conn = sqlite3.connect(str(mem.db_path)) conn.execute("DELETE FROM chunks") conn.commit() @@ -164,7 +173,9 @@ class TestSemanticMemory: def test_search_returns_results(self, mem): md = mem.vault_path / "searchable.md" - md.write_text("# Python\n\nPython is a programming language used for web development and data science.") + md.write_text( + "# Python\n\nPython is a programming language used for web development and data science." + ) mem.index_file(md) results = mem.search("programming language") @@ -179,7 +190,9 @@ class TestSemanticMemory: def test_get_relevant_context(self, mem): md = mem.vault_path / "context.md" - md.write_text("# Important\n\nThis is very important information about the system architecture.") + md.write_text( + "# Important\n\nThis is very important information about the system architecture." + ) mem.index_file(md) ctx = mem.get_relevant_context("architecture") diff --git a/tests/timmy/test_session.py b/tests/timmy/test_session.py index f66a411..8c062d6 100644 --- a/tests/timmy/test_session.py +++ b/tests/timmy/test_session.py @@ -4,15 +4,16 @@ from unittest.mock import MagicMock, patch import pytest - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def _reset_session_singleton(): """Reset the module-level singleton between tests.""" import timmy.session as mod + mod._agent = None yield mod._agent = None @@ -22,6 +23,7 @@ def _reset_session_singleton(): # chat() # --------------------------------------------------------------------------- + def test_chat_returns_string(): """chat() should return a plain string response.""" mock_agent = MagicMock() @@ -29,6 +31,7 @@ def test_chat_returns_string(): with patch("timmy.session._get_agent", return_value=mock_agent): from timmy.session import chat + result = chat("Hi Timmy") assert isinstance(result, str) @@ -42,6 +45,7 @@ def test_chat_passes_session_id(): with patch("timmy.session._get_agent", return_value=mock_agent): from timmy.session import chat + chat("test", session_id="my-session") _, kwargs = mock_agent.run.call_args @@ -55,6 +59,7 @@ def test_chat_uses_default_session_id(): with patch("timmy.session._get_agent", return_value=mock_agent): from timmy.session import chat + chat("test") _, kwargs = mock_agent.run.call_args @@ -68,6 +73,7 @@ def test_chat_singleton_agent_reused(): with patch("timmy.agent.create_timmy", return_value=mock_agent) as mock_factory: from timmy.session import chat + chat("first message") chat("second message") @@ -82,9 +88,11 @@ def test_chat_extracts_user_name(): mock_mem = MagicMock() - with patch("timmy.session._get_agent", return_value=mock_agent), \ - patch("timmy.memory_system.memory_system", mock_mem): + with patch("timmy.session._get_agent", return_value=mock_agent), patch( + "timmy.memory_system.memory_system", mock_mem + ): from timmy.session import chat + chat("my name is Alex") mock_mem.update_user_fact.assert_called_once_with("Name", "Alex") @@ -95,11 +103,13 @@ def test_chat_graceful_degradation_on_memory_failure(): mock_agent = MagicMock() mock_agent.run.return_value = MagicMock(content="I'm operational.") - with patch("timmy.session._get_agent", return_value=mock_agent), \ - patch("timmy.conversation.conversation_manager") as mock_cm: + with patch("timmy.session._get_agent", return_value=mock_agent), patch( + "timmy.conversation.conversation_manager" + ) as mock_cm: mock_cm.extract_user_name.side_effect = Exception("memory broken") from timmy.session import chat + result = chat("test message") assert "operational" in result @@ -109,6 +119,7 @@ def test_chat_graceful_degradation_on_memory_failure(): # _clean_response() # --------------------------------------------------------------------------- + def test_clean_response_strips_json_tool_calls(): """JSON tool call blocks should be removed from response text.""" from timmy.session import _clean_response @@ -158,12 +169,14 @@ def test_clean_response_preserves_normal_text(): def test_clean_response_handles_empty_string(): """Empty string should be returned as-is.""" from timmy.session import _clean_response + assert _clean_response("") == "" def test_clean_response_handles_none(): """None should be returned as-is.""" from timmy.session import _clean_response + assert _clean_response(None) is None @@ -171,10 +184,12 @@ def test_clean_response_handles_none(): # reset_session() # --------------------------------------------------------------------------- + def test_reset_session_clears_context(): """reset_session() should clear the conversation context.""" with patch("timmy.conversation.conversation_manager") as mock_cm: from timmy.session import reset_session + reset_session("test-session") mock_cm.clear_context.assert_called_once_with("test-session") diff --git a/tests/timmy/test_session_logging.py b/tests/timmy/test_session_logging.py index a4139d2..009a718 100644 --- a/tests/timmy/test_session_logging.py +++ b/tests/timmy/test_session_logging.py @@ -1,10 +1,11 @@ """Tests for session logging.""" -import pytest -import tempfile import json +import tempfile from pathlib import Path +import pytest + def test_session_logger_records_message(): """Should record a user message.""" diff --git a/tests/timmy/test_thinking.py b/tests/timmy/test_thinking.py index 6977851..9494f5d 100644 --- a/tests/timmy/test_thinking.py +++ b/tests/timmy/test_thinking.py @@ -6,14 +6,15 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _make_engine(tmp_path: Path): """Create a ThinkingEngine with an isolated temp DB.""" from timmy.thinking import ThinkingEngine + db_path = tmp_path / "thoughts.db" return ThinkingEngine(db_path=db_path) @@ -22,9 +23,11 @@ def _make_engine(tmp_path: Path): # Config # --------------------------------------------------------------------------- + def test_thinking_config_defaults(): """Settings should expose thinking_enabled and thinking_interval_seconds.""" from config import Settings + s = Settings() assert s.thinking_enabled is True assert s.thinking_interval_seconds == 300 @@ -39,6 +42,7 @@ def test_thinking_config_override(): def _settings_with(**kwargs): from config import Settings + return Settings(**kwargs) @@ -46,6 +50,7 @@ def _settings_with(**kwargs): # ThinkingEngine init # --------------------------------------------------------------------------- + def test_engine_init_creates_table(tmp_path): """ThinkingEngine should create the thoughts SQLite table on init.""" engine = _make_engine(tmp_path) @@ -71,6 +76,7 @@ def test_engine_init_empty(tmp_path): # Store and retrieve # --------------------------------------------------------------------------- + def test_store_and_retrieve_thought(tmp_path): """Storing a thought should make it retrievable.""" engine = _make_engine(tmp_path) @@ -107,6 +113,7 @@ def test_store_thought_chains(tmp_path): # Thought chain retrieval # --------------------------------------------------------------------------- + def test_get_thought_chain(tmp_path): """get_thought_chain should return the full chain in chronological order.""" engine = _make_engine(tmp_path) @@ -145,6 +152,7 @@ def test_get_thought_chain_missing(tmp_path): # Recent thoughts # --------------------------------------------------------------------------- + def test_get_recent_thoughts_limit(tmp_path): """get_recent_thoughts should respect the limit parameter.""" engine = _make_engine(tmp_path) @@ -174,9 +182,11 @@ def test_count_thoughts(tmp_path): # Seed gathering # --------------------------------------------------------------------------- + def test_gather_seed_returns_valid_type(tmp_path): """_gather_seed should return a valid seed_type from SEED_TYPES.""" from timmy.thinking import SEED_TYPES + engine = _make_engine(tmp_path) # Run many times to cover randomness @@ -216,6 +226,7 @@ def test_seed_from_memory_graceful(tmp_path): # Continuity context # --------------------------------------------------------------------------- + def test_continuity_first_thought(tmp_path): """First thought should get a special 'first thought' context.""" engine = _make_engine(tmp_path) @@ -238,14 +249,17 @@ def test_continuity_includes_recent(tmp_path): # think_once (async) # --------------------------------------------------------------------------- + @pytest.mark.asyncio async def test_think_once_stores_thought(tmp_path): """think_once should store a thought in the DB.""" engine = _make_engine(tmp_path) - with patch.object(engine, "_call_agent", return_value="I am alive and pondering."), \ - patch.object(engine, "_log_event"), \ - patch.object(engine, "_broadcast", new_callable=AsyncMock): + with patch.object( + engine, "_call_agent", return_value="I am alive and pondering." + ), patch.object(engine, "_log_event"), patch.object( + engine, "_broadcast", new_callable=AsyncMock + ): thought = await engine.think_once() assert thought is not None @@ -258,9 +272,9 @@ async def test_think_once_logs_event(tmp_path): """think_once should log a swarm event.""" engine = _make_engine(tmp_path) - with patch.object(engine, "_call_agent", return_value="A thought."), \ - patch.object(engine, "_log_event") as mock_log, \ - patch.object(engine, "_broadcast", new_callable=AsyncMock): + with patch.object(engine, "_call_agent", return_value="A thought."), patch.object( + engine, "_log_event" + ) as mock_log, patch.object(engine, "_broadcast", new_callable=AsyncMock): await engine.think_once() mock_log.assert_called_once() @@ -273,9 +287,9 @@ async def test_think_once_broadcasts(tmp_path): """think_once should broadcast via WebSocket.""" engine = _make_engine(tmp_path) - with patch.object(engine, "_call_agent", return_value="Broadcast this."), \ - patch.object(engine, "_log_event"), \ - patch.object(engine, "_broadcast", new_callable=AsyncMock) as mock_bc: + with patch.object(engine, "_call_agent", return_value="Broadcast this."), patch.object( + engine, "_log_event" + ), patch.object(engine, "_broadcast", new_callable=AsyncMock) as mock_bc: await engine.think_once() mock_bc.assert_called_once() @@ -300,9 +314,9 @@ async def test_think_once_skips_empty_response(tmp_path): """think_once should skip storing when agent returns empty string.""" engine = _make_engine(tmp_path) - with patch.object(engine, "_call_agent", return_value=" "), \ - patch.object(engine, "_log_event"), \ - patch.object(engine, "_broadcast", new_callable=AsyncMock): + with patch.object(engine, "_call_agent", return_value=" "), patch.object( + engine, "_log_event" + ), patch.object(engine, "_broadcast", new_callable=AsyncMock): thought = await engine.think_once() assert thought is None @@ -326,9 +340,11 @@ async def test_think_once_chains_thoughts(tmp_path): """Successive think_once calls should chain thoughts via parent_id.""" engine = _make_engine(tmp_path) - with patch.object(engine, "_call_agent", side_effect=["First.", "Second.", "Third."]), \ - patch.object(engine, "_log_event"), \ - patch.object(engine, "_broadcast", new_callable=AsyncMock): + with patch.object( + engine, "_call_agent", side_effect=["First.", "Second.", "Third."] + ), patch.object(engine, "_log_event"), patch.object( + engine, "_broadcast", new_callable=AsyncMock + ): t1 = await engine.think_once() t2 = await engine.think_once() t3 = await engine.think_once() @@ -342,6 +358,7 @@ async def test_think_once_chains_thoughts(tmp_path): # Dashboard route # --------------------------------------------------------------------------- + def test_thinking_route_returns_200(client): """GET /thinking should return 200.""" response = client.get("/thinking") diff --git a/tests/timmy/test_timmy_tools.py b/tests/timmy/test_timmy_tools.py index 7b30a4f..5cd54cd 100644 --- a/tests/timmy/test_timmy_tools.py +++ b/tests/timmy/test_timmy_tools.py @@ -4,17 +4,17 @@ Covers tool usage statistics, persona-to-toolkit mapping, catalog generation, and graceful degradation when Agno is unavailable. """ -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch import pytest from timmy.tools import ( _TOOL_USAGE, + PERSONA_TOOLKITS, _track_tool_usage, + get_all_available_tools, get_tool_stats, get_tools_for_persona, - get_all_available_tools, - PERSONA_TOOLKITS, ) @@ -104,6 +104,7 @@ class TestPersonaToolkits: "seer", "forge", "quill", + "lab", "pixel", "lyra", "reel", @@ -163,9 +164,9 @@ class TestToolCatalog: "list_files", } for tool_id in base_tools: - assert "orchestrator" in catalog[tool_id]["available_in"], ( - f"Orchestrator missing tool: {tool_id}" - ) + assert ( + "orchestrator" in catalog[tool_id]["available_in"] + ), f"Orchestrator missing tool: {tool_id}" def test_catalog_echo_research_tools(self): catalog = get_all_available_tools() @@ -196,9 +197,10 @@ class TestAiderTool: This is a smoke test - we just verify it returns something. """ - from timmy.tools import create_aider_tool from pathlib import Path + from timmy.tools import create_aider_tool + tool = create_aider_tool(Path.cwd()) # Call with a simple prompt - should return something (even if error) diff --git a/tests/timmy/test_tools_extended.py b/tests/timmy/test_tools_extended.py index 34a7089..0d74206 100644 --- a/tests/timmy/test_tools_extended.py +++ b/tests/timmy/test_tools_extended.py @@ -1,17 +1,18 @@ """Extended tests for timmy.tools — covers tool tracking, stats, and create_* functions.""" +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from timmy.tools import ( - _track_tool_usage, - get_tool_stats, - calculator, _TOOL_USAGE, - ToolStats, AgentTools, PersonaTools, + ToolStats, + _track_tool_usage, + calculator, create_aider_tool, + get_tool_stats, ) @@ -97,6 +98,7 @@ class TestCalculatorExtended: def test_math_functions(self): import math + assert calculator("math.sqrt(144)") == "12.0" assert calculator("math.pi") == str(math.pi) assert calculator("math.log(100, 10)") == str(math.log(100, 10)) @@ -132,6 +134,7 @@ class TestCreateToolFunctions: with patch("timmy.tools._ImportError", ImportError("no agno")): with pytest.raises(ImportError): from timmy.tools import create_research_tools + create_research_tools() def test_create_code_tools_no_agno(self): @@ -139,6 +142,7 @@ class TestCreateToolFunctions: with patch("timmy.tools._ImportError", ImportError("no agno")): with pytest.raises(ImportError): from timmy.tools import create_code_tools + create_code_tools() def test_create_data_tools_no_agno(self): @@ -146,6 +150,7 @@ class TestCreateToolFunctions: with patch("timmy.tools._ImportError", ImportError("no agno")): with pytest.raises(ImportError): from timmy.tools import create_data_tools + create_data_tools() def test_create_writing_tools_no_agno(self): @@ -153,6 +158,7 @@ class TestCreateToolFunctions: with patch("timmy.tools._ImportError", ImportError("no agno")): with pytest.raises(ImportError): from timmy.tools import create_writing_tools + create_writing_tools() @@ -187,6 +193,7 @@ class TestAiderTool: @patch("subprocess.run") def test_aider_timeout(self, mock_run, tmp_path): import subprocess + mock_run.side_effect = subprocess.TimeoutExpired(cmd="aider", timeout=120) tool = create_aider_tool(tmp_path) result = tool.run_aider("slow task") diff --git a/tests/timmy/test_vector_store.py b/tests/timmy/test_vector_store.py index f9113e6..ce6041d 100644 --- a/tests/timmy/test_vector_store.py +++ b/tests/timmy/test_vector_store.py @@ -1,17 +1,18 @@ """Tests for vector store (semantic memory) system.""" import pytest + from timmy.memory.vector_store import ( - store_memory, - search_memories, - get_memory_context, - recall_personal_facts, - store_personal_fact, - delete_memory, - get_memory_stats, - prune_memories, _cosine_similarity, _keyword_overlap, + delete_memory, + get_memory_context, + get_memory_stats, + prune_memories, + recall_personal_facts, + search_memories, + store_memory, + store_personal_fact, ) @@ -25,13 +26,13 @@ class TestVectorStore: source="test_agent", context_type="conversation", ) - + assert entry.content == "This is a test memory" assert entry.source == "test_agent" assert entry.context_type == "conversation" assert entry.id is not None assert entry.timestamp is not None - + def test_store_memory_with_metadata(self): """Test storing memory with metadata.""" entry = store_memory( @@ -43,27 +44,27 @@ class TestVectorStore: session_id="session-456", metadata={"importance": "high", "tags": ["test"]}, ) - + assert entry.agent_id == "agent-001" assert entry.task_id == "task-123" assert entry.session_id == "session-456" assert entry.metadata == {"importance": "high", "tags": ["test"]} - + def test_search_memories_basic(self): """Test basic memory search.""" # Store some memories store_memory("Bitcoin is a decentralized currency", source="user") store_memory("Lightning Network enables fast payments", source="user") store_memory("Python is a programming language", source="user") - + # Search for Bitcoin-related memories results = search_memories("cryptocurrency", limit=5) - + # Should find at least one relevant result assert len(results) > 0 # Check that results have relevance scores assert all(r.relevance_score is not None for r in results) - + def test_search_with_filters(self): """Test searching with filters.""" # Store memories with different types @@ -85,67 +86,67 @@ class TestVectorStore: context_type="conversation", agent_id="agent-2", ) - + # Filter by context type facts = search_memories("AI", context_type="fact", limit=5) assert all(f.context_type == "fact" for f in facts) - + # Filter by agent agent1_memories = search_memories("conversation", agent_id="agent-1", limit=5) assert all(m.agent_id == "agent-1" for m in agent1_memories) - + def test_get_memory_context(self): """Test getting formatted memory context.""" # Store memories store_memory("Important fact about the project", source="user") store_memory("Another relevant detail", source="agent") - + # Get context context = get_memory_context("project details", max_tokens=500) - + assert isinstance(context, str) assert len(context) > 0 assert "Relevant context from memory:" in context - + def test_personal_facts(self): """Test storing and recalling personal facts.""" # Store a personal fact fact = store_personal_fact("User prefers dark mode", agent_id="agent-1") - + assert fact.context_type == "fact" assert fact.content == "User prefers dark mode" - + # Recall facts facts = recall_personal_facts(agent_id="agent-1") assert "User prefers dark mode" in facts - + def test_delete_memory(self): """Test deleting a memory entry.""" # Create a memory entry = store_memory("To be deleted", source="test") - + # Delete it deleted = delete_memory(entry.id) assert deleted is True - + # Verify it's gone (search shouldn't find it) results = search_memories("To be deleted", limit=10) assert not any(r.id == entry.id for r in results) - + # Deleting non-existent should return False deleted_again = delete_memory(entry.id) assert deleted_again is False - + def test_get_memory_stats(self): """Test memory statistics.""" stats = get_memory_stats() - + assert "total_entries" in stats assert "by_type" in stats assert "with_embeddings" in stats assert "has_embedding_model" in stats assert isinstance(stats["total_entries"], int) - + def test_prune_memories(self): """Test pruning old memories.""" # This just verifies the function works without error @@ -156,48 +157,48 @@ class TestVectorStore: class TestVectorStoreUtils: """Test utility functions.""" - + def test_cosine_similarity_identical(self): """Test cosine similarity of identical vectors.""" vec = [1.0, 0.0, 0.0] similarity = _cosine_similarity(vec, vec) assert similarity == pytest.approx(1.0) - + def test_cosine_similarity_orthogonal(self): """Test cosine similarity of orthogonal vectors.""" vec1 = [1.0, 0.0, 0.0] vec2 = [0.0, 1.0, 0.0] similarity = _cosine_similarity(vec1, vec2) assert similarity == pytest.approx(0.0) - + def test_cosine_similarity_opposite(self): """Test cosine similarity of opposite vectors.""" vec1 = [1.0, 0.0, 0.0] vec2 = [-1.0, 0.0, 0.0] similarity = _cosine_similarity(vec1, vec2) assert similarity == pytest.approx(-1.0) - + def test_cosine_similarity_zero_vector(self): """Test cosine similarity with zero vector.""" vec1 = [1.0, 0.0, 0.0] vec2 = [0.0, 0.0, 0.0] similarity = _cosine_similarity(vec1, vec2) assert similarity == 0.0 - + def test_keyword_overlap_exact(self): """Test keyword overlap with exact match.""" query = "bitcoin lightning" content = "bitcoin lightning network" overlap = _keyword_overlap(query, content) assert overlap == 1.0 - + def test_keyword_overlap_partial(self): """Test keyword overlap with partial match.""" query = "bitcoin lightning" content = "bitcoin is great" overlap = _keyword_overlap(query, content) assert overlap == 0.5 - + def test_keyword_overlap_none(self): """Test keyword overlap with no match.""" query = "bitcoin" @@ -208,7 +209,7 @@ class TestVectorStoreUtils: class TestVectorStoreIntegration: """Integration tests for vector store workflow.""" - + def test_memory_workflow(self): """Test complete memory workflow: store -> search -> retrieve.""" # Store memories @@ -230,33 +231,33 @@ class TestVectorStoreIntegration: context_type="conversation", session_id="session-1", ) - + # Search for deadline-related memories results = search_memories("when is the deadline", limit=5) - + # Should find the deadline memory assert len(results) > 0 # Check that the most relevant result contains "deadline" assert any("deadline" in r.content.lower() for r in results[:3]) - + # Get context for a prompt context = get_memory_context("project timeline", session_id="session-1") assert "deadline" in context.lower() or "implement" in context.lower() - + def test_embedding_vs_keyword_fallback(self): """Test that the system works with or without embedding model.""" stats = get_memory_stats() - + # Store a memory entry = store_memory( "Testing embedding functionality", source="test", compute_embedding=True, ) - + # Should have embedding (even if it's fallback) assert entry.embedding is not None - + # Search should work regardless results = search_memories("embedding test", limit=5) assert len(results) > 0 diff --git a/tests/timmy_serve/test_inter_agent.py b/tests/timmy_serve/test_inter_agent.py index dead093..524042b 100644 --- a/tests/timmy_serve/test_inter_agent.py +++ b/tests/timmy_serve/test_inter_agent.py @@ -16,8 +16,10 @@ class TestAgentMessage: def test_custom_fields(self): msg = AgentMessage( - from_agent="seer", to_agent="forge", - content="hello", message_type="command", + from_agent="seer", + to_agent="forge", + content="hello", + message_type="command", ) assert msg.from_agent == "seer" assert msg.to_agent == "forge" diff --git a/tox.ini b/tox.ini index 816bb56..86f09f5 100644 --- a/tox.ini +++ b/tox.ini @@ -8,11 +8,11 @@ commands_pre = poetry install --with dev --quiet [testenv:unit] description = Fast unit tests (no I/O, no external services) -commands = poetry run pytest tests/ -q --tb=short -m "unit and not ollama and not docker and not selenium and not external_api" +commands = poetry run pytest tests/ -q --tb=short -m "unit and not integration and not ollama and not docker and not selenium and not external_api" [testenv:integration] description = Integration tests (may use SQLite, but no external services) -commands = poetry run pytest tests/ -q --tb=short -m "integration and not ollama and not docker and not selenium and not external_api" +commands = poetry run pytest tests/ -q --tb=short -m "integration and not unit and not ollama and not docker and not selenium and not external_api" [testenv:ollama] description = Live LLM tests via Ollama (requires Ollama running with a tiny model)