mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Set embedding_size config parameter for Text Embedding models (#532)
This commit is contained in:
parent
940f2a9bad
commit
50d301f7cb
@ -238,10 +238,12 @@ class TextEmbeddingInferenceOptions(InferenceConfig):
|
|||||||
*,
|
*,
|
||||||
tokenization: NlpTokenizationConfig,
|
tokenization: NlpTokenizationConfig,
|
||||||
results_field: t.Optional[str] = None,
|
results_field: t.Optional[str] = None,
|
||||||
|
embedding_size: t.Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(configuration_type="text_embedding")
|
super().__init__(configuration_type="text_embedding")
|
||||||
self.tokenization = tokenization
|
self.tokenization = tokenization
|
||||||
self.results_field = results_field
|
self.results_field = results_field
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
|
||||||
|
|
||||||
class TextExpansionInferenceOptions(InferenceConfig):
|
class TextExpansionInferenceOptions(InferenceConfig):
|
||||||
|
@ -49,6 +49,10 @@ class TraceableModel(ABC):
|
|||||||
self._model.eval()
|
self._model.eval()
|
||||||
return self._trace()
|
return self._trace()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sample_output(self) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _trace(self) -> TracedModelTypes:
|
def _trace(self) -> TracedModelTypes:
|
||||||
...
|
...
|
||||||
|
@ -443,6 +443,14 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
def _trace(self) -> TracedModelTypes:
|
def _trace(self) -> TracedModelTypes:
|
||||||
|
inputs = self._compatible_inputs()
|
||||||
|
return torch.jit.trace(self._model, inputs)
|
||||||
|
|
||||||
|
def sample_output(self) -> Tensor:
|
||||||
|
inputs = self._compatible_inputs()
|
||||||
|
return self._model(*inputs)
|
||||||
|
|
||||||
|
def _compatible_inputs(self) -> Tuple[Tensor, ...]:
|
||||||
inputs = self._prepare_inputs()
|
inputs = self._prepare_inputs()
|
||||||
|
|
||||||
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
|
||||||
@ -458,21 +466,16 @@ class _TransformerTraceableModel(TraceableModel):
|
|||||||
transformers.BartConfig,
|
transformers.BartConfig,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return torch.jit.trace(
|
del inputs["token_type_ids"]
|
||||||
self._model,
|
return (inputs["input_ids"], inputs["attention_mask"])
|
||||||
(inputs["input_ids"], inputs["attention_mask"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
|
||||||
|
inputs["position_ids"] = position_ids
|
||||||
return torch.jit.trace(
|
return (
|
||||||
self._model,
|
inputs["input_ids"],
|
||||||
(
|
inputs["attention_mask"],
|
||||||
inputs["input_ids"],
|
inputs["token_type_ids"],
|
||||||
inputs["attention_mask"],
|
inputs["position_ids"],
|
||||||
inputs["token_type_ids"],
|
|
||||||
position_ids,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -640,16 +643,23 @@ class TransformerModel:
|
|||||||
tokenization_config.max_sequence_length = 386
|
tokenization_config.max_sequence_length = 386
|
||||||
tokenization_config.span = 128
|
tokenization_config.span = 128
|
||||||
tokenization_config.truncate = "none"
|
tokenization_config.truncate = "none"
|
||||||
inference_config = (
|
|
||||||
TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
if self._traceable_model.classification_labels():
|
||||||
|
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
tokenization=tokenization_config,
|
tokenization=tokenization_config,
|
||||||
classification_labels=self._traceable_model.classification_labels(),
|
classification_labels=self._traceable_model.classification_labels(),
|
||||||
)
|
)
|
||||||
if self._traceable_model.classification_labels()
|
elif self._task_type == "text_embedding":
|
||||||
else TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
sample_embedding, _ = self._traceable_model.sample_output()
|
||||||
|
embedding_size = sample_embedding.size(-1)
|
||||||
|
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
|
tokenization=tokenization_config,
|
||||||
|
embedding_size=embedding_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
|
||||||
tokenization=tokenization_config
|
tokenization=tokenization_config
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return NlpTrainedModelConfig(
|
return NlpTrainedModelConfig(
|
||||||
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
description=f"Model {self._model_id} for task type '{self._task_type}'",
|
||||||
|
@ -136,6 +136,13 @@ class TestTraceableModel(TraceableModel, ABC):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sample_output(self) -> torch.Tensor:
|
||||||
|
input_ids = torch.tensor(np.array(range(0, len(TEST_BERT_VOCAB))))
|
||||||
|
attention_mask = torch.tensor([1] * len(TEST_BERT_VOCAB))
|
||||||
|
token_type_ids = torch.tensor([0] * len(TEST_BERT_VOCAB))
|
||||||
|
position_ids = torch.arange(len(TEST_BERT_VOCAB), dtype=torch.long)
|
||||||
|
return self._model(input_ids, attention_mask, token_type_ids, position_ids)
|
||||||
|
|
||||||
|
|
||||||
class NerModule(nn.Module):
|
class NerModule(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user