Freeze the traced PyTorch model

This commit is contained in:
David Kyle 2022-06-21 13:43:18 +01:00 committed by GitHub
parent ec041ffdfd
commit 081c8efaa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,5 +59,6 @@ class TraceableModel(ABC):
def save(self, path: str) -> str:
model_path = os.path.join(path, "traced_pytorch_model.pt")
trace_model = self.trace()
trace_model = torch.jit.freeze(trace_model)
torch.jit.save(trace_model, model_path)
return model_path