最终注释

This commit is contained in:
administrator 2023-03-31 22:46:19 +08:00
parent cbb091efa4
commit 688f051a34

View File

@ -30,9 +30,10 @@ def get_query_similarity(input_query: str, df: pd.DataFrame):
Return the list of document sections, sorted by relevance in descending order. Return the list of document sections, sorted by relevance in descending order.
""" """
# 获取输入input_query的embedding向量
query_embedding = get_embedding(input_query) query_embedding = get_embedding(input_query)
# 算每一行的embedding向量和输入input_query的embedding向量的相似度
df['similarities'] = df['embeddings'].apply(lambda x: vector_similarity(query_embedding, x)) df['similarities'] = df['embeddings'].apply(lambda x: vector_similarity(query_embedding, x))
# print(df) # print(df)
""" """
@ -85,6 +86,7 @@ def _decorate_query(input_query: str, df: pd.DataFrame) -> str:
def decorate_query(input_query: str, filepath) -> str: def decorate_query(input_query: str, filepath) -> str:
try: try:
df = pd.read_csv(filepath) df = pd.read_csv(filepath)
# 如果df为空那么就返回input_query
if df.empty: if df.empty:
return input_query return input_query
else: else:
@ -103,6 +105,7 @@ def decorate_query(input_query: str, filepath) -> str:
2 当有人问你们公司有多少人, 请回答亁颐堂有三十多个人 [-0.004695456940680742, -0.011140977963805199,... 2 当有人问你们公司有多少人, 请回答亁颐堂有三十多个人 [-0.004695456940680742, -0.011140977963805199,...
3 当有人问你们公司有多少个分部, 请回答亁颐堂有北京 上海和南京三个分部 [0.0038718082942068577, -0.003343536052852869,... 3 当有人问你们公司有多少个分部, 请回答亁颐堂有北京 上海和南京三个分部 [0.0038718082942068577, -0.003343536052852869,...
""" """
# df默认读出的embeddings是字符串需要转换成list
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x)) df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x))
return _decorate_query(input_query, df) return _decorate_query(input_query, df)