diff --git a/backend/data/market_news_api_test.go b/backend/data/market_news_api_test.go index 9f24c16..3770b92 100644 --- a/backend/data/market_news_api_test.go +++ b/backend/data/market_news_api_test.go @@ -3,8 +3,10 @@ package data import ( "encoding/json" "github.com/coocood/freecache" + "github.com/tidwall/gjson" "go-stock/backend/db" "go-stock/backend/logger" + "strings" "testing" ) @@ -140,14 +142,36 @@ func TestInvestCalendar(t *testing.T) { db.Init("../../data/stock.db") res := NewMarketNewsApi().InvestCalendar("2025-06") for _, a := range res { - logger.SugaredLogger.Debugf("value: %+v", a) + bytes, err := json.Marshal(a) + if err != nil { + continue + } + date := gjson.Get(string(bytes), "date") + list := gjson.Get(string(bytes), "list") + + logger.SugaredLogger.Debugf("value: %+v,list: %+v", date.String(), list) } } func TestClsCalendar(t *testing.T) { db.Init("../../data/stock.db") res := NewMarketNewsApi().ClsCalendar() + md := strings.Builder{} for _, a := range res { - logger.SugaredLogger.Debugf("value: %+v", a) + bytes, err := json.Marshal(a) + if err != nil { + continue + } + //logger.SugaredLogger.Debugf("value: %+v", string(bytes)) + date := gjson.Get(string(bytes), "calendar_day") + md.WriteString("\n### 事件/会议日期:" + date.String()) + list := gjson.Get(string(bytes), "items") + //logger.SugaredLogger.Debugf("value: %+v,list: %+v", date.String(), list) + list.ForEach(func(key, value gjson.Result) bool { + logger.SugaredLogger.Debugf("key: %+v,value: %+v", key.String(), gjson.Get(value.String(), "title")) + md.WriteString("\n- " + gjson.Get(value.String(), "title").String()) + return true + }) } + logger.SugaredLogger.Debugf("md:\n %s", md.String()) } diff --git a/backend/data/openai_api.go b/backend/data/openai_api.go index 459792a..65dcc1e 100644 --- a/backend/data/openai_api.go +++ b/backend/data/openai_api.go @@ -173,7 +173,7 @@ func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromp "content": "当前本地时间是:" + time.Now().Format("2006-01-02 15:04:05"), }) wg := &sync.WaitGroup{} - wg.Add(1) + wg.Add(2) go func() { defer wg.Done() var market strings.Builder @@ -190,6 +190,38 @@ func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromp "content": "当前市场指数行情情况如下:\n" + market.String(), }) }() + + go func() { + defer wg.Done() + md := strings.Builder{} + res := NewMarketNewsApi().ClsCalendar() + for _, a := range res { + bytes, err := json.Marshal(a) + if err != nil { + continue + } + //logger.SugaredLogger.Debugf("value: %+v", string(bytes)) + date := gjson.Get(string(bytes), "calendar_day") + md.WriteString("\n### 事件/会议日期:" + date.String()) + list := gjson.Get(string(bytes), "items") + //logger.SugaredLogger.Debugf("value: %+v,list: %+v", date.String(), list) + list.ForEach(func(key, value gjson.Result) bool { + logger.SugaredLogger.Debugf("key: %+v,value: %+v", key.String(), gjson.Get(value.String(), "title")) + md.WriteString("\n- " + gjson.Get(value.String(), "title").String()) + return true + }) + } + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": "近期重大事件/会议", + }) + msg = append(msg, map[string]interface{}{ + "role": "assistant", + "content": "近期重大事件/会议如下:\n" + md.String(), + }) + + }() + wg.Wait() news := NewMarketNewsApi().GetNewsList("财联社电报", random.RandInt(50, 150)) @@ -946,12 +978,42 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c "time": time.Now().Format(time.DateTime), } + content := "无符合条件的数据" res := NewSearchStockApi(words).SearchStock(random.RandInt(5, 10)) - searchRes, _ := json.Marshal(res) - - content := gjson.Get(string(searchRes), "data.result").String() - - //logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content) + if convertor.ToString(res["code"]) == "100" { + resData := res["data"].(map[string]any) + result := resData["result"].(map[string]any) + dataList := result["dataList"].([]any) + columns := result["columns"].([]any) + headers := map[string]string{} + for _, v := range columns { + //logger.SugaredLogger.Infof("v:%+v", v) + d := v.(map[string]any) + //logger.SugaredLogger.Infof("key:%s title:%s dateMsg:%s unit:%s", d["key"], d["title"], d["dateMsg"], d["unit"]) + title := convertor.ToString(d["title"]) + if convertor.ToString(d["dateMsg"]) != "" { + title = title + "[" + convertor.ToString(d["dateMsg"]) + "]" + } + if convertor.ToString(d["unit"]) != "" { + title = title + "(" + convertor.ToString(d["unit"]) + ")" + } + headers[d["key"].(string)] = title + } + table := &[]map[string]any{} + for _, v := range dataList { + d := v.(map[string]any) + tmp := map[string]any{} + for key, title := range headers { + tmp[title] = convertor.ToString(d[key]) + } + *table = append(*table, tmp) + } + jsonData, _ := json.Marshal(*table) + markdownTable, _ := JSONToMarkdownTable(jsonData) + //logger.SugaredLogger.Infof("markdownTable=\n%s", markdownTable) + content = "\r\n### 工具筛选出的股票数据:\r\n" + markdownTable + "\r\n" + } + logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> \n%s", words, content) messages = append(messages, map[string]interface{}{ "role": "assistant", @@ -975,6 +1037,15 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c "tool_call_id": currentCallId, }) + //ch <- map[string]any{ + // "code": 1, + // "question": question, + // "chatId": streamResponse.Id, + // "model": streamResponse.Model, + // "content": "\r\n```\r\n调用工具:SearchStockByIndicators,\n结果:" + content + "\r\n```\r\n", + // "time": time.Now().Format(time.DateTime), + //} + } if funcName == "GetStockKLine" { @@ -992,30 +1063,86 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c if err != nil { toIntDay = 90 } - res := NewStockDataApi().GetHK_KLineData(stockCode, "day", toIntDay) - searchRes, _ := json.Marshal(res) - messages = append(messages, map[string]interface{}{ - "role": "assistant", - "content": currentAIContent.String(), - "tool_calls": []map[string]any{ - { - "id": currentCallId, - "tool_call_id": currentCallId, - "type": "function", - "function": map[string]string{ - "name": funcName, - "arguments": funcArguments, - "parameters": funcArguments, + if strutil.HasPrefixAny(stockCode, []string{"sz", "sh", "hk", "us", "gb_"}) { + K := &[]KLineData{} + if strutil.HasPrefixAny(stockCode, []string{"sz", "sh"}) { + K = NewStockDataApi().GetKLineData(stockCode, "240", o.KDays) + } + if strutil.HasPrefixAny(stockCode, []string{"hk", "us", "gb_"}) { + K = NewStockDataApi().GetHK_KLineData(stockCode, "day", o.KDays) + } + Kmap := &[]map[string]any{} + for _, kline := range *K { + mapk := make(map[string]any, 6) + mapk["日期"] = kline.Day + mapk["开盘价"] = kline.Open + mapk["最高价"] = kline.High + mapk["最低价"] = kline.Low + mapk["收盘价"] = kline.Close + Volume, _ := convertor.ToFloat(kline.Volume) + mapk["成交量(万手)"] = Volume / 10000.00 / 100.00 + *Kmap = append(*Kmap, mapk) + } + jsonData, _ := json.Marshal(Kmap) + markdownTable, _ := JSONToMarkdownTable(jsonData) + logger.SugaredLogger.Infof("getKLineData=\n%s", markdownTable) + + messages = append(messages, map[string]interface{}{ + "role": "assistant", + "content": currentAIContent.String(), + "tool_calls": []map[string]any{ + { + "id": currentCallId, + "tool_call_id": currentCallId, + "type": "function", + "function": map[string]string{ + "name": funcName, + "arguments": funcArguments, + "parameters": funcArguments, + }, }, }, - }, - }) - messages = append(messages, map[string]interface{}{ - "role": "tool", - "content": stockCode + convertor.ToString(toIntDay) + "日K线数据:\n" + string(searchRes) + "\n", - "tool_call_id": currentCallId, - }) + }) + res := "\r\n ### " + stockCode + convertor.ToString(toIntDay) + "日K线数据:\r\n" + markdownTable + "\r\n" + messages = append(messages, map[string]interface{}{ + "role": "tool", + "content": res, + "tool_call_id": currentCallId, + }) + logger.SugaredLogger.Infof("GetStockKLine:stockCode:%s days:%s --> \n%s", stockCode, days, res) + + //ch <- map[string]any{ + // "code": 1, + // "question": question, + // "chatId": streamResponse.Id, + // "model": streamResponse.Model, + // "content": "\r\n```\r\n调用工具:GetStockKLine,\n结果:" + res + "\r\n```\r\n", + // "time": time.Now().Format(time.DateTime), + //} + } else { + messages = append(messages, map[string]interface{}{ + "role": "assistant", + "content": currentAIContent.String(), + "tool_calls": []map[string]any{ + { + "id": currentCallId, + "tool_call_id": currentCallId, + "type": "function", + "function": map[string]string{ + "name": funcName, + "arguments": funcArguments, + "parameters": funcArguments, + }, + }, + }, + }) + messages = append(messages, map[string]interface{}{ + "role": "tool", + "content": "无数据,可能股票代码错误。(A股:sh,sz开头;港股hk开头,美股:us开头)", + "tool_call_id": currentCallId, + }) + } } } diff --git a/backend/data/openai_api_test.go b/backend/data/openai_api_test.go index 8c97c6d..a84704a 100644 --- a/backend/data/openai_api_test.go +++ b/backend/data/openai_api_test.go @@ -30,13 +30,16 @@ func TestNewDeepSeekOpenAiConfig(t *testing.T) { ai := NewDeepSeekOpenAi(context.TODO()) //res := ai.NewChatStream("长电科技", "sh600584", "长电科技分析和总结", nil) - res := ai.NewSummaryStockNewsStreamWithTools("总结市场资讯,发掘潜力标的/行业/板块/概念,控制风险,最后按风险登记生成指标选股策略汇总表,每个策略中的指标分号分隔,写成一行", nil, tools) + res := ai.NewSummaryStockNewsStreamWithTools("总结市场资讯,发掘潜力标的/行业/板块/概念,控制风险。调用工具函数验证", nil, tools) for { select { case msg := <-res: if len(msg) > 0 { t.Log(msg) + if msg["content"] == "DONE" { + return + } } } } diff --git a/backend/data/search_stock_api_test.go b/backend/data/search_stock_api_test.go index 441c382..aa627d0 100644 --- a/backend/data/search_stock_api_test.go +++ b/backend/data/search_stock_api_test.go @@ -1,6 +1,8 @@ package data import ( + "encoding/json" + "github.com/duke-git/lancet/v2/convertor" "go-stock/backend/db" "go-stock/backend/logger" "testing" @@ -13,15 +15,36 @@ func TestSearchStock(t *testing.T) { data := res["data"].(map[string]any) result := data["result"].(map[string]any) dataList := result["dataList"].([]any) - for _, v := range dataList { + columns := result["columns"].([]any) + headers := map[string]string{} + for _, v := range columns { + //logger.SugaredLogger.Infof("v:%+v", v) d := v.(map[string]any) - logger.SugaredLogger.Infof("%s:%s", d["INDUSTRY"], d["SECURITY_SHORT_NAME"]) + //logger.SugaredLogger.Infof("key:%s title:%s dateMsg:%s unit:%s", d["key"], d["title"], d["dateMsg"], d["unit"]) + title := convertor.ToString(d["title"]) + if convertor.ToString(d["dateMsg"]) != "" { + title = title + "[" + convertor.ToString(d["dateMsg"]) + "]" + } + if convertor.ToString(d["unit"]) != "" { + title = title + "(" + convertor.ToString(d["unit"]) + ")" + } + headers[d["key"].(string)] = title } - //columns := result["columns"].([]any) - //for _, v := range columns { - // logger.SugaredLogger.Infof("v:%+v", v) - //} - + table := &[]map[string]any{} + for _, v := range dataList { + //logger.SugaredLogger.Infof("v:%+v", v) + d := v.(map[string]any) + tmp := map[string]any{} + for key, title := range headers { + //logger.SugaredLogger.Infof("%s:%s", title, convertor.ToString(d[key])) + tmp[title] = convertor.ToString(d[key]) + } + *table = append(*table, tmp) + //logger.SugaredLogger.Infof("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------") + } + jsonData, _ := json.Marshal(*table) + markdownTable, _ := JSONToMarkdownTable(jsonData) + logger.SugaredLogger.Infof("markdownTable=\n%s", markdownTable) } func TestSearchStockApi_HotStrategy(t *testing.T) {