mirror of
https://github.com/elastic/eland.git
synced 2025-07-11 00:02:14 +08:00
Freeze the traced PyTorch model
This commit is contained in:
parent
ec041ffdfd
commit
081c8efaa0
@ -59,5 +59,6 @@ class TraceableModel(ABC):
|
||||
def save(self, path: str) -> str:
|
||||
model_path = os.path.join(path, "traced_pytorch_model.pt")
|
||||
trace_model = self.trace()
|
||||
trace_model = torch.jit.freeze(trace_model)
|
||||
torch.jit.save(trace_model, model_path)
|
||||
return model_path
|
||||
|
Loading…
x
Reference in New Issue
Block a user