From 5a76f826df45f57cba67e095a2695221b22153ae Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 12 Aug 2024 14:40:12 +0100 Subject: [PATCH] Add note about using text_similarity for rerank to the CLI (#716) --- eland/cli/eland_import_hub_model.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/eland/cli/eland_import_hub_model.py b/eland/cli/eland_import_hub_model.py index 7980a3c..4ca8544 100755 --- a/eland/cli/eland_import_hub_model.py +++ b/eland/cli/eland_import_hub_model.py @@ -41,7 +41,9 @@ MODEL_HUB_URL = "https://huggingface.co" def get_arg_parser(): from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + exit_on_error=False + ) # throw exception rather than exit location_args = parser.add_mutually_exclusive_group(required=True) location_args.add_argument( "--url", @@ -97,7 +99,7 @@ def get_arg_parser(): "--task-type", required=False, choices=SUPPORTED_TASK_TYPES, - help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. " + help="The task type for the model usage. Use text_similarity for rerank tasks. Will attempt to auto-detect task type for the model if not provided. " "Default: auto", default="auto", ) @@ -159,6 +161,23 @@ def get_arg_parser(): return parser +def parse_args(): + parser = get_arg_parser() + try: + return parser.parse_args() + except argparse.ArgumentError as argument_error: + if argument_error.argument_name == "--task-type": + message = ( + argument_error.message + + "\n\nUse 'text_similarity' for rerank tasks in Elasticsearch" + ) + parser.error(message=message) + else: + parser.error(message=argument_error.message) + except argparse.ArgumentTypeError as type_error: + parser.error(str(type_error)) + + def get_es_client(cli_args, logger): try: es_args = { @@ -262,7 +281,7 @@ def main(): assert SUPPORTED_TASK_TYPES # Parse arguments - args = get_arg_parser().parse_args() + args = parse_args() # Connect to ES logger.info("Establishing connection to Elasticsearch")