feat(data):添加tushare数据接口并优化股票代码转换功能(设置好提问模板后可进行K线分析功能)

- 新增 TushareApi 结构体和 GetDaily 方法,用于获取 A 股日线行情数据
- 在 openai_api.go 中添加获取股票日 K线数据的协程
- 在 utils.go 中添加股票代码与 tushare 代码相互转换的函数
- 更新相关测试文件以支持新功能
This commit is contained in:
spark 2025-02-17 17:33:17 +08:00
parent 8d3cd7b151
commit c81b1a730d
6 changed files with 126 additions and 6 deletions

View File

@ -141,7 +141,19 @@ func (o OpenAi) NewChatStream(stock, stockCode, userQuestion string) <-chan map[
logger.SugaredLogger.Infof("final question:%s", question) logger.SugaredLogger.Infof("final question:%s", question)
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(5) wg.Add(6)
go func() {
defer wg.Done()
endDate := time.Now().Format("20060102")
startDate := time.Now().Add(-time.Hour * 24 * 365).Format("20060102")
K := NewTushareApi(getConfig()).GetDaily(ConvertStockCodeToTushareCode(stockCode), startDate, endDate)
msg = append(msg, map[string]interface{}{
"role": "assistant",
"content": stock + "日K数据如下\n" + K,
})
}()
go func() { go func() {
defer wg.Done() defer wg.Done()
messages := SearchStockPriceInfo(stockCode, o.CrawlTimeOut) messages := SearchStockPriceInfo(stockCode, o.CrawlTimeOut)

View File

@ -9,13 +9,10 @@ import (
func TestNewDeepSeekOpenAiConfig(t *testing.T) { func TestNewDeepSeekOpenAiConfig(t *testing.T) {
db.Init("../../data/stock.db") db.Init("../../data/stock.db")
ai := NewDeepSeekOpenAi(context.TODO()) ai := NewDeepSeekOpenAi(context.TODO())
res := ai.NewChatStream("北京文化", "sz000802") res := ai.NewChatStream("北京文化", "sz000802", "")
for { for {
select { select {
case msg := <-res: case msg := <-res:
if msg == "" {
continue
}
t.Log(msg) t.Log(msg)
} }
} }

View File

@ -0,0 +1,65 @@
package data
import (
"github.com/duke-git/lancet/v2/convertor"
"github.com/duke-git/lancet/v2/slice"
"github.com/go-resty/resty/v2"
"go-stock/backend/logger"
)
// @Author spark
// @Date 2025/2/17 12:33
// @Desc
//-----------------------------------------------------------------------------------
type TushareApi struct {
client *resty.Client
config *Settings
}
func NewTushareApi(config *Settings) *TushareApi {
return &TushareApi{
client: resty.New(),
config: config,
}
}
// GetDaily tushare A股日线行情
func (receiver TushareApi) GetDaily(tsCode, startDate, endDate string) string {
logger.SugaredLogger.Debugf("tushare daily request: ts_code=%s, start_date=%s, end_date=%s", tsCode, startDate, endDate)
fields := "ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount"
resp := &TushareStockBasicResponse{}
_, err := receiver.client.R().
SetHeader("content-type", "application/json").
SetBody(&TushareRequest{
ApiName: "daily",
Token: receiver.config.TushareToken,
Params: map[string]any{
"ts_code": tsCode,
"start_date": startDate,
"end_date": endDate,
},
Fields: fields}).
SetResult(resp).
Post(tushareApiUrl)
if err != nil {
logger.SugaredLogger.Error(err)
return ""
}
res := ""
if resp.Data.Items != nil && len(resp.Data.Items) > 0 {
fieldsStr := slice.JoinFunc(resp.Data.Fields, ",", func(s string) string {
return "\"" + convertor.ToString(s) + "\""
})
res += fieldsStr + "\n"
for _, item := range resp.Data.Items {
//logger.SugaredLogger.Debugf("%s", slice.Join(item, ","))
t := slice.JoinFunc(item, ",", func(s any) any {
return "\"" + convertor.ToString(s) + "\""
})
res += t + "\n"
}
}
logger.SugaredLogger.Debugf("tushare response: %s", res)
return res
}

View File

@ -0,0 +1,18 @@
package data
import (
"go-stock/backend/db"
"testing"
)
// @Author spark
// @Date 2025/2/17 12:44
// @Desc
// -----------------------------------------------------------------------------------
func TestGetDaily(t *testing.T) {
db.Init("../../data/stock.db")
tushareApi := NewTushareApi(getConfig())
res := tushareApi.GetDaily("000802.SZ", "20250101", "20250217")
t.Log(res)
}

View File

@ -1,6 +1,9 @@
package data package data
import "regexp" import (
"regexp"
"strings"
)
// @Author spark // @Author spark
// @Date 2025/2/13 13:08 // @Date 2025/2/13 13:08
@ -21,3 +24,23 @@ func RemoveAllNonDigitChar(s string) string {
re := regexp.MustCompile(`\D`) re := regexp.MustCompile(`\D`)
return re.ReplaceAllString(s, "") return re.ReplaceAllString(s, "")
} }
// RemoveAllDigitChar 去除所有数字字符
func RemoveAllDigitChar(s string) string {
re := regexp.MustCompile(`\d`)
return re.ReplaceAllString(s, "")
}
// ConvertStockCodeToTushareCode 将股票代码转换为tushare的股票代码
func ConvertStockCodeToTushareCode(stockCode string) string {
//提取非数字
stockCode = RemoveAllNonDigitChar(stockCode) + "." + strings.ToUpper(RemoveAllDigitChar(stockCode))
return stockCode
}
// ConvertTushareCodeToStockCode 将tushare股票代码转换为的普通股票代码
func ConvertTushareCodeToStockCode(stockCode string) string {
//提取非数字
stockCode = strings.ToLower(RemoveAllDigitChar(stockCode)) + RemoveAllNonDigitChar(stockCode)
return strings.ReplaceAll(stockCode, ".", "")
}

View File

@ -31,3 +31,8 @@ func TestRemoveNonPrintable(t *testing.T) {
logger.SugaredLogger.Infof("RemoveAllBlankChar(%s)", txt) logger.SugaredLogger.Infof("RemoveAllBlankChar(%s)", txt)
} }
func TestConvertStockCodeToTushareCode(t *testing.T) {
logger.SugaredLogger.Infof("ConvertStockCodeToTushareCode(%s)", ConvertStockCodeToTushareCode("sz000802"))
logger.SugaredLogger.Infof("ConvertTushareCodeToStockCode(%s)", ConvertTushareCodeToStockCode("000802.SZ"))
}