Add note about using text_similarity for rerank to the CLI (#716)

This commit is contained in:
David Kyle 2024-08-12 14:40:12 +01:00 committed by GitHub
parent fd8886da6a
commit 5a76f826df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,7 +41,9 @@ MODEL_HUB_URL = "https://huggingface.co"
def get_arg_parser():
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
exit_on_error=False
) # throw exception rather than exit
location_args = parser.add_mutually_exclusive_group(required=True)
location_args.add_argument(
"--url",
@ -97,7 +99,7 @@ def get_arg_parser():
"--task-type",
required=False,
choices=SUPPORTED_TASK_TYPES,
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
help="The task type for the model usage. Use text_similarity for rerank tasks. Will attempt to auto-detect task type for the model if not provided. "
"Default: auto",
default="auto",
)
@ -159,6 +161,23 @@ def get_arg_parser():
return parser
def parse_args():
parser = get_arg_parser()
try:
return parser.parse_args()
except argparse.ArgumentError as argument_error:
if argument_error.argument_name == "--task-type":
message = (
argument_error.message
+ "\n\nUse 'text_similarity' for rerank tasks in Elasticsearch"
)
parser.error(message=message)
else:
parser.error(message=argument_error.message)
except argparse.ArgumentTypeError as type_error:
parser.error(str(type_error))
def get_es_client(cli_args, logger):
try:
es_args = {
@ -262,7 +281,7 @@ def main():
assert SUPPORTED_TASK_TYPES
# Parse arguments
args = get_arg_parser().parse_args()
args = parse_args()
# Connect to ES
logger.info("Establishing connection to Elasticsearch")