diff --git a/bin/eland_import_hub_model b/bin/eland_import_hub_model index d06732a..a4b6197 100755 --- a/bin/eland_import_hub_model +++ b/bin/eland_import_hub_model @@ -41,6 +41,8 @@ MODEL_HUB_URL = "https://huggingface.co" def get_arg_parser(): + from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES + parser = argparse.ArgumentParser() location_args = parser.add_mutually_exclusive_group(required=True) location_args.add_argument( @@ -127,7 +129,7 @@ def get_arg_parser(): return parser -def get_es_client(cli_args): +def get_es_client(cli_args, logger): try: es_args = { 'request_timeout': 300, @@ -159,7 +161,7 @@ def get_es_client(cli_args): exit(1) -def check_cluster_version(es_client): +def check_cluster_version(es_client, logger): es_info = es_client.info() logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})") @@ -180,7 +182,8 @@ def check_cluster_version(es_client): return sem_ver -if __name__ == "__main__": + +def main(): # Configure logging logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s') logger = logging.getLogger(__name__) @@ -202,14 +205,15 @@ if __name__ == "__main__": \033[1m{sys.executable} -m pip install 'eland[pytorch]' \033[0m""")) exit(1) + assert SUPPORTED_TASK_TYPES # Parse arguments args = get_arg_parser().parse_args() # Connect to ES logger.info("Establishing connection to Elasticsearch") - es = get_es_client(args) - cluster_version = check_cluster_version(es) + es = get_es_client(args, logger) + cluster_version = check_cluster_version(es, logger) # Trace and save model, then upload it from temp file with tempfile.TemporaryDirectory() as tmp_dir: @@ -254,3 +258,5 @@ if __name__ == "__main__": logger.info(f"Model successfully imported with id '{ptm.model_id}'") +if __name__ == "__main__": + main() diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 4f1e312..7b624d9 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -494,7 +494,7 @@ class _TransformerTraceableModel(TraceableModel): class _TraceableClassificationModel(_TransformerTraceableModel, ABC): def classification_labels(self) -> Optional[List[str]]: id_label_items = self._model.config.id2label.items() - labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore + labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # Make classes like I-PER into I_PER which fits Java enumerations return [label.replace("-", "_") for label in labels] @@ -636,7 +636,7 @@ class TransformerModel: def _load_vocab(self) -> Dict[str, List[str]]: vocab_items = self._tokenizer.get_vocab().items() - vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore + vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] vocab_obj = { "vocabulary": vocabulary, }