mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Minimize if main section (#554)
For migration from scripts to console_scripts in setup.py, the current long if __name__ == "__main__": section is a blocker because the console_scripts requires to specify a function as an entrypoint. Move the logic into a main() function.
This commit is contained in:
parent
bf3b092ed4
commit
55967a7324
@ -41,6 +41,8 @@ MODEL_HUB_URL = "https://huggingface.co"
|
|||||||
|
|
||||||
|
|
||||||
def get_arg_parser():
|
def get_arg_parser():
|
||||||
|
from eland.ml.pytorch.transformers import SUPPORTED_TASK_TYPES
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
location_args = parser.add_mutually_exclusive_group(required=True)
|
location_args = parser.add_mutually_exclusive_group(required=True)
|
||||||
location_args.add_argument(
|
location_args.add_argument(
|
||||||
@ -127,7 +129,7 @@ def get_arg_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def get_es_client(cli_args):
|
def get_es_client(cli_args, logger):
|
||||||
try:
|
try:
|
||||||
es_args = {
|
es_args = {
|
||||||
'request_timeout': 300,
|
'request_timeout': 300,
|
||||||
@ -159,7 +161,7 @@ def get_es_client(cli_args):
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
def check_cluster_version(es_client):
|
def check_cluster_version(es_client, logger):
|
||||||
es_info = es_client.info()
|
es_info = es_client.info()
|
||||||
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
logger.info(f"Connected to cluster named '{es_info['cluster_name']}' (version: {es_info['version']['number']})")
|
||||||
|
|
||||||
@ -180,7 +182,8 @@ def check_cluster_version(es_client):
|
|||||||
|
|
||||||
return sem_ver
|
return sem_ver
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
def main():
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s')
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -202,14 +205,15 @@ if __name__ == "__main__":
|
|||||||
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
|
||||||
\033[0m"""))
|
\033[0m"""))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
assert SUPPORTED_TASK_TYPES
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
args = get_arg_parser().parse_args()
|
args = get_arg_parser().parse_args()
|
||||||
|
|
||||||
# Connect to ES
|
# Connect to ES
|
||||||
logger.info("Establishing connection to Elasticsearch")
|
logger.info("Establishing connection to Elasticsearch")
|
||||||
es = get_es_client(args)
|
es = get_es_client(args, logger)
|
||||||
cluster_version = check_cluster_version(es)
|
cluster_version = check_cluster_version(es, logger)
|
||||||
|
|
||||||
# Trace and save model, then upload it from temp file
|
# Trace and save model, then upload it from temp file
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
@ -254,3 +258,5 @@ if __name__ == "__main__":
|
|||||||
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
|
logger.info(f"Model successfully imported with id '{ptm.model_id}'")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -494,7 +494,7 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
|
||||||
def classification_labels(self) -> Optional[List[str]]:
|
def classification_labels(self) -> Optional[List[str]]:
|
||||||
id_label_items = self._model.config.id2label.items()
|
id_label_items = self._model.config.id2label.items()
|
||||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
|
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
|
||||||
|
|
||||||
# Make classes like I-PER into I_PER which fits Java enumerations
|
# Make classes like I-PER into I_PER which fits Java enumerations
|
||||||
return [label.replace("-", "_") for label in labels]
|
return [label.replace("-", "_") for label in labels]
|
||||||
@ -636,7 +636,7 @@ class TransformerModel:
|
|||||||
|
|
||||||
def _load_vocab(self) -> Dict[str, List[str]]:
|
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||||
vocab_items = self._tokenizer.get_vocab().items()
|
vocab_items = self._tokenizer.get_vocab().items()
|
||||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
|
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
|
||||||
vocab_obj = {
|
vocab_obj = {
|
||||||
"vocabulary": vocabulary,
|
"vocabulary": vocabulary,
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user