ordeq_torch
TorchModel
dataclass
¶
Bases: IO[Module]
IO class for reading and writing PyTorch models using state dictionaries.
This class provides safe loading by using torch.load with
weights_only=True and requires the model class to be specified for
instantiation.
Example usage:
>>> from pathlib import Path
>>> import torch.nn as nn
>>> from ordeq_torch import TorchModel
>>>
>>> # Define a simple model class
>>> class SimpleModel(nn.Module):
... def __init__(self, input_size, hidden_size):
... super().__init__()
... self.linear = nn.Linear(input_size, hidden_size)
...
>>> model_path = Path("model_state.pth")
>>> io = TorchModel(
... path=model_path,
... model_class=SimpleModel,
... model_args=(10, 5)
... )
>>> # Load a model
>>> model = io.load() # doctest: +SKIP
>>> # Save a model's state dict
>>> io.save(model) # doctest: +SKIP
load(**load_options)
¶
Load a PyTorch model by instantiating the model class and loading its state dict.
Creates a new instance of the specified model class using the provided arguments and kwargs, then loads the state dictionary from the file and sets the model to evaluation mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**load_options
|
Any
|
Additional options to pass to |
{}
|
Returns:
| Type | Description |
|---|---|
Module
|
The instantiated PyTorch model with loaded weights in |
Module
|
evaluation mode. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If model_class is None. |
save(model, **save_options)
¶
Save a PyTorch model's state dictionary to the specified path.
Extracts the state dictionary from the model and saves it to the file. This approach provides safer loading compared to saving the entire model object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The PyTorch model whose state dict should be saved. |
required |
**save_options
|
Any
|
Additional options to pass to |
{}
|
TorchObject
dataclass
¶
Bases: IO[Any]
IO class for reading and writing arbitrary PyTorch objects using pickle.
This class can load and save any PyTorch-compatible object including tensors and Python data structures. It uses PyTorch's native serialization which is based on Python's pickle protocol.
Example usage:
>>> from pathlib import Path
>>> import torch
>>> from ordeq_torch import TorchObject
>>>
>>> # Save and load a tensor
>>> tensor_path = Path("tensor.pt")
>>> io = TorchObject(path=tensor_path)
>>> tensor = torch.randn(3, 4)
>>> io.save(tensor) # doctest: +SKIP
>>> loaded_tensor = io.load() # doctest: +SKIP
>>>
>>> # Can also save other objects like lists of tensors
>>> data = [torch.randn(2, 2), {"key": torch.ones(5)}]
>>> io.save(data) # doctest: +SKIP
load(**load_options)
¶
Load a PyTorch object from the file specified by the path attribute.
Uses torch.load to deserialize the object. The loaded object can be
a tensor, Python data structure or any other Pickle-compatible object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**load_options
|
Any
|
Additional options to pass to |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
The deserialized PyTorch object. |
save(data, **save_options)
¶
Save a PyTorch object to the file specified by the path attribute.
Serializes any Pickle-compatible object including tensors or
Python data structures to a file using torch.save.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Any
|
The PyTorch object to be saved. |
required |
**save_options
|
Any
|
Additional options to pass to |
{}
|