mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Add PyTorch modules to noxfile
We added the `pytorch` module which is type checked but was not in the noxfile as such. This change also addresses type errors that arose after adding type checking.
This commit is contained in:
parent
7209f61773
commit
5bc1a824a7
@ -21,7 +21,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Set, Tuple, Union
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm # type: ignore
|
||||||
|
|
||||||
from eland.common import ensure_es_client
|
from eland.common import ensure_es_client
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class PyTorchModel:
|
|||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
self, body: Dict[str, Any], timeout: str = DEFAULT_TIMEOUT
|
||||||
) -> Dict[str, Any]:
|
) -> Union[bool, Any]:
|
||||||
return self._client.transport.perform_request(
|
return self._client.transport.perform_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
f"/_ml/trained_models/{self.model_id}/deployment/_infer",
|
||||||
@ -124,7 +124,7 @@ class PyTorchModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def delete(self) -> None:
|
def delete(self) -> None:
|
||||||
self._client.ml.delete_trained_model(self.model_id, ignore=(404,))
|
self._client.ml.delete_trained_model(model_id=self.model_id, ignore=(404,))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(
|
def list(
|
||||||
|
@ -23,11 +23,11 @@ libraries such as sentence-transformers.
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch # type: ignore
|
||||||
import transformers
|
import transformers # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@ -66,7 +66,7 @@ TracedModelTypes = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class _DistilBertWrapper(nn.Module):
|
class _DistilBertWrapper(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
A simple wrapper around DistilBERT model which makes the model inputs
|
A simple wrapper around DistilBERT model which makes the model inputs
|
||||||
conform to Elasticsearch's native inference processor interface.
|
conform to Elasticsearch's native inference processor interface.
|
||||||
@ -96,7 +96,7 @@ class _DistilBertWrapper(nn.Module):
|
|||||||
return self._model(input_ids=input_ids, attention_mask=attention_mask)
|
return self._model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class _SentenceTransformerWrapper(nn.Module):
|
class _SentenceTransformerWrapper(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
A wrapper around sentence-transformer models to provide pooling,
|
A wrapper around sentence-transformer models to provide pooling,
|
||||||
normalization and other graph layers that are not defined in the base
|
normalization and other graph layers that are not defined in the base
|
||||||
@ -122,7 +122,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _remove_pooling_layer(self):
|
def _remove_pooling_layer(self) -> None:
|
||||||
"""
|
"""
|
||||||
Removes any last pooling layer which is not used to create embeddings.
|
Removes any last pooling layer which is not used to create embeddings.
|
||||||
Leaving this layer in will cause it to return a NoneType which in turn
|
Leaving this layer in will cause it to return a NoneType which in turn
|
||||||
@ -135,7 +135,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
|||||||
if hasattr(self._hf_model, "pooler"):
|
if hasattr(self._hf_model, "pooler"):
|
||||||
self._hf_model.pooler = None
|
self._hf_model.pooler = None
|
||||||
|
|
||||||
def _replace_transformer_layer(self):
|
def _replace_transformer_layer(self) -> None:
|
||||||
"""
|
"""
|
||||||
Replaces the HuggingFace Transformer layer in the SentenceTransformer
|
Replaces the HuggingFace Transformer layer in the SentenceTransformer
|
||||||
modules so we can set it with one that has pooling layer removed and
|
modules so we can set it with one that has pooling layer removed and
|
||||||
@ -167,7 +167,7 @@ class _SentenceTransformerWrapper(nn.Module):
|
|||||||
return self._st_model(inputs)[self._output_key]
|
return self._st_model(inputs)[self._output_key]
|
||||||
|
|
||||||
|
|
||||||
class _DPREncoderWrapper(nn.Module):
|
class _DPREncoderWrapper(nn.Module): # type: ignore
|
||||||
"""
|
"""
|
||||||
AutoModel loading does not work for DPRContextEncoders, this only exists as
|
AutoModel loading does not work for DPRContextEncoders, this only exists as
|
||||||
a workaround. This may never be fixed so this is likely permanent.
|
a workaround. This may never be fixed so this is likely permanent.
|
||||||
@ -240,10 +240,7 @@ class _TraceableModel(ABC):
|
|||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
||||||
def classification_labels(self) -> Optional[List[str]]:
|
def quantize(self) -> None:
|
||||||
return None
|
|
||||||
|
|
||||||
def quantize(self):
|
|
||||||
torch.quantization.quantize_dynamic(
|
torch.quantization.quantize_dynamic(
|
||||||
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
self._model, {torch.nn.Linear}, dtype=torch.qint8
|
||||||
)
|
)
|
||||||
@ -275,11 +272,14 @@ class _TraceableModel(ABC):
|
|||||||
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
def _prepare_inputs(self) -> transformers.BatchEncoding:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def classification_labels(self) -> Optional[List[str]]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class _TraceableClassificationModel(_TraceableModel, ABC):
|
class _TraceableClassificationModel(_TraceableModel, ABC):
|
||||||
def classification_labels(self) -> Optional[List[str]]:
|
def classification_labels(self) -> Optional[List[str]]:
|
||||||
id_label_items = self._model.config.id2label.items()
|
id_label_items = self._model.config.id2label.items()
|
||||||
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]
|
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore
|
||||||
|
|
||||||
# Make classes like I-PER into I_PER which fits Java enumerations
|
# Make classes like I-PER into I_PER which fits Java enumerations
|
||||||
return [label.replace("-", "_") for label in labels]
|
return [label.replace("-", "_") for label in labels]
|
||||||
@ -361,15 +361,15 @@ class TransformerModel:
|
|||||||
self._vocab = self._load_vocab()
|
self._vocab = self._load_vocab()
|
||||||
self._config = self._create_config()
|
self._config = self._create_config()
|
||||||
|
|
||||||
def _load_vocab(self):
|
def _load_vocab(self) -> Dict[str, List[str]]:
|
||||||
vocab_items = self._tokenizer.get_vocab().items()
|
vocab_items = self._tokenizer.get_vocab().items()
|
||||||
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])]
|
vocabulary = [k for k, _ in sorted(vocab_items, key=lambda kv: kv[1])] # type: ignore
|
||||||
return {
|
return {
|
||||||
"vocabulary": vocabulary,
|
"vocabulary": vocabulary,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_config(self):
|
def _create_config(self) -> Dict[str, Any]:
|
||||||
inference_config = {
|
inference_config: Dict[str, Dict[str, Any]] = {
|
||||||
self._task_type: {
|
self._task_type: {
|
||||||
"tokenization": {
|
"tokenization": {
|
||||||
"bert": {
|
"bert": {
|
||||||
@ -448,7 +448,7 @@ class TransformerModel:
|
|||||||
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def elasticsearch_model_id(self):
|
def elasticsearch_model_id(self) -> str:
|
||||||
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
# Elasticsearch model IDs need to be a specific format: no special chars, all lowercase, max 64 chars
|
||||||
return self._model_id.replace("/", "__").lower()[:64]
|
return self._model_id.replace("/", "__").lower()[:64]
|
||||||
|
|
||||||
|
@ -44,6 +44,9 @@ TYPED_FILES = (
|
|||||||
"eland/ml/_optional.py",
|
"eland/ml/_optional.py",
|
||||||
"eland/ml/_model_serializer.py",
|
"eland/ml/_model_serializer.py",
|
||||||
"eland/ml/ml_model.py",
|
"eland/ml/ml_model.py",
|
||||||
|
"eland/ml/pytorch/__init__.py",
|
||||||
|
"eland/ml/pytorch/_pytorch_model.py",
|
||||||
|
"eland/ml/pytorch/transformers.py",
|
||||||
"eland/ml/transformers/__init__.py",
|
"eland/ml/transformers/__init__.py",
|
||||||
"eland/ml/transformers/base.py",
|
"eland/ml/transformers/base.py",
|
||||||
"eland/ml/transformers/lightgbm.py",
|
"eland/ml/transformers/lightgbm.py",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user