diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000000..fb5da2bc5a --- /dev/null +++ b/package-lock.json @@ -0,0 +1,58 @@ +{ + "name": "VChart", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "dependencies": { + "js-yaml": "^4.1.0" + }, + "devDependencies": { + "@types/js-yaml": "^4.0.9" + } + }, + "node_modules/@types/js-yaml": { + "version": "4.0.9", + "resolved": "https://registry.npmjs.org/@types/js-yaml/-/js-yaml-4.0.9.tgz", + "integrity": "sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==", + "dev": true + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==" + }, + "node_modules/js-yaml": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "dependencies": { + "argparse": "^2.0.1" + }, + "bin": { + "js-yaml": "bin/js-yaml.js" + } + } + }, + "dependencies": { + "@types/js-yaml": { + "version": "4.0.9", + "resolved": "https://registry.npmjs.org/@types/js-yaml/-/js-yaml-4.0.9.tgz", + "integrity": "sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==", + "dev": true + }, + "argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==" + }, + "js-yaml": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "requires": { + "argparse": "^2.0.1" + } + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000000..9a8b89242c --- /dev/null +++ b/package.json @@ -0,0 +1,8 @@ +{ + "dependencies": { + "js-yaml": "^4.1.0" + }, + "devDependencies": { + "@types/js-yaml": "^4.0.9" + } +} diff --git a/packages/vmind/__tests__/browser/src/pages/ChartPreview.tsx b/packages/vmind/__tests__/browser/src/pages/ChartPreview.tsx index cfb3bced8b..aad426f409 100644 --- a/packages/vmind/__tests__/browser/src/pages/ChartPreview.tsx +++ b/packages/vmind/__tests__/browser/src/pages/ChartPreview.tsx @@ -36,7 +36,7 @@ export function ChartPreview(props: IPropsType) { const [outType, setOutType] = useState<'gif' | 'video' | ''>(''); const [src, setSrc] = useState(''); - const vmind = new VMind(import.meta.OPENAI_KEY!); + const vmind = new VMind({}); // const [describe, setDescribe] = useState(mockUserInput6.input); // const [csv, setCsv] = useState(mockUserInput6.csv); // const [loading, setLoading] = useState(false); diff --git a/packages/vmind/__tests__/browser/src/pages/DataInput.tsx b/packages/vmind/__tests__/browser/src/pages/DataInput.tsx index 420857055f..acdf978ddd 100644 --- a/packages/vmind/__tests__/browser/src/pages/DataInput.tsx +++ b/packages/vmind/__tests__/browser/src/pages/DataInput.tsx @@ -23,7 +23,6 @@ import { mockUserInput14, mockUserInput16 } from '../constants/mockData'; -import { excel2csv } from '../../../../src/excel'; import VMind from '../../../../src/index'; import { Model } from '../../../../src/typings'; @@ -68,15 +67,24 @@ export function DataInput(props: IPropsType) { const [spec, setSpec] = useState(''); const [time, setTime] = useState(1000); const [loading, setLoading] = useState(false); - const vmind = new VMind(import.meta.env.OPENAI_KEY!, { - url: import.meta.env.VITE_OPENAI_URL ?? undefined + //const vmind = new VMind({ + // url: import.meta.env.VITE_OPENAI_URL ?? undefined, + // model:Model.GPT3_5 + //}); + + const vmind = new VMind({ + url: import.meta.env.VITE_SKYLARK_URL ?? undefined, + model: Model.SKYLARK, + headers: { + 'api-key': import.meta.env.VITE_SKYLARK_KEY + } }); const askGPT = useCallback(async () => { setLoading(true); - //const {fieldInfo,dataset}=vmind.parseCSVData(csv) - const { fieldInfo, dataset } = await vmind.parseDataWithGPT(csv, describe); - const { spec, time } = await vmind.generateChart(Model.GPT3_5, describe, fieldInfo, dataset); + const { fieldInfo, dataset } = vmind.parseCSVData(csv); + //const { fieldInfo, dataset } = await vmind.parseCSVDataWithLLM(csv, describe); + const { spec, time } = await vmind.generateChart(describe, fieldInfo, dataset); props.onSpecGenerate(spec, time as any); setLoading(false); }, [vmind, csv, describe, props]); diff --git a/packages/vmind/src/common/dataProcess/index.ts b/packages/vmind/src/common/dataProcess/index.ts index ff22011a9b..3d9f62fef6 100644 --- a/packages/vmind/src/common/dataProcess/index.ts +++ b/packages/vmind/src/common/dataProcess/index.ts @@ -1,5 +1,5 @@ import { DataSet, DataView, csvParser, fold } from '@visactor/vdataset'; -import { DataItem, IGPTOptions, SimpleFieldInfo } from '../../typings'; +import { DataItem, SimpleFieldInfo } from '../../typings'; import { getFieldInfoFromDataset } from './utils'; export const parseCSVWithVChart = (csvString: string) => { diff --git a/packages/vmind/src/common/schema.ts b/packages/vmind/src/common/schema.ts new file mode 100644 index 0000000000..e7c2298aba --- /dev/null +++ b/packages/vmind/src/common/schema.ts @@ -0,0 +1,23 @@ +import { LOCATION, SimpleFieldInfo, VizSchema } from '../typings'; + +/** + * generate vizSchema from fieldInfo + * @param fieldInfo SimpleFieldInfo + * @returns + */ +export const getSchemaFromFieldInfo = (fieldInfo: SimpleFieldInfo[]): Partial => { + const schema = { + fields: fieldInfo + //.filter(d => usefulFields.includes(d.fieldName)) + .map(d => ({ + id: d.fieldName, + alias: d.fieldName, + description: d.description, + visible: true, + type: d.type, + role: d.role, + location: d.role as unknown as LOCATION + })) + }; + return schema; +}; diff --git a/packages/vmind/src/common/vizDataToSpec/constants.ts b/packages/vmind/src/common/vizDataToSpec/constants.ts index 5c37be1b29..112150d2ed 100644 --- a/packages/vmind/src/common/vizDataToSpec/constants.ts +++ b/packages/vmind/src/common/vizDataToSpec/constants.ts @@ -19,7 +19,7 @@ export const oneByOneGroupSize = 10; //one-by-one动画 10个点一组 export const DEFAULT_VIDEO_LENGTH = 2000; export const DEFAULT_PIE_VIDEO_LENGTH = 5000; export const DEFAULT_VIDEO_LENGTH_LONG = 10000; -export const CHARTTYP_VIDEO_ELENGTH: Record = { +export const VIDEO_LENGTH_BY_CHART_TYPE: Record = { pie: DEFAULT_PIE_VIDEO_LENGTH, wordCloud: DEFAULT_VIDEO_LENGTH_LONG, wordcloud: DEFAULT_VIDEO_LENGTH_LONG diff --git a/packages/vmind/src/common/vizDataToSpec/index.ts b/packages/vmind/src/common/vizDataToSpec/index.ts index e6859082a0..5828fa59fa 100644 --- a/packages/vmind/src/common/vizDataToSpec/index.ts +++ b/packages/vmind/src/common/vizDataToSpec/index.ts @@ -1,2 +1,2 @@ -export { vizDataToSpec, patchChartTypeAndCell, checkChartTypeAndCell } from './vizDataToSpec'; +export { vizDataToSpec, checkChartTypeAndCell } from './vizDataToSpec'; export { SUPPORTED_CHART_LIST } from './constants'; diff --git a/packages/vmind/src/common/vizDataToSpec/utils.ts b/packages/vmind/src/common/vizDataToSpec/utils.ts index 7fde46ce26..56e773de3b 100644 --- a/packages/vmind/src/common/vizDataToSpec/utils.ts +++ b/packages/vmind/src/common/vizDataToSpec/utils.ts @@ -1,3 +1,5 @@ +import { VIDEO_LENGTH_BY_CHART_TYPE, DEFAULT_VIDEO_LENGTH } from './constants'; + export const detectAxesType = (values: any[], field: string) => { const isNumber = values.every(d => !d[field] || !isNaN(Number(d[field]))); if (isNumber) { @@ -17,3 +19,27 @@ export const CARTESIAN_CHART_LIST = [ 'Waterfall Chart', 'Box Plot Chart' ]; + +export const estimateVideoTime = (chartType: string, spec: any, parsedTime?: number) => { + //估算视频长度 + if (chartType === 'DYNAMIC BAR CHART') { + const frameNumber = spec.player.specs.length; + const duration = spec.player.interval; + return { + totalTime: parsedTime ?? frameNumber * duration, + frameArr: parsedTime + ? Array.from(new Array(frameNumber).keys()).map(n => Number(parsedTime / frameNumber)) + : Array.from(new Array(frameNumber).keys()).map(n => duration) + }; + } + + // chartType不是真实的图表类型,转一次 + const map: Record = { + 'PIE CHART': 'pie', + 'WORD CLOUD': 'wordCloud' + }; + return { + totalTime: parsedTime ?? VIDEO_LENGTH_BY_CHART_TYPE[map[chartType]] ?? DEFAULT_VIDEO_LENGTH, + frameArr: [] + }; +}; diff --git a/packages/vmind/src/common/vizDataToSpec/vizDataToSpec.ts b/packages/vmind/src/common/vizDataToSpec/vizDataToSpec.ts index 4fdfe9d38e..31d5dcb7e0 100644 --- a/packages/vmind/src/common/vizDataToSpec/vizDataToSpec.ts +++ b/packages/vmind/src/common/vizDataToSpec/vizDataToSpec.ts @@ -65,187 +65,6 @@ export const vizDataToSpec = ( return spec; }; -export const patchChartTypeAndCell = (chartTypeOutter: string, cell: any, dataset: any[]) => { - //对GPT返回结果进行修正 - //某些时候由于用户输入的意图不明确,GPT返回的cell中可能缺少字段。 - //此时需要根据规则补全 - //TODO: 多个y字段时,使用fold - - const { x, y } = cell; - - let chartType = chartTypeOutter; - // y轴字段有多个时,处理方式: - // 1. 图表类型为: 箱型图, 图表类型不做矫正 - // 2. 图表类型为: 柱状图 或 折线图, 图表类型矫正为双轴图 - // 3. 其他情况, 图表类型矫正为散点图 - if (y && typeof y !== 'string' && y.length > 1) { - if (chartType === 'BOX PLOT CHART') { - return { - chartTypeNew: chartType, - cellNew: cell - }; - } - if (chartType === 'BAR CHART' || chartType === 'LINE CHART') { - chartType = 'DUAL AXIS CHART'; - } else { - return { - chartTypeNew: 'SCATTER PLOT', - cellNew: { - ...cell, - x: y[0], - y: y[1], - color: typeof x === 'string' ? x : x[0] - } - }; - } - } - //双轴图 订正yLeft和yRight - if (chartType === 'DUAL AXIS CHART' && cell.yLeft && cell.yRight) { - return { - chartTypeNew: chartType, - cellNew: { ...cell, y: [cell.yLeft, cell.yRight] } - }; - } - //饼图 必须有color字段和angle字段 - if (chartType === 'PIE CHART') { - const cellNew = { ...cell }; - if (!cellNew.color || !cellNew.angle) { - const usedFields = Object.values(cell); - const dataFields = Object.keys(dataset[0]); - const remainedFields = dataFields.filter(f => !usedFields.includes(f)); - if (!cellNew.color) { - //没有分配颜色字段,从剩下的字段里选择一个离散字段分配到颜色上 - const colorField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'band'; - }); - if (colorField) { - cellNew.color = colorField; - } else { - cellNew.color = remainedFields[0]; - } - } - if (!cellNew.angle) { - //没有分配角度字段,从剩下的字段里选择一个连续字段分配到角度上 - const angleField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'linear'; - }); - if (angleField) { - cellNew.angle = angleField; - } else { - cellNew.angle = remainedFields[0]; - } - } - } - return { - chartTypeNew: chartType, - cellNew - }; - } - //词云 必须有color字段和size字段 - if (chartType === 'WORD CLOUD') { - const cellNew = { ...cell }; - if (!cellNew.size || !cellNew.color || cellNew.color === cellNew.size) { - const usedFields = Object.values(cell); - const dataFields = Object.keys(dataset[0]); - const remainedFields = dataFields.filter(f => !usedFields.includes(f)); - //首先根据cell中的其他字段选择size和color - //若没有,则从数据的剩余字段中选择 - if (!cellNew.size || cellNew.size === cellNew.color) { - const newSize = cellNew.weight ?? cellNew.fontSize; - if (newSize) { - cellNew.size = newSize; - } else { - const sizeField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'linear'; - }); - if (sizeField) { - cellNew.size = sizeField; - } else { - cellNew.size = remainedFields[0]; - } - } - } - if (!cellNew.color) { - const newColor = cellNew.text ?? cellNew.word ?? cellNew.label ?? cellNew.x; - if (newColor) { - cellNew.color = newColor; - } else { - const colorField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'band'; - }); - if (colorField) { - cellNew.color = colorField; - } else { - cellNew.color = remainedFields[0]; - } - } - } - } - return { - chartTypeNew: chartType, - cellNew - }; - } - if (chartType === 'DYNAMIC BAR CHART') { - const cellNew = { ...cell }; - - if (!cell.time || cell.time === '' || cell.time.length === 0) { - const flattenedXField = Array.isArray(cell.x) ? cell.x : [cell.x]; - const usedFields = Object.values(cellNew).filter(f => !Array.isArray(f)); - usedFields.push(...flattenedXField); - const dataFields = Object.keys(dataset[0]); - const remainedFields = dataFields.filter(f => !usedFields.includes(f)); - - //动态条形图没有time字段,选择一个离散字段作为time - const timeField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'band'; - }); - if (timeField) { - cellNew.time = timeField; - } else { - cellNew.time = remainedFields[0]; - } - } - return { - chartTypeNew: chartType, - cellNew - }; - } - //直角坐标图表 必须有x字段 - if (CARTESIAN_CHART_LIST.map(chart => chart.toUpperCase()).includes(chartType)) { - const cellNew = { ...cell }; - if (!cellNew.x) { - const usedFields = Object.values(cell); - const dataFields = Object.keys(dataset[0]); - const remainedFields = dataFields.filter(f => !usedFields.includes(f)); - //没有分配x字段,从剩下的字段里选择一个离散字段分配到x上 - const xField = remainedFields.find(f => { - const fieldType = detectAxesType(dataset, f); - return fieldType === 'band'; - }); - if (xField) { - cellNew.x = xField; - } else { - cellNew.x = remainedFields[0]; - } - } - return { - chartTypeNew: chartType, - cellNew - }; - } - - return { - chartTypeNew: chartType, - cellNew: cell - }; -}; - export const checkChartTypeAndCell = (chartType: string, cell: any): boolean => { switch (chartType) { case 'BAR CHART': diff --git a/packages/vmind/src/core/VMind.ts b/packages/vmind/src/core/VMind.ts index 2062dec7ca..2145634003 100644 --- a/packages/vmind/src/core/VMind.ts +++ b/packages/vmind/src/core/VMind.ts @@ -1,26 +1,40 @@ import { _chatToVideoWasm } from '../chart-to-video'; import { generateChartWithGPT } from '../gpt/chart-generation/NLToChart'; -import { IGPTOptions, TimeType, Model, SimpleFieldInfo, DataItem } from '../typings'; +import { ILLMOptions, TimeType, Model, SimpleFieldInfo, DataItem } from '../typings'; import type { FFmpeg } from '@ffmpeg/ffmpeg'; import { parseCSVDataWithGPT } from '../gpt/dataProcess'; import { parseCSVData as parseCSVDataWithRule } from '../common/dataProcess'; +import { generateChartWithSkylark } from '../skylark/chart-generation'; class VMind { private _FPS = 30; - private _options: IGPTOptions | undefined; + private _options: ILLMOptions | undefined; private _model: Model; - constructor(options?: IGPTOptions) { - this._options = options; + constructor(options?: ILLMOptions) { + this._options = { ...(options ?? {}) }; this._model = options.model; } + /** + * parse csv string and get the name, type of each field using rule-based method. + * @param csvString csv data user want to visualize + * @returns fieldInfo and raw dataset. + */ parseCSVData(csvString: string): { fieldInfo: SimpleFieldInfo[]; dataset: DataItem[] } { //Parse CSV Data without LLM //return dataset and fieldInfo return parseCSVDataWithRule(csvString); } + /** + * call LLM to parse csv data. return fieldInfo and raw dataset. + * fieldInfo includes name, type, role, description of each field. + * NOTE: This will transfer your data to LLM. + * @param csvString csv data user want to visualize + * @param userPrompt + * @returns + */ parseCSVDataWithLLM(csvString: string, userPrompt: string) { if ([Model.GPT3_5, Model.GPT4].includes(this._model)) { return parseCSVDataWithGPT(csvString, userPrompt, this._options); @@ -28,6 +42,15 @@ class VMind { console.error('Unsupported Model!'); } + /** + * + * @param userPrompt user's visualization intention (what aspect they want to show in the data) + * @param fieldInfo information about fields in the dataset. field name, type, etc. You can get fieldInfo using parseCSVData or parseCSVDataWithLLM + * @param dataset raw dataset used in the chart + * @param colorPalette color palette of the chart + * @param animationDuration duration of chart animation. + * @returns spec and time duration of the chart. + */ async generateChart( userPrompt: string, //user's intent of visualization, usually aspect in data that they want to visualize fieldInfo: SimpleFieldInfo[], @@ -37,11 +60,9 @@ class VMind { ) { if ([Model.GPT3_5, Model.GPT4].includes(this._model)) { return generateChartWithGPT(userPrompt, fieldInfo, dataset, this._options, colorPalette, animationDuration); - } else if (this._model == Model.SKYLARK) { - return {}; } if (this._model === Model.SKYLARK) { - return generateChartWithSkylark(userPrompt, fieldInfo); + return generateChartWithSkylark(userPrompt, fieldInfo, dataset, this._options, colorPalette, animationDuration); } return {}; } diff --git a/packages/vmind/src/gpt/chart-generation/NLToChart.ts b/packages/vmind/src/gpt/chart-generation/NLToChart.ts index 69a27d94d3..2ed635ec16 100644 --- a/packages/vmind/src/gpt/chart-generation/NLToChart.ts +++ b/packages/vmind/src/gpt/chart-generation/NLToChart.ts @@ -1,44 +1,18 @@ -import { - DEFAULT_VIDEO_LENGTH, - CHARTTYP_VIDEO_ELENGTH, - SUPPORTED_CHART_LIST -} from '../../common/vizDataToSpec/constants'; -import { DataItem, GPTChartAdvisorResult, IGPTOptions, LOCATION, SimpleFieldInfo, VizSchema } from '../../typings'; -import { checkChartTypeAndCell, patchChartTypeAndCell, vizDataToSpec } from '../../common/vizDataToSpec'; +import { SUPPORTED_CHART_LIST } from '../../common/vizDataToSpec/constants'; +import { DataItem, GPTChartAdvisorResult, ILLMOptions, LOCATION, SimpleFieldInfo, VizSchema } from '../../typings'; +import { checkChartTypeAndCell, vizDataToSpec } from '../../common/vizDataToSpec'; import { parseGPTResponse, requestGPT } from '../utils'; -import { patchUserInput } from './utils'; +import { patchChartTypeAndCell, patchUserInput } from './utils'; import { ChartAdvisorPromptEnglish } from './prompts'; import { chartAdvisorHandler } from '../../common/chartAdvisor'; - -export const estimateVideoTime = (chartType: string, spec: any, parsedTime?: number) => { - //估算视频长度 - if (chartType === 'DYNAMIC BAR CHART') { - const frameNumber = spec.player.specs.length; - const duration = spec.player.interval; - return { - totalTime: parsedTime ?? frameNumber * duration, - frameArr: parsedTime - ? Array.from(new Array(frameNumber).keys()).map(n => Number(parsedTime / frameNumber)) - : Array.from(new Array(frameNumber).keys()).map(n => duration) - }; - } - - // chartType不是真实的图表类型,转一次 - const map: Record = { - 'PIE CHART': 'pie', - 'WORD CLOUD': 'wordCloud' - }; - return { - totalTime: parsedTime ?? CHARTTYP_VIDEO_ELENGTH[map[chartType]] ?? DEFAULT_VIDEO_LENGTH, - frameArr: [] - }; -}; +import { estimateVideoTime } from '../../common/vizDataToSpec/utils'; +import { getSchemaFromFieldInfo } from '../../common/schema'; export const generateChartWithGPT = async ( userPrompt: string, //user's intent of visualization, usually aspect in data that they want to visualize fieldInfo: SimpleFieldInfo[], propsDataset: DataItem[], - options: IGPTOptions, + options: ILLMOptions, colorPalette?: string[], animationDuration?: number ) => { @@ -50,7 +24,7 @@ export const generateChartWithGPT = async ( let dataset: DataItem[] = propsDataset; try { // throw 'test chartAdvisorHandler'; - const resJson: any = await chartAdvisorGPT(schema, fieldInfo, userInputFinal, options); + const resJson: any = await chartAdvisorGPT(schema, userInputFinal, options); const chartTypeRes = resJson['CHART_TYPE'].toUpperCase(); const cellRes = resJson['FIELD_MAP']; @@ -62,6 +36,7 @@ export const generateChartWithGPT = async ( } catch (err) { console.warn(err); console.warn('LLM generation error, use rule generation.'); + // call rule-based method to get recommended chart type and fieldMap(cell) const advisorResult = chartAdvisorHandler(schema, dataset); chartType = advisorResult.chartType; cell = advisorResult.cell; @@ -82,13 +57,20 @@ export const generateChartWithGPT = async ( }; }; +/** + * call GPT to get recommended chart type and fieldMap + * @param schema VizSchema + * @param userInput user input about their intention + * @param options vmind options + * @returns + */ export const chartAdvisorGPT = async ( schema: Partial, - fieldInfo: SimpleFieldInfo[], userInput: string, - options: IGPTOptions | undefined + options: ILLMOptions | undefined ) => { - const filteredFields = fieldInfo.filter( + //call GPT + const filteredFields = schema.fields.filter( field => true //usefulFields.includes(field.fieldName) ); @@ -108,20 +90,3 @@ export const chartAdvisorGPT = async ( } return advisorResJson; }; - -export const getSchemaFromFieldInfo = (fieldInfo: SimpleFieldInfo[]): Partial => { - const schema = { - fields: fieldInfo - //.filter(d => usefulFields.includes(d.fieldName)) - .map(d => ({ - id: d.fieldName, - alias: d.fieldName, - description: d.description, - visible: true, - type: d.type, - role: d.role, - location: d.role as unknown as LOCATION - })) - }; - return schema; -}; diff --git a/packages/vmind/src/gpt/chart-generation/utils.ts b/packages/vmind/src/gpt/chart-generation/utils.ts index bc8aa115a7..b4f7e4fae3 100644 --- a/packages/vmind/src/gpt/chart-generation/utils.ts +++ b/packages/vmind/src/gpt/chart-generation/utils.ts @@ -1,3 +1,5 @@ +import { CARTESIAN_CHART_LIST, detectAxesType } from '../../common/vizDataToSpec/utils'; + export const patchUserInput = (userInput: string) => { const FULL_WIDTH_SYMBOLS = [',', '。']; const HALF_WIDTH_SYMBOLS = [',', '.']; @@ -26,3 +28,184 @@ export const patchUserInput = (userInput: string) => { '严格按照prompt中的格式回复,不要有任何多余内容。 Use the original fieldName and DO NOT change or translate any word of the data fields in the response.'; return finalStr; }; + +export const patchChartTypeAndCell = (chartTypeOutter: string, cell: any, dataset: any[]) => { + //对GPT返回结果进行修正 + //某些时候由于用户输入的意图不明确,GPT返回的cell中可能缺少字段。 + //此时需要根据规则补全 + //TODO: 多个y字段时,使用fold + + const { x, y } = cell; + + let chartType = chartTypeOutter; + // y轴字段有多个时,处理方式: + // 1. 图表类型为: 箱型图, 图表类型不做矫正 + // 2. 图表类型为: 柱状图 或 折线图, 图表类型矫正为双轴图 + // 3. 其他情况, 图表类型矫正为散点图 + if (y && typeof y !== 'string' && y.length > 1) { + if (chartType === 'BOX PLOT CHART') { + return { + chartTypeNew: chartType, + cellNew: cell + }; + } + if (chartType === 'BAR CHART' || chartType === 'LINE CHART') { + chartType = 'DUAL AXIS CHART'; + } else { + return { + chartTypeNew: 'SCATTER PLOT', + cellNew: { + ...cell, + x: y[0], + y: y[1], + color: typeof x === 'string' ? x : x[0] + } + }; + } + } + //双轴图 订正yLeft和yRight + if (chartType === 'DUAL AXIS CHART' && cell.yLeft && cell.yRight) { + return { + chartTypeNew: chartType, + cellNew: { ...cell, y: [cell.yLeft, cell.yRight] } + }; + } + //饼图 必须有color字段和angle字段 + if (chartType === 'PIE CHART') { + const cellNew = { ...cell }; + if (!cellNew.color || !cellNew.angle) { + const usedFields = Object.values(cell); + const dataFields = Object.keys(dataset[0]); + const remainedFields = dataFields.filter(f => !usedFields.includes(f)); + if (!cellNew.color) { + //没有分配颜色字段,从剩下的字段里选择一个离散字段分配到颜色上 + const colorField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'band'; + }); + if (colorField) { + cellNew.color = colorField; + } else { + cellNew.color = remainedFields[0]; + } + } + if (!cellNew.angle) { + //没有分配角度字段,从剩下的字段里选择一个连续字段分配到角度上 + const angleField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'linear'; + }); + if (angleField) { + cellNew.angle = angleField; + } else { + cellNew.angle = remainedFields[0]; + } + } + } + return { + chartTypeNew: chartType, + cellNew + }; + } + //词云 必须有color字段和size字段 + if (chartType === 'WORD CLOUD') { + const cellNew = { ...cell }; + if (!cellNew.size || !cellNew.color || cellNew.color === cellNew.size) { + const usedFields = Object.values(cell); + const dataFields = Object.keys(dataset[0]); + const remainedFields = dataFields.filter(f => !usedFields.includes(f)); + //首先根据cell中的其他字段选择size和color + //若没有,则从数据的剩余字段中选择 + if (!cellNew.size || cellNew.size === cellNew.color) { + const newSize = cellNew.weight ?? cellNew.fontSize; + if (newSize) { + cellNew.size = newSize; + } else { + const sizeField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'linear'; + }); + if (sizeField) { + cellNew.size = sizeField; + } else { + cellNew.size = remainedFields[0]; + } + } + } + if (!cellNew.color) { + const newColor = cellNew.text ?? cellNew.word ?? cellNew.label ?? cellNew.x; + if (newColor) { + cellNew.color = newColor; + } else { + const colorField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'band'; + }); + if (colorField) { + cellNew.color = colorField; + } else { + cellNew.color = remainedFields[0]; + } + } + } + } + return { + chartTypeNew: chartType, + cellNew + }; + } + if (chartType === 'DYNAMIC BAR CHART') { + const cellNew = { ...cell }; + + if (!cell.time || cell.time === '' || cell.time.length === 0) { + const flattenedXField = Array.isArray(cell.x) ? cell.x : [cell.x]; + const usedFields = Object.values(cellNew).filter(f => !Array.isArray(f)); + usedFields.push(...flattenedXField); + const dataFields = Object.keys(dataset[0]); + const remainedFields = dataFields.filter(f => !usedFields.includes(f)); + + //动态条形图没有time字段,选择一个离散字段作为time + const timeField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'band'; + }); + if (timeField) { + cellNew.time = timeField; + } else { + cellNew.time = remainedFields[0]; + } + } + return { + chartTypeNew: chartType, + cellNew + }; + } + //直角坐标图表 必须有x字段 + if (CARTESIAN_CHART_LIST.map(chart => chart.toUpperCase()).includes(chartType)) { + const cellNew = { ...cell }; + if (!cellNew.x) { + const usedFields = Object.values(cell); + const dataFields = Object.keys(dataset[0]); + const remainedFields = dataFields.filter(f => !usedFields.includes(f)); + //没有分配x字段,从剩下的字段里选择一个离散字段分配到x上 + const xField = remainedFields.find(f => { + const fieldType = detectAxesType(dataset, f); + return fieldType === 'band'; + }); + if (xField) { + cellNew.x = xField; + } else { + cellNew.x = remainedFields[0]; + } + } + return { + chartTypeNew: chartType, + cellNew + }; + } + + return { + chartTypeNew: chartType, + cellNew: cell + }; +}; diff --git a/packages/vmind/src/gpt/dataProcess/index.ts b/packages/vmind/src/gpt/dataProcess/index.ts index 30f2ed011e..32d4d8bfbb 100644 --- a/packages/vmind/src/gpt/dataProcess/index.ts +++ b/packages/vmind/src/gpt/dataProcess/index.ts @@ -1,6 +1,6 @@ import { getDataset, parseCSVData } from '../../common/dataProcess'; import { readTopNLine } from '../../common/dataProcess/utils'; -import { IGPTOptions } from '../../typings'; +import { ILLMOptions } from '../../typings'; import { parseGPTResponse, requestGPT } from '../utils'; import { DataProcessPromptEnglish } from './prompts'; @@ -8,7 +8,7 @@ import { DataProcessPromptEnglish } from './prompts'; ** call GPT to parse csv data **get the fieldInfo from csv file */ -export const parseCSVDataWithGPT = async (csvFile: string, userInput: string, options: IGPTOptions | undefined) => { +export const parseCSVDataWithGPT = async (csvFile: string, userInput: string, options: ILLMOptions | undefined) => { const DATA_TOP_N = 5; //取csv文件的前多少条数据 const topNCSVFile = readTopNLine(csvFile, DATA_TOP_N); const dataProcessMessage = `CSV file content:\n${topNCSVFile}\nUser Input: ${userInput}`; @@ -31,6 +31,6 @@ export const parseCSVDataWithGPT = async (csvFile: string, userInput: string, op //传统方法做兜底 const { fieldInfo } = parseCSVData(csvFile); console.error('gpt parse data error!'); - return { FIELD_INFO: fieldInfo, dataset }; + return { fieldInfo, dataset }; } }; diff --git a/packages/vmind/src/gpt/utils.ts b/packages/vmind/src/gpt/utils.ts index 9e6755ef33..7a633828f4 100644 --- a/packages/vmind/src/gpt/utils.ts +++ b/packages/vmind/src/gpt/utils.ts @@ -1,15 +1,15 @@ -import { GPTDataProcessResult, IGPTOptions } from '../typings'; +import { GPTDataProcessResult, ILLMOptions } from '../typings'; import axios from 'axios'; import JSON5 from 'json5'; -export const requestGPT = async (prompt: string, userMessage: string, options: IGPTOptions | undefined) => { +export const requestGPT = async (prompt: string, userMessage: string, options: ILLMOptions | undefined) => { const OPENAI_API_URL = 'https://api.openai.com/v1/chat/completions'; const url: string = options?.url ?? OPENAI_API_URL; - const defaultHeaders: HeadersInit = { 'Content-Type': 'application/json' }; + const headers = { ...(options.headers ?? {}), 'Content-Type': 'application/json' }; const res = await axios(url, { method: options?.method ?? 'POST', - headers: options?.headers ?? (defaultHeaders as any), //must has Authorization: `Bearer ${openAIKey}` if use openai api + headers, //must has Authorization: `Bearer ${openAIKey}` if use openai api data: { model: options?.model ?? 'gpt-3.5-turbo', messages: [ diff --git a/packages/vmind/src/skylark/chart-generation/NLToChart.ts b/packages/vmind/src/skylark/chart-generation/NLToChart.ts new file mode 100644 index 0000000000..5eb8b8532e --- /dev/null +++ b/packages/vmind/src/skylark/chart-generation/NLToChart.ts @@ -0,0 +1,80 @@ +import axios from 'axios'; +import { chartAdvisorHandler } from '../../common/chartAdvisor'; +import { getSchemaFromFieldInfo } from '../../common/schema'; +import { SUPPORTED_CHART_LIST, checkChartTypeAndCell, vizDataToSpec } from '../../common/vizDataToSpec'; +import { DataItem, ILLMOptions, SimpleFieldInfo, VizSchema } from '../../typings'; +import { patchChartTypeAndCell, requestSkyLark } from './utils'; +import { ChartRecommendPrompt } from './prompts'; +import { parseSkylarkResponse } from '../utils'; +import { ChartRecommendResult } from '../typings'; +import { estimateVideoTime } from '../../common/vizDataToSpec/utils'; + +export const generateChartWithSkylark = async ( + userPrompt: string, //user's intent of visualization, usually aspect in data that they want to visualize + fieldInfo: SimpleFieldInfo[], + propsDataset: DataItem[], + options: ILLMOptions, + colorPalette?: string[], + animationDuration?: number +) => { + const schema = getSchemaFromFieldInfo(fieldInfo); + const colors = colorPalette; + let chartType; + let cell; + let dataset: DataItem[] = propsDataset; + try { + // throw 'test chartAdvisorHandler'; + const resJson: any = await chartAdvisorSkylark(schema, fieldInfo, userPrompt, options); + + const chartTypeRes = resJson.chartType.toUpperCase(); + //TODO: request skylark for cell according to chartType + const patchResult = patchChartTypeAndCell(chartTypeRes, cellRes, dataset); + if (checkChartTypeAndCell(patchResult.chartTypeNew, patchResult.cellNew)) { + chartType = patchResult.chartTypeNew; + cell = patchResult.cellNew; + } + } catch (err) { + console.warn(err); + console.warn('LLM generation error, use rule generation.'); + const advisorResult = chartAdvisorHandler(schema, dataset); + chartType = advisorResult.chartType; + cell = advisorResult.cell; + dataset = advisorResult.dataset as DataItem[]; + } + const spec = vizDataToSpec( + dataset, + chartType, + cell, + colors, + animationDuration ? animationDuration * 1000 : undefined + ); + spec.background = '#00000033'; + console.info(spec); + return { + spec, + time: estimateVideoTime(chartType, spec, animationDuration ? animationDuration * 1000 : undefined) + }; +}; + +export const chartAdvisorSkylark = async ( + schema: Partial, + fieldInfo: SimpleFieldInfo[], + userInput: string, + options: ILLMOptions | undefined +) => { + const chartAdvisorMessage = `User Input: ${userInput}\nData field description: ${JSON.stringify(schema.fields)}`; + console.log(chartAdvisorMessage); + + const recommendRes = await requestSkyLark(ChartRecommendPrompt, chartAdvisorMessage, options); + + const recommendResJson: ChartRecommendResult = parseSkylarkResponse(recommendRes); + + console.log(recommendResJson); + if (recommendResJson.error) { + throw Error('Network Error!'); + } + if (!SUPPORTED_CHART_LIST.includes(recommendResJson['chartType'])) { + throw Error('Unsupported Chart Type. Please Change User Input'); + } + return recommendResJson; +}; diff --git a/packages/vmind/src/skylark/chart-generation/index.ts b/packages/vmind/src/skylark/chart-generation/index.ts new file mode 100644 index 0000000000..9c86d94ce9 --- /dev/null +++ b/packages/vmind/src/skylark/chart-generation/index.ts @@ -0,0 +1 @@ +export * from './NLToChart'; diff --git a/packages/vmind/src/skylark/chart-generation/prompts.ts b/packages/vmind/src/skylark/chart-generation/prompts.ts new file mode 100644 index 0000000000..d5e204d2b8 --- /dev/null +++ b/packages/vmind/src/skylark/chart-generation/prompts.ts @@ -0,0 +1,19 @@ +import { SUPPORTED_CHART_LIST } from '../../common/vizDataToSpec/constants'; + +export const ChartRecommendPrompt = `You are an export in data visualization. +Your task is: +1. Based on the user's input, infer the user's intention, such as comparison, ranking, trend display, proportion, distribution, etc. If user did not show their intention, just ignore and do the next steps. +2. Select the single chart type that best suites the data from the list of supported charts: ${JSON.stringify( + SUPPORTED_CHART_LIST +)}. +3. Response in YAML format without any additional descriptions + +Knowledge: +1. Dynamic Bar Chart is a dynamic chart that is suitable for displaying changing data and can be used to show ranking, comparisons or data changes over time. It usually has a time field. It updates the data dynamically according to the time field and at each time point, the current data is displayed using a bar chart. + +Let's think step by step. Fill your thoughts in {thoughts}. + +Respone in the following format: +thoughts: //Your thoughts +chartType: //chartType you choose based on data and user's input. +`; diff --git a/packages/vmind/src/skylark/chart-generation/utils.ts b/packages/vmind/src/skylark/chart-generation/utils.ts new file mode 100644 index 0000000000..ecb588b5b2 --- /dev/null +++ b/packages/vmind/src/skylark/chart-generation/utils.ts @@ -0,0 +1,44 @@ +import axios from 'axios'; +import { ILLMOptions } from '../../typings'; + +export const patchChartTypeAndCell = (chartTypeRes, cellRes, dataset) => { + return { + chartTypeNew: chartTypeRes, + cellNew: cellRes + }; +}; + +/** + * + * @param prompt + * @param message + * @param options + */ +export const requestSkyLark = async (prompt: string, message: string, options: ILLMOptions) => { + const url: string = options?.url; + const headers = { ...(options.headers ?? {}), 'Content-Type': 'application/json' }; + + const res = await axios(url, { + method: options?.method ?? 'POST', + headers, //must has Authorization: `Bearer ${openAIKey}` if use openai api + data: { + model: options?.model ?? 'gpt-3.5-turbo', + messages: [ + { + role: 'system', + content: prompt + }, + { + role: 'user', + content: message + } + ], + max_tokens: options?.max_tokens ?? 2000, + temperature: options?.temperature ?? 0 + } + }) + .then(response => response.data) + .then(data => data.choices); + + return res; +}; diff --git a/packages/vmind/src/skylark/typings/index.ts b/packages/vmind/src/skylark/typings/index.ts new file mode 100644 index 0000000000..f5b5c21352 --- /dev/null +++ b/packages/vmind/src/skylark/typings/index.ts @@ -0,0 +1,5 @@ +export type ChartRecommendResult = { + chartType: string; + thoughts?: string; + error?: boolean; +}; diff --git a/packages/vmind/src/skylark/utils.ts b/packages/vmind/src/skylark/utils.ts new file mode 100644 index 0000000000..0603c78043 --- /dev/null +++ b/packages/vmind/src/skylark/utils.ts @@ -0,0 +1,13 @@ +import yaml from 'js-yaml'; +import { ChartRecommendResult } from './typings'; + +export const parseSkylarkResponse = (larkResponse: any): ChartRecommendResult => { + try { + const resJson = yaml.load(larkResponse[0].message.content) as ChartRecommendResult; + return { + chartType: resJson.chartType + }; + } catch (err) { + return { error: true, chartType: undefined }; + } +}; diff --git a/packages/vmind/src/typings/index.ts b/packages/vmind/src/typings/index.ts index 0f2d503fa7..e9d491c498 100644 --- a/packages/vmind/src/typings/index.ts +++ b/packages/vmind/src/typings/index.ts @@ -1,4 +1,4 @@ -export interface IGPTOptions { +export interface ILLMOptions { url?: string; /** gpt request header, which has higher priority */ headers?: HeadersInit; @@ -107,7 +107,7 @@ export type VizSchema = { export enum Model { GPT3_5 = 'gpt-3.5', GPT4 = 'gpt-3.5', - SKYLARK = 'skylark' + SKYLARK = 'skylark-pro' } export type ChartGenerationProps = {