diff --git a/apps/dev-playground/client/src/routeTree.gen.ts b/apps/dev-playground/client/src/routeTree.gen.ts index 948df49f..1f2aff3e 100644 --- a/apps/dev-playground/client/src/routeTree.gen.ts +++ b/apps/dev-playground/client/src/routeTree.gen.ts @@ -16,6 +16,7 @@ import { Route as ReconnectRouteRouteImport } from './routes/reconnect.route' import { Route as LakebaseRouteRouteImport } from './routes/lakebase.route' import { Route as GenieRouteRouteImport } from './routes/genie.route' import { Route as DataVisualizationRouteRouteImport } from './routes/data-visualization.route' +import { Route as ChartInferenceRouteRouteImport } from './routes/chart-inference.route' import { Route as ArrowAnalyticsRouteRouteImport } from './routes/arrow-analytics.route' import { Route as AnalyticsRouteRouteImport } from './routes/analytics.route' import { Route as IndexRouteImport } from './routes/index' @@ -55,6 +56,11 @@ const DataVisualizationRouteRoute = DataVisualizationRouteRouteImport.update({ path: '/data-visualization', getParentRoute: () => rootRouteImport, } as any) +const ChartInferenceRouteRoute = ChartInferenceRouteRouteImport.update({ + id: '/chart-inference', + path: '/chart-inference', + getParentRoute: () => rootRouteImport, +} as any) const ArrowAnalyticsRouteRoute = ArrowAnalyticsRouteRouteImport.update({ id: '/arrow-analytics', path: '/arrow-analytics', @@ -75,6 +81,7 @@ export interface FileRoutesByFullPath { '/': typeof IndexRoute '/analytics': typeof AnalyticsRouteRoute '/arrow-analytics': typeof ArrowAnalyticsRouteRoute + '/chart-inference': typeof ChartInferenceRouteRoute '/data-visualization': typeof DataVisualizationRouteRoute '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute @@ -87,6 +94,7 @@ export interface FileRoutesByTo { '/': typeof IndexRoute '/analytics': typeof AnalyticsRouteRoute '/arrow-analytics': typeof ArrowAnalyticsRouteRoute + '/chart-inference': typeof ChartInferenceRouteRoute '/data-visualization': typeof DataVisualizationRouteRoute '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute @@ -100,6 +108,7 @@ export interface FileRoutesById { '/': typeof IndexRoute '/analytics': typeof AnalyticsRouteRoute '/arrow-analytics': typeof ArrowAnalyticsRouteRoute + '/chart-inference': typeof ChartInferenceRouteRoute '/data-visualization': typeof DataVisualizationRouteRoute '/genie': typeof GenieRouteRoute '/lakebase': typeof LakebaseRouteRoute @@ -114,6 +123,7 @@ export interface FileRouteTypes { | '/' | '/analytics' | '/arrow-analytics' + | '/chart-inference' | '/data-visualization' | '/genie' | '/lakebase' @@ -126,6 +136,7 @@ export interface FileRouteTypes { | '/' | '/analytics' | '/arrow-analytics' + | '/chart-inference' | '/data-visualization' | '/genie' | '/lakebase' @@ -138,6 +149,7 @@ export interface FileRouteTypes { | '/' | '/analytics' | '/arrow-analytics' + | '/chart-inference' | '/data-visualization' | '/genie' | '/lakebase' @@ -151,6 +163,7 @@ export interface RootRouteChildren { IndexRoute: typeof IndexRoute AnalyticsRouteRoute: typeof AnalyticsRouteRoute ArrowAnalyticsRouteRoute: typeof ArrowAnalyticsRouteRoute + ChartInferenceRouteRoute: typeof ChartInferenceRouteRoute DataVisualizationRouteRoute: typeof DataVisualizationRouteRoute GenieRouteRoute: typeof GenieRouteRoute LakebaseRouteRoute: typeof LakebaseRouteRoute @@ -211,6 +224,13 @@ declare module '@tanstack/react-router' { preLoaderRoute: typeof DataVisualizationRouteRouteImport parentRoute: typeof rootRouteImport } + '/chart-inference': { + id: '/chart-inference' + path: '/chart-inference' + fullPath: '/chart-inference' + preLoaderRoute: typeof ChartInferenceRouteRouteImport + parentRoute: typeof rootRouteImport + } '/arrow-analytics': { id: '/arrow-analytics' path: '/arrow-analytics' @@ -239,6 +259,7 @@ const rootRouteChildren: RootRouteChildren = { IndexRoute: IndexRoute, AnalyticsRouteRoute: AnalyticsRouteRoute, ArrowAnalyticsRouteRoute: ArrowAnalyticsRouteRoute, + ChartInferenceRouteRoute: ChartInferenceRouteRoute, DataVisualizationRouteRoute: DataVisualizationRouteRoute, GenieRouteRoute: GenieRouteRoute, LakebaseRouteRoute: LakebaseRouteRoute, diff --git a/apps/dev-playground/client/src/routes/__root.tsx b/apps/dev-playground/client/src/routes/__root.tsx index 72a1020b..b3ef6233 100644 --- a/apps/dev-playground/client/src/routes/__root.tsx +++ b/apps/dev-playground/client/src/routes/__root.tsx @@ -88,6 +88,14 @@ function RootComponent() { Genie + + + diff --git a/apps/dev-playground/client/src/routes/chart-inference.route.tsx b/apps/dev-playground/client/src/routes/chart-inference.route.tsx new file mode 100644 index 00000000..f799c4d6 --- /dev/null +++ b/apps/dev-playground/client/src/routes/chart-inference.route.tsx @@ -0,0 +1,280 @@ +import { + Card, + GenieQueryVisualization, + inferChartType, + transformGenieData, +} from "@databricks/appkit-ui/react"; +import { createFileRoute } from "@tanstack/react-router"; +import { useMemo } from "react"; + +export const Route = createFileRoute("/chart-inference")({ + component: ChartInferenceRoute, +}); + +// --------------------------------------------------------------------------- +// Helper to build a Genie-shaped statement_response from simple definitions +// --------------------------------------------------------------------------- + +interface SampleColumn { + name: string; + type_name: string; +} + +function makeStatementResponse( + columns: SampleColumn[], + rows: (string | null)[][], +) { + return { + manifest: { schema: { columns } }, + result: { data_array: rows }, + }; +} + +// --------------------------------------------------------------------------- +// Sample datasets — one per inference rule +// --------------------------------------------------------------------------- + +const SAMPLES: { + title: string; + description: string; + expected: string; + data: ReturnType; +}[] = [ + { + title: "Timeseries (date + revenue)", + description: "Rule 1: DATE + numeric → line chart", + expected: "line", + data: makeStatementResponse( + [ + { name: "date", type_name: "DATE" }, + { name: "revenue", type_name: "DECIMAL" }, + ], + [ + ["2024-01-01", "12000"], + ["2024-02-01", "15500"], + ["2024-03-01", "13200"], + ["2024-04-01", "17800"], + ["2024-05-01", "19200"], + ["2024-06-01", "21000"], + ["2024-07-01", "18500"], + ["2024-08-01", "22100"], + ["2024-09-01", "24500"], + ["2024-10-01", "23000"], + ["2024-11-01", "26800"], + ["2024-12-01", "29000"], + ], + ), + }, + { + title: "Few categories (region + sales)", + description: "Rule 2: STRING + 1 numeric, 3 categories → pie chart", + expected: "pie", + data: makeStatementResponse( + [ + { name: "region", type_name: "STRING" }, + { name: "sales", type_name: "DECIMAL" }, + ], + [ + ["North America", "45000"], + ["Europe", "32000"], + ["Asia Pacific", "28000"], + ], + ), + }, + { + title: "Moderate categories (product + revenue)", + description: "Rule 3: STRING + 1 numeric, 15 categories → bar chart", + expected: "bar", + data: makeStatementResponse( + [ + { name: "product", type_name: "STRING" }, + { name: "revenue", type_name: "DECIMAL" }, + ], + Array.from({ length: 15 }, (_, i) => [ + `Product ${String.fromCharCode(65 + i)}`, + String(Math.round(5000 + Math.sin(i) * 3000)), + ]), + ), + }, + { + title: "Many categories (city + population)", + description: "Rule 4: STRING + 1 numeric, 150 categories → line chart", + expected: "line", + data: makeStatementResponse( + [ + { name: "city", type_name: "STRING" }, + { name: "population", type_name: "INT" }, + ], + Array.from({ length: 150 }, (_, i) => [ + `City ${i + 1}`, + String(Math.round(10000 + Math.random() * 90000)), + ]), + ), + }, + { + title: "Multi-series timeseries (month + revenue + cost)", + description: "Rule 1: DATE + multiple numerics → line chart", + expected: "line", + data: makeStatementResponse( + [ + { name: "month", type_name: "DATE" }, + { name: "revenue", type_name: "DECIMAL" }, + { name: "cost", type_name: "DECIMAL" }, + ], + [ + ["2024-01-01", "12000", "8000"], + ["2024-02-01", "15500", "9200"], + ["2024-03-01", "13200", "8800"], + ["2024-04-01", "17800", "10500"], + ["2024-05-01", "19200", "11000"], + ["2024-06-01", "21000", "12500"], + ], + ), + }, + { + title: "Grouped bar (department + budget + actual)", + description: "Rule 5: STRING + N numerics, 8 categories → bar chart", + expected: "bar", + data: makeStatementResponse( + [ + { name: "department", type_name: "STRING" }, + { name: "budget", type_name: "DECIMAL" }, + { name: "actual", type_name: "DECIMAL" }, + ], + [ + ["Engineering", "500000", "480000"], + ["Marketing", "300000", "320000"], + ["Sales", "400000", "410000"], + ["Support", "200000", "190000"], + ["HR", "150000", "145000"], + ["Finance", "180000", "175000"], + ["Legal", "120000", "115000"], + ["Operations", "250000", "240000"], + ], + ), + }, + { + title: "Scatter (height + weight)", + description: "Rule 7: 2 numerics only → scatter chart", + expected: "scatter", + data: makeStatementResponse( + [ + { name: "height_cm", type_name: "DOUBLE" }, + { name: "weight_kg", type_name: "DOUBLE" }, + ], + Array.from({ length: 30 }, (_, i) => [ + String(150 + i * 1.2), + String(Math.round(45 + i * 1.5 + (Math.random() - 0.5) * 10)), + ]), + ), + }, + { + title: "Single row (name + value)", + description: "Skip: < 2 rows → table only", + expected: "none (table only)", + data: makeStatementResponse( + [ + { name: "metric", type_name: "STRING" }, + { name: "value", type_name: "DECIMAL" }, + ], + [["Total Revenue", "125000"]], + ), + }, + { + title: "All strings (first_name + last_name + city)", + description: "Skip: no numeric columns → table only", + expected: "none (table only)", + data: makeStatementResponse( + [ + { name: "first_name", type_name: "STRING" }, + { name: "last_name", type_name: "STRING" }, + { name: "city", type_name: "STRING" }, + ], + [ + ["Alice", "Smith", "New York"], + ["Bob", "Jones", "London"], + ["Carol", "Lee", "Tokyo"], + ], + ), + }, +]; + +// --------------------------------------------------------------------------- +// Per-sample card component +// --------------------------------------------------------------------------- + +function SampleCard({ + title, + description, + expected, + data, +}: (typeof SAMPLES)[number]) { + const transformed = useMemo(() => transformGenieData(data), [data]); + const inference = useMemo( + () => + transformed + ? inferChartType(transformed.rows, transformed.columns) + : null, + [transformed], + ); + + return ( + +
+

{title}

+

{description}

+
+ +
+ + Expected: {expected} + + + Inferred:{" "} + {inference + ? `${inference.chartType} (x: ${inference.xKey}, y: ${Array.isArray(inference.yKey) ? inference.yKey.join(", ") : inference.yKey})` + : "null (no chart)"} + + + Rows: {transformed?.rows.length ?? 0} + + + Columns: {transformed?.columns.length ?? 0} + +
+ + +
+ ); +} + +// --------------------------------------------------------------------------- +// Route component +// --------------------------------------------------------------------------- + +function ChartInferenceRoute() { + return ( +
+
+
+
+

+ Chart Inference Demo +

+

+ Sample datasets exercising each Genie chart inference rule. Each + card shows the inferred chart type, axes, and the rendered + visualization. +

+
+ +
+ {SAMPLES.map((sample) => ( + + ))} +
+
+
+
+ ); +} diff --git a/packages/appkit-ui/src/react/charts/base.tsx b/packages/appkit-ui/src/react/charts/base.tsx index 4a781c96..6a623eb4 100644 --- a/packages/appkit-ui/src/react/charts/base.tsx +++ b/packages/appkit-ui/src/react/charts/base.tsx @@ -172,7 +172,7 @@ export function BaseChart({ // Memoize option building const option = useMemo(() => { - const { xData, yFields, chartType: detectedChartType } = normalized; + const { xData, yFields, xField, chartType: detectedChartType } = normalized; if (xData.length === 0) return null; @@ -190,6 +190,7 @@ export function BaseChart({ colors, title, showLegend, + xField, }; const isPie = chartType === "pie" || chartType === "donut"; const isRadar = chartType === "radar"; diff --git a/packages/appkit-ui/src/react/charts/normalize.ts b/packages/appkit-ui/src/react/charts/normalize.ts index 8736512d..7fe3a4f7 100644 --- a/packages/appkit-ui/src/react/charts/normalize.ts +++ b/packages/appkit-ui/src/react/charts/normalize.ts @@ -8,7 +8,11 @@ import type { Orientation, } from "./types"; import { isArrowTable } from "./types"; -import { sortTimeSeriesAscending, toChartArray } from "./utils"; +import { + sortNumericAscending, + sortTimeSeriesAscending, + toChartArray, +} from "./utils"; // ============================================================================ // Type Detection Helpers @@ -244,6 +248,12 @@ export function normalizeChartData( yDataMap, resolvedYKeys, )); + } else if (xData.length > 0 && xData.every(isNumericValue)) { + ({ xData, yDataMap } = sortNumericAscending( + xData, + yDataMap, + resolvedYKeys, + )); } return { @@ -283,6 +293,12 @@ export function normalizeChartData( yDataMap, resolvedYKeys, )); + } else if (xData.length > 0 && xData.every(isNumericValue)) { + ({ xData, yDataMap } = sortNumericAscending( + xData, + yDataMap, + resolvedYKeys, + )); } return { diff --git a/packages/appkit-ui/src/react/charts/options.ts b/packages/appkit-ui/src/react/charts/options.ts index 3bd99bd9..5a7bcfca 100644 --- a/packages/appkit-ui/src/react/charts/options.ts +++ b/packages/appkit-ui/src/react/charts/options.ts @@ -12,6 +12,7 @@ export interface OptionBuilderContext { colors: string[]; title?: string; showLegend: boolean; + xField?: string; } export interface CartesianContext extends OptionBuilderContext { @@ -256,10 +257,11 @@ export function buildCartesianOption( ctx; const hasMultipleSeries = ctx.yFields.length > 1; const seriesType = chartType === "area" ? "line" : chartType; + const isScatter = chartType === "scatter"; return { ...buildBaseOption(ctx), - tooltip: { trigger: "axis" }, + tooltip: { trigger: isScatter ? "item" : "axis" }, legend: ctx.showLegend && hasMultipleSeries ? { top: "bottom" } : undefined, grid: { left: "10%", @@ -268,27 +270,34 @@ export function buildCartesianOption( bottom: ctx.showLegend && hasMultipleSeries ? "20%" : "15%", }, xAxis: { - type: isTimeSeries ? "time" : "category", - data: isTimeSeries ? undefined : ctx.xData, - axisLabel: isTimeSeries - ? undefined - : { - rotate: ctx.xData.length > 10 ? 45 : 0, - formatter: (v: string) => truncateLabel(String(v), 10), - }, + type: isScatter ? "value" : isTimeSeries ? "time" : "category", + data: isScatter || isTimeSeries ? undefined : ctx.xData, + name: ctx.xField ? formatLabel(ctx.xField) : undefined, + axisLabel: + isScatter || isTimeSeries + ? { show: true } + : { + rotate: ctx.xData.length > 10 ? 45 : 0, + formatter: (v: string) => truncateLabel(String(v), 10), + }, + }, + yAxis: { + type: "value", + name: ctx.yFields.length === 1 ? formatLabel(ctx.yFields[0]) : undefined, }, - yAxis: { type: "value" }, series: ctx.yFields.map((key, idx) => ({ name: formatLabel(key), type: seriesType, - data: isTimeSeries - ? createTimeSeriesData(ctx.xData, ctx.yDataMap[key]) - : ctx.yDataMap[key], + data: isScatter + ? ctx.xData.map((x, i) => [x, ctx.yDataMap[key][i]]) + : isTimeSeries + ? createTimeSeriesData(ctx.xData, ctx.yDataMap[key]) + : ctx.yDataMap[key], smooth: chartType === "line" || chartType === "area" ? smooth : undefined, showSymbol: chartType === "line" || chartType === "area" ? showSymbol : undefined, - symbol: chartType === "scatter" ? "circle" : undefined, - symbolSize: chartType === "scatter" ? symbolSize : undefined, + symbol: isScatter ? "circle" : undefined, + symbolSize: isScatter ? symbolSize : undefined, areaStyle: chartType === "area" ? { opacity: 0.3 } : undefined, stack: stacked && chartType === "area" ? "total" : undefined, itemStyle: diff --git a/packages/appkit-ui/src/react/charts/utils.ts b/packages/appkit-ui/src/react/charts/utils.ts index be2f97cc..1b057e04 100644 --- a/packages/appkit-ui/src/react/charts/utils.ts +++ b/packages/appkit-ui/src/react/charts/utils.ts @@ -77,6 +77,35 @@ export function createTimeSeriesData( return result; } +/** + * Sorts numeric x-data in ascending order, reordering y-data to match. + * Also coerces numeric string x-values to numbers. + */ +export function sortNumericAscending( + xData: (string | number)[], + yDataMap: Record, + yFields: string[], +): { + xData: (string | number)[]; + yDataMap: Record; +} { + if (xData.length <= 1) { + return { xData, yDataMap }; + } + + const indices = xData.map((_, i) => i); + indices.sort((a, b) => Number(xData[a]) - Number(xData[b])); + + const sortedXData = indices.map((i) => Number(xData[i])); + const sortedYDataMap: Record = {}; + for (const key of yFields) { + const original = yDataMap[key]; + sortedYDataMap[key] = indices.map((i) => Number(original[i])); + } + + return { xData: sortedXData, yDataMap: sortedYDataMap }; +} + /** * Sorts time-series data in ascending chronological order. */ diff --git a/packages/appkit-ui/src/react/genie/__tests__/genie-chart-inference.test.ts b/packages/appkit-ui/src/react/genie/__tests__/genie-chart-inference.test.ts new file mode 100644 index 00000000..ee6cdd3d --- /dev/null +++ b/packages/appkit-ui/src/react/genie/__tests__/genie-chart-inference.test.ts @@ -0,0 +1,302 @@ +import { describe, expect, test } from "vitest"; +import { inferChartType } from "../genie-chart-inference"; +import type { GenieColumnMeta } from "../genie-query-transform"; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function cols( + ...defs: Array<[string, "numeric" | "date" | "string"]> +): GenieColumnMeta[] { + return defs.map(([name, category]) => ({ + name, + typeName: category.toUpperCase(), + category, + })); +} + +function makeRows( + keys: string[], + data: unknown[][], +): Record[] { + return data.map((row) => { + const record: Record = {}; + for (let i = 0; i < keys.length; i++) { + record[keys[i]] = row[i]; + } + return record; + }); +} + +// --------------------------------------------------------------------------- +// Skip rules +// --------------------------------------------------------------------------- + +describe("inferChartType — skip rules", () => { + test("returns null for < 2 rows", () => { + const columns = cols(["name", "string"], ["value", "numeric"]); + const rows = makeRows(["name", "value"], [["A", 10]]); + expect(inferChartType(rows, columns)).toBeNull(); + }); + + test("returns null for < 2 columns", () => { + const columns = cols(["value", "numeric"]); + const rows = makeRows(["value"], [[10], [20]]); + expect(inferChartType(rows, columns)).toBeNull(); + }); + + test("returns null when no numeric columns", () => { + const columns = cols(["a", "string"], ["b", "string"]); + const rows = makeRows( + ["a", "b"], + [ + ["x", "y"], + ["w", "z"], + ], + ); + expect(inferChartType(rows, columns)).toBeNull(); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 1: DATE + numeric(s) → line +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 1: timeseries", () => { + test("date + single numeric → line", () => { + const columns = cols(["day", "date"], ["revenue", "numeric"]); + const rows = makeRows( + ["day", "revenue"], + [ + ["2024-01-01", 100], + ["2024-01-02", 200], + ], + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "line", + xKey: "day", + yKey: "revenue", + }); + }); + + test("date + multiple numerics → line with yKey array", () => { + const columns = cols( + ["month", "date"], + ["revenue", "numeric"], + ["cost", "numeric"], + ); + const rows = makeRows( + ["month", "revenue", "cost"], + [ + ["2024-01", 100, 80], + ["2024-02", 200, 150], + ], + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "line", + xKey: "month", + yKey: ["revenue", "cost"], + }); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 2: STRING + 1 numeric, ≤7 categories → pie +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 2: pie", () => { + test("string + 1 numeric, 3 categories → pie", () => { + const columns = cols(["region", "string"], ["sales", "numeric"]); + const rows = makeRows( + ["region", "sales"], + [ + ["North", 100], + ["South", 200], + ["East", 150], + ], + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "pie", + xKey: "region", + yKey: "sales", + }); + }); + + test("string + 1 numeric, exactly 7 categories → pie", () => { + const columns = cols(["cat", "string"], ["val", "numeric"]); + const rows = makeRows( + ["cat", "val"], + Array.from({ length: 7 }, (_, i) => [`cat${i}`, i * 10]), + ); + const result = inferChartType(rows, columns); + expect(result?.chartType).toBe("pie"); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 3: STRING + 1 numeric, ≤100 categories → bar +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 3: bar", () => { + test("string + 1 numeric, 15 categories → bar", () => { + const columns = cols(["product", "string"], ["revenue", "numeric"]); + const rows = makeRows( + ["product", "revenue"], + Array.from({ length: 15 }, (_, i) => [`product${i}`, i * 100]), + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "bar", + xKey: "product", + yKey: "revenue", + }); + }); + + test("boundary: 8 categories (just above pie threshold) → bar", () => { + const columns = cols(["cat", "string"], ["val", "numeric"]); + const rows = makeRows( + ["cat", "val"], + Array.from({ length: 8 }, (_, i) => [`cat${i}`, i]), + ); + const result = inferChartType(rows, columns); + expect(result?.chartType).toBe("bar"); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 4: STRING + 1 numeric, >100 categories → line +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 4: many categories → line", () => { + test("string + 1 numeric, 150 categories → line", () => { + const columns = cols(["city", "string"], ["population", "numeric"]); + const rows = makeRows( + ["city", "population"], + Array.from({ length: 150 }, (_, i) => [`city${i}`, i * 1000]), + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "line", + xKey: "city", + yKey: "population", + }); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 5: STRING + N numerics, ≤50 categories → bar (grouped) +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 5: grouped bar", () => { + test("string + 2 numerics, 8 categories → bar", () => { + const columns = cols( + ["department", "string"], + ["budget", "numeric"], + ["actual", "numeric"], + ); + const rows = makeRows( + ["department", "budget", "actual"], + Array.from({ length: 8 }, (_, i) => [`dept${i}`, i * 100, i * 90]), + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "bar", + xKey: "department", + yKey: ["budget", "actual"], + }); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 6: STRING + N numerics, >50 categories → line +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 6: multi-series line", () => { + test("string + 2 numerics, 60 categories → line", () => { + const columns = cols( + ["item", "string"], + ["metric_a", "numeric"], + ["metric_b", "numeric"], + ); + const rows = makeRows( + ["item", "metric_a", "metric_b"], + Array.from({ length: 60 }, (_, i) => [`item${i}`, i, i * 2]), + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "line", + xKey: "item", + yKey: ["metric_a", "metric_b"], + }); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 7: 2+ numerics only → scatter +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 7: scatter", () => { + test("2 numerics, no strings → scatter", () => { + const columns = cols(["height", "numeric"], ["weight", "numeric"]); + const rows = makeRows( + ["height", "weight"], + [ + [170, 70], + [180, 80], + [160, 55], + ], + ); + const result = inferChartType(rows, columns); + expect(result).toEqual({ + chartType: "scatter", + xKey: "height", + yKey: "weight", + }); + }); +}); + +// --------------------------------------------------------------------------- +// Rule 8: fallback +// --------------------------------------------------------------------------- + +describe("inferChartType — Rule 8: fallback", () => { + test("date + no numeric → null", () => { + const columns = cols(["day", "date"], ["label", "string"]); + const rows = makeRows( + ["day", "label"], + [ + ["2024-01-01", "a"], + ["2024-01-02", "b"], + ], + ); + expect(inferChartType(rows, columns)).toBeNull(); + }); +}); + +// --------------------------------------------------------------------------- +// Priority: date takes precedence over string +// --------------------------------------------------------------------------- + +describe("inferChartType — priority", () => { + test("date + string + numeric → uses date (line), not string", () => { + const columns = cols( + ["day", "date"], + ["region", "string"], + ["sales", "numeric"], + ); + const rows = makeRows( + ["day", "region", "sales"], + [ + ["2024-01-01", "North", 100], + ["2024-01-02", "South", 200], + ], + ); + const result = inferChartType(rows, columns); + expect(result?.chartType).toBe("line"); + expect(result?.xKey).toBe("day"); + }); +}); diff --git a/packages/appkit-ui/src/react/genie/__tests__/genie-query-transform.test.ts b/packages/appkit-ui/src/react/genie/__tests__/genie-query-transform.test.ts new file mode 100644 index 00000000..3eba66b0 --- /dev/null +++ b/packages/appkit-ui/src/react/genie/__tests__/genie-query-transform.test.ts @@ -0,0 +1,182 @@ +import { describe, expect, test } from "vitest"; +import { classifySqlType, transformGenieData } from "../genie-query-transform"; + +// --------------------------------------------------------------------------- +// classifySqlType +// --------------------------------------------------------------------------- + +describe("classifySqlType", () => { + test("classifies numeric types", () => { + for (const t of [ + "DECIMAL", + "INT", + "INTEGER", + "BIGINT", + "LONG", + "FLOAT", + "DOUBLE", + "SMALLINT", + "TINYINT", + "SHORT", + "BYTE", + ]) { + expect(classifySqlType(t)).toBe("numeric"); + } + }); + + test("classifies date types", () => { + for (const t of ["DATE", "TIMESTAMP", "TIMESTAMP_NTZ"]) { + expect(classifySqlType(t)).toBe("date"); + } + }); + + test("classifies string types", () => { + for (const t of [ + "STRING", + "VARCHAR", + "CHAR", + "BOOLEAN", + "BINARY", + "UNKNOWN", + ]) { + expect(classifySqlType(t)).toBe("string"); + } + }); + + test("is case-insensitive", () => { + expect(classifySqlType("decimal")).toBe("numeric"); + expect(classifySqlType("Timestamp")).toBe("date"); + }); +}); + +// --------------------------------------------------------------------------- +// transformGenieData +// --------------------------------------------------------------------------- + +describe("transformGenieData", () => { + function makeResponse( + columns: Array<{ name: string; type_name: string }>, + dataArray: (string | null)[][], + ) { + return { + manifest: { schema: { columns } }, + result: { data_array: dataArray }, + }; + } + + test("transforms basic numeric and string data", () => { + const data = makeResponse( + [ + { name: "region", type_name: "STRING" }, + { name: "sales", type_name: "DECIMAL" }, + ], + [ + ["North", "1000.50"], + ["South", "2000.75"], + ], + ); + + const result = transformGenieData(data); + expect(result).not.toBeNull(); + expect(result?.columns).toHaveLength(2); + expect(result?.columns[0]).toEqual({ + name: "region", + typeName: "STRING", + category: "string", + }); + expect(result?.columns[1]).toEqual({ + name: "sales", + typeName: "DECIMAL", + category: "numeric", + }); + expect(result?.rows).toEqual([ + { region: "North", sales: 1000.5 }, + { region: "South", sales: 2000.75 }, + ]); + }); + + test("handles date columns as strings", () => { + const data = makeResponse( + [ + { name: "day", type_name: "DATE" }, + { name: "revenue", type_name: "INT" }, + ], + [["2024-01-15", "500"]], + ); + + const result = transformGenieData(data); + expect(result?.rows[0]).toEqual({ day: "2024-01-15", revenue: 500 }); + expect(result?.columns[0].category).toBe("date"); + }); + + test("handles null values", () => { + const data = makeResponse( + [ + { name: "name", type_name: "STRING" }, + { name: "value", type_name: "INT" }, + ], + [ + [null, "10"], + ["foo", null], + ], + ); + + const result = transformGenieData(data); + expect(result?.rows).toEqual([ + { name: null, value: 10 }, + { name: "foo", value: null }, + ]); + }); + + test("handles non-numeric strings in numeric columns", () => { + const data = makeResponse( + [ + { name: "name", type_name: "STRING" }, + { name: "value", type_name: "INT" }, + ], + [["a", "not_a_number"]], + ); + + const result = transformGenieData(data); + expect(result?.rows[0].value).toBeNull(); + }); + + test("returns null for empty data_array", () => { + const data = makeResponse([{ name: "a", type_name: "STRING" }], []); + expect(transformGenieData(data)).toBeNull(); + }); + + test("returns null for missing columns", () => { + expect( + transformGenieData({ + manifest: { schema: { columns: [] } }, + result: { data_array: [["x"]] }, + }), + ).toBeNull(); + }); + + test("returns null for null/undefined input", () => { + expect(transformGenieData(null)).toBeNull(); + expect(transformGenieData(undefined)).toBeNull(); + expect(transformGenieData("string")).toBeNull(); + }); + + test("returns null for malformed structure", () => { + expect(transformGenieData({})).toBeNull(); + expect(transformGenieData({ manifest: {} })).toBeNull(); + expect(transformGenieData({ manifest: { schema: {} } })).toBeNull(); + }); + + test("handles rows shorter than columns (missing cells)", () => { + const data = makeResponse( + [ + { name: "a", type_name: "STRING" }, + { name: "b", type_name: "INT" }, + ], + [["hello"]], + ); + + const result = transformGenieData(data); + expect(result?.rows[0]).toEqual({ a: "hello", b: null }); + }); +}); diff --git a/packages/appkit-ui/src/react/genie/genie-chart-inference.ts b/packages/appkit-ui/src/react/genie/genie-chart-inference.ts new file mode 100644 index 00000000..ab811703 --- /dev/null +++ b/packages/appkit-ui/src/react/genie/genie-chart-inference.ts @@ -0,0 +1,163 @@ +/** + * ┌─────────────────────────────────────────────────────────────────────┐ + * │ CHART INFERENCE RULES │ + * │ │ + * │ These rules determine what chart type is shown for Genie query │ + * │ results. Modify thresholds and chart type mappings here. │ + * │ │ + * │ Column types are classified from SQL type_name: │ + * │ DATE: DATE, TIMESTAMP, TIMESTAMP_NTZ │ + * │ NUMERIC: DECIMAL, INT, DOUBLE, FLOAT, LONG, etc. │ + * │ STRING: STRING, VARCHAR, CHAR │ + * │ │ + * │ Rules (applied in priority order): │ + * │ │ + * │ SKIP (return null): │ + * │ - < 2 rows │ + * │ - < 2 columns │ + * │ - No numeric columns │ + * │ │ + * │ MATCH: │ + * │ 1. DATE + numeric(s) → line (timeseries) │ + * │ 2. STRING + 1 numeric, ≤7 cats → pie │ + * │ 3. STRING + 1 numeric, ≤100 cats → bar │ + * │ 4. STRING + 1 numeric, >100 cats → line │ + * │ 5. STRING + N numerics, ≤50 cats → bar (grouped) │ + * │ 6. STRING + N numerics, >50 cats → line (multi-series) │ + * │ 7. 2+ numerics only → scatter │ + * │ 8. Otherwise → null (skip) │ + * │ │ + * │ KNOWN LIMITATIONS: │ + * │ - First-column heuristic: picks first string col as category │ + * │ - No semantic understanding (can't tell ID from meaningful val) │ + * └─────────────────────────────────────────────────────────────────────┘ + */ + +import type { ChartType } from "../charts/types"; +import type { GenieColumnMeta } from "./genie-query-transform"; + +// --------------------------------------------------------------------------- +// Configuration — edit thresholds here +// --------------------------------------------------------------------------- + +const INFERENCE_CONFIG = { + /** Min rows required to show any chart */ + minRows: 2, + /** Max unique categories for pie chart */ + pieMaxCategories: 7, + /** Max unique categories for bar chart (single series) */ + barMaxCategories: 100, + /** Max unique categories for grouped bar chart (multi series) */ + groupedBarMaxCategories: 50, +} as const; + +export interface ChartInference { + chartType: ChartType; + xKey: string; + yKey: string | string[]; +} + +function countUnique(rows: Record[], key: string): number { + const seen = new Set(); + for (const row of rows) { + seen.add(row[key]); + } + return seen.size; +} + +function hasNegativeValues( + rows: Record[], + key: string, +): boolean { + for (const row of rows) { + if (Number(row[key]) < 0) return true; + } + return false; +} + +// --------------------------------------------------------------------------- +// Main inference function +// --------------------------------------------------------------------------- + +/** + * Infer the best chart type for the given Genie query result. + * Returns `null` when the data is not suitable for charting. + */ +export function inferChartType( + rows: Record[], + columns: GenieColumnMeta[], +): ChartInference | null { + // Guard: need at least minRows and 2 columns + if (rows.length < INFERENCE_CONFIG.minRows || columns.length < 2) { + return null; + } + + const dateCols = columns.filter((c) => c.category === "date"); + const numericCols = columns.filter((c) => c.category === "numeric"); + const stringCols = columns.filter((c) => c.category === "string"); + + // Guard: must have at least one numeric column + if (numericCols.length === 0) return null; + + // Rule 1: DATE + numeric(s) → line (timeseries) + if (dateCols.length > 0 && numericCols.length >= 1) { + return { + chartType: "line", + xKey: dateCols[0].name, + yKey: + numericCols.length === 1 + ? numericCols[0].name + : numericCols.map((c) => c.name), + }; + } + + // Rules 2–6: STRING + numeric(s) + if (stringCols.length > 0 && numericCols.length >= 1) { + const xKey = stringCols[0].name; + const uniqueCategories = countUnique(rows, xKey); + + if (numericCols.length === 1) { + const yKey = numericCols[0].name; + + // Rule 2: few categories → pie (skip if negative values) + if ( + uniqueCategories <= INFERENCE_CONFIG.pieMaxCategories && + !hasNegativeValues(rows, yKey) + ) { + return { chartType: "pie", xKey, yKey }; + } + // Rule 3: moderate categories → bar + if (uniqueCategories <= INFERENCE_CONFIG.barMaxCategories) { + return { chartType: "bar", xKey, yKey }; + } + // Rule 4: many categories → line + return { chartType: "line", xKey, yKey }; + } + + // Multiple numerics + const yKey = numericCols.map((c) => c.name); + + // Rule 5: moderate categories → bar (grouped) + if (uniqueCategories <= INFERENCE_CONFIG.groupedBarMaxCategories) { + return { chartType: "bar", xKey, yKey }; + } + // Rule 6: many categories → line (multi-series) + return { chartType: "line", xKey, yKey }; + } + + // Rule 7: 2+ numerics only (no string, no date) → scatter + if ( + numericCols.length >= 2 && + stringCols.length === 0 && + dateCols.length === 0 + ) { + return { + chartType: "scatter", + xKey: numericCols[0].name, + yKey: numericCols[1].name, + }; + } + + // Rule 8: fallback — no chart + return null; +} diff --git a/packages/appkit-ui/src/react/genie/genie-chat-message.tsx b/packages/appkit-ui/src/react/genie/genie-chat-message.tsx index 04f15f04..4c979ffe 100644 --- a/packages/appkit-ui/src/react/genie/genie-chat-message.tsx +++ b/packages/appkit-ui/src/react/genie/genie-chat-message.tsx @@ -3,6 +3,7 @@ import { useMemo } from "react"; import { cn } from "../lib/utils"; import { Avatar, AvatarFallback } from "../ui/avatar"; import { Card } from "../ui/card"; +import { GenieQueryVisualization } from "./genie-query-visualization"; import type { GenieAttachmentResponse, GenieMessageItem } from "./types"; /** @@ -91,30 +92,41 @@ export function GenieChatMessage({ {queryAttachments.length > 0 && (
- {queryAttachments.map((att) => ( - -
- - {att.query?.title ?? "SQL Query"} - -
- {att.query?.description && ( - - {att.query.description} - - )} - {att.query?.query && ( -
-                        {att.query.query}
-                      
- )} -
-
-
- ))} + {queryAttachments.map((att) => { + const key = att.attachmentId ?? "query"; + const queryResult = att.attachmentId + ? message.queryResults.get(att.attachmentId) + : undefined; + + return ( +
+ +
+ + {att.query?.title ?? "SQL Query"} + +
+ {att.query?.description && ( + + {att.query.description} + + )} + {att.query?.query && ( +
+                            {att.query.query}
+                          
+ )} +
+
+
+ {queryResult != null && ( + + + + )} +
+ ); + })}
)} diff --git a/packages/appkit-ui/src/react/genie/genie-query-transform.ts b/packages/appkit-ui/src/react/genie/genie-query-transform.ts new file mode 100644 index 00000000..6db50902 --- /dev/null +++ b/packages/appkit-ui/src/react/genie/genie-query-transform.ts @@ -0,0 +1,118 @@ +/** + * Converts Genie's statement_response data into a flat record array + * suitable for charting. + * + * The Genie API returns `{ manifest.schema.columns, result.data_array }` + * where each column carries a SQL `type_name`. This module parses values + * according to those types so downstream chart code receives proper + * numbers and strings. + */ + +// SQL type_name values that map to numeric JS values +const NUMERIC_SQL_TYPES = new Set([ + "DECIMAL", + "INT", + "INTEGER", + "BIGINT", + "LONG", + "SMALLINT", + "TINYINT", + "FLOAT", + "DOUBLE", + "SHORT", + "BYTE", +]); + +// SQL type_name values that map to date/timestamp strings +const DATE_SQL_TYPES = new Set(["DATE", "TIMESTAMP", "TIMESTAMP_NTZ"]); + +export type ColumnCategory = "numeric" | "date" | "string"; + +export interface GenieColumnMeta { + name: string; + typeName: string; + category: ColumnCategory; +} + +export interface TransformedGenieData { + rows: Record[]; + columns: GenieColumnMeta[]; +} + +/** + * Classify a SQL type_name into a high-level category. + */ +export function classifySqlType(typeName: string): ColumnCategory { + const upper = typeName.toUpperCase(); + if (NUMERIC_SQL_TYPES.has(upper)) return "numeric"; + if (DATE_SQL_TYPES.has(upper)) return "date"; + return "string"; +} + +/** + * Parse a single cell value based on its column category. + */ +function parseValue(raw: string | null, category: ColumnCategory): unknown { + if (raw == null) return null; + if (category === "numeric") { + const n = Number(raw); + return Number.isNaN(n) ? null : n; + } + // Dates and strings stay as strings — normalizeChartData detects ISO dates + return raw; +} + +/** + * Transform a Genie statement_response into chart-ready rows + column metadata. + * + * Expects `data` to have the shape: + * ``` + * { + * manifest: { schema: { columns: [{ name, type_name }, ...] } }, + * result: { data_array: [["val", ...], ...] } + * } + * ``` + * + * Returns `null` when the data is empty or malformed. + */ +export function transformGenieData(data: unknown): TransformedGenieData | null { + if (!data || typeof data !== "object") return null; + + const obj = data as Record; + + // Extract columns schema + const manifest = obj.manifest as Record | undefined; + const schema = manifest?.schema as Record | undefined; + const rawColumns = schema?.columns as + | Array<{ name: string; type_name: string }> + | undefined; + + if (!rawColumns || !Array.isArray(rawColumns) || rawColumns.length === 0) { + return null; + } + + // Extract data rows + const result = obj.result as Record | undefined; + const dataArray = result?.data_array as string[][] | undefined; + + if (!dataArray || !Array.isArray(dataArray) || dataArray.length === 0) { + return null; + } + + const columns: GenieColumnMeta[] = rawColumns.map((col) => ({ + name: col.name, + typeName: col.type_name, + category: classifySqlType(col.type_name), + })); + + const rows: Record[] = dataArray.map((row) => { + const record: Record = {}; + for (let i = 0; i < columns.length; i++) { + const col = columns[i]; + record[col.name] = parseValue(row[i] ?? null, col.category); + } + return record; + }); + + return { rows, columns }; +} diff --git a/packages/appkit-ui/src/react/genie/genie-query-visualization.tsx b/packages/appkit-ui/src/react/genie/genie-query-visualization.tsx new file mode 100644 index 00000000..d6dae452 --- /dev/null +++ b/packages/appkit-ui/src/react/genie/genie-query-visualization.tsx @@ -0,0 +1,108 @@ +import { useMemo } from "react"; +import { BaseChart } from "../charts/base"; +import { ChartErrorBoundary } from "../charts/chart-error-boundary"; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "../ui/table"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "../ui/tabs"; +import { inferChartType } from "./genie-chart-inference"; +import { transformGenieData } from "./genie-query-transform"; + +const TABLE_ROW_LIMIT = 50; +const CHART_HEIGHT = 250; + +export interface GenieQueryVisualizationProps { + /** Raw statement_response from the Genie API */ + data: unknown; + /** Additional CSS classes */ + className?: string; +} + +/** + * Renders a chart + data table for a Genie query result. + * + * - When a chart type can be inferred: shows Tabs with "Chart" (default) and "Table" + * - When no chart fits: shows only the data table + * - When data is empty/malformed: renders nothing + */ +export function GenieQueryVisualization({ + data, + className, +}: GenieQueryVisualizationProps) { + const transformed = useMemo(() => transformGenieData(data), [data]); + const inference = useMemo( + () => + transformed + ? inferChartType(transformed.rows, transformed.columns) + : null, + [transformed], + ); + + if (!transformed || transformed.rows.length === 0) return null; + + const { rows, columns } = transformed; + const truncated = rows.length > TABLE_ROW_LIMIT; + const displayRows = truncated ? rows.slice(0, TABLE_ROW_LIMIT) : rows; + + const dataTable = ( +
+ + + + {columns.map((col) => ( + {col.name} + ))} + + + + {displayRows.map((row, i) => ( + // biome-ignore lint/suspicious/noArrayIndexKey: tabular data rows have no unique identifier + + {columns.map((col) => ( + + {row[col.name] != null ? String(row[col.name]) : ""} + + ))} + + ))} + +
+ {truncated && ( +

+ Showing {TABLE_ROW_LIMIT} of {rows.length} rows +

+ )} +
+ ); + + if (!inference) { + return
{dataTable}
; + } + + return ( + + + Chart + Table + + + + + + + {dataTable} + + ); +} diff --git a/packages/appkit-ui/src/react/genie/index.ts b/packages/appkit-ui/src/react/genie/index.ts index 0055abff..24254adc 100644 --- a/packages/appkit-ui/src/react/genie/index.ts +++ b/packages/appkit-ui/src/react/genie/index.ts @@ -1,6 +1,17 @@ +export { + type ChartInference, + inferChartType, +} from "./genie-chart-inference"; export { GenieChat } from "./genie-chat"; export { GenieChatInput } from "./genie-chat-input"; export { GenieChatMessage } from "./genie-chat-message"; export { GenieChatMessageList } from "./genie-chat-message-list"; +export { + type ColumnCategory, + type GenieColumnMeta, + type TransformedGenieData, + transformGenieData, +} from "./genie-query-transform"; +export { GenieQueryVisualization } from "./genie-query-visualization"; export type * from "./types"; export { useGenieChat } from "./use-genie-chat";