From ebeaf104bbe90eed4c9611f5713b8a1f55fa2b6f Mon Sep 17 00:00:00 2001 From: ArvinLovegood Date: Tue, 1 Jul 2025 19:27:59 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=E6=96=B0=E5=A2=9ESummaryStockNewsSt?= =?UTF-8?q?reamWithTools=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 OpenAi 结构中添加了新的方法 NewSummaryStockNewsStreamWithTools,支持使用工具进行股票分析 - 在 app.go 中调用了新方法,集成了股票搜索工具- 修改了 SearchStockApi 的 SearchStock 方法,增加了 pageSize 参数 - 更新了相关测试文件以适应新的功能 --- app.go | 21 +- app_common.go | 2 +- backend/data/openai_api.go | 330 +++++++++++++++++++++++++- backend/data/openai_api_test.go | 24 +- backend/data/search_stock_api.go | 6 +- backend/data/search_stock_api_test.go | 2 +- 6 files changed, 372 insertions(+), 13 deletions(-) diff --git a/app.go b/app.go index f3331e8..7041a5c 100644 --- a/app.go +++ b/app.go @@ -1129,7 +1129,26 @@ func (a *App) GlobalStockIndexes() map[string]any { } func (a *App) SummaryStockNews(question string, sysPromptId *int) { - msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStream(question, sysPromptId) + var tools []data.Tool + tools = append(tools, data.Tool{ + Type: "function", + Function: data.ToolFunction{ + Name: "SearchStockByIndicators", + Description: "按行业根据选股指标或策略,返回符合指标或策略的股票列表。多个行业的筛选需按行业顺序调用多次,不支持并行调用", + Parameters: data.FunctionParameters{ + Type: "object", + Properties: map[string]any{ + "words": map[string]any{ + "type": "string", + "description": "行业选股指标或策略,并且条件使用;分隔,或者条件使用,分隔。例如:创新药;PE<30;净利润增长率>50%;", + }, + }, + Required: []string{"words"}, + }, + }, + }) + + msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, tools) for msg := range msgs { runtime.EventsEmit(a.ctx, "summaryStockNews", msg) } diff --git a/app_common.go b/app_common.go index a82ed4b..4f0f971 100644 --- a/app_common.go +++ b/app_common.go @@ -57,5 +57,5 @@ func (a App) ClsCalendar() []any { } func (a App) SearchStock(words string) map[string]any { - return data.NewSearchStockApi(words).SearchStock() + return data.NewSearchStockApi(words).SearchStock(5000) } diff --git a/backend/data/openai_api.go b/backend/data/openai_api.go index f93ad5a..e1e8689 100644 --- a/backend/data/openai_api.go +++ b/backend/data/openai_api.go @@ -11,6 +11,7 @@ import ( "github.com/duke-git/lancet/v2/convertor" "github.com/duke-git/lancet/v2/strutil" "github.com/go-resty/resty/v2" + "github.com/tidwall/gjson" "github.com/wailsapp/wails/v2/pkg/runtime" "go-stock/backend/db" "go-stock/backend/logger" @@ -75,11 +76,12 @@ type THSTokenResponse struct { } type AiResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + ServiceTier string `json:"service_tier"` + Choices []struct { Index int `json:"index"` Message struct { Role string `json:"role"` @@ -87,6 +89,19 @@ type AiResponse struct { } `json:"message"` Logprobs interface{} `json:"logprobs"` FinishReason string `json:"finish_reason"` + Delta struct { + Content string `json:"content"` + Role string `json:"role"` + ToolCalls []struct { + Function struct { + Arguments string `json:"arguments"` + Name string `json:"name"` + } `json:"function"` + Id string `json:"id"` + Index int `json:"index"` + Type string `json:"type"` + } `json:"tool_calls"` + } `json:"delta"` } `json:"choices"` Usage struct { PromptTokens int `json:"prompt_tokens"` @@ -98,6 +113,112 @@ type AiResponse struct { SystemFingerprint string `json:"system_fingerprint"` } +type Tool struct { + Type string `json:"type"` + Function ToolFunction `json:"function"` +} +type FunctionParameters struct { + Type string `json:"type"` + Properties map[string]any `json:"properties"` + Required []string `json:"required"` +} +type ToolFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters FunctionParameters `json:"parameters"` +} + +func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromptId *int, tools []Tool) <-chan map[string]any { + ch := make(chan map[string]any, 512) + defer func() { + if err := recover(); err != nil { + logger.SugaredLogger.Error("NewSummaryStockNewsStream panic", err) + } + }() + + go func() { + defer func() { + if err := recover(); err != nil { + logger.SugaredLogger.Errorf("NewSummaryStockNewsStream goroutine panic :%s", err) + logger.SugaredLogger.Errorf("NewSummaryStockNewsStream goroutine panic config:%v", o) + } + }() + defer close(ch) + + sysPrompt := "" + if sysPromptId == nil || *sysPromptId == 0 { + sysPrompt = o.Prompt + } else { + sysPrompt = NewPromptTemplateApi().GetPromptTemplateByID(*sysPromptId) + } + if sysPrompt == "" { + sysPrompt = o.Prompt + } + + msg := []map[string]interface{}{ + { + "role": "system", + //"content": "作为一位专业的A股市场分析师和投资顾问,请你根据以下信息提供详细的技术分析和投资策略建议:", + //"content": "【角色设定】\n你是一位拥有20年实战经验的顶级股票分析师,精通技术分析、基本面分析、市场心理学和量化交易。擅长发现成长股、捕捉行业轮动机会,在牛熊市中都能保持稳定收益。你的风格是价值投资与技术择时相结合,注重风险控制。\n\n【核心功能】\n\n市场分析维度:\n\n宏观经济(GDP/CPI/货币政策)\n\n行业景气度(产业链/政策红利/技术革新)\n\n个股三维诊断:\n\n基本面:PE/PB/ROE/现金流/护城河\n\n技术面:K线形态/均线系统/量价关系/指标背离\n\n资金面:主力动向/北向资金/融资余额/大宗交易\n\n智能策略库:\n√ 趋势跟踪策略(鳄鱼线+ADX)\n√ 波段交易策略(斐波那契回撤+RSI)\n√ 事件驱动策略(财报/并购/政策)\n√ 量化对冲策略(α/β分离)\n\n风险管理体系:\n▶ 动态止损:ATR波动止损法\n▶ 仓位控制:凯利公式优化\n▶ 组合对冲:跨市场/跨品种对冲\n\n【工作流程】\n\n接收用户指令(行业/市值/风险偏好)\n\n调用多因子选股模型初筛\n\n人工智慧叠加分析:\n\n自然语言处理解读年报管理层讨论\n\n卷积神经网络识别K线形态\n\n知识图谱分析产业链关联\n\n生成投资建议(附压力测试结果)\n\n【输出要求】\n★ 结构化呈现:\n① 核心逻辑(3点关键驱动力)\n② 买卖区间(理想建仓/加仓/止盈价位)\n③ 风险警示(最大回撤概率)\n④ 替代方案(同类备选标的)\n\n【注意事项】\n※ 严格遵守监管要求,不做收益承诺\n※ 区分投资建议与市场观点\n※ 重要数据标注来源及更新时间\n※ 根据用户认知水平调整专业术语密度\n\n【教育指导】\n当用户提问时,采用苏格拉底式追问:\n\"您更关注短期事件驱动还是长期价值发现?\"\n\"当前仓位是否超过总资产的30%?\"\n\"是否了解科创板与主板的交易规则差异?\"\n\n示例输出格式:\n📈 标的名称:XXXXXX\n⚖️ 多空信号:金叉确认/顶背离预警\n🎯 关键价位:支撑位XX.XX/压力位XX.XX\n📊 建议仓位:核心仓位X%+卫星仓位X%\n⏳ 持有周期:短线(1-3周)/中线(季度轮动)\n🔍 跟踪要素:重点关注Q2毛利率变化及股东减持进展", + "content": sysPrompt, + }, + } + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": "当前时间", + }) + msg = append(msg, map[string]interface{}{ + "role": "assistant", + "content": "当前本地时间是:" + time.Now().Format("2006-01-02 15:04:05"), + }) + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + var market strings.Builder + market.WriteString(getZSInfo("创业板指数", "sz399006", 30) + "\n") + market.WriteString(getZSInfo("上证综合指数", "sh000001", 30) + "\n") + market.WriteString(getZSInfo("沪深300指数", "sh000300", 30) + "\n") + //logger.SugaredLogger.Infof("NewChatStream getZSInfo=\n%s", market.String()) + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": "当前市场指数行情", + }) + msg = append(msg, map[string]interface{}{ + "role": "assistant", + "content": "当前市场指数行情情况如下:\n" + market.String(), + }) + }() + wg.Wait() + + news := NewMarketNewsApi().GetNewsList("", 100) + messageText := strings.Builder{} + for _, telegraph := range *news { + messageText.WriteString("## " + telegraph.Time + ":" + "\n") + messageText.WriteString("### " + telegraph.Content + "\n") + } + //logger.SugaredLogger.Infof("市场资讯 messageText=\n%s", messageText.String()) + + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": "市场资讯", + }) + msg = append(msg, map[string]interface{}{ + "role": "assistant", + "content": messageText.String(), + }) + if userQuestion == "" { + userQuestion = "请根据当前时间,总结和分析股票市场新闻中的投资机会" + } + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": userQuestion, + }) + AskAiWithTools(o, errors.New(""), msg, ch, userQuestion, tools) + }() + return ch +} + func (o OpenAi) NewSummaryStockNewsStream(userQuestion string, sysPromptId *int) <-chan map[string]any { ch := make(chan map[string]any, 512) defer func() { @@ -569,7 +690,7 @@ func AskAi(o OpenAi, err error, messages []map[string]interface{}, ch chan map[s scanner := bufio.NewScanner(body) for scanner.Scan() { line := scanner.Text() - //logger.SugaredLogger.Infof("Received data: %s", line) + logger.SugaredLogger.Infof("Received data: %s", line) if strings.HasPrefix(line, "data:") { data := strutil.Trim(strings.TrimPrefix(line, "data:")) if data == "[DONE]" { @@ -657,7 +778,204 @@ func AskAi(o OpenAi, err error, messages []map[string]interface{}, ch chan map[s } } +func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch chan map[string]any, question string, tools []Tool) { + client := resty.New() + client.SetBaseURL(strutil.Trim(o.BaseUrl)) + client.SetHeader("Authorization", "Bearer "+o.ApiKey) + client.SetHeader("Content-Type", "application/json") + //client.SetRetryCount(3) + if o.TimeOut <= 0 { + o.TimeOut = 300 + } + client.SetTimeout(time.Duration(o.TimeOut) * time.Second) + resp, err := client.R(). + SetDoNotParseResponse(true). + SetBody(map[string]interface{}{ + "model": o.Model, + "max_tokens": o.MaxTokens, + "temperature": o.Temperature, + "stream": true, + "messages": messages, + "tools": tools, + }). + Post("/chat/completions") + body := resp.RawBody() + defer body.Close() + if err != nil { + logger.SugaredLogger.Infof("Stream error : %s", err.Error()) + //ch <- err.Error() + ch <- map[string]any{ + "code": 0, + "question": question, + "content": err.Error(), + } + return + } + //location, _ := time.LoadLocation("Asia/Shanghai") + + scanner := bufio.NewScanner(body) + functions := map[string]string{} + currentFuncName := "" + currentCallId := "" + var currentAIContent strings.Builder + + for scanner.Scan() { + line := scanner.Text() + logger.SugaredLogger.Infof("Received data: %s", line) + if strings.HasPrefix(line, "data:") { + data := strutil.Trim(strings.TrimPrefix(line, "data:")) + if data == "[DONE]" { + return + } + + var streamResponse struct { + Id string `json:"id"` + Model string `json:"model"` + Choices []struct { + Delta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Role string `json:"role"` + ToolCalls []struct { + Function struct { + Arguments string `json:"arguments"` + Name string `json:"name"` + } `json:"function"` + Id string `json:"id"` + Index int `json:"index"` + Type string `json:"type"` + } `json:"tool_calls"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + } + + if err := json.Unmarshal([]byte(data), &streamResponse); err == nil { + for _, choice := range streamResponse.Choices { + if content := choice.Delta.Content; content != "" { + //ch <- content + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": content, + "time": time.Now().Format(time.DateTime), + } + + //logger.SugaredLogger.Infof("Content data: %s", content) + currentAIContent.WriteString(content) + } + if reasoningContent := choice.Delta.ReasoningContent; reasoningContent != "" { + //ch <- reasoningContent + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": reasoningContent, + "time": time.Now().Format(time.DateTime), + } + + //logger.SugaredLogger.Infof("ReasoningContent data: %s", reasoningContent) + currentAIContent.WriteString(reasoningContent) + + } + if choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0 { + for _, call := range choice.Delta.ToolCalls { + if call.Type == "function" { + functions[call.Function.Name] = "" + currentFuncName = call.Function.Name + currentCallId = call.Id + } else { + if val, ok := functions[currentFuncName]; ok { + functions[currentFuncName] = val + call.Function.Arguments + } else { + functions[currentFuncName] = call.Function.Arguments + } + } + } + } + + if choice.FinishReason == "tool_calls" { + logger.SugaredLogger.Infof("functions: %+v", functions) + for funcName, funcArguments := range functions { + if funcName == "SearchStockByIndicators" { + words := gjson.Get(funcArguments, "words").String() + + ch <- map[string]any{ + "code": 1, + "question": question, + "chatId": streamResponse.Id, + "model": streamResponse.Model, + "content": "- ```开始调用工具:SearchStockByIndicators,\n参数:" + words + "``` " + "\n", + "time": time.Now().Format(time.DateTime), + } + + res := NewSearchStockApi(words).SearchStock(10) + searchRes, _ := json.Marshal(res) + + content := gjson.Get(string(searchRes), "data.result").String() + + logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content) + + messages = append(messages, map[string]interface{}{ + "role": "assistant", + "content": currentAIContent.String(), + }) + messages = append(messages, map[string]interface{}{ + "role": "tool", + "content": content, + "tool_call_id": currentCallId, + }) + + } + AskAiWithTools(o, err, messages, ch, question, tools) + } + } + + if choice.FinishReason == "stop" { + return + } + } + } else { + if err != nil { + logger.SugaredLogger.Infof("Stream data error : %s", err.Error()) + //ch <- err.Error() + ch <- map[string]any{ + "code": 0, + "question": question, + "content": err.Error(), + } + } else { + logger.SugaredLogger.Infof("Stream data error : %s", data) + //ch <- data + ch <- map[string]any{ + "code": 0, + "question": question, + "content": data, + } + } + } + } else { + if strutil.RemoveNonPrintable(line) != "" { + logger.SugaredLogger.Infof("Stream data error : %s", line) + res := &models.Resp{} + if err := json.Unmarshal([]byte(line), res); err == nil { + //ch <- line + ch <- map[string]any{ + "code": 0, + "question": question, + "content": res.Message, + } + } + } + + } + + } +} func checkIsIndexBasic(stock string) bool { count := int64(0) db.Dao.Model(&IndexBasic{}).Where("name = ?", stock).Count(&count) diff --git a/backend/data/openai_api_test.go b/backend/data/openai_api_test.go index 46880ed..3856425 100644 --- a/backend/data/openai_api_test.go +++ b/backend/data/openai_api_test.go @@ -8,8 +8,30 @@ import ( func TestNewDeepSeekOpenAiConfig(t *testing.T) { db.Init("../../data/stock.db") + + var tools []Tool + tools = append(tools, Tool{ + Type: "function", + Function: ToolFunction{ + Name: "SearchStockByIndicators", + Description: "通过解析自然语言,形成选股指标或策略,返回符合指标或策略的股票列表", + Parameters: FunctionParameters{ + Type: "object", + Properties: map[string]any{ + "words": map[string]any{ + "type": "string", + "description": "选股指标或策略的自然语言", + }, + }, + Required: []string{"words"}, + }, + }, + }) + ai := NewDeepSeekOpenAi(context.TODO()) - res := ai.NewChatStream("长电科技", "sh600584", "长电科技分析和总结", nil) + //res := ai.NewChatStream("长电科技", "sh600584", "长电科技分析和总结", nil) + res := ai.NewSummaryStockNewsStreamWithTools("总结市场资讯,发掘潜力标的/行业/板块/概念,控制风险,最后按风险登记生成指标选股策略汇总表,每个策略中的指标分号分隔,写成一行", nil, tools) + for { select { case msg := <-res: diff --git a/backend/data/search_stock_api.go b/backend/data/search_stock_api.go index cea7d58..a07e81d 100644 --- a/backend/data/search_stock_api.go +++ b/backend/data/search_stock_api.go @@ -19,7 +19,7 @@ type SearchStockApi struct { func NewSearchStockApi(words string) *SearchStockApi { return &SearchStockApi{words: words} } -func (s SearchStockApi) SearchStock() map[string]any { +func (s SearchStockApi) SearchStock(pageSize int) map[string]any { url := "https://np-tjxg-g.eastmoney.com/api/smart-tag/stock/v3/pw/search-code" resp, err := resty.New().SetTimeout(time.Duration(30)*time.Second).R(). SetHeader("Host", "np-tjxg-g.eastmoney.com"). @@ -29,7 +29,7 @@ func (s SearchStockApi) SearchStock() map[string]any { SetHeader("Content-Type", "application/json"). SetBody(fmt.Sprintf(`{ "keyWord": "%s", - "pageSize": 50000, + "pageSize": %d, "pageNo": 1, "fingerprint": "e38b5faabf9378c8238e57219f0ebc9b", "gids": [], @@ -43,7 +43,7 @@ func (s SearchStockApi) SearchStock() map[string]any { "ownSelectAll": false, "dxInfo": [], "extraCondition": "" - }`, s.words)).Post(url) + }`, s.words, pageSize)).Post(url) if err != nil { logger.SugaredLogger.Errorf("SearchStock-err:%+v", err) return map[string]any{} diff --git a/backend/data/search_stock_api_test.go b/backend/data/search_stock_api_test.go index 4ca5646..5102f3b 100644 --- a/backend/data/search_stock_api_test.go +++ b/backend/data/search_stock_api_test.go @@ -9,7 +9,7 @@ import ( func TestSearchStock(t *testing.T) { db.Init("../../data/stock.db") - res := NewSearchStockApi("算力股;净利润连续3年增长").SearchStock() + res := NewSearchStockApi("算力股;净利润连续3年增长").SearchStock(10) data := res["data"].(map[string]any) result := data["result"].(map[string]any) dataList := result["dataList"].([]any)