Wrapper application around upstream Vanna with: - Tenant-aware ChromaDB memory (per program/store) - ClickHouse RLS runner with introspection guards - PT-BR system prompt and chat translations - Custom Plotly chart generator (ranked bar, datetime coercion) - Embed bootstrap (theme pierce + i18n + markdown) shared by demo and React app - Event sink for chat turn observability
381 lines
15 KiB
Python
381 lines
15 KiB
Python
"""Override do VisualizeDataTool com `chart_type` controlável pelo LLM
|
|
+ generator que evita o fallback `go.Table` do upstream.
|
|
|
|
Três customizações sobre upstream:
|
|
|
|
1. Arg novo `chart_type` (line/bar/scatter/histogram/area) no schema. Quando
|
|
passado pelo LLM, força o tipo. Heurística automática só roda como
|
|
fallback quando o LLM omite (a guidance no system_prompt + descrição do
|
|
tool empurra o LLM a sempre passar).
|
|
|
|
2. `ClubPetroChartGenerator` continua removendo a heurística "4+ colunas
|
|
→ go.Table" do upstream (`vanna/src/vanna/integrations/plotly/chart_generator.py:51-55`),
|
|
que duplicava visualmente o dataframe rich.
|
|
|
|
3. `_coerce_datetime_columns` foi movida pra rodar nos DOIS caminhos
|
|
(<4 e >=4 colunas). Antes só rodava no 4+, então queries 2-col
|
|
`SELECT data_da_compra, faturamento` nunca eram detectadas como
|
|
time series e viravam bar. Agora a coerção uniforme torna o
|
|
comportamento previsível e o `chart_type=line` forçado funciona
|
|
consistente.
|
|
|
|
`chart_type` é threadado via `ContextVar` em vez de monkey-patch — o
|
|
`VisualizeDataTool` é singleton no `ToolRegistry` e pode ter execuções
|
|
concorrentes em conversas paralelas.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextvars
|
|
import json
|
|
from typing import Any, Dict, List, Literal, Optional, Type
|
|
|
|
import pandas as pd
|
|
import plotly.graph_objects as go
|
|
import plotly.io as pio
|
|
from pydantic import Field
|
|
|
|
from vanna.core.tool import ToolContext, ToolResult
|
|
from vanna.integrations.plotly import PlotlyChartGenerator
|
|
from vanna.tools import VisualizeDataTool
|
|
from vanna.tools.visualize_data import VisualizeDataArgs
|
|
|
|
from events_sink import record_chart
|
|
|
|
|
|
ChartType = Literal["line", "bar", "scatter", "histogram", "area"]
|
|
|
|
_chart_type_override: contextvars.ContextVar[Optional[ChartType]] = (
|
|
contextvars.ContextVar("chart_type_override", default=None)
|
|
)
|
|
|
|
|
|
class VisualizeDataArgsPT(VisualizeDataArgs):
|
|
chart_type: Optional[ChartType] = Field(
|
|
default=None,
|
|
description=(
|
|
"Tipo do gráfico. Passe SEMPRE que souber o tipo certo: "
|
|
"'line' = série temporal (evolução, tendência ao longo do tempo); "
|
|
"'bar' = ranking, top N, comparação categórica; "
|
|
"'scatter' = correlação entre 2 métricas numéricas; "
|
|
"'histogram' = distribuição de 1 métrica; "
|
|
"'area' = acumulado temporal. "
|
|
"Omita SÓ se genuinamente em dúvida — sem hint o sistema cai em "
|
|
"heurística por shape do DataFrame, que pode escolher errado."
|
|
),
|
|
)
|
|
|
|
|
|
_DESCRIPTION = (
|
|
"Renderiza um gráfico Plotly a partir do CSV produzido por run_sql na rodada anterior. "
|
|
"CHAME SEMPRE que o resultado da query for naturalmente visual — não espere o usuário pedir "
|
|
"a palavra 'gráfico' explicitamente; deduza do tipo de pergunta.\n"
|
|
"\n"
|
|
"GATILHOS OBRIGATÓRIOS (chame após o run_sql):\n"
|
|
"• Ranking / Top N — 'top 10 produtos', 'quais atendentes mais venderam', 'maiores clientes', "
|
|
"'ranking de X', 'melhores/piores Y' → chart_type='bar'.\n"
|
|
"• Série temporal — 'evolução', 'ao longo de', 'por dia/semana/mês/hora', 'últimos N dias', "
|
|
"'tendência', 'crescimento', 'comparar meses' → chart_type='line'.\n"
|
|
"• Comparação categórica — 'vendas por categoria', 'faturamento por produto', 'X por Y', "
|
|
"'distribuição', 'breakdown' → chart_type='bar'.\n"
|
|
"• Participação / share — 'participação', '% de', 'share' → chart_type='bar' "
|
|
"(evitar pizza com >3 fatias).\n"
|
|
"• Correlação — 'relação entre X e Y', 'preço vs volume', 'desconto vs faturamento' "
|
|
"→ chart_type='scatter'.\n"
|
|
"• Distribuição de UMA métrica — 'distribuição dos tickets', 'histograma de Y' "
|
|
"→ chart_type='histogram'.\n"
|
|
"• Acumulado temporal — 'faturamento acumulado por mês' → chart_type='area'.\n"
|
|
"\n"
|
|
"FORMA IDEAL DO CSV antes de chamar (responsabilidade do run_sql anterior):\n"
|
|
"• Série temporal: 1 coluna de tempo + 1-3 métricas.\n"
|
|
"• Ranking/categórico: 1 coluna de categoria + 1 métrica, máx 20-40 linhas.\n"
|
|
"• Scatter: 2 colunas numéricas.\n"
|
|
"• ≤3 colunas no total é o ponto ideal — se a query trouxer 4+ colunas, esta ferramenta "
|
|
"ainda gera chart real (não tabela), mas o resultado fica mais limpo com SELECT focado.\n"
|
|
"\n"
|
|
"QUANDO NÃO CHAMAR:\n"
|
|
"• Pergunta com resposta numérica única ('qual o faturamento de hoje?') — o número no texto "
|
|
"basta, gráfico não agrega.\n"
|
|
"• Usuário pediu explicitamente só a tabela ('me lista', 'exporta a tabela').\n"
|
|
"\n"
|
|
"ARGUMENTOS: filename = caminho do CSV retornado pelo run_sql anterior; "
|
|
"chart_type = tipo do gráfico (passe sempre que souber, ver lista acima); "
|
|
"title (opcional) = rótulo curto em pt-BR (ex: 'Faturamento por dia — últimos 30 dias')."
|
|
)
|
|
|
|
|
|
class ClubPetroChartGenerator(PlotlyChartGenerator):
|
|
"""Generator com:
|
|
|
|
- Coerção de datetime aplicada pra QUALQUER nº de colunas (uniforme).
|
|
- Override `chart_type` lido de ContextVar (vem do `VisualizeDataToolPT.execute`).
|
|
- Drop do fallback "4+ colunas → go.Table" do upstream — sempre chart real.
|
|
"""
|
|
|
|
def generate_chart(
|
|
self,
|
|
df: pd.DataFrame,
|
|
title: str = "Chart",
|
|
chart_type: Optional[ChartType] = None,
|
|
) -> Dict[str, Any]:
|
|
if df.empty:
|
|
raise ValueError("Cannot visualize empty DataFrame")
|
|
|
|
chart_type = chart_type or _chart_type_override.get()
|
|
df = self._coerce_datetime_columns(df)
|
|
|
|
if chart_type is not None:
|
|
fig = self._render_forced(df, title, chart_type)
|
|
return json.loads(pio.to_json(fig))
|
|
|
|
if len(df.columns) < 4:
|
|
return super().generate_chart(df, title)
|
|
|
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
|
categorical_cols = df.select_dtypes(
|
|
include=["object", "category"]
|
|
).columns.tolist()
|
|
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
|
|
|
|
if datetime_cols and numeric_cols:
|
|
fig = self._create_time_series_chart(
|
|
df.sort_values(datetime_cols[0]),
|
|
datetime_cols[0],
|
|
numeric_cols[:3],
|
|
title,
|
|
)
|
|
elif categorical_cols and numeric_cols:
|
|
cat = categorical_cols[0]
|
|
num = numeric_cols[0]
|
|
agg = (
|
|
df.groupby(cat)[num]
|
|
.sum()
|
|
.reset_index()
|
|
.sort_values(num, ascending=False)
|
|
.head(40)
|
|
)
|
|
fig = self._create_ranked_bar_chart(agg, cat, num, title)
|
|
elif len(numeric_cols) >= 2:
|
|
fig = self._create_scatter_plot(df, numeric_cols[0], numeric_cols[1], title)
|
|
else:
|
|
fig = self._create_generic_chart(df, df.columns[0], df.columns[1], title)
|
|
|
|
return json.loads(pio.to_json(fig))
|
|
|
|
def _render_forced(
|
|
self, df: pd.DataFrame, title: str, chart_type: ChartType
|
|
) -> go.Figure:
|
|
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
|
|
categorical_cols = df.select_dtypes(
|
|
include=["object", "category"]
|
|
).columns.tolist()
|
|
datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist()
|
|
|
|
if chart_type == "line":
|
|
if not numeric_cols:
|
|
raise ValueError(
|
|
"chart_type='line' requires at least one numeric column"
|
|
)
|
|
if datetime_cols:
|
|
return self._create_time_series_chart(
|
|
df.sort_values(datetime_cols[0]),
|
|
datetime_cols[0],
|
|
numeric_cols[:5],
|
|
title,
|
|
)
|
|
x_col = (
|
|
categorical_cols[0]
|
|
if categorical_cols
|
|
else df.columns[0]
|
|
)
|
|
y_cols = [c for c in numeric_cols if c != x_col][:5]
|
|
if not y_cols:
|
|
raise ValueError(
|
|
"chart_type='line' needs a numeric column distinct from the x-axis"
|
|
)
|
|
return self._create_time_series_chart(df, x_col, y_cols, title)
|
|
|
|
if chart_type == "area":
|
|
if not numeric_cols:
|
|
raise ValueError(
|
|
"chart_type='area' requires at least one numeric column"
|
|
)
|
|
x_col = (
|
|
datetime_cols[0]
|
|
if datetime_cols
|
|
else (categorical_cols[0] if categorical_cols else df.columns[0])
|
|
)
|
|
ordered = df.sort_values(x_col) if datetime_cols else df
|
|
y_cols = [c for c in numeric_cols if c != x_col][:5]
|
|
if not y_cols:
|
|
raise ValueError(
|
|
"chart_type='area' needs a numeric column distinct from the x-axis"
|
|
)
|
|
return self._create_area_chart(ordered, x_col, y_cols, title)
|
|
|
|
if chart_type == "bar":
|
|
if not numeric_cols:
|
|
raise ValueError(
|
|
"chart_type='bar' requires at least one numeric column"
|
|
)
|
|
x_col = categorical_cols[0] if categorical_cols else df.columns[0]
|
|
y_col = next((c for c in numeric_cols if c != x_col), None)
|
|
if y_col is None:
|
|
raise ValueError(
|
|
"chart_type='bar' needs a numeric column distinct from the x-axis"
|
|
)
|
|
agg = (
|
|
df.groupby(x_col)[y_col]
|
|
.sum()
|
|
.reset_index()
|
|
.sort_values(y_col, ascending=False)
|
|
.head(40)
|
|
)
|
|
return self._create_ranked_bar_chart(agg, x_col, y_col, title)
|
|
|
|
if chart_type == "scatter":
|
|
if len(numeric_cols) < 2:
|
|
raise ValueError(
|
|
"chart_type='scatter' requires 2 numeric columns"
|
|
)
|
|
return self._create_scatter_plot(
|
|
df, numeric_cols[0], numeric_cols[1], title
|
|
)
|
|
|
|
if chart_type == "histogram":
|
|
if not numeric_cols:
|
|
raise ValueError(
|
|
"chart_type='histogram' requires at least one numeric column"
|
|
)
|
|
return self._create_histogram(df, numeric_cols[0], title)
|
|
|
|
raise ValueError(f"Unknown chart_type: {chart_type!r}")
|
|
|
|
def _create_ranked_bar_chart(
|
|
self,
|
|
df: pd.DataFrame,
|
|
x_col: str,
|
|
y_col: str,
|
|
title: str,
|
|
) -> go.Figure:
|
|
"""Bar chart preservando a ordem do `df` (sem re-groupby).
|
|
|
|
Upstream `_create_bar_chart` re-agrega com `groupby(x_col).sum()`
|
|
e perde a ordenação descendente — bars saem alfabéticos. Pra
|
|
ranking real precisamos travar `categoryorder=array`.
|
|
"""
|
|
order = df[x_col].astype(str).tolist()
|
|
fig = go.Figure(
|
|
data=[
|
|
go.Bar(
|
|
x=order,
|
|
y=df[y_col].tolist(),
|
|
marker_color=self.THEME_COLORS["orange"],
|
|
)
|
|
]
|
|
)
|
|
fig.update_layout(
|
|
title=title,
|
|
xaxis_title=x_col,
|
|
yaxis_title=y_col,
|
|
xaxis={"categoryorder": "array", "categoryarray": order},
|
|
)
|
|
self._apply_standard_layout(fig)
|
|
return fig
|
|
|
|
def _create_area_chart(
|
|
self,
|
|
df: pd.DataFrame,
|
|
x_col: str,
|
|
y_cols: List[str],
|
|
title: str,
|
|
) -> go.Figure:
|
|
fig = go.Figure()
|
|
for i, col in enumerate(y_cols):
|
|
color = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)]
|
|
fig.add_trace(
|
|
go.Scatter(
|
|
x=df[x_col],
|
|
y=df[col],
|
|
mode="lines",
|
|
name=col,
|
|
fill="tozeroy" if i == 0 else "tonexty",
|
|
line=dict(color=color),
|
|
)
|
|
)
|
|
fig.update_layout(
|
|
title=title,
|
|
xaxis_title=x_col,
|
|
yaxis_title="Value",
|
|
hovermode="x unified",
|
|
)
|
|
self._apply_standard_layout(fig)
|
|
return fig
|
|
|
|
@staticmethod
|
|
def _coerce_datetime_columns(df: pd.DataFrame) -> pd.DataFrame:
|
|
"""Converte colunas object que parecem date/datetime em datetime64.
|
|
|
|
ClickHouse devolve datas como string no CSV; pandas lê como object e
|
|
a heurística de datetime do upstream falha. Aceita dois formatos:
|
|
- ISO `2026-01-01` (ou `2026/01/01`) — quando a query devolve o
|
|
valor cru (DateTime sem coerção downstream).
|
|
- BR `01/01/2026` — quando o `RLSClickHouseRunner` formata
|
|
colunas datetime pra exibição no rich component (vê
|
|
`_format_date_columns` em rls_runner.py). pandas precisa
|
|
`dayfirst=True` ou seria parseado como mês/dia (US).
|
|
Best-effort (`errors="raise"` dentro do try) — só converte colunas
|
|
que são parseáveis 100%.
|
|
"""
|
|
for col in df.select_dtypes(include=["object"]).columns:
|
|
sample = df[col].dropna().astype(str).head(5).tolist()
|
|
if not sample:
|
|
continue
|
|
looks_like_iso = all(
|
|
len(s) >= 8 and s[:4].isdigit() and s[4] in "-/"
|
|
for s in sample
|
|
)
|
|
looks_like_br = all(
|
|
len(s) >= 10
|
|
and s[:2].isdigit()
|
|
and s[2] == "/"
|
|
and s[3:5].isdigit()
|
|
and s[5] == "/"
|
|
and s[6:10].isdigit()
|
|
for s in sample
|
|
)
|
|
if not (looks_like_iso or looks_like_br):
|
|
continue
|
|
try:
|
|
df[col] = pd.to_datetime(
|
|
df[col], dayfirst=looks_like_br, errors="raise"
|
|
)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return df
|
|
|
|
|
|
class VisualizeDataToolPT(VisualizeDataTool):
|
|
"""VisualizeDataTool com schema PT-BR (chart_type) + generator custom."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs.setdefault("plotly_generator", ClubPetroChartGenerator())
|
|
super().__init__(*args, **kwargs)
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return _DESCRIPTION
|
|
|
|
def get_args_schema(self) -> Type[VisualizeDataArgsPT]:
|
|
return VisualizeDataArgsPT
|
|
|
|
async def execute(
|
|
self, context: ToolContext, args: VisualizeDataArgsPT
|
|
) -> ToolResult:
|
|
record_chart(args.chart_type or "auto", args.title or "")
|
|
token = _chart_type_override.set(args.chart_type)
|
|
try:
|
|
return await super().execute(context, args)
|
|
finally:
|
|
_chart_type_override.reset(token)
|