"""Tenant-aware ChromaDB memory: schema compartilhado + tool-usage por loja.
Vanna's `ChromaAgentMemory` é single-collection — todas as memórias caem no mesmo
pool, sem scoping. No app ClubPetro isso vazaria perguntas/aprendizados entre
clientes (program × loja) que compartilham o mesmo deploy.
Esta classe compõe DUAS instâncias de `ChromaAgentMemory`:
• shared (text memories) — collection `` — schema docs do `train.py`.
Mesma collection que o app sempre usou; intacto.
• per-tenant (tool usage) — collection `__p__s`,
instanciada lazy na primeira chamada de cada (program_id, store_id).
Roteamento (1 método da ABC = 1 chamada delegada):
text_memory: save / search / get_recent / delete → shared
tool_usage: save / search / get_recent / delete → tenant
clear_memories: → tenant (preserva schema)
`context.user.program_id` / `store_id` são lidos por chamada — RLS já valida via
`_require_id` no resolver, então só chega aqui já checado. `train.py` roda com
user "trainer" sem program/store; só usa text_memory, então não toca a lógica
per-tenant — funciona sem mudança.
"""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional, Tuple
from vanna.capabilities.agent_memory import (
AgentMemory,
TextMemory,
TextMemorySearchResult,
ToolMemory,
ToolMemorySearchResult,
)
from vanna.core.tool import ToolContext
from vanna.integrations.chromadb import ChromaAgentMemory
_SLUG_RE = re.compile(r"[^a-z0-9]+")
def _slug(value: str) -> str:
"""Normaliza ID pra um fragmento válido de collection name do ChromaDB.
ChromaDB exige nomes lowercase começando/terminando em alfanumérico,
com `_`/`-`/`.` permitidos no meio. `_require_id` (rls_runner.py) já
garante `^[A-Za-z0-9_-]+$`, então só precisamos lowercase + colapso
de qualquer char "estranho" (defensive).
"""
s = _SLUG_RE.sub("", value.lower())
return s or "x"
class TenantAwareChromaMemory(AgentMemory):
"""Roteia memórias entre collection shared (schema) e per-tenant (tool usage).
Args:
persist_directory: Pasta do PersistentClient do ChromaDB.
base_collection_name: Nome da collection compartilhada (text memories).
Per-tenant collections usam esse nome como prefixo.
embedding_function: Embedding function passada pra todas as collections
(compartilhada + tenants). Mantém consistência semântica.
"""
def __init__(
self,
persist_directory: str,
base_collection_name: str,
embedding_function: Optional[Any] = None,
):
self._persist_directory = persist_directory
self._base = base_collection_name
self._ef = embedding_function
self._shared = ChromaAgentMemory(
persist_directory=persist_directory,
collection_name=base_collection_name,
embedding_function=embedding_function,
)
self._tenants: Dict[Tuple[str, str], ChromaAgentMemory] = {}
@staticmethod
def _tenant_ids(context: ToolContext) -> Tuple[str, str]:
user = getattr(context, "user", None)
prog = getattr(user, "program_id", None) if user else None
store = getattr(user, "store_id", None) if user else None
if not prog or not store:
raise PermissionError(
"Tool-usage memory requires User.program_id and User.store_id; "
"got program_id={!r} store_id={!r}".format(prog, store)
)
return str(prog), str(store)
def _tenant(self, context: ToolContext) -> ChromaAgentMemory:
key = self._tenant_ids(context)
cached = self._tenants.get(key)
if cached is not None:
return cached
prog, store = key
name = f"{self._base}__p{_slug(prog)}__s{_slug(store)}"
instance = ChromaAgentMemory(
persist_directory=self._persist_directory,
collection_name=name,
embedding_function=self._ef,
)
self._tenants[key] = instance
return instance
# === tool usage (per-tenant) ===
async def save_tool_usage(
self,
question: str,
tool_name: str,
args: Dict[str, Any],
context: ToolContext,
success: bool = True,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
return await self._tenant(context).save_tool_usage(
question=question,
tool_name=tool_name,
args=args,
context=context,
success=success,
metadata=metadata,
)
async def search_similar_usage(
self,
question: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
tool_name_filter: Optional[str] = None,
) -> List[ToolMemorySearchResult]:
return await self._tenant(context).search_similar_usage(
question=question,
context=context,
limit=limit,
similarity_threshold=similarity_threshold,
tool_name_filter=tool_name_filter,
)
async def get_recent_memories(
self, context: ToolContext, limit: int = 10
) -> List[ToolMemory]:
return await self._tenant(context).get_recent_memories(
context=context, limit=limit
)
async def delete_by_id(self, context: ToolContext, memory_id: str) -> bool:
return await self._tenant(context).delete_by_id(
context=context, memory_id=memory_id
)
# === text memory (shared / schema) ===
async def save_text_memory(
self, content: str, context: ToolContext
) -> TextMemory:
return await self._shared.save_text_memory(content=content, context=context)
async def search_text_memories(
self,
query: str,
context: ToolContext,
*,
limit: int = 10,
similarity_threshold: float = 0.7,
) -> List[TextMemorySearchResult]:
return await self._shared.search_text_memories(
query=query,
context=context,
limit=limit,
similarity_threshold=similarity_threshold,
)
async def get_recent_text_memories(
self, context: ToolContext, limit: int = 10
) -> List[TextMemory]:
return await self._shared.get_recent_text_memories(
context=context, limit=limit
)
async def delete_text_memory(self, context: ToolContext, memory_id: str) -> bool:
return await self._shared.delete_text_memory(
context=context, memory_id=memory_id
)
# === clear: só tenant (schema é gerenciado pelo train.py) ===
async def clear_memories(
self,
context: ToolContext,
tool_name: Optional[str] = None,
before_date: Optional[str] = None,
) -> int:
return await self._tenant(context).clear_memories(
context=context, tool_name=tool_name, before_date=before_date
)