diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 13d1e6d..d21b9f4 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -238,10 +238,12 @@ class TextEmbeddingInferenceOptions(InferenceConfig): *, tokenization: NlpTokenizationConfig, results_field: t.Optional[str] = None, + embedding_size: t.Optional[int] = None, ): super().__init__(configuration_type="text_embedding") self.tokenization = tokenization self.results_field = results_field + self.embedding_size = embedding_size class TextExpansionInferenceOptions(InferenceConfig): diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index fa6eb12..2f906fa 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -49,6 +49,10 @@ class TraceableModel(ABC): self._model.eval() return self._trace() + @abstractmethod + def sample_output(self) -> torch.Tensor: + ... + @abstractmethod def _trace(self) -> TracedModelTypes: ... diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index f48f7be..e8f61ac 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -443,6 +443,14 @@ class _TransformerTraceableModel(TraceableModel): self._tokenizer = tokenizer def _trace(self) -> TracedModelTypes: + inputs = self._compatible_inputs() + return torch.jit.trace(self._model, inputs) + + def sample_output(self) -> Tensor: + inputs = self._compatible_inputs() + return self._model(*inputs) + + def _compatible_inputs(self) -> Tuple[Tensor, ...]: inputs = self._prepare_inputs() # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface @@ -458,21 +466,16 @@ class _TransformerTraceableModel(TraceableModel): transformers.BartConfig, ), ): - return torch.jit.trace( - self._model, - (inputs["input_ids"], inputs["attention_mask"]), - ) + del inputs["token_type_ids"] + return (inputs["input_ids"], inputs["attention_mask"]) position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) - - return torch.jit.trace( - self._model, - ( - inputs["input_ids"], - inputs["attention_mask"], - inputs["token_type_ids"], - position_ids, - ), + inputs["position_ids"] = position_ids + return ( + inputs["input_ids"], + inputs["attention_mask"], + inputs["token_type_ids"], + inputs["position_ids"], ) @abstractmethod @@ -640,16 +643,23 @@ class TransformerModel: tokenization_config.max_sequence_length = 386 tokenization_config.span = 128 tokenization_config.truncate = "none" - inference_config = ( - TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + + if self._traceable_model.classification_labels(): + inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config, classification_labels=self._traceable_model.classification_labels(), ) - if self._traceable_model.classification_labels() - else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + elif self._task_type == "text_embedding": + sample_embedding, _ = self._traceable_model.sample_output() + embedding_size = sample_embedding.size(-1) + inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config, + embedding_size=embedding_size, + ) + else: + inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config ) - ) return NlpTrainedModelConfig( description=f"Model {self._model_id} for task type '{self._task_type}'", diff --git a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py index 48b351a..4a81be9 100644 --- a/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py +++ b/tests/ml/pytorch/test_transformer_pytorch_model_pytest.py @@ -136,6 +136,13 @@ class TestTraceableModel(TraceableModel, ABC): ), ) + def sample_output(self) -> torch.Tensor: + input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB)))) + attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB)) + token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB)) + position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long) + return self._model(input_ids, attention_mask, token_type_ids, position_ids) + class NerModule(nn.Module): def forward(