From 688f051a346a8c8fc368bed2bf3a59e12b4cb079 Mon Sep 17 00:00:00 2001 From: administrator Date: Fri, 31 Mar 2023 22:46:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=80=E7=BB=88=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gpt_2_create_question.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gpt_2_create_question.py b/gpt_2_create_question.py index 2090897..9209f2b 100644 --- a/gpt_2_create_question.py +++ b/gpt_2_create_question.py @@ -30,9 +30,10 @@ def get_query_similarity(input_query: str, df: pd.DataFrame): Return the list of document sections, sorted by relevance in descending order. """ - + # 获取输入input_query的embedding向量 query_embedding = get_embedding(input_query) + # 算每一行的embedding向量和输入input_query的embedding向量的相似度 df['similarities'] = df['embeddings'].apply(lambda x: vector_similarity(query_embedding, x)) # print(df) """ @@ -85,6 +86,7 @@ def _decorate_query(input_query: str, df: pd.DataFrame) -> str: def decorate_query(input_query: str, filepath) -> str: try: df = pd.read_csv(filepath) + # 如果df为空,那么就返回input_query if df.empty: return input_query else: @@ -103,6 +105,7 @@ def decorate_query(input_query: str, filepath) -> str: 2 当有人问:你们公司有多少人, 请回答:亁颐堂有三十多个人 [-0.004695456940680742, -0.011140977963805199,... 3 当有人问:你们公司有多少个分部, 请回答:亁颐堂有北京 上海和南京三个分部 [0.0038718082942068577, -0.003343536052852869,... """ + # df默认读出的embeddings是字符串,需要转换成list df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x)) return _decorate_query(input_query, df)