"""Override do VisualizeDataTool com `chart_type` controlável pelo LLM + generator que evita o fallback `go.Table` do upstream. Três customizações sobre upstream: 1. Arg novo `chart_type` (line/bar/scatter/histogram/area) no schema. Quando passado pelo LLM, força o tipo. Heurística automática só roda como fallback quando o LLM omite (a guidance no system_prompt + descrição do tool empurra o LLM a sempre passar). 2. `ClubPetroChartGenerator` continua removendo a heurística "4+ colunas → go.Table" do upstream (`vanna/src/vanna/integrations/plotly/chart_generator.py:51-55`), que duplicava visualmente o dataframe rich. 3. `_coerce_datetime_columns` foi movida pra rodar nos DOIS caminhos (<4 e >=4 colunas). Antes só rodava no 4+, então queries 2-col `SELECT data_da_compra, faturamento` nunca eram detectadas como time series e viravam bar. Agora a coerção uniforme torna o comportamento previsível e o `chart_type=line` forçado funciona consistente. `chart_type` é threadado via `ContextVar` em vez de monkey-patch — o `VisualizeDataTool` é singleton no `ToolRegistry` e pode ter execuções concorrentes em conversas paralelas. """ from __future__ import annotations import contextvars import json from typing import Any, Dict, List, Literal, Optional, Type import pandas as pd import plotly.graph_objects as go import plotly.io as pio from pydantic import Field from vanna.core.tool import ToolContext, ToolResult from vanna.integrations.plotly import PlotlyChartGenerator from vanna.tools import VisualizeDataTool from vanna.tools.visualize_data import VisualizeDataArgs from events_sink import record_chart ChartType = Literal["line", "bar", "scatter", "histogram", "area"] _chart_type_override: contextvars.ContextVar[Optional[ChartType]] = ( contextvars.ContextVar("chart_type_override", default=None) ) class VisualizeDataArgsPT(VisualizeDataArgs): chart_type: Optional[ChartType] = Field( default=None, description=( "Tipo do gráfico. Passe SEMPRE que souber o tipo certo: " "'line' = série temporal (evolução, tendência ao longo do tempo); " "'bar' = ranking, top N, comparação categórica; " "'scatter' = correlação entre 2 métricas numéricas; " "'histogram' = distribuição de 1 métrica; " "'area' = acumulado temporal. " "Omita SÓ se genuinamente em dúvida — sem hint o sistema cai em " "heurística por shape do DataFrame, que pode escolher errado." ), ) _DESCRIPTION = ( "Renderiza um gráfico Plotly a partir do CSV produzido por run_sql na rodada anterior. " "CHAME SEMPRE que o resultado da query for naturalmente visual — não espere o usuário pedir " "a palavra 'gráfico' explicitamente; deduza do tipo de pergunta.\n" "\n" "GATILHOS OBRIGATÓRIOS (chame após o run_sql):\n" "• Ranking / Top N — 'top 10 produtos', 'quais atendentes mais venderam', 'maiores clientes', " "'ranking de X', 'melhores/piores Y' → chart_type='bar'.\n" "• Série temporal — 'evolução', 'ao longo de', 'por dia/semana/mês/hora', 'últimos N dias', " "'tendência', 'crescimento', 'comparar meses' → chart_type='line'.\n" "• Comparação categórica — 'vendas por categoria', 'faturamento por produto', 'X por Y', " "'distribuição', 'breakdown' → chart_type='bar'.\n" "• Participação / share — 'participação', '% de', 'share' → chart_type='bar' " "(evitar pizza com >3 fatias).\n" "• Correlação — 'relação entre X e Y', 'preço vs volume', 'desconto vs faturamento' " "→ chart_type='scatter'.\n" "• Distribuição de UMA métrica — 'distribuição dos tickets', 'histograma de Y' " "→ chart_type='histogram'.\n" "• Acumulado temporal — 'faturamento acumulado por mês' → chart_type='area'.\n" "\n" "FORMA IDEAL DO CSV antes de chamar (responsabilidade do run_sql anterior):\n" "• Série temporal: 1 coluna de tempo + 1-3 métricas.\n" "• Ranking/categórico: 1 coluna de categoria + 1 métrica, máx 20-40 linhas.\n" "• Scatter: 2 colunas numéricas.\n" "• ≤3 colunas no total é o ponto ideal — se a query trouxer 4+ colunas, esta ferramenta " "ainda gera chart real (não tabela), mas o resultado fica mais limpo com SELECT focado.\n" "\n" "QUANDO NÃO CHAMAR:\n" "• Pergunta com resposta numérica única ('qual o faturamento de hoje?') — o número no texto " "basta, gráfico não agrega.\n" "• Usuário pediu explicitamente só a tabela ('me lista', 'exporta a tabela').\n" "\n" "ARGUMENTOS: filename = caminho do CSV retornado pelo run_sql anterior; " "chart_type = tipo do gráfico (passe sempre que souber, ver lista acima); " "title (opcional) = rótulo curto em pt-BR (ex: 'Faturamento por dia — últimos 30 dias')." ) class ClubPetroChartGenerator(PlotlyChartGenerator): """Generator com: - Coerção de datetime aplicada pra QUALQUER nº de colunas (uniforme). - Override `chart_type` lido de ContextVar (vem do `VisualizeDataToolPT.execute`). - Drop do fallback "4+ colunas → go.Table" do upstream — sempre chart real. """ def generate_chart( self, df: pd.DataFrame, title: str = "Chart", chart_type: Optional[ChartType] = None, ) -> Dict[str, Any]: if df.empty: raise ValueError("Cannot visualize empty DataFrame") chart_type = chart_type or _chart_type_override.get() df = self._coerce_datetime_columns(df) if chart_type is not None: fig = self._render_forced(df, title, chart_type) return json.loads(pio.to_json(fig)) if len(df.columns) < 4: return super().generate_chart(df, title) numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() categorical_cols = df.select_dtypes( include=["object", "category"] ).columns.tolist() datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist() if datetime_cols and numeric_cols: fig = self._create_time_series_chart( df.sort_values(datetime_cols[0]), datetime_cols[0], numeric_cols[:3], title, ) elif categorical_cols and numeric_cols: cat = categorical_cols[0] num = numeric_cols[0] agg = ( df.groupby(cat)[num] .sum() .reset_index() .sort_values(num, ascending=False) .head(40) ) fig = self._create_ranked_bar_chart(agg, cat, num, title) elif len(numeric_cols) >= 2: fig = self._create_scatter_plot(df, numeric_cols[0], numeric_cols[1], title) else: fig = self._create_generic_chart(df, df.columns[0], df.columns[1], title) return json.loads(pio.to_json(fig)) def _render_forced( self, df: pd.DataFrame, title: str, chart_type: ChartType ) -> go.Figure: numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() categorical_cols = df.select_dtypes( include=["object", "category"] ).columns.tolist() datetime_cols = df.select_dtypes(include=["datetime64"]).columns.tolist() if chart_type == "line": if not numeric_cols: raise ValueError( "chart_type='line' requires at least one numeric column" ) if datetime_cols: return self._create_time_series_chart( df.sort_values(datetime_cols[0]), datetime_cols[0], numeric_cols[:5], title, ) x_col = ( categorical_cols[0] if categorical_cols else df.columns[0] ) y_cols = [c for c in numeric_cols if c != x_col][:5] if not y_cols: raise ValueError( "chart_type='line' needs a numeric column distinct from the x-axis" ) return self._create_time_series_chart(df, x_col, y_cols, title) if chart_type == "area": if not numeric_cols: raise ValueError( "chart_type='area' requires at least one numeric column" ) x_col = ( datetime_cols[0] if datetime_cols else (categorical_cols[0] if categorical_cols else df.columns[0]) ) ordered = df.sort_values(x_col) if datetime_cols else df y_cols = [c for c in numeric_cols if c != x_col][:5] if not y_cols: raise ValueError( "chart_type='area' needs a numeric column distinct from the x-axis" ) return self._create_area_chart(ordered, x_col, y_cols, title) if chart_type == "bar": if not numeric_cols: raise ValueError( "chart_type='bar' requires at least one numeric column" ) x_col = categorical_cols[0] if categorical_cols else df.columns[0] y_col = next((c for c in numeric_cols if c != x_col), None) if y_col is None: raise ValueError( "chart_type='bar' needs a numeric column distinct from the x-axis" ) agg = ( df.groupby(x_col)[y_col] .sum() .reset_index() .sort_values(y_col, ascending=False) .head(40) ) return self._create_ranked_bar_chart(agg, x_col, y_col, title) if chart_type == "scatter": if len(numeric_cols) < 2: raise ValueError( "chart_type='scatter' requires 2 numeric columns" ) return self._create_scatter_plot( df, numeric_cols[0], numeric_cols[1], title ) if chart_type == "histogram": if not numeric_cols: raise ValueError( "chart_type='histogram' requires at least one numeric column" ) return self._create_histogram(df, numeric_cols[0], title) raise ValueError(f"Unknown chart_type: {chart_type!r}") def _create_ranked_bar_chart( self, df: pd.DataFrame, x_col: str, y_col: str, title: str, ) -> go.Figure: """Bar chart preservando a ordem do `df` (sem re-groupby). Upstream `_create_bar_chart` re-agrega com `groupby(x_col).sum()` e perde a ordenação descendente — bars saem alfabéticos. Pra ranking real precisamos travar `categoryorder=array`. """ order = df[x_col].astype(str).tolist() fig = go.Figure( data=[ go.Bar( x=order, y=df[y_col].tolist(), marker_color=self.THEME_COLORS["orange"], ) ] ) fig.update_layout( title=title, xaxis_title=x_col, yaxis_title=y_col, xaxis={"categoryorder": "array", "categoryarray": order}, ) self._apply_standard_layout(fig) return fig def _create_area_chart( self, df: pd.DataFrame, x_col: str, y_cols: List[str], title: str, ) -> go.Figure: fig = go.Figure() for i, col in enumerate(y_cols): color = self.COLOR_PALETTE[i % len(self.COLOR_PALETTE)] fig.add_trace( go.Scatter( x=df[x_col], y=df[col], mode="lines", name=col, fill="tozeroy" if i == 0 else "tonexty", line=dict(color=color), ) ) fig.update_layout( title=title, xaxis_title=x_col, yaxis_title="Value", hovermode="x unified", ) self._apply_standard_layout(fig) return fig @staticmethod def _coerce_datetime_columns(df: pd.DataFrame) -> pd.DataFrame: """Converte colunas object que parecem date/datetime em datetime64. ClickHouse devolve datas como string no CSV; pandas lê como object e a heurística de datetime do upstream falha. Aceita dois formatos: - ISO `2026-01-01` (ou `2026/01/01`) — quando a query devolve o valor cru (DateTime sem coerção downstream). - BR `01/01/2026` — quando o `RLSClickHouseRunner` formata colunas datetime pra exibição no rich component (vê `_format_date_columns` em rls_runner.py). pandas precisa `dayfirst=True` ou seria parseado como mês/dia (US). Best-effort (`errors="raise"` dentro do try) — só converte colunas que são parseáveis 100%. """ for col in df.select_dtypes(include=["object"]).columns: sample = df[col].dropna().astype(str).head(5).tolist() if not sample: continue looks_like_iso = all( len(s) >= 8 and s[:4].isdigit() and s[4] in "-/" for s in sample ) looks_like_br = all( len(s) >= 10 and s[:2].isdigit() and s[2] == "/" and s[3:5].isdigit() and s[5] == "/" and s[6:10].isdigit() for s in sample ) if not (looks_like_iso or looks_like_br): continue try: df[col] = pd.to_datetime( df[col], dayfirst=looks_like_br, errors="raise" ) except (ValueError, TypeError): pass return df class VisualizeDataToolPT(VisualizeDataTool): """VisualizeDataTool com schema PT-BR (chart_type) + generator custom.""" def __init__(self, *args, **kwargs): kwargs.setdefault("plotly_generator", ClubPetroChartGenerator()) super().__init__(*args, **kwargs) @property def description(self) -> str: return _DESCRIPTION def get_args_schema(self) -> Type[VisualizeDataArgsPT]: return VisualizeDataArgsPT async def execute( self, context: ToolContext, args: VisualizeDataArgsPT ) -> ToolResult: record_chart(args.chart_type or "auto", args.title or "") token = _chart_type_override.set(args.chart_type) try: return await super().execute(context, args) finally: _chart_type_override.reset(token)