diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index 3de9e9b..6d84d43 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -123,6 +123,7 @@ class InferenceConfig: } } + class SlimInferenceOptions(InferenceConfig): def __init__( self, @@ -134,6 +135,7 @@ class SlimInferenceOptions(InferenceConfig): self.tokenization = tokenization self.results_field = results_field + class TextClassificationInferenceOptions(InferenceConfig): def __init__( self,