diff --git a/openai/chatgpt.go b/openai/chatgpt.go index 0649530..1904846 100644 --- a/openai/chatgpt.go +++ b/openai/chatgpt.go @@ -85,6 +85,8 @@ curl https://api.openai.com/v1/chat/completions \ */ +var contextMgr ContextMgr + // Completions sendMsg func Completions(msg string) (*string, error) { apiKey := config.GetOpenAiApiKey() @@ -93,6 +95,24 @@ func Completions(msg string) (*string, error) { } var messages []ChatMessage + messages = append(messages, ChatMessage{ + Role: "system", + Content: "You are a helpful assistant.", + }) + + list := contextMgr.GetData() + for i := 0; i < len(list); i++ { + messages = append(messages, ChatMessage{ + Role: "user", + Content: list[i].Request, + }) + + messages = append(messages, ChatMessage{ + Role: "assistant", + Content: list[i].Response, + }) + } + messages = append(messages, ChatMessage{ Role: "user", Content: msg, @@ -150,6 +170,8 @@ func Completions(msg string) (*string, error) { reply += "\n" reply += v.Message.Content } + + contextMgr.AppendMsg(msg, reply) } if len(reply) == 0 { diff --git a/openai/context_mgr.go b/openai/context_mgr.go new file mode 100644 index 0000000..2fc709a --- /dev/null +++ b/openai/context_mgr.go @@ -0,0 +1,45 @@ +package openai + +import ( + "time" +) + +type Context struct { + Request string + Response string + Time int64 +} + +type ContextMgr struct { + contextList []*Context +} + +func (m *ContextMgr) Init() { + m.contextList = make([]*Context, 10) +} + +func (m *ContextMgr) checkExpire() { + timeNow := time.Now().Unix() + if len(m.contextList) > 0 { + startPos := len(m.contextList) - 1 + for i := 0; i < len(m.contextList); i++ { + if timeNow-m.contextList[i].Time < 1*60 { + startPos = i + break + } + } + + m.contextList = m.contextList[startPos:] + } +} + +func (m *ContextMgr) AppendMsg(request string, response string) { + m.checkExpire() + context := &Context{Request: request, Response: response, Time: time.Now().Unix()} + m.contextList = append(m.contextList, context) +} + +func (m *ContextMgr) GetData() []*Context { + m.checkExpire() + return m.contextList +}