Add PyTorch modules to noxfile

We added the `pytorch` module which is type checked but was not in the
noxfile as such. This change also addresses type errors that arose after
adding type checking.
This commit is contained in:
Josh Devins 2021-11-29 17:03:25 +01:00 committed by GitHub
parent 7209f61773
commit 5bc1a824a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 22 deletions

View File

@ -21,7 +21,7 @@ import math
import os
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union
from tqdm.auto import tqdm
from tqdm.auto import tqdm # type: ignore
from eland.common import ensure_es_client
@ -101,7 +101,7 @@ class PyTorchModel:
def infer(
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
) -> Dict[str, Any]:
) -> Union[bool, Any]:
return self._client.transport.perform_request(
"POST",
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
@ -124,7 +124,7 @@ class PyTorchModel:
)
def delete(self) -> None:
self._client.ml.delete_trained_model(self.model_id, ignore=(404,))
self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,))
@classmethod
def list(

View File

@ -23,11 +23,11 @@ libraries such as sentence-transformers.
import json
import os.path
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import transformers
from sentence_transformers import SentenceTransformer
import torch # type: ignore
import transformers # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from torch import Tensor, nn
from transformers import (
AutoConfig,
@ -66,7 +66,7 @@ TracedModelTypes = Union[
]
class _DistilBertWrapper(nn.Module):
class _DistilBertWrapper(nn.Module): # type: ignore
"""
A simple wrapper around DistilBERT model which makes the model inputs
conform to Elasticsearch's native inference processor interface.
@ -96,7 +96,7 @@ class _DistilBertWrapper(nn.Module):
return self._model(input_ids=input_ids, attention_mask=attention_mask)
class _SentenceTransformerWrapper(nn.Module):
class _SentenceTransformerWrapper(nn.Module): # type: ignore
"""
A wrapper around sentence-transformer models to provide pooling,
normalization and other graph layers that are not defined in the base
@ -122,7 +122,7 @@ class _SentenceTransformerWrapper(nn.Module):
else:
return None
def _remove_pooling_layer(self):
def _remove_pooling_layer(self) -> None:
"""
Removes any last pooling layer which is not used to create embeddings.
Leaving this layer in will cause it to return a NoneType which in turn
@ -135,7 +135,7 @@ class _SentenceTransformerWrapper(nn.Module):
if hasattr(self._hf_model, "pooler"):
self._hf_model.pooler = None
def _replace_transformer_layer(self):
def _replace_transformer_layer(self) -> None:
"""
Replaces the HuggingFace Transformer layer in the SentenceTransformer
modules so we can set it with one that has pooling layer removed and
@ -167,7 +167,7 @@ class _SentenceTransformerWrapper(nn.Module):
return self._st_model(inputs)[self._output_key]
class _DPREncoderWrapper(nn.Module):
class _DPREncoderWrapper(nn.Module): # type: ignore
"""
AutoModel loading does not work for DPRContextEncoders, this only exists as
a workaround. This may never be fixed so this is likely permanent.
@ -240,10 +240,7 @@ class _TraceableModel(ABC):
self._tokenizer = tokenizer
self._model = model
def classification_labels(self) -> Optional[List[str]]:
return None
def quantize(self):
def quantize(self) -> None:
torch.quantization.quantize_dynamic(
self._model, {torch.nn.Linear}, dtype=torch.qint8
)
@ -275,11 +272,14 @@ class _TraceableModel(ABC):
def _prepare_inputs(self) -> transformers.BatchEncoding:
...
def classification_labels(self) -> Optional[List[str]]:
return None
class _TraceableClassificationModel(_TraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
# Make classes like I-PER into I_PER which fits Java enumerations
return [label.replace("-", "_") for label in labels]
@ -361,15 +361,15 @@ class TransformerModel:
self._vocab = self._load_vocab()
self._config = self._create_config()
def _load_vocab(self):
def _load_vocab(self) -> Dict[str, List[str]]:
vocab_items = self._tokenizer.get_vocab().items()
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
return {
"vocabulary": vocabulary,
}
def _create_config(self):
inference_config = {
def _create_config(self) -> Dict[str, Any]:
inference_config: Dict[str, Dict[str, Any]] = {
self._task_type: {
"tokenization": {
"bert": {
@ -448,7 +448,7 @@ class TransformerModel:
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
)
def elasticsearch_model_id(self):
def elasticsearch_model_id(self) -> str:
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
return self._model_id.replace("/", "__").lower()[:64]

View File

@ -44,6 +44,9 @@ TYPED_FILES = (
"eland/ml/_optional.py",
"eland/ml/_model_serializer.py",
"eland/ml/ml_model.py",
"eland/ml/pytorch/__init__.py",
"eland/ml/pytorch/_pytorch_model.py",
"eland/ml/pytorch/transformers.py",
"eland/ml/transformers/__init__.py",
"eland/ml/transformers/base.py",
"eland/ml/transformers/lightgbm.py",