mirror of
https://github.com/collinsctk/chatgpt_embeddings.git
synced 2025-07-19 00:00:05 +08:00
init
This commit is contained in:
parent
794a85d46b
commit
cc2bb88989
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
17
.idea/aws.xml
generated
Normal file
17
.idea/aws.xml
generated
Normal file
@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="accountSettings">
|
||||
<option name="activeProfile" value="profile:default" />
|
||||
<option name="activeRegion" value="us-east-1" />
|
||||
<option name="recentlyUsedProfiles">
|
||||
<list>
|
||||
<option value="profile:default" />
|
||||
</list>
|
||||
</option>
|
||||
<option name="recentlyUsedRegions">
|
||||
<list>
|
||||
<option value="us-east-1" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
10
.idea/chatgpt_embeddings.iml
generated
Normal file
10
.idea/chatgpt_embeddings.iml
generated
Normal file
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
50
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
50
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,50 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<Languages>
|
||||
<language minSize="245" name="Python" />
|
||||
</Languages>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="HttpUrlsUsage" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredUrls">
|
||||
<list>
|
||||
<option value="http://localhost" />
|
||||
<option value="http://127.0.0.1" />
|
||||
<option value="http://0.0.0.0" />
|
||||
<option value="http://www.w3.org/" />
|
||||
<option value="http://json-schema.org/draft" />
|
||||
<option value="http://java.sun.com/" />
|
||||
<option value="http://xmlns.jcp.org/" />
|
||||
<option value="http://javafx.com/javafx/" />
|
||||
<option value="http://javafx.com/fxml" />
|
||||
<option value="http://maven.apache.org/xsd/" />
|
||||
<option value="http://maven.apache.org/POM/" />
|
||||
<option value="http://www.springframework.org/schema/" />
|
||||
<option value="http://www.springframework.org/tags" />
|
||||
<option value="http://www.springframework.org/security/tags" />
|
||||
<option value="http://www.thymeleaf.org" />
|
||||
<option value="http://www.jboss.org/j2ee/schema/" />
|
||||
<option value="http://www.jboss.com/xml/ns/" />
|
||||
<option value="http://www.ibm.com/webservices/xsd" />
|
||||
<option value="http://activemq.apache.org/schema/" />
|
||||
<option value="http://schema.cloudfoundry.org/spring/" />
|
||||
<option value="http://schemas.xmlsoap.org/" />
|
||||
<option value="http://cxf.apache.org/schemas/" />
|
||||
<option value="http://primefaces.org/ui" />
|
||||
<option value="http://tiles.apache.org/" />
|
||||
<option value="http://www.mingjiao.org" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="1">
|
||||
<item index="0" class="java.lang.String" itemvalue="flask" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
4
.idea/misc.xml
generated
Normal file
4
.idea/misc.xml
generated
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (chatgpt_embeddings)" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/chatgpt_embeddings.iml" filepath="$PROJECT_DIR$/.idea/chatgpt_embeddings.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
52
gpt_1_embeddings_training.py
Normal file
52
gpt_1_embeddings_training.py
Normal file
@ -0,0 +1,52 @@
|
||||
# 参考文章
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb
|
||||
|
||||
|
||||
import openai
|
||||
import pandas as pd
|
||||
import time
|
||||
from gpt_0_basic_info import api_key, excel_file_path, csv_file_path
|
||||
|
||||
COMPLETIONS_MODEL = "text-davinci-003"
|
||||
EMBEDDING_MODEL = "text-embedding-ada-002"
|
||||
|
||||
|
||||
def get_embedding(text: str, open_ai_api_key: str, model: str = EMBEDDING_MODEL) -> list[float]:
|
||||
openai.api_key = open_ai_api_key
|
||||
result = openai.Embedding.create(
|
||||
model=model,
|
||||
input=text
|
||||
)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
|
||||
def compute_doc_embeddings(df: pd.DataFrame, open_ai_api_key: str):
|
||||
"""
|
||||
Create an embedding for each row in the dataframe using the OpenAI Embeddings API.
|
||||
|
||||
Return a datafram with embedding
|
||||
}
|
||||
"""
|
||||
df['embeddings'] = ''
|
||||
df['embeddings'] = df['embeddings'].astype('object')
|
||||
|
||||
for idx, r in df.iterrows():
|
||||
print(idx)
|
||||
df.at[idx, 'embeddings'] = get_embedding(r.QandA, open_ai_api_key)
|
||||
time.sleep(1)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def getembeddings(api_key, excelfilepath, csvfilepath):
|
||||
df = pd.read_excel(excelfilepath)
|
||||
df['prompt'] = df['prompt'].apply(lambda x: x.replace('\n', ''))
|
||||
df['prompt'] = df['prompt'].apply(lambda x: "当有人问:" + x + '')
|
||||
df['completion'] = df['completion'].apply(lambda x: "请回答:" + x)
|
||||
df['QandA'] = df['prompt'] + df['completion']
|
||||
df = compute_doc_embeddings(df, api_key)[['QandA', 'embeddings']]
|
||||
df.to_csv(csvfilepath, index=False, encoding='utf-8_sig')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
getembeddings(api_key, excel_file_path, csv_file_path)
|
120
gpt_2_create_question.py
Normal file
120
gpt_2_create_question.py
Normal file
@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
import openai
|
||||
import pandas as pd
|
||||
import ast
|
||||
from gpt_0_basic_info import api_key, excel_file_path, csv_file_path
|
||||
|
||||
MAXCONTEXTLEN = 1500
|
||||
|
||||
COMPLETIONS_MODEL = "text-davinci-003"
|
||||
EMBEDDING_MODEL = "text-embedding-ada-002"
|
||||
|
||||
|
||||
def get_embedding(text: str, open_ai_api_key: str, model: str=EMBEDDING_MODEL) -> list[float]:
|
||||
|
||||
openai.api_key = open_ai_api_key
|
||||
result = openai.Embedding.create(
|
||||
model=model,
|
||||
input=text
|
||||
)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
|
||||
def compute_doc_embeddings(df: pd.DataFrame, open_ai_api_key: str) :
|
||||
"""
|
||||
Create an embedding for each row in the dataframe using the OpenAI Embeddings API.
|
||||
|
||||
Return a datafram with embedding
|
||||
}
|
||||
"""
|
||||
df['embeddings'] = ''
|
||||
df['embeddings'] = df['embeddings'].astype('object')
|
||||
|
||||
for idx, r in df.iterrows():
|
||||
df.at[idx, 'embeddings'] = get_embedding(r.content, open_ai_api_key)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def vector_similarity(x: list[float], y: list[float]) -> float:
|
||||
"""
|
||||
Returns the similarity between two vectors.
|
||||
|
||||
Because OpenAI Embeddings are normalized to length 1, the cosine similarity is the same as the dot product.
|
||||
"""
|
||||
return np.dot(np.array(x), np.array(y))
|
||||
|
||||
|
||||
def get_query_similarity(query: str, df: pd.DataFrame, open_ai_api_key: str):
|
||||
"""
|
||||
Find the query embedding for the supplied query, and compare it against all of the pre-calculated document embeddings
|
||||
to find the most relevant sections.
|
||||
|
||||
Return the list of document sections, sorted by relevance in descending order.
|
||||
"""
|
||||
openai.api_key = open_ai_api_key
|
||||
|
||||
query_embedding = get_embedding(query, open_ai_api_key)
|
||||
|
||||
#df['similarities'] = 0
|
||||
|
||||
df['similarities'] = df['embeddings'].apply(lambda x:vector_similarity(query_embedding, x))
|
||||
|
||||
#print(df['similarities'])
|
||||
'''
|
||||
for idx, r in df.iterrows():
|
||||
df.loc[idx, 'similarities'] = vector_similarity(query_embedding, r.embeddings)
|
||||
'''
|
||||
|
||||
two_largest = df['similarities'].nlargest(2).index.tolist()
|
||||
|
||||
# print('get_query_similarity!!!!!!!!')
|
||||
|
||||
context = '' if df.loc[two_largest[0]]['similarities'] < 0.8 else df.loc[two_largest[0]]['QandA'] if (df.loc[two_largest[1]]['similarities'] < 0.8 or (len(df.loc[two_largest[1]]['QandA'] + '\n' + df.loc[two_largest[0]]['QandA'])>=MAXCONTEXTLEN)) else (df.loc[two_largest[1]]['QandA'] + '\n' + df.loc[two_largest[0]]['QandA'])
|
||||
# print(two_largest[0], df.loc[two_largest[0]]['similarities'], df.loc[two_largest[0]]['QandA'])
|
||||
# print(two_largest[1], df.loc[two_largest[1]]['similarities'], df.loc[two_largest[1]]['QandA'])
|
||||
# print(len(df.loc[two_largest[1]]['QandA'] + '\n' + df.loc[two_largest[0]]['QandA']))
|
||||
# print(context)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def _decorate_query(query: str, df: pd.DataFrame, open_ai_api_key: str)-> str:
|
||||
|
||||
try:
|
||||
context = get_query_similarity(query, df, open_ai_api_key)
|
||||
if context != '':
|
||||
header = """请使用上下文尽可能真实、自然地回答问题,如果答案未包含在上下文中,请不要编造回答,并且不要在回答中包含”根据上下文”这个短语。\n\n上下文:\n"""
|
||||
#header = "上下文:\n"
|
||||
query = header + context + "\n\n 问题: " + query + "\n 回答:?"
|
||||
# print(query)
|
||||
return query
|
||||
except:
|
||||
# print('ERROR 444444')
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def decorate_query(query: str, open_ai_api_key, filename='foodsembeddings.csv')-> str:
|
||||
filepath = filename
|
||||
try:
|
||||
df = pd.read_csv(filepath)
|
||||
if df.empty:
|
||||
return query
|
||||
else:
|
||||
try:
|
||||
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x))
|
||||
return _decorate_query(query, df, open_ai_api_key)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return query
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return query
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query = '亁颐堂是做什么的'
|
||||
print(decorate_query(query, api_key, filename=csv_file_path))
|
||||
|
||||
|
27
gpt_3_query.py
Normal file
27
gpt_3_query.py
Normal file
@ -0,0 +1,27 @@
|
||||
import openai
|
||||
from gpt_2_create_question import decorate_query
|
||||
from gpt_0_basic_info import api_key, csv_file_path
|
||||
openai.api_key = api_key
|
||||
|
||||
|
||||
def question(query):
|
||||
response = openai.ChatCompletion.create(
|
||||
model='gpt-3.5-turbo',
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": query}
|
||||
],
|
||||
max_tokens=100,
|
||||
n=1,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
output_text = response.choices[0].message['content'].strip()
|
||||
return output_text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
query = '亁颐堂是做什么的'
|
||||
new_query = decorate_query(query, api_key, filename=csv_file_path)
|
||||
print(new_query)
|
||||
print(question(new_query))
|
BIN
source_data/qa.xlsx
Executable file
BIN
source_data/qa.xlsx
Executable file
Binary file not shown.
5
trained_data/qa_embeddings.csv
Normal file
5
trained_data/qa_embeddings.csv
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user