mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add PyTorch modules to noxfile
We added the `pytorch` module which is type checked but was not in the noxfile as such. This change also addresses type errors that arose after adding type checking.
This commit is contained in:
parent
7209f61773
commit
5bc1a824a7
@ -21,7 +21,7 @@ import math
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import tqdm # type: ignore
|
||||
|
||||
from eland.common import ensure_es_client
|
||||
|
||||
@ -101,7 +101,7 @@ class PyTorchModel:
|
||||
|
||||
def infer(
|
||||
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
||||
) -> Dict[str, Any]:
|
||||
) -> Union[bool, Any]:
|
||||
return self._client.transport.perform_request(
|
||||
"POST",
|
||||
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
||||
@ -124,7 +124,7 @@ class PyTorchModel:
|
||||
)
|
||||
|
||||
def delete(self) -> None:
|
||||
self._client.ml.delete_trained_model(self.model_id, ignore=(404,))
|
||||
self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,))
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
|
@ -23,11 +23,11 @@ libraries such as sentence-transformers.
|
||||
import json
|
||||
import os.path
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch # type: ignore
|
||||
import transformers # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from torch import Tensor, nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@ -66,7 +66,7 @@ TracedModelTypes = Union[
|
||||
]
|
||||
|
||||
|
||||
class _DistilBertWrapper(nn.Module):
|
||||
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
A simple wrapper around DistilBERT model which makes the model inputs
|
||||
conform to Elasticsearch's native inference processor interface.
|
||||
@ -96,7 +96,7 @@ class _DistilBertWrapper(nn.Module):
|
||||
return self._model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
class _SentenceTransformerWrapper(nn.Module):
|
||||
class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
A wrapper around sentence-transformer models to provide pooling,
|
||||
normalization and other graph layers that are not defined in the base
|
||||
@ -122,7 +122,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _remove_pooling_layer(self):
|
||||
def _remove_pooling_layer(self) -> None:
|
||||
"""
|
||||
Removes any last pooling layer which is not used to create embeddings.
|
||||
Leaving this layer in will cause it to return a NoneType which in turn
|
||||
@ -135,7 +135,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
||||
if hasattr(self._hf_model, "pooler"):
|
||||
self._hf_model.pooler = None
|
||||
|
||||
def _replace_transformer_layer(self):
|
||||
def _replace_transformer_layer(self) -> None:
|
||||
"""
|
||||
Replaces the HuggingFace Transformer layer in the SentenceTransformer
|
||||
modules so we can set it with one that has pooling layer removed and
|
||||
@ -167,7 +167,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
||||
return self._st_model(inputs)[self._output_key]
|
||||
|
||||
|
||||
class _DPREncoderWrapper(nn.Module):
|
||||
class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||
"""
|
||||
AutoModel loading does not work for DPRContextEncoders, this only exists as
|
||||
a workaround. This may never be fixed so this is likely permanent.
|
||||
@ -240,10 +240,7 @@ class _TraceableModel(ABC):
|
||||
self._tokenizer = tokenizer
|
||||
self._model = model
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
def quantize(self):
|
||||
def quantize(self) -> None:
|
||||
torch.quantization.quantize_dynamic(
|
||||
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
||||
)
|
||||
@ -275,11 +272,14 @@ class _TraceableModel(ABC):
|
||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||
...
|
||||
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
return None
|
||||
|
||||
|
||||
class _TraceableClassificationModel(_TraceableModel, ABC):
|
||||
def classification_labels(self) -> Optional[List[str]]:
|
||||
id_label_items = self._model.config.id2label.items()
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
|
||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
|
||||
|
||||
# Make classes like I-PER into I_PER which fits Java enumerations
|
||||
return [label.replace("-", "_") for label in labels]
|
||||
@ -361,15 +361,15 @@ class TransformerModel:
|
||||
self._vocab = self._load_vocab()
|
||||
self._config = self._create_config()
|
||||
|
||||
def _load_vocab(self):
|
||||
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||
vocab_items = self._tokenizer.get_vocab().items()
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
|
||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
|
||||
return {
|
||||
"vocabulary": vocabulary,
|
||||
}
|
||||
|
||||
def _create_config(self):
|
||||
inference_config = {
|
||||
def _create_config(self) -> Dict[str, Any]:
|
||||
inference_config: Dict[str, Dict[str, Any]] = {
|
||||
self._task_type: {
|
||||
"tokenization": {
|
||||
"bert": {
|
||||
@ -448,7 +448,7 @@ class TransformerModel:
|
||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||
)
|
||||
|
||||
def elasticsearch_model_id(self):
|
||||
def elasticsearch_model_id(self) -> str:
|
||||
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
||||
return self._model_id.replace("/", "__").lower()[:64]
|
||||
|
||||
|
@ -44,6 +44,9 @@ TYPED_FILES = (
|
||||
"eland/ml/_optional.py",
|
||||
"eland/ml/_model_serializer.py",
|
||||
"eland/ml/ml_model.py",
|
||||
"eland/ml/pytorch/__init__.py",
|
||||
"eland/ml/pytorch/_pytorch_model.py",
|
||||
"eland/ml/pytorch/transformers.py",
|
||||
"eland/ml/transformers/__init__.py",
|
||||
"eland/ml/transformers/base.py",
|
||||
"eland/ml/transformers/lightgbm.py",
|
||||
|
Loading…
x
Reference in New Issue
Block a user