From 7b625e2e80e80ffa2dbcfdb10739ea9bac4bd020 Mon Sep 17 00:00:00 2001 From: ArvinLovegood Date: Mon, 31 Mar 2025 14:49:44 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):AI=E5=88=86=E6=9E=90=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=A4=A7=E7=9B=98=E6=8C=87=E6=95=B0=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 getZSInfo 函数,用于获取指定股票代码的大盘指数信息 - 在处理用户问题时添加大盘指数信息查询功能 - 优化了代码结构,提高了可维护性 --- backend/data/openai_api.go | 15 ++++++++++++-- backend/data/stock_data_api.go | 31 +++++++++++++++++++++++++++++ backend/data/stock_data_api_test.go | 7 +++++-- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/backend/data/openai_api.go b/backend/data/openai_api.go index c274ac5..2b805cc 100644 --- a/backend/data/openai_api.go +++ b/backend/data/openai_api.go @@ -159,7 +159,19 @@ func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string, sysPromptId logger.SugaredLogger.Infof("final question:%s", question) wg := &sync.WaitGroup{} - wg.Add(6) + wg.Add(7) + + go func() { + defer wg.Done() + var market strings.Builder + market.WriteString(getZSInfo("创业板指数", "sz399006", 30) + "\n") + market.WriteString(getZSInfo("上证综合指数", "sh000001", 30) + "\n") + market.WriteString(getZSInfo("沪深300指数", "sh000300", 30) + "\n") + msg = append(msg, map[string]interface{}{ + "role": "user", + "content": "大盘指数情况如下:\n" + market.String(), + }) + }() go func() { defer wg.Done() @@ -517,7 +529,6 @@ func GetFinancialReports(stockCode string, crawlTimeOut int64) *[]string { logger.SugaredLogger.Infof("GetFinancialReports搜索股票-%s: %s", stockCode, url) - db.Init("../../data/stock.db") crawlerAPI := CrawlerApi{} crawlerBaseInfo := CrawlerBaseInfo{ Name: "TestCrawler", diff --git a/backend/data/stock_data_api.go b/backend/data/stock_data_api.go index 1bcb81e..76ede50 100644 --- a/backend/data/stock_data_api.go +++ b/backend/data/stock_data_api.go @@ -891,6 +891,37 @@ func getHKStockPriceInfo(stockCode string, crawlTimeOut int64) *[]string { return &messages } +func getZSInfo(name, stockCode string, crawlTimeOut int64) string { + url := "https://finance.sina.com.cn/realstock/company/" + stockCode + "/nc.shtml" + crawlerAPI := CrawlerApi{} + crawlerBaseInfo := CrawlerBaseInfo{ + Name: "TestCrawler", + Description: "Test Crawler Description", + BaseUrl: "https://finance.sina.com.cn", + Headers: map[string]string{"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36 Edg/133.0.0.0"}, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(crawlTimeOut)*time.Second) + defer cancel() + crawlerAPI = crawlerAPI.NewCrawler(ctx, crawlerBaseInfo) + html, ok := crawlerAPI.GetHtml(url, "div#hqDetails table", true) + if !ok { + return "" + } + document, err := goquery.NewDocumentFromReader(strings.NewReader(html)) + if err != nil { + logger.SugaredLogger.Error(err.Error()) + } + + //price + price := strutil.RemoveWhiteSpace(document.Find("div#price").First().Text(), false) + hqTime := strutil.RemoveWhiteSpace(document.Find("div#hqTime").First().Text(), false) + + var markdown strings.Builder + markdown.WriteString(fmt.Sprintf("### %s:%s 时间:%s\n", name, price, hqTime)) + GetTableMarkdown(document, "div#hqDetails table", &markdown) + return markdown.String() +} + func getSHSZStockPriceInfo(stockCode string, crawlTimeOut int64) *[]string { url := "https://finance.sina.com.cn/realstock/company/" + stockCode + "/nc.shtml" crawlerAPI := CrawlerApi{} diff --git a/backend/data/stock_data_api_test.go b/backend/data/stock_data_api_test.go index 132bbfe..1c55f84 100644 --- a/backend/data/stock_data_api_test.go +++ b/backend/data/stock_data_api_test.go @@ -49,8 +49,11 @@ func TestSearchStockPriceInfo(t *testing.T) { db.Init("../../data/stock.db") //SearchStockPriceInfo("hk06030", 30) //SearchStockPriceInfo("sh600171", 30) - SearchStockPriceInfo("gb_aapl", 30) - SearchStockPriceInfo("bj430198", 30) + //SearchStockPriceInfo("gb_aapl", 30) + //SearchStockPriceInfo("bj430198", 30) + getZSInfo("创业板指数", "sz399006", 30) + getZSInfo("上证综合指数", "sh000001", 30) + getZSInfo("沪深300指数", "sh000300", 30) }