diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ae60f48..69ef7db 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -393,11 +393,13 @@ class _DPREncoderWrapper(nn.Module): # type: ignore def is_compatible() -> bool: is_dpr_model = config.model_type == "dpr" - has_architectures = len(config.architectures) == 1 - is_supported_architecture = ( + has_architectures = ( + config.architectures is not None and len(config.architectures) == 1 + ) + is_supported_architecture = has_architectures and ( config.architectures[0] in _DPREncoderWrapper._SUPPORTED_MODELS_NAMES ) - return is_dpr_model and has_architectures and is_supported_architecture + return is_dpr_model and is_supported_architecture if is_compatible(): model = getattr(transformers, config.architectures[0]).from_pretrained(