From c81b1a730dfaff24219ff59c814eb04727f9f417 Mon Sep 17 00:00:00 2001 From: spark Date: Mon, 17 Feb 2025 17:33:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=E6=B7=BB=E5=8A=A0tushare=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=8E=A5=E5=8F=A3=E5=B9=B6=E4=BC=98=E5=8C=96=E8=82=A1?= =?UTF-8?q?=E7=A5=A8=E4=BB=A3=E7=A0=81=E8=BD=AC=E6=8D=A2=E5=8A=9F=E8=83=BD?= =?UTF-8?q?(=E8=AE=BE=E7=BD=AE=E5=A5=BD=E6=8F=90=E9=97=AE=E6=A8=A1?= =?UTF-8?q?=E6=9D=BF=E5=90=8E=E5=8F=AF=E8=BF=9B=E8=A1=8CK=E7=BA=BF?= =?UTF-8?q?=E5=88=86=E6=9E=90=E5=8A=9F=E8=83=BD)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 TushareApi 结构体和 GetDaily 方法,用于获取 A 股日线行情数据 - 在 openai_api.go 中添加获取股票日 K线数据的协程 - 在 utils.go 中添加股票代码与 tushare 代码相互转换的函数 - 更新相关测试文件以支持新功能 --- backend/data/openai_api.go | 14 +++++- backend/data/openai_api_test.go | 5 +-- backend/data/tushare_data_api.go | 65 +++++++++++++++++++++++++++ backend/data/tushare_data_api_test.go | 18 ++++++++ backend/data/utils.go | 25 ++++++++++- backend/data/utils_test.go | 5 +++ 6 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 backend/data/tushare_data_api.go create mode 100644 backend/data/tushare_data_api_test.go diff --git a/backend/data/openai_api.go b/backend/data/openai_api.go index ea15125..86d26fa 100644 --- a/backend/data/openai_api.go +++ b/backend/data/openai_api.go @@ -141,7 +141,19 @@ func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string) <-chan map[ logger.SugaredLogger.Infof("final question:%s", question) wg := &sync.WaitGroup{} - wg.Add(5) + wg.Add(6) + + go func() { + defer wg.Done() + endDate := time.Now().Format("20060102") + startDate := time.Now().Add(-time.Hour * 24 * 365).Format("20060102") + K := NewTushareApi(getConfig()).GetDaily(ConvertStockCodeToTushareCode(stockCode), startDate, endDate) + msg = append(msg, map[string]interface{}{ + "role": "assistant", + "content": stock + "日K数据如下:\n" + K, + }) + }() + go func() { defer wg.Done() messages := SearchStockPriceInfo(stockCode, o.CrawlTimeOut) diff --git a/backend/data/openai_api_test.go b/backend/data/openai_api_test.go index f097aa6..7bc9994 100644 --- a/backend/data/openai_api_test.go +++ b/backend/data/openai_api_test.go @@ -9,13 +9,10 @@ import ( func TestNewDeepSeekOpenAiConfig(t *testing.T) { db.Init("../../data/stock.db") ai := NewDeepSeekOpenAi(context.TODO()) - res := ai.NewChatStream("北京文化", "sz000802") + res := ai.NewChatStream("北京文化", "sz000802", "") for { select { case msg := <-res: - if msg == "" { - continue - } t.Log(msg) } } diff --git a/backend/data/tushare_data_api.go b/backend/data/tushare_data_api.go new file mode 100644 index 0000000..3eda060 --- /dev/null +++ b/backend/data/tushare_data_api.go @@ -0,0 +1,65 @@ +package data + +import ( + "github.com/duke-git/lancet/v2/convertor" + "github.com/duke-git/lancet/v2/slice" + "github.com/go-resty/resty/v2" + "go-stock/backend/logger" +) + +// @Author spark +// @Date 2025/2/17 12:33 +// @Desc +//----------------------------------------------------------------------------------- + +type TushareApi struct { + client *resty.Client + config *Settings +} + +func NewTushareApi(config *Settings) *TushareApi { + return &TushareApi{ + client: resty.New(), + config: config, + } +} + +// GetDaily tushare A股日线行情 +func (receiver TushareApi) GetDaily(tsCode, startDate, endDate string) string { + logger.SugaredLogger.Debugf("tushare daily request: ts_code=%s, start_date=%s, end_date=%s", tsCode, startDate, endDate) + fields := "ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount" + resp := &TushareStockBasicResponse{} + _, err := receiver.client.R(). + SetHeader("content-type", "application/json"). + SetBody(&TushareRequest{ + ApiName: "daily", + Token: receiver.config.TushareToken, + Params: map[string]any{ + "ts_code": tsCode, + "start_date": startDate, + "end_date": endDate, + }, + Fields: fields}). + SetResult(resp). + Post(tushareApiUrl) + if err != nil { + logger.SugaredLogger.Error(err) + return "" + } + res := "" + if resp.Data.Items != nil && len(resp.Data.Items) > 0 { + fieldsStr := slice.JoinFunc(resp.Data.Fields, ",", func(s string) string { + return "\"" + convertor.ToString(s) + "\"" + }) + res += fieldsStr + "\n" + for _, item := range resp.Data.Items { + //logger.SugaredLogger.Debugf("%s", slice.Join(item, ",")) + t := slice.JoinFunc(item, ",", func(s any) any { + return "\"" + convertor.ToString(s) + "\"" + }) + res += t + "\n" + } + } + logger.SugaredLogger.Debugf("tushare response: %s", res) + return res +} diff --git a/backend/data/tushare_data_api_test.go b/backend/data/tushare_data_api_test.go new file mode 100644 index 0000000..7973c66 --- /dev/null +++ b/backend/data/tushare_data_api_test.go @@ -0,0 +1,18 @@ +package data + +import ( + "go-stock/backend/db" + "testing" +) + +// @Author spark +// @Date 2025/2/17 12:44 +// @Desc +// ----------------------------------------------------------------------------------- +func TestGetDaily(t *testing.T) { + db.Init("../../data/stock.db") + tushareApi := NewTushareApi(getConfig()) + res := tushareApi.GetDaily("000802.SZ", "20250101", "20250217") + t.Log(res) + +} diff --git a/backend/data/utils.go b/backend/data/utils.go index ea145b5..05fa98f 100644 --- a/backend/data/utils.go +++ b/backend/data/utils.go @@ -1,6 +1,9 @@ package data -import "regexp" +import ( + "regexp" + "strings" +) // @Author spark // @Date 2025/2/13 13:08 @@ -21,3 +24,23 @@ func RemoveAllNonDigitChar(s string) string { re := regexp.MustCompile(`\D`) return re.ReplaceAllString(s, "") } + +// RemoveAllDigitChar 去除所有数字字符 +func RemoveAllDigitChar(s string) string { + re := regexp.MustCompile(`\d`) + return re.ReplaceAllString(s, "") +} + +// ConvertStockCodeToTushareCode 将股票代码转换为tushare的股票代码 +func ConvertStockCodeToTushareCode(stockCode string) string { + //提取非数字 + stockCode = RemoveAllNonDigitChar(stockCode) + "." + strings.ToUpper(RemoveAllDigitChar(stockCode)) + return stockCode +} + +// ConvertTushareCodeToStockCode 将tushare股票代码转换为的普通股票代码 +func ConvertTushareCodeToStockCode(stockCode string) string { + //提取非数字 + stockCode = strings.ToLower(RemoveAllDigitChar(stockCode)) + RemoveAllNonDigitChar(stockCode) + return strings.ReplaceAll(stockCode, ".", "") +} diff --git a/backend/data/utils_test.go b/backend/data/utils_test.go index 1dde4bf..ff358f2 100644 --- a/backend/data/utils_test.go +++ b/backend/data/utils_test.go @@ -31,3 +31,8 @@ func TestRemoveNonPrintable(t *testing.T) { logger.SugaredLogger.Infof("RemoveAllBlankChar(%s)", txt) } + +func TestConvertStockCodeToTushareCode(t *testing.T) { + logger.SugaredLogger.Infof("ConvertStockCodeToTushareCode(%s)", ConvertStockCodeToTushareCode("sz000802")) + logger.SugaredLogger.Infof("ConvertTushareCodeToStockCode(%s)", ConvertTushareCodeToStockCode("000802.SZ")) +}