mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Set embedding_size config parameter for Text Embedding models (#532)
This commit is contained in:
parent
940f2a9bad
commit
50d301f7cb
@ -238,10 +238,12 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
|
||||
*,
|
||||
tokenization: NlpTokenizationConfig,
|
||||
results_field: t.Optional[str] = None,
|
||||
embedding_size: t.Optional[int] = None,
|
||||
):
|
||||
super().__init__(configuration_type="text_embedding")
|
||||
self.tokenization = tokenization
|
||||
self.results_field = results_field
|
||||
self.embedding_size = embedding_size
|
||||
|
||||
|
||||
class TextExpansionInferenceOptions(InferenceConfig):
|
||||
|
@ -49,6 +49,10 @@ class TraceableModel(ABC):
|
||||
self._model.eval()
|
||||
return self._trace()
|
||||
|
||||
@abstractmethod
|
||||
def sample_output(self) -> torch.Tensor:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
...
|
||||
|
@ -443,6 +443,14 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def _trace(self) -> TracedModelTypes:
|
||||
inputs = self._compatible_inputs()
|
||||
return torch.jit.trace(self._model, inputs)
|
||||
|
||||
def sample_output(self) -> Tensor:
|
||||
inputs = self._compatible_inputs()
|
||||
return self._model(*inputs)
|
||||
|
||||
def _compatible_inputs(self) -> Tuple[Tensor, ...]:
|
||||
inputs = self._prepare_inputs()
|
||||
|
||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||
@ -458,21 +466,16 @@ class _TransformerTraceableModel(TraceableModel):
|
||||
transformers.BartConfig,
|
||||
),
|
||||
):
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(inputs["input_ids"], inputs["attention_mask"]),
|
||||
)
|
||||
del inputs["token_type_ids"]
|
||||
return (inputs["input_ids"], inputs["attention_mask"])
|
||||
|
||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||
|
||||
return torch.jit.trace(
|
||||
self._model,
|
||||
(
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
position_ids,
|
||||
),
|
||||
inputs["position_ids"] = position_ids
|
||||
return (
|
||||
inputs["input_ids"],
|
||||
inputs["attention_mask"],
|
||||
inputs["token_type_ids"],
|
||||
inputs["position_ids"],
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -640,16 +643,23 @@ class TransformerModel:
|
||||
tokenization_config.max_sequence_length = 386
|
||||
tokenization_config.span = 128
|
||||
tokenization_config.truncate = "none"
|
||||
inference_config = (
|
||||
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
|
||||
if self._traceable_model.classification_labels():
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config,
|
||||
classification_labels=self._traceable_model.classification_labels(),
|
||||
)
|
||||
if self._traceable_model.classification_labels()
|
||||
else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
elif self._task_type == "text_embedding":
|
||||
sample_embedding, _ = self._traceable_model.sample_output()
|
||||
embedding_size = sample_embedding.size(-1)
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config,
|
||||
embedding_size=embedding_size,
|
||||
)
|
||||
else:
|
||||
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||
tokenization=tokenization_config
|
||||
)
|
||||
)
|
||||
|
||||
return NlpTrainedModelConfig(
|
||||
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||
|
@ -136,6 +136,13 @@ class TestTraceableModel(TraceableModel, ABC):
|
||||
),
|
||||
)
|
||||
|
||||
def sample_output(self) -> torch.Tensor:
|
||||
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
|
||||
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
|
||||
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
|
||||
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
|
||||
return self._model(input_ids, attention_mask, token_type_ids, position_ids)
|
||||
|
||||
|
||||
class NerModule(nn.Module):
|
||||
def forward(
|
||||
|
Loading…
x
Reference in New Issue
Block a user