From 55967a7324d622cc41736d07599e0658b8d06502 Mon Sep 17 00:00:00 2001 From: Youhei Sakurai Date: Wed, 5 Jul 2023 18:49:16 +0900 Subject: [PATCH] Minimize if main section (#554) For migration from scripts to console_scripts in setup.py, the current long if __name__ == "__main__": section is a blocker because the console_scripts requires to specify a function as an entrypoint. Move the logic into a main() function. --- bin/eland_import_hub_model | 16 +++++++++++----- eland/ml/pytorch/transformers.py | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index d06732a..a4b6197 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -41,6 +41,8 @@ MODEL_HUB_URL = "https://huggingface.co" def get_arg_parser(): + from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES + parser = argparse.ArgumentParser() location_args = parser.add_mutually_exclusive_group(required=True) location_args.add_argument( @@ -127,7 +129,7 @@ def get_arg_parser(): return parser -def get_es_client(cli_args): +def get_es_client(cli_args, logger): try: es_args = { 'request_timeout': 300, @@ -159,7 +161,7 @@ def get_es_client(cli_args): exit(1) -def check_cluster_version(es_client): +def check_cluster_version(es_client, logger): es_info = es_client.info() logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})") @@ -180,7 +182,8 @@ def check_cluster_version(es_client): return sem_ver -if __name__ == "__main__": + +def main(): # Configure logging logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s') logger = logging.getLogger(__name__) @@ -202,14 +205,15 @@ if __name__ == "__main__": \033[1m{sys.executable} -m pip install 'eland[pytorch]' \033[0m""")) exit(1) + assert SUPPORTED_TASK_TYPES # Parse arguments args = get_arg_parser().parse_args() # Connect to ES logger.info("Establishing connection to Elasticsearch") - es = get_es_client(args) - cluster_version = check_cluster_version(es) + es = get_es_client(args, logger) + cluster_version = check_cluster_version(es, logger) # Trace and save model, then upload it from temp file with tempfile.TemporaryDirectory() as tmp_dir: @@ -254,3 +258,5 @@ if __name__ == "__main__": logger.info(f"Model successfully imported with id '{ptm.model_id}'") +if __name__ == "__main__": + main() diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 4f1e312..7b624d9 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -494,7 +494,7 @@ class _TransformerTraceableModel(TraceableModel): class _TraceableClassificationModel(_TransformerTraceableModel, ABC): def classification_labels(self) -> Optional[List[str]]: id_label_items = self._model.config.id2label.items() - labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore + labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # Make classes like I-PER into I_PER which fits Java enumerations return [label.replace("-", "_") for label in labels] @@ -636,7 +636,7 @@ class TransformerModel: def _load_vocab(self) -> Dict[str, List[str]]: vocab_items = self._tokenizer.get_vocab().items() - vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore + vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] vocab_obj = { "vocabulary": vocabulary, }