From 6d345ae91dbe8d64b2f0575a51b2e8cfe81c601c Mon Sep 17 00:00:00 2001 From: ArvinLovegood Date: Wed, 2 Jul 2025 12:13:52 +0800 Subject: [PATCH] =?UTF-8?q?feat(app):=E9=9B=86=E6=88=90AI=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E5=B9=B6=E4=BC=98=E5=8C=96=E8=82=A1=E7=A5=A8=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=8E=B7=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 App 结构中添加 AiTools 字段,用于存储 AI 工具配置 - 新增 AddTools 函数,定义了两个 AI 工具:SearchStockByIndicators 和 GetStockKLine- 修改 NewApp 函数,初始化时加载 AI 工具配置- 更新相关函数,支持使用 AI 工具进行股票数据查询- 优化股票 K 线数据获取逻辑,增加对不同市场股票代码的支持 --- app.go | 73 ++++++++++++++++++++++++++------------ backend/data/openai_api.go | 72 +++++++++++++++++++++++++++++-------- 2 files changed, 108 insertions(+), 37 deletions(-) diff --git a/app.go b/app.go index faf10c0..57f11d4 100644 --- a/app.go +++ b/app.go @@ -36,6 +36,7 @@ type App struct { cache *freecache.Cache cron *cron.Cron cronEntrys map[string]cron.EntryID + AiTools []data.Tool } // NewApp creates a new App application struct @@ -44,13 +45,60 @@ func NewApp() *App { cache := freecache.NewCache(cacheSize) c := cron.New(cron.WithSeconds()) c.Start() + var tools []data.Tool + tools = AddTools(tools) return &App{ cache: cache, cron: c, cronEntrys: make(map[string]cron.EntryID), + AiTools: tools, } } +func AddTools(tools []data.Tool) []data.Tool { + tools = append(tools, data.Tool{ + Type: "function", + Function: data.ToolFunction{ + Name: "SearchStockByIndicators", + Description: "根据自然语言筛选股票,返回自然语言选股条件要求的股票所有相关数据。单独输入股票名称可以获取当前股票最新的股价交易数据和基础财务指标信息", + Parameters: data.FunctionParameters{ + Type: "object", + Properties: map[string]any{ + "words": map[string]any{ + "type": "string", + "description": "选股自然语言,并且条件使用;分隔,或者条件使用,分隔。例1:创新药;PE<30;净利润增长率>50%。 例2:上证指数(指数名称)。 例3:长电科技(股票名称)", + }, + }, + Required: []string{"words"}, + }, + }, + }) + + tools = append(tools, data.Tool{ + Type: "function", + Function: data.ToolFunction{ + Name: "GetStockKLine", + Description: "获取股票日K线数据", + Parameters: data.FunctionParameters{ + Type: "object", + Properties: map[string]any{ + "days": map[string]any{ + "type": "string", + "description": "日K数据条数", + }, + "stockCode": map[string]any{ + "type": "string", + "description": "股票代码(A股:sh,sz开头;港股hk开头,美股:us开头)", + }, + }, + Required: []string{"days", "stockCode"}, + }, + }, + }) + + return tools +} + // startup is called at application startup func (a *App) startup(ctx context.Context) { defer PanicHandler() @@ -311,7 +359,7 @@ func (a *App) AddCronTask(follow data.FollowedStock) func() { return func() { go runtime.EventsEmit(a.ctx, "warnMsg", "开始自动分析"+follow.Name+"_"+follow.StockCode) ai := data.NewDeepSeekOpenAi(a.ctx) - msgs := ai.NewChatStream(follow.Name, follow.StockCode, "", nil) + msgs := ai.NewChatStream(follow.Name, follow.StockCode, "", nil, a.AiTools) var res strings.Builder chatId := "" @@ -748,7 +796,7 @@ func (a *App) SendDingDingMessageByType(message string, stockCode string, msgTyp } func (a *App) NewChatStream(stock, stockCode, question string, sysPromptId *int) { - msgs := data.NewDeepSeekOpenAi(a.ctx).NewChatStream(stock, stockCode, question, sysPromptId) + msgs := data.NewDeepSeekOpenAi(a.ctx).NewChatStream(stock, stockCode, question, sysPromptId, a.AiTools) for msg := range msgs { runtime.EventsEmit(a.ctx, "newChatStream", msg) } @@ -1129,26 +1177,7 @@ func (a *App) GlobalStockIndexes() map[string]any { } func (a *App) SummaryStockNews(question string, sysPromptId *int) { - var tools []data.Tool - tools = append(tools, data.Tool{ - Type: "function", - Function: data.ToolFunction{ - Name: "SearchStockByIndicators", - Description: "根据自然语言筛选股票,返回自然语言选股条件要求的股票所有相关数据。单独输入股票名称可以获取当前股票最新的股价交易数据和基础财务指标信息", - Parameters: data.FunctionParameters{ - Type: "object", - Properties: map[string]any{ - "words": map[string]any{ - "type": "string", - "description": "选股自然语言,并且条件使用;分隔,或者条件使用,分隔。例1:创新药;PE<30;净利润增长率>50%。 例2:上证指数(指数名称)。 例3:长电科技(股票名称)", - }, - }, - Required: []string{"words"}, - }, - }, - }) - - msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, tools) + msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, a.AiTools) for msg := range msgs { runtime.EventsEmit(a.ctx, "summaryStockNews", msg) } diff --git a/backend/data/openai_api.go b/backend/data/openai_api.go index f486d6a..89fd91d 100644 --- a/backend/data/openai_api.go +++ b/backend/data/openai_api.go @@ -191,7 +191,7 @@ func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromp }() wg.Wait() - news := NewMarketNewsApi().GetNewsList("", 100) + news := NewMarketNewsApi().GetNewsList("财联社电报", 500) messageText := strings.Builder{} for _, telegraph := range *news { messageText.WriteString("## " + telegraph.Time + ":" + "\n") @@ -310,7 +310,7 @@ func (o OpenAi) NewSummaryStockNewsStream(userQuestion string, sysPromptId *int) return ch } -func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId *int) <-chan map[string]any { +func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId *int, tools []Tool) <-chan map[string]any { ch := make(chan map[string]any, 512) defer func() { @@ -647,7 +647,11 @@ func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId //reqJson, _ := json.Marshal(msg) //logger.SugaredLogger.Errorf("Stream request: \n%s\n", reqJson) - AskAi(o, err, msg, ch, question) + if tools != nil && len(tools) > 0 { + AskAiWithTools(o, err, msg, ch, question, tools) + } else { + AskAi(o, err, msg, ch, question) + } }() return ch } @@ -859,17 +863,30 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c for _, choice := range streamResponse.Choices { if content := choice.Delta.Content; content != "" { //ch <- content - ch <- map[string]any{ - "code": 1, - "question": question, - "chatId": streamResponse.Id, - "model": streamResponse.Model, - "content": content, - "time": time.Now().Format(time.DateTime), + //logger.SugaredLogger.Infof("Content data: %s", content) + + if content == "###" { + currentAIContent.WriteString("\r\n" + content) + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": "\r\n" + content, + "time": time.Now().Format(time.DateTime), + } + } else { + currentAIContent.WriteString(content) + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": content, + "time": time.Now().Format(time.DateTime), + } } - //logger.SugaredLogger.Infof("Content data: %s", content) - currentAIContent.WriteString(content) } if reasoningContent := choice.Delta.ReasoningContent; reasoningContent != "" { //ch <- reasoningContent @@ -913,16 +930,16 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c "question": question, "chatId": streamResponse.Id, "model": streamResponse.Model, - "content": "- ```开始调用工具:SearchStockByIndicators,\n参数:" + words + "``` " + "\n", + "content": "\r\n```\r\n开始调用工具:SearchStockByIndicators,\n参数:" + words + "\r\n```\r\n", "time": time.Now().Format(time.DateTime), } - res := NewSearchStockApi(words).SearchStock(10) + res := NewSearchStockApi(words).SearchStock(50) searchRes, _ := json.Marshal(res) content := gjson.Get(string(searchRes), "data.result").String() - logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content) + //logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content) //messages = append(messages, map[string]interface{}{ // "role": "assistant", @@ -935,6 +952,31 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c }) } + + if funcName == "GetStockKLine" { + stockCode := gjson.Get(funcArguments, "stockCode").String() + days := gjson.Get(funcArguments, "days").String() + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": "\r\n```\r\n开始调用工具:GetStockKLine,\n参数:" + stockCode + "," + days + "\r\n```\r\n", + "time": time.Now().Format(time.DateTime), + } + toIntDay, err := convertor.ToInt(days) + if err != nil { + toIntDay = 90 + } + res := NewStockDataApi().GetHK_KLineData(stockCode, "day", toIntDay) + searchRes, _ := json.Marshal(res) + messages = append(messages, map[string]interface{}{ + "role": "tool", + "content": stockCode + convertor.ToString(toIntDay) + "日K线数据:\n" + string(searchRes) + "\n", + "tool_call_id": currentCallId, + }) + } + AskAiWithTools(o, err, messages, ch, question, tools) } }