vanna-clubpetro/agent.py
leonardosalazar-cp 736cc0f277
Some checks failed
CD / build (pull_request) Failing after 1s
CD / deploy (pull_request) Has been skipped
fix: troca header markdown por negrito na saudação inicial
O `####` aparecia literal no balão da saudação porque o markdown
processor do bootstrap renderiza headers, mas o estilo destoava
do resto. Usa **negrito** pra destaque visual consistente.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-06 16:21:48 -03:00

183 lines
6.8 KiB
Python

"""Vanna agent factory: OpenAI + ChromaDB (local) + ClickHouse Cloud."""
import os
from typing import Optional
from dotenv import load_dotenv
from vanna import Agent, AgentConfig, User
from vanna.core.registry import ToolRegistry, ToolRejection
from vanna.core.system_prompt import DefaultSystemPromptBuilder
from vanna.core.tool import Tool, ToolContext
from vanna.core.user import RequestContext, UserResolver
from vanna.core.workflow import DefaultWorkflowHandler
from vanna.integrations.openai import OpenAILlmService
from vanna.tools import LocalFileSystem, RunSqlTool
from vanna.tools.agent_memory import (
SaveQuestionToolArgsTool,
SaveTextMemoryTool,
SearchSavedCorrectToolUsesTool,
)
from events_sink import record_sql, record_tool
from rls_runner import RLSClickHouseRunner, _require_id
from system_prompt import SYSTEM_PROMPT
from tenant_memory import TenantAwareChromaMemory
from viz_tool import VisualizeDataToolPT
class EventCapturingToolRegistry(ToolRegistry):
"""ToolRegistry que registra cada tool call (e SQL do run_sql) no
TurnRecord ativo via ContextVar. `transform_args` é chamado pelo upstream
em TODA execução de tool antes do `tool.execute`, então é o ponto único
pra capturar a atividade do agente sem wrappar cada tool.
"""
async def transform_args(
self,
tool: Tool,
args,
user: User,
context: ToolContext,
):
record_tool(tool.name)
if tool.name == "run_sql":
sql = getattr(args, "sql", None)
if isinstance(sql, str):
record_sql(sql)
return await super().transform_args(tool, args, user, context)
load_dotenv()
class StaticUserResolver(UserResolver):
"""Always resolves to the same single-user identity (CLI / local use)."""
def __init__(self, user: User):
self._user = user
async def resolve_user(self, request_context: RequestContext) -> User:
return self._user
class RequestContextUserResolver(UserResolver):
"""Reads program_id/store_id from request_context.
Prefers query_params (web component sends them in the endpoint URL),
falls back to metadata (for server-to-server callers using ChatRequest.metadata).
Validates immediately so invalid/missing RLS context fails before any tool runs.
"""
async def resolve_user(self, request_context: RequestContext) -> User:
qp = request_context.query_params or {}
meta = request_context.metadata or {}
program_id = _require_id(
"program_id", qp.get("program_id") or meta.get("program_id")
)
store_id = _require_id(
"store_id", qp.get("store_id") or meta.get("store_id")
)
return User(
id="web",
username="web",
program_id=program_id,
store_id=store_id,
)
def build_agent(
program_id: Optional[str] = None,
store_id: Optional[str] = None,
user_resolver: Optional[UserResolver] = None,
) -> Agent:
if user_resolver is None:
program_id = program_id or os.environ.get("RLS_PROGRAM_ID")
store_id = store_id or os.environ.get("RLS_STORE_ID")
if not program_id or not store_id:
raise RuntimeError(
"RLS requires program_id and store_id. "
"Pass via build_agent(...) or set RLS_PROGRAM_ID / RLS_STORE_ID."
)
user_resolver = StaticUserResolver(
User(
id="local",
username="local",
program_id=program_id,
store_id=store_id,
)
)
llm = OpenAILlmService(
model=os.environ["OPENAI_MODEL"],
api_key=os.environ["OPENAI_API_KEY"],
)
sql_runner = RLSClickHouseRunner(
host=os.environ["CLICKHOUSE_HOST"],
port=int(os.environ["CLICKHOUSE_PORT"]),
user=os.environ["CLICKHOUSE_USER"],
password=os.environ["CLICKHOUSE_PASSWORD"],
database=os.environ["CLICKHOUSE_DATABASE"],
secure=os.environ.get("CLICKHOUSE_SECURE", "true").lower() == "true",
)
file_system = LocalFileSystem(working_directory="./data_storage")
tools = EventCapturingToolRegistry()
tools.register_local_tool(
RunSqlTool(sql_runner=sql_runner, file_system=file_system),
access_groups=[],
)
tools.register_local_tool(
VisualizeDataToolPT(file_system=file_system),
access_groups=[],
)
# Memory tools — fecham o loop de self-learning. Search registrado
# primeiro pra incentivar "consulta antes de gerar SQL nova"; Save*
# depois. Todos zero-arg: leem agent_memory do ToolContext em runtime.
tools.register_local_tool(SearchSavedCorrectToolUsesTool(), access_groups=[])
tools.register_local_tool(SaveQuestionToolArgsTool(), access_groups=[])
tools.register_local_tool(SaveTextMemoryTool(), access_groups=[])
# Memória multi-tenant: text memories (schema docs do train.py) ficam
# numa collection compartilhada `vanna_clickhouse_gold`; tool-usage
# memories (pares pergunta→args salvos pelo LLM em runtime) vão pra
# collection per-(program_id, store_id) — evita vazamento entre tenants.
memory = TenantAwareChromaMemory(
persist_directory="./chroma_db",
base_collection_name="vanna_clickhouse_gold",
)
# Default 1.0 mantém compat com modelos de reasoning/gpt-5* que rejeitam
# outros valores. Pra modelos que aceitam ajuste (ex.: gpt-4o), set
# OPENAI_TEMPERATURE=0.2 no .env pra mais determinismo na geração de SQL.
temperature = float(os.environ.get("OPENAI_TEMPERATURE", "1.0"))
welcome_message = (
"**👋 Olá! Aqui é a ClubPetro IA**\n\n"
"Sua assistente de inteligência de dados. Eu transformo dados complexos em "
"respostas claras, direto ao ponto. Precisa de um relatório de faturamento, "
"entender a performance da sua equipe ou aprofundar no comportamento de "
"compra de seus clientes? É só perguntar. Eu cuido dos números e gráficos "
"para você focar no que importa: **lucrar mais**.\n\n"
"**Experimente:**\n"
"- *Faturamento por mês no último semestre*\n"
"- *Top 10 produtos da semana*\n"
"- *Clientes que ganharam mais desconto*\n"
"- *Frentistas com maior índice de fidelidade*"
)
return Agent(
llm_service=llm,
tool_registry=tools,
user_resolver=user_resolver,
agent_memory=memory,
config=AgentConfig(stream_responses=False, temperature=temperature),
system_prompt_builder=DefaultSystemPromptBuilder(base_prompt=SYSTEM_PROMPT),
workflow_handler=DefaultWorkflowHandler(welcome_message=welcome_message),
)
def local_request_context() -> RequestContext:
return RequestContext(remote_addr="127.0.0.1", metadata={"source": "cli"})