Set embedding_size config parameter for Text Embedding models (#532)

This commit is contained in:
David Kyle 2023-04-25 11:41:14 +01:00 committed by GitHub
parent 940f2a9bad
commit 50d301f7cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 18 deletions

View File

@ -238,10 +238,12 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
*, *,
tokenization: NlpTokenizationConfig, tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None, results_field: t.Optional[str] = None,
embedding_size: t.Optional[int] = None,
): ):
super().__init__(configuration_type="text_embedding") super().__init__(configuration_type="text_embedding")
self.tokenization = tokenization self.tokenization = tokenization
self.results_field = results_field self.results_field = results_field
self.embedding_size = embedding_size
class TextExpansionInferenceOptions(InferenceConfig): class TextExpansionInferenceOptions(InferenceConfig):

View File

@ -49,6 +49,10 @@ class TraceableModel(ABC):
self._model.eval() self._model.eval()
return self._trace() return self._trace()
@abstractmethod
def sample_output(self) -> torch.Tensor:
...
@abstractmethod @abstractmethod
def _trace(self) -> TracedModelTypes: def _trace(self) -> TracedModelTypes:
... ...

View File

@ -443,6 +443,14 @@ class _TransformerTraceableModel(TraceableModel):
self._tokenizer = tokenizer self._tokenizer = tokenizer
def _trace(self) -> TracedModelTypes: def _trace(self) -> TracedModelTypes:
inputs = self._compatible_inputs()
return torch.jit.trace(self._model, inputs)
def sample_output(self) -> Tensor:
inputs = self._compatible_inputs()
return self._model(*inputs)
def _compatible_inputs(self) -> Tuple[Tensor, ...]:
inputs = self._prepare_inputs() inputs = self._prepare_inputs()
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
@ -458,21 +466,16 @@ class _TransformerTraceableModel(TraceableModel):
transformers.BartConfig, transformers.BartConfig,
), ),
): ):
return torch.jit.trace( del inputs["token_type_ids"]
self._model, return (inputs["input_ids"], inputs["attention_mask"])
(inputs["input_ids"], inputs["attention_mask"]),
)
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return torch.jit.trace( return (
self._model, inputs["input_ids"],
( inputs["attention_mask"],
inputs["input_ids"], inputs["token_type_ids"],
inputs["attention_mask"], inputs["position_ids"],
inputs["token_type_ids"],
position_ids,
),
) )
@abstractmethod @abstractmethod
@ -640,16 +643,23 @@ class TransformerModel:
tokenization_config.max_sequence_length = 386 tokenization_config.max_sequence_length = 386
tokenization_config.span = 128 tokenization_config.span = 128
tokenization_config.truncate = "none" tokenization_config.truncate = "none"
inference_config = (
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( if self._traceable_model.classification_labels():
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config, tokenization=tokenization_config,
classification_labels=self._traceable_model.classification_labels(), classification_labels=self._traceable_model.classification_labels(),
) )
if self._traceable_model.classification_labels() elif self._task_type == "text_embedding":
else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( sample_embedding, _ = self._traceable_model.sample_output()
embedding_size = sample_embedding.size(-1)
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
embedding_size=embedding_size,
)
else:
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config tokenization=tokenization_config
) )
)
return NlpTrainedModelConfig( return NlpTrainedModelConfig(
description=f"Model {self._model_id} for task type '{self._task_type}'", description=f"Model {self._model_id} for task type '{self._task_type}'",

View File

@ -136,6 +136,13 @@ class TestTraceableModel(TraceableModel, ABC):
), ),
) )
def sample_output(self) -> torch.Tensor:
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
return self._model(input_ids, attention_mask, token_type_ids, position_ids)
class NerModule(nn.Module): class NerModule(nn.Module):
def forward( def forward(