vanna-clubpetro/viz_tool.py
leonardosalazar-cp 1d152c0dce Initial commit: Vanna 2.0 deployment for ClubPetro
Wrapper application around upstream Vanna with:
- Tenant-aware ChromaDB memory (per program/store)
- ClickHouse RLS runner with introspection guards
- PT-BR system prompt and chat translations
- Custom Plotly chart generator (ranked bar, datetime coercion)
- Embed bootstrap (theme pierce + i18n + markdown) shared by demo and React app
- Event sink for chat turn observability
2026-04-29 17:22:05 -03:00

381 lines
15 KiB
Python

"""Override do VisualizeDataTool com `chart_type` controlável pelo LLM
+ generator que evita o fallback `go.Table` do upstream.
Três customizações sobre upstream:
1. Arg novo `chart_type` (line/bar/scatter/histogram/area) no schema. Quando
passado pelo LLM, força o tipo. Heurística automática só roda como
fallback quando o LLM omite (a guidance no system_prompt + descrição do
tool empurra o LLM a sempre passar).
2. `ClubPetroChartGenerator` continua removendo a heurística "4+ colunas
→ go.Table" do upstream (`vanna/src/vanna/integrations/plotly/chart_generator.py:51-55`),
que duplicava visualmente o dataframe rich.
3. `_coerce_datetime_columns` foi movida pra rodar nos DOIS caminhos
(<4 e >=4 colunas). Antes só rodava no 4+, então queries 2-col
`SELECT data_da_compra, faturamento` nunca eram detectadas como
time series e viravam bar. Agora a coerção uniforme torna o
comportamento previsível e o `chart_type=line` forçado funciona
consistente.
`chart_type` é threadado via `ContextVar` em vez de monkey-patch — o
`VisualizeDataTool` é singleton no `ToolRegistry` e pode ter execuções
concorrentes em conversas paralelas.
"""
from __future__ import annotations
import contextvars
import json
from typing import Any, Dict, List, Literal, Optional, Type
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
from pydantic import Field
from vanna.core.tool import ToolContext, ToolResult
from vanna.integrations.plotly import PlotlyChartGenerator
from vanna.tools import VisualizeDataTool
from vanna.tools.visualize_data import VisualizeDataArgs
from events_sink import record_chart
ChartType = Literal["line", "bar", "scatter", "histogram", "area"]
_chart_type_override: contextvars.ContextVar[Optional[ChartType]] = (
contextvars.ContextVar("chart_type_override", default=None)
)
class VisualizeDataArgsPT(VisualizeDataArgs):
chart_type: Optional[ChartType] = Field(
default=None,
description=(
"Tipo do gráfico. Passe SEMPRE que souber o tipo certo: "
"'line' = série temporal (evolução, tendência ao longo do tempo); "
"'bar' = ranking, top N, comparação categórica; "
"'scatter' = correlação entre 2 métricas numéricas; "
"'histogram' = distribuição de 1 métrica; "
"'area' = acumulado temporal. "
"Omita SÓ se genuinamente em dúvida — sem hint o sistema cai em "
"heurística por shape do DataFrame, que pode escolher errado."
),
)
_DESCRIPTION = (
"Renderiza um gráfico Plotly a partir do CSV produzido por run_sql na rodada anterior. "
"CHAME SEMPRE que o resultado da query for naturalmente visual — não espere o usuário pedir "
"a palavra 'gráfico' explicitamente; deduza do tipo de pergunta.\n"
"\n"
"GATILHOS OBRIGATÓRIOS (chame após o run_sql):\n"
"• Ranking / Top N — 'top 10 produtos', 'quais atendentes mais venderam', 'maiores clientes', "
"'ranking de X', 'melhores/piores Y' → chart_type='bar'.\n"
"• Série temporal — 'evolução', 'ao longo de', 'por dia/semana/mês/hora', 'últimos N dias', "
"'tendência', 'crescimento', 'comparar meses' → chart_type='line'.\n"
"• Comparação categórica — 'vendas por categoria', 'faturamento por produto', 'X por Y', "
"'distribuição', 'breakdown' → chart_type='bar'.\n"
"• Participação / share — 'participação', '% de', 'share' → chart_type='bar' "
"(evitar pizza com >3 fatias).\n"
"• Correlação — 'relação entre X e Y', 'preço vs volume', 'desconto vs faturamento' "
"→ chart_type='scatter'.\n"
"• Distribuição de UMA métrica — 'distribuição dos tickets', 'histograma de Y' "
"→ chart_type='histogram'.\n"
"• Acumulado temporal — 'faturamento acumulado por mês' → chart_type='area'.\n"
"\n"
"FORMA IDEAL DO CSV antes de chamar (responsabilidade do run_sql anterior):\n"
"• Série temporal: 1 coluna de tempo + 1-3 métricas.\n"
"• Ranking/categórico: 1 coluna de categoria + 1 métrica, máx 20-40 linhas.\n"
"• Scatter: 2 colunas numéricas.\n"
"• ≤3 colunas no total é o ponto ideal — se a query trouxer 4+ colunas, esta ferramenta "
"ainda gera chart real (não tabela), mas o resultado fica mais limpo com SELECT focado.\n"
"\n"
"QUANDO NÃO CHAMAR:\n"
"• Pergunta com resposta numérica única ('qual o faturamento de hoje?') — o número no texto "
"basta, gráfico não agrega.\n"
"• Usuário pediu explicitamente só a tabela ('me lista', 'exporta a tabela').\n"
"\n"
"ARGUMENTOS: filename = caminho do CSV retornado pelo run_sql anterior; "
"chart_type = tipo do gráfico (passe sempre que souber, ver lista acima); "
"title (opcional) = rótulo curto em pt-BR (ex: 'Faturamento por dia — últimos 30 dias')."
)
class ClubPetroChartGenerator(PlotlyChartGenerator):
"""Generator com:
- Coerção de datetime aplicada pra QUALQUER nº de colunas (uniforme).
- Override `chart_type` lido de ContextVar (vem do `VisualizeDataToolPT.execute`).
- Drop do fallback "4+ colunas → go.Table" do upstream — sempre chart real.
"""
def generate_chart(
self,
df: pd.DataFrame,
title: str = "Chart",
chart_type: Optional[ChartType] = None,
) -> Dict[str, Any]:
if df.empty:
raise ValueError("Cannot visualize empty DataFrame")
chart_type = chart_type or _chart_type_override.get()
df = self._coerce_datetime_columns(df)
if chart_type is not None:
fig = self._render_forced(df, title, chart_type)
return json.loads(pio.to_json(fig))
if len(df.columns) < 4:
return super().generate_chart(df, title)
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
if datetime_cols and numeric_cols:
fig = self._create_time_series_chart(
df.sort_values(datetime_cols[0]),
datetime_cols[0],
numeric_cols[:3],
title,
)
elif categorical_cols and numeric_cols:
cat = categorical_cols[0]
num = numeric_cols[0]
agg = (
df.groupby(cat)[num]
.sum()
.reset_index()
.sort_values(num, ascending=False)
.head(40)
)
fig = self._create_ranked_bar_chart(agg, cat, num, title)
elif len(numeric_cols) >= 2:
fig = self._create_scatter_plot(df, numeric_cols[0], numeric_cols[1], title)
else:
fig = self._create_generic_chart(df, df.columns[0], df.columns[1], title)
return json.loads(pio.to_json(fig))
def _render_forced(
self, df: pd.DataFrame, title: str, chart_type: ChartType
) -> go.Figure:
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
if chart_type == "line":
if not numeric_cols:
raise ValueError(
"chart_type='line' requires at least one numeric column"
)
if datetime_cols:
return self._create_time_series_chart(
df.sort_values(datetime_cols[0]),
datetime_cols[0],
numeric_cols[:5],
title,
)
x_col = (
categorical_cols[0]
if categorical_cols
else df.columns[0]
)
y_cols = [c for c in numeric_cols if c != x_col][:5]
if not y_cols:
raise ValueError(
"chart_type='line' needs a numeric column distinct from the x-axis"
)
return self._create_time_series_chart(df, x_col, y_cols, title)
if chart_type == "area":
if not numeric_cols:
raise ValueError(
"chart_type='area' requires at least one numeric column"
)
x_col = (
datetime_cols[0]
if datetime_cols
else (categorical_cols[0] if categorical_cols else df.columns[0])
)
ordered = df.sort_values(x_col) if datetime_cols else df
y_cols = [c for c in numeric_cols if c != x_col][:5]
if not y_cols:
raise ValueError(
"chart_type='area' needs a numeric column distinct from the x-axis"
)
return self._create_area_chart(ordered, x_col, y_cols, title)
if chart_type == "bar":
if not numeric_cols:
raise ValueError(
"chart_type='bar' requires at least one numeric column"
)
x_col = categorical_cols[0] if categorical_cols else df.columns[0]
y_col = next((c for c in numeric_cols if c != x_col), None)
if y_col is None:
raise ValueError(
"chart_type='bar' needs a numeric column distinct from the x-axis"
)
agg = (
df.groupby(x_col)[y_col]
.sum()
.reset_index()
.sort_values(y_col, ascending=False)
.head(40)
)
return self._create_ranked_bar_chart(agg, x_col, y_col, title)
if chart_type == "scatter":
if len(numeric_cols) < 2:
raise ValueError(
"chart_type='scatter' requires 2 numeric columns"
)
return self._create_scatter_plot(
df, numeric_cols[0], numeric_cols[1], title
)
if chart_type == "histogram":
if not numeric_cols:
raise ValueError(
"chart_type='histogram' requires at least one numeric column"
)
return self._create_histogram(df, numeric_cols[0], title)
raise ValueError(f"Unknown chart_type: {chart_type!r}")
def _create_ranked_bar_chart(
self,
df: pd.DataFrame,
x_col: str,
y_col: str,
title: str,
) -> go.Figure:
"""Bar chart preservando a ordem do `df` (sem re-groupby).
Upstream `_create_bar_chart` re-agrega com `groupby(x_col).sum()`
e perde a ordenação descendente — bars saem alfabéticos. Pra
ranking real precisamos travar `categoryorder=array`.
"""
order = df[x_col].astype(str).tolist()
fig = go.Figure(
data=[
go.Bar(
x=order,
y=df[y_col].tolist(),
marker_color=self.THEME_COLORS["orange"],
)
]
)
fig.update_layout(
title=title,
xaxis_title=x_col,
yaxis_title=y_col,
xaxis={"categoryorder": "array", "categoryarray": order},
)
self._apply_standard_layout(fig)
return fig
def _create_area_chart(
self,
df: pd.DataFrame,
x_col: str,
y_cols: List[str],
title: str,
) -> go.Figure:
fig = go.Figure()
for i, col in enumerate(y_cols):
color = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
fig.add_trace(
go.Scatter(
x=df[x_col],
y=df[col],
mode="lines",
name=col,
fill="tozeroy" if i == 0 else "tonexty",
line=dict(color=color),
)
)
fig.update_layout(
title=title,
xaxis_title=x_col,
yaxis_title="Value",
hovermode="x unified",
)
self._apply_standard_layout(fig)
return fig
@staticmethod
def _coerce_datetime_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Converte colunas object que parecem date/datetime em datetime64.
ClickHouse devolve datas como string no CSV; pandas lê como object e
a heurística de datetime do upstream falha. Aceita dois formatos:
- ISO `2026-01-01` (ou `2026/01/01`) — quando a query devolve o
valor cru (DateTime sem coerção downstream).
- BR `01/01/2026` — quando o `RLSClickHouseRunner` formata
colunas datetime pra exibição no rich component (vê
`_format_date_columns` em rls_runner.py). pandas precisa
`dayfirst=True` ou seria parseado como mês/dia (US).
Best-effort (`errors="raise"` dentro do try) — só converte colunas
que são parseáveis 100%.
"""
for col in df.select_dtypes(include=["object"]).columns:
sample = df[col].dropna().astype(str).head(5).tolist()
if not sample:
continue
looks_like_iso = all(
len(s) >= 8 and s[:4].isdigit() and s[4] in "-/"
for s in sample
)
looks_like_br = all(
len(s) >= 10
and s[:2].isdigit()
and s[2] == "/"
and s[3:5].isdigit()
and s[5] == "/"
and s[6:10].isdigit()
for s in sample
)
if not (looks_like_iso or looks_like_br):
continue
try:
df[col] = pd.to_datetime(
df[col], dayfirst=looks_like_br, errors="raise"
)
except (ValueError, TypeError):
pass
return df
class VisualizeDataToolPT(VisualizeDataTool):
"""VisualizeDataTool com schema PT-BR (chart_type) + generator custom."""
def __init__(self, *args, **kwargs):
kwargs.setdefault("plotly_generator", ClubPetroChartGenerator())
super().__init__(*args, **kwargs)
@property
def description(self) -> str:
return _DESCRIPTION
def get_args_schema(self) -> Type[VisualizeDataArgsPT]:
return VisualizeDataArgsPT
async def execute(
self, context: ToolContext, args: VisualizeDataArgsPT
) -> ToolResult:
record_chart(args.chart_type or "auto", args.title or "")
token = _chart_type_override.set(args.chart_type)
try:
return await super().execute(context, args)
finally:
_chart_type_override.reset(token)