mirror of
https://github.com/elastic/eland.git
synced 2025-07-24 00:00:39 +08:00
Add note about using text_similarity for rerank to the CLI (#716)
This commit is contained in:
parent
fd8886da6a
commit
5a76f826df
@ -41,7 +41,9 @@ MODEL_HUB_URL = "https://huggingface.co"
|
|||||||
def get_arg_parser():
|
def get_arg_parser():
|
||||||
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
|
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(
|
||||||
|
exit_on_error=False
|
||||||
|
) # throw exception rather than exit
|
||||||
location_args = parser.add_mutually_exclusive_group(required=True)
|
location_args = parser.add_mutually_exclusive_group(required=True)
|
||||||
location_args.add_argument(
|
location_args.add_argument(
|
||||||
"--url",
|
"--url",
|
||||||
@ -97,7 +99,7 @@ def get_arg_parser():
|
|||||||
"--task-type",
|
"--task-type",
|
||||||
required=False,
|
required=False,
|
||||||
choices=SUPPORTED_TASK_TYPES,
|
choices=SUPPORTED_TASK_TYPES,
|
||||||
help="The task type for the model usage. Will attempt to auto-detect task type for the model if not provided. "
|
help="The task type for the model usage. Use text_similarity for rerank tasks. Will attempt to auto-detect task type for the model if not provided. "
|
||||||
"Default: auto",
|
"Default: auto",
|
||||||
default="auto",
|
default="auto",
|
||||||
)
|
)
|
||||||
@ -159,6 +161,23 @@ def get_arg_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = get_arg_parser()
|
||||||
|
try:
|
||||||
|
return parser.parse_args()
|
||||||
|
except argparse.ArgumentError as argument_error:
|
||||||
|
if argument_error.argument_name == "--task-type":
|
||||||
|
message = (
|
||||||
|
argument_error.message
|
||||||
|
+ "\n\nUse 'text_similarity' for rerank tasks in Elasticsearch"
|
||||||
|
)
|
||||||
|
parser.error(message=message)
|
||||||
|
else:
|
||||||
|
parser.error(message=argument_error.message)
|
||||||
|
except argparse.ArgumentTypeError as type_error:
|
||||||
|
parser.error(str(type_error))
|
||||||
|
|
||||||
|
|
||||||
def get_es_client(cli_args, logger):
|
def get_es_client(cli_args, logger):
|
||||||
try:
|
try:
|
||||||
es_args = {
|
es_args = {
|
||||||
@ -262,7 +281,7 @@ def main():
|
|||||||
assert SUPPORTED_TASK_TYPES
|
assert SUPPORTED_TASK_TYPES
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
args = get_arg_parser().parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Connect to ES
|
# Connect to ES
|
||||||
logger.info("Establishing connection to Elasticsearch")
|
logger.info("Establishing connection to Elasticsearch")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user