mirror of
https://github.com/ArvinLovegood/go-stock.git
synced 2025-07-19 00:00:09 +08:00
feat(app):集成AI工具并优化股票数据获取
- 在 App 结构中添加 AiTools 字段,用于存储 AI 工具配置 - 新增 AddTools 函数,定义了两个 AI 工具:SearchStockByIndicators 和 GetStockKLine- 修改 NewApp 函数,初始化时加载 AI 工具配置- 更新相关函数,支持使用 AI 工具进行股票数据查询- 优化股票 K 线数据获取逻辑,增加对不同市场股票代码的支持
This commit is contained in:
parent
888a97e4d3
commit
6d345ae91d
73
app.go
73
app.go
@ -36,6 +36,7 @@ type App struct {
|
||||
cache *freecache.Cache
|
||||
cron *cron.Cron
|
||||
cronEntrys map[string]cron.EntryID
|
||||
AiTools []data.Tool
|
||||
}
|
||||
|
||||
// NewApp creates a new App application struct
|
||||
@ -44,13 +45,60 @@ func NewApp() *App {
|
||||
cache := freecache.NewCache(cacheSize)
|
||||
c := cron.New(cron.WithSeconds())
|
||||
c.Start()
|
||||
var tools []data.Tool
|
||||
tools = AddTools(tools)
|
||||
return &App{
|
||||
cache: cache,
|
||||
cron: c,
|
||||
cronEntrys: make(map[string]cron.EntryID),
|
||||
AiTools: tools,
|
||||
}
|
||||
}
|
||||
|
||||
func AddTools(tools []data.Tool) []data.Tool {
|
||||
tools = append(tools, data.Tool{
|
||||
Type: "function",
|
||||
Function: data.ToolFunction{
|
||||
Name: "SearchStockByIndicators",
|
||||
Description: "根据自然语言筛选股票,返回自然语言选股条件要求的股票所有相关数据。单独输入股票名称可以获取当前股票最新的股价交易数据和基础财务指标信息",
|
||||
Parameters: data.FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"words": map[string]any{
|
||||
"type": "string",
|
||||
"description": "选股自然语言,并且条件使用;分隔,或者条件使用,分隔。例1:创新药;PE<30;净利润增长率>50%。 例2:上证指数(指数名称)。 例3:长电科技(股票名称)",
|
||||
},
|
||||
},
|
||||
Required: []string{"words"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
tools = append(tools, data.Tool{
|
||||
Type: "function",
|
||||
Function: data.ToolFunction{
|
||||
Name: "GetStockKLine",
|
||||
Description: "获取股票日K线数据",
|
||||
Parameters: data.FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"days": map[string]any{
|
||||
"type": "string",
|
||||
"description": "日K数据条数",
|
||||
},
|
||||
"stockCode": map[string]any{
|
||||
"type": "string",
|
||||
"description": "股票代码(A股:sh,sz开头;港股hk开头,美股:us开头)",
|
||||
},
|
||||
},
|
||||
Required: []string{"days", "stockCode"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
// startup is called at application startup
|
||||
func (a *App) startup(ctx context.Context) {
|
||||
defer PanicHandler()
|
||||
@ -311,7 +359,7 @@ func (a *App) AddCronTask(follow data.FollowedStock) func() {
|
||||
return func() {
|
||||
go runtime.EventsEmit(a.ctx, "warnMsg", "开始自动分析"+follow.Name+"_"+follow.StockCode)
|
||||
ai := data.NewDeepSeekOpenAi(a.ctx)
|
||||
msgs := ai.NewChatStream(follow.Name, follow.StockCode, "", nil)
|
||||
msgs := ai.NewChatStream(follow.Name, follow.StockCode, "", nil, a.AiTools)
|
||||
var res strings.Builder
|
||||
|
||||
chatId := ""
|
||||
@ -748,7 +796,7 @@ func (a *App) SendDingDingMessageByType(message string, stockCode string, msgTyp
|
||||
}
|
||||
|
||||
func (a *App) NewChatStream(stock, stockCode, question string, sysPromptId *int) {
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewChatStream(stock, stockCode, question, sysPromptId)
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewChatStream(stock, stockCode, question, sysPromptId, a.AiTools)
|
||||
for msg := range msgs {
|
||||
runtime.EventsEmit(a.ctx, "newChatStream", msg)
|
||||
}
|
||||
@ -1129,26 +1177,7 @@ func (a *App) GlobalStockIndexes() map[string]any {
|
||||
}
|
||||
|
||||
func (a *App) SummaryStockNews(question string, sysPromptId *int) {
|
||||
var tools []data.Tool
|
||||
tools = append(tools, data.Tool{
|
||||
Type: "function",
|
||||
Function: data.ToolFunction{
|
||||
Name: "SearchStockByIndicators",
|
||||
Description: "根据自然语言筛选股票,返回自然语言选股条件要求的股票所有相关数据。单独输入股票名称可以获取当前股票最新的股价交易数据和基础财务指标信息",
|
||||
Parameters: data.FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"words": map[string]any{
|
||||
"type": "string",
|
||||
"description": "选股自然语言,并且条件使用;分隔,或者条件使用,分隔。例1:创新药;PE<30;净利润增长率>50%。 例2:上证指数(指数名称)。 例3:长电科技(股票名称)",
|
||||
},
|
||||
},
|
||||
Required: []string{"words"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, tools)
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, a.AiTools)
|
||||
for msg := range msgs {
|
||||
runtime.EventsEmit(a.ctx, "summaryStockNews", msg)
|
||||
}
|
||||
|
@ -191,7 +191,7 @@ func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromp
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
news := NewMarketNewsApi().GetNewsList("", 100)
|
||||
news := NewMarketNewsApi().GetNewsList("财联社电报", 500)
|
||||
messageText := strings.Builder{}
|
||||
for _, telegraph := range *news {
|
||||
messageText.WriteString("## " + telegraph.Time + ":" + "\n")
|
||||
@ -310,7 +310,7 @@ func (o OpenAi) NewSummaryStockNewsStream(userQuestion string, sysPromptId *int)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId *int) <-chan map[string]any {
|
||||
func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId *int, tools []Tool) <-chan map[string]any {
|
||||
ch := make(chan map[string]any, 512)
|
||||
|
||||
defer func() {
|
||||
@ -647,7 +647,11 @@ func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId
|
||||
|
||||
//reqJson, _ := json.Marshal(msg)
|
||||
//logger.SugaredLogger.Errorf("Stream request: \n%s\n", reqJson)
|
||||
AskAi(o, err, msg, ch, question)
|
||||
if tools != nil && len(tools) > 0 {
|
||||
AskAiWithTools(o, err, msg, ch, question, tools)
|
||||
} else {
|
||||
AskAi(o, err, msg, ch, question)
|
||||
}
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
@ -859,17 +863,30 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c
|
||||
for _, choice := range streamResponse.Choices {
|
||||
if content := choice.Delta.Content; content != "" {
|
||||
//ch <- content
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": content,
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
//logger.SugaredLogger.Infof("Content data: %s", content)
|
||||
|
||||
if content == "###" {
|
||||
currentAIContent.WriteString("\r\n" + content)
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": "\r\n" + content,
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
} else {
|
||||
currentAIContent.WriteString(content)
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": content,
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
}
|
||||
|
||||
//logger.SugaredLogger.Infof("Content data: %s", content)
|
||||
currentAIContent.WriteString(content)
|
||||
}
|
||||
if reasoningContent := choice.Delta.ReasoningContent; reasoningContent != "" {
|
||||
//ch <- reasoningContent
|
||||
@ -913,16 +930,16 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": "- ```开始调用工具:SearchStockByIndicators,\n参数:" + words + "``` " + "\n",
|
||||
"content": "\r\n```\r\n开始调用工具:SearchStockByIndicators,\n参数:" + words + "\r\n```\r\n",
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
|
||||
res := NewSearchStockApi(words).SearchStock(10)
|
||||
res := NewSearchStockApi(words).SearchStock(50)
|
||||
searchRes, _ := json.Marshal(res)
|
||||
|
||||
content := gjson.Get(string(searchRes), "data.result").String()
|
||||
|
||||
logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content)
|
||||
//logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content)
|
||||
|
||||
//messages = append(messages, map[string]interface{}{
|
||||
// "role": "assistant",
|
||||
@ -935,6 +952,31 @@ func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch c
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
if funcName == "GetStockKLine" {
|
||||
stockCode := gjson.Get(funcArguments, "stockCode").String()
|
||||
days := gjson.Get(funcArguments, "days").String()
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": "\r\n```\r\n开始调用工具:GetStockKLine,\n参数:" + stockCode + "," + days + "\r\n```\r\n",
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
toIntDay, err := convertor.ToInt(days)
|
||||
if err != nil {
|
||||
toIntDay = 90
|
||||
}
|
||||
res := NewStockDataApi().GetHK_KLineData(stockCode, "day", toIntDay)
|
||||
searchRes, _ := json.Marshal(res)
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": "tool",
|
||||
"content": stockCode + convertor.ToString(toIntDay) + "日K线数据:\n" + string(searchRes) + "\n",
|
||||
"tool_call_id": currentCallId,
|
||||
})
|
||||
}
|
||||
|
||||
AskAiWithTools(o, err, messages, ch, question, tools)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user