mirror of
https://github.com/ArvinLovegood/go-stock.git
synced 2025-07-19 00:00:09 +08:00
feat(data):新增SummaryStockNewsStreamWithTools功能
- 在 OpenAi 结构中添加了新的方法 NewSummaryStockNewsStreamWithTools,支持使用工具进行股票分析 - 在 app.go 中调用了新方法,集成了股票搜索工具- 修改了 SearchStockApi 的 SearchStock 方法,增加了 pageSize 参数 - 更新了相关测试文件以适应新的功能
This commit is contained in:
parent
b945a0e0e1
commit
ebeaf104bb
21
app.go
21
app.go
@ -1129,7 +1129,26 @@ func (a *App) GlobalStockIndexes() map[string]any {
|
||||
}
|
||||
|
||||
func (a *App) SummaryStockNews(question string, sysPromptId *int) {
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStream(question, sysPromptId)
|
||||
var tools []data.Tool
|
||||
tools = append(tools, data.Tool{
|
||||
Type: "function",
|
||||
Function: data.ToolFunction{
|
||||
Name: "SearchStockByIndicators",
|
||||
Description: "按行业根据选股指标或策略,返回符合指标或策略的股票列表。多个行业的筛选需按行业顺序调用多次,不支持并行调用",
|
||||
Parameters: data.FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"words": map[string]any{
|
||||
"type": "string",
|
||||
"description": "行业选股指标或策略,并且条件使用;分隔,或者条件使用,分隔。例如:创新药;PE<30;净利润增长率>50%;",
|
||||
},
|
||||
},
|
||||
Required: []string{"words"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
msgs := data.NewDeepSeekOpenAi(a.ctx).NewSummaryStockNewsStreamWithTools(question, sysPromptId, tools)
|
||||
for msg := range msgs {
|
||||
runtime.EventsEmit(a.ctx, "summaryStockNews", msg)
|
||||
}
|
||||
|
@ -57,5 +57,5 @@ func (a App) ClsCalendar() []any {
|
||||
}
|
||||
|
||||
func (a App) SearchStock(words string) map[string]any {
|
||||
return data.NewSearchStockApi(words).SearchStock()
|
||||
return data.NewSearchStockApi(words).SearchStock(5000)
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/duke-git/lancet/v2/convertor"
|
||||
"github.com/duke-git/lancet/v2/strutil"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/wailsapp/wails/v2/pkg/runtime"
|
||||
"go-stock/backend/db"
|
||||
"go-stock/backend/logger"
|
||||
@ -75,11 +76,12 @@ type THSTokenResponse struct {
|
||||
}
|
||||
|
||||
type AiResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
@ -87,6 +89,19 @@ type AiResponse struct {
|
||||
} `json:"message"`
|
||||
Logprobs interface{} `json:"logprobs"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
ToolCalls []struct {
|
||||
Function struct {
|
||||
Arguments string `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
Id string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
@ -98,6 +113,112 @@ type AiResponse struct {
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
type FunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]any `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters FunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
func (o OpenAi) NewSummaryStockNewsStreamWithTools(userQuestion string, sysPromptId *int, tools []Tool) <-chan map[string]any {
|
||||
ch := make(chan map[string]any, 512)
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.SugaredLogger.Error("NewSummaryStockNewsStream panic", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.SugaredLogger.Errorf("NewSummaryStockNewsStream goroutine panic :%s", err)
|
||||
logger.SugaredLogger.Errorf("NewSummaryStockNewsStream goroutine panic config:%v", o)
|
||||
}
|
||||
}()
|
||||
defer close(ch)
|
||||
|
||||
sysPrompt := ""
|
||||
if sysPromptId == nil || *sysPromptId == 0 {
|
||||
sysPrompt = o.Prompt
|
||||
} else {
|
||||
sysPrompt = NewPromptTemplateApi().GetPromptTemplateByID(*sysPromptId)
|
||||
}
|
||||
if sysPrompt == "" {
|
||||
sysPrompt = o.Prompt
|
||||
}
|
||||
|
||||
msg := []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
//"content": "作为一位专业的A股市场分析师和投资顾问,请你根据以下信息提供详细的技术分析和投资策略建议:",
|
||||
//"content": "【角色设定】\n你是一位拥有20年实战经验的顶级股票分析师,精通技术分析、基本面分析、市场心理学和量化交易。擅长发现成长股、捕捉行业轮动机会,在牛熊市中都能保持稳定收益。你的风格是价值投资与技术择时相结合,注重风险控制。\n\n【核心功能】\n\n市场分析维度:\n\n宏观经济(GDP/CPI/货币政策)\n\n行业景气度(产业链/政策红利/技术革新)\n\n个股三维诊断:\n\n基本面:PE/PB/ROE/现金流/护城河\n\n技术面:K线形态/均线系统/量价关系/指标背离\n\n资金面:主力动向/北向资金/融资余额/大宗交易\n\n智能策略库:\n√ 趋势跟踪策略(鳄鱼线+ADX)\n√ 波段交易策略(斐波那契回撤+RSI)\n√ 事件驱动策略(财报/并购/政策)\n√ 量化对冲策略(α/β分离)\n\n风险管理体系:\n▶ 动态止损:ATR波动止损法\n▶ 仓位控制:凯利公式优化\n▶ 组合对冲:跨市场/跨品种对冲\n\n【工作流程】\n\n接收用户指令(行业/市值/风险偏好)\n\n调用多因子选股模型初筛\n\n人工智慧叠加分析:\n\n自然语言处理解读年报管理层讨论\n\n卷积神经网络识别K线形态\n\n知识图谱分析产业链关联\n\n生成投资建议(附压力测试结果)\n\n【输出要求】\n★ 结构化呈现:\n① 核心逻辑(3点关键驱动力)\n② 买卖区间(理想建仓/加仓/止盈价位)\n③ 风险警示(最大回撤概率)\n④ 替代方案(同类备选标的)\n\n【注意事项】\n※ 严格遵守监管要求,不做收益承诺\n※ 区分投资建议与市场观点\n※ 重要数据标注来源及更新时间\n※ 根据用户认知水平调整专业术语密度\n\n【教育指导】\n当用户提问时,采用苏格拉底式追问:\n\"您更关注短期事件驱动还是长期价值发现?\"\n\"当前仓位是否超过总资产的30%?\"\n\"是否了解科创板与主板的交易规则差异?\"\n\n示例输出格式:\n📈 标的名称:XXXXXX\n⚖️ 多空信号:金叉确认/顶背离预警\n🎯 关键价位:支撑位XX.XX/压力位XX.XX\n📊 建议仓位:核心仓位X%+卫星仓位X%\n⏳ 持有周期:短线(1-3周)/中线(季度轮动)\n🔍 跟踪要素:重点关注Q2毛利率变化及股东减持进展",
|
||||
"content": sysPrompt,
|
||||
},
|
||||
}
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": "当前时间",
|
||||
})
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "当前本地时间是:" + time.Now().Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var market strings.Builder
|
||||
market.WriteString(getZSInfo("创业板指数", "sz399006", 30) + "\n")
|
||||
market.WriteString(getZSInfo("上证综合指数", "sh000001", 30) + "\n")
|
||||
market.WriteString(getZSInfo("沪深300指数", "sh000300", 30) + "\n")
|
||||
//logger.SugaredLogger.Infof("NewChatStream getZSInfo=\n%s", market.String())
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": "当前市场指数行情",
|
||||
})
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "当前市场指数行情情况如下:\n" + market.String(),
|
||||
})
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
news := NewMarketNewsApi().GetNewsList("", 100)
|
||||
messageText := strings.Builder{}
|
||||
for _, telegraph := range *news {
|
||||
messageText.WriteString("## " + telegraph.Time + ":" + "\n")
|
||||
messageText.WriteString("### " + telegraph.Content + "\n")
|
||||
}
|
||||
//logger.SugaredLogger.Infof("市场资讯 messageText=\n%s", messageText.String())
|
||||
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": "市场资讯",
|
||||
})
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": messageText.String(),
|
||||
})
|
||||
if userQuestion == "" {
|
||||
userQuestion = "请根据当前时间,总结和分析股票市场新闻中的投资机会"
|
||||
}
|
||||
msg = append(msg, map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": userQuestion,
|
||||
})
|
||||
AskAiWithTools(o, errors.New(""), msg, ch, userQuestion, tools)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (o OpenAi) NewSummaryStockNewsStream(userQuestion string, sysPromptId *int) <-chan map[string]any {
|
||||
ch := make(chan map[string]any, 512)
|
||||
defer func() {
|
||||
@ -569,7 +690,7 @@ func AskAi(o OpenAi, err error, messages []map[string]interface{}, ch chan map[s
|
||||
scanner := bufio.NewScanner(body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
//logger.SugaredLogger.Infof("Received data: %s", line)
|
||||
logger.SugaredLogger.Infof("Received data: %s", line)
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
data := strutil.Trim(strings.TrimPrefix(line, "data:"))
|
||||
if data == "[DONE]" {
|
||||
@ -657,7 +778,204 @@ func AskAi(o OpenAi, err error, messages []map[string]interface{}, ch chan map[s
|
||||
|
||||
}
|
||||
}
|
||||
func AskAiWithTools(o OpenAi, err error, messages []map[string]interface{}, ch chan map[string]any, question string, tools []Tool) {
|
||||
client := resty.New()
|
||||
client.SetBaseURL(strutil.Trim(o.BaseUrl))
|
||||
client.SetHeader("Authorization", "Bearer "+o.ApiKey)
|
||||
client.SetHeader("Content-Type", "application/json")
|
||||
//client.SetRetryCount(3)
|
||||
if o.TimeOut <= 0 {
|
||||
o.TimeOut = 300
|
||||
}
|
||||
client.SetTimeout(time.Duration(o.TimeOut) * time.Second)
|
||||
resp, err := client.R().
|
||||
SetDoNotParseResponse(true).
|
||||
SetBody(map[string]interface{}{
|
||||
"model": o.Model,
|
||||
"max_tokens": o.MaxTokens,
|
||||
"temperature": o.Temperature,
|
||||
"stream": true,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}).
|
||||
Post("/chat/completions")
|
||||
|
||||
body := resp.RawBody()
|
||||
defer body.Close()
|
||||
if err != nil {
|
||||
logger.SugaredLogger.Infof("Stream error : %s", err.Error())
|
||||
//ch <- err.Error()
|
||||
ch <- map[string]any{
|
||||
"code": 0,
|
||||
"question": question,
|
||||
"content": err.Error(),
|
||||
}
|
||||
return
|
||||
}
|
||||
//location, _ := time.LoadLocation("Asia/Shanghai")
|
||||
|
||||
scanner := bufio.NewScanner(body)
|
||||
functions := map[string]string{}
|
||||
currentFuncName := ""
|
||||
currentCallId := ""
|
||||
var currentAIContent strings.Builder
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
logger.SugaredLogger.Infof("Received data: %s", line)
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
data := strutil.Trim(strings.TrimPrefix(line, "data:"))
|
||||
if data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var streamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Role string `json:"role"`
|
||||
ToolCalls []struct {
|
||||
Function struct {
|
||||
Arguments string `json:"arguments"`
|
||||
Name string `json:"name"`
|
||||
} `json:"function"`
|
||||
Id string `json:"id"`
|
||||
Index int `json:"index"`
|
||||
Type string `json:"type"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(data), &streamResponse); err == nil {
|
||||
for _, choice := range streamResponse.Choices {
|
||||
if content := choice.Delta.Content; content != "" {
|
||||
//ch <- content
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": content,
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
|
||||
//logger.SugaredLogger.Infof("Content data: %s", content)
|
||||
currentAIContent.WriteString(content)
|
||||
}
|
||||
if reasoningContent := choice.Delta.ReasoningContent; reasoningContent != "" {
|
||||
//ch <- reasoningContent
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": reasoningContent,
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
|
||||
//logger.SugaredLogger.Infof("ReasoningContent data: %s", reasoningContent)
|
||||
currentAIContent.WriteString(reasoningContent)
|
||||
|
||||
}
|
||||
if choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0 {
|
||||
for _, call := range choice.Delta.ToolCalls {
|
||||
if call.Type == "function" {
|
||||
functions[call.Function.Name] = ""
|
||||
currentFuncName = call.Function.Name
|
||||
currentCallId = call.Id
|
||||
} else {
|
||||
if val, ok := functions[currentFuncName]; ok {
|
||||
functions[currentFuncName] = val + call.Function.Arguments
|
||||
} else {
|
||||
functions[currentFuncName] = call.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if choice.FinishReason == "tool_calls" {
|
||||
logger.SugaredLogger.Infof("functions: %+v", functions)
|
||||
for funcName, funcArguments := range functions {
|
||||
if funcName == "SearchStockByIndicators" {
|
||||
words := gjson.Get(funcArguments, "words").String()
|
||||
|
||||
ch <- map[string]any{
|
||||
"code": 1,
|
||||
"question": question,
|
||||
"chatId": streamResponse.Id,
|
||||
"model": streamResponse.Model,
|
||||
"content": "- ```开始调用工具:SearchStockByIndicators,\n参数:" + words + "``` " + "\n",
|
||||
"time": time.Now().Format(time.DateTime),
|
||||
}
|
||||
|
||||
res := NewSearchStockApi(words).SearchStock(10)
|
||||
searchRes, _ := json.Marshal(res)
|
||||
|
||||
content := gjson.Get(string(searchRes), "data.result").String()
|
||||
|
||||
logger.SugaredLogger.Infof("SearchStockByIndicators:words:%s --> %s", words, content)
|
||||
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": currentAIContent.String(),
|
||||
})
|
||||
messages = append(messages, map[string]interface{}{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": currentCallId,
|
||||
})
|
||||
|
||||
}
|
||||
AskAiWithTools(o, err, messages, ch, question, tools)
|
||||
}
|
||||
}
|
||||
|
||||
if choice.FinishReason == "stop" {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
logger.SugaredLogger.Infof("Stream data error : %s", err.Error())
|
||||
//ch <- err.Error()
|
||||
ch <- map[string]any{
|
||||
"code": 0,
|
||||
"question": question,
|
||||
"content": err.Error(),
|
||||
}
|
||||
} else {
|
||||
logger.SugaredLogger.Infof("Stream data error : %s", data)
|
||||
//ch <- data
|
||||
ch <- map[string]any{
|
||||
"code": 0,
|
||||
"question": question,
|
||||
"content": data,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if strutil.RemoveNonPrintable(line) != "" {
|
||||
logger.SugaredLogger.Infof("Stream data error : %s", line)
|
||||
res := &models.Resp{}
|
||||
if err := json.Unmarshal([]byte(line), res); err == nil {
|
||||
//ch <- line
|
||||
ch <- map[string]any{
|
||||
"code": 0,
|
||||
"question": question,
|
||||
"content": res.Message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
func checkIsIndexBasic(stock string) bool {
|
||||
count := int64(0)
|
||||
db.Dao.Model(&IndexBasic{}).Where("name = ?", stock).Count(&count)
|
||||
|
@ -8,8 +8,30 @@ import (
|
||||
|
||||
func TestNewDeepSeekOpenAiConfig(t *testing.T) {
|
||||
db.Init("../../data/stock.db")
|
||||
|
||||
var tools []Tool
|
||||
tools = append(tools, Tool{
|
||||
Type: "function",
|
||||
Function: ToolFunction{
|
||||
Name: "SearchStockByIndicators",
|
||||
Description: "通过解析自然语言,形成选股指标或策略,返回符合指标或策略的股票列表",
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"words": map[string]any{
|
||||
"type": "string",
|
||||
"description": "选股指标或策略的自然语言",
|
||||
},
|
||||
},
|
||||
Required: []string{"words"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ai := NewDeepSeekOpenAi(context.TODO())
|
||||
res := ai.NewChatStream("长电科技", "sh600584", "长电科技分析和总结", nil)
|
||||
//res := ai.NewChatStream("长电科技", "sh600584", "长电科技分析和总结", nil)
|
||||
res := ai.NewSummaryStockNewsStreamWithTools("总结市场资讯,发掘潜力标的/行业/板块/概念,控制风险,最后按风险登记生成指标选股策略汇总表,每个策略中的指标分号分隔,写成一行", nil, tools)
|
||||
|
||||
for {
|
||||
select {
|
||||
case msg := <-res:
|
||||
|
@ -19,7 +19,7 @@ type SearchStockApi struct {
|
||||
func NewSearchStockApi(words string) *SearchStockApi {
|
||||
return &SearchStockApi{words: words}
|
||||
}
|
||||
func (s SearchStockApi) SearchStock() map[string]any {
|
||||
func (s SearchStockApi) SearchStock(pageSize int) map[string]any {
|
||||
url := "https://np-tjxg-g.eastmoney.com/api/smart-tag/stock/v3/pw/search-code"
|
||||
resp, err := resty.New().SetTimeout(time.Duration(30)*time.Second).R().
|
||||
SetHeader("Host", "np-tjxg-g.eastmoney.com").
|
||||
@ -29,7 +29,7 @@ func (s SearchStockApi) SearchStock() map[string]any {
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(fmt.Sprintf(`{
|
||||
"keyWord": "%s",
|
||||
"pageSize": 50000,
|
||||
"pageSize": %d,
|
||||
"pageNo": 1,
|
||||
"fingerprint": "e38b5faabf9378c8238e57219f0ebc9b",
|
||||
"gids": [],
|
||||
@ -43,7 +43,7 @@ func (s SearchStockApi) SearchStock() map[string]any {
|
||||
"ownSelectAll": false,
|
||||
"dxInfo": [],
|
||||
"extraCondition": ""
|
||||
}`, s.words)).Post(url)
|
||||
}`, s.words, pageSize)).Post(url)
|
||||
if err != nil {
|
||||
logger.SugaredLogger.Errorf("SearchStock-err:%+v", err)
|
||||
return map[string]any{}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
func TestSearchStock(t *testing.T) {
|
||||
db.Init("../../data/stock.db")
|
||||
|
||||
res := NewSearchStockApi("算力股;净利润连续3年增长").SearchStock()
|
||||
res := NewSearchStockApi("算力股;净利润连续3年增长").SearchStock(10)
|
||||
data := res["data"].(map[string]any)
|
||||
result := data["result"].(map[string]any)
|
||||
dataList := result["dataList"].([]any)
|
||||
|
Loading…
x
Reference in New Issue
Block a user