Skip to content

ordeq_torch

TorchModel dataclass

Bases: IO[Module]

IO class for reading and writing PyTorch models using state dictionaries.

This class provides safe loading by using torch.load with weights_only=True and requires the model class to be specified for instantiation.

Example usage:

>>> from pathlib import Path
>>> import torch.nn as nn
>>> from ordeq_torch import TorchModel
>>>
>>> # Define a simple model class
>>> class SimpleModel(nn.Module):
...     def __init__(self, input_size, hidden_size):
...         super().__init__()
...         self.linear = nn.Linear(input_size, hidden_size)
...
>>> model_path = Path("model_state.pth")
>>> io = TorchModel(
...     path=model_path,
...     model_class=SimpleModel,
...     model_args=(10, 5)
... )
>>> # Load a model
>>> model = io.load()  # doctest: +SKIP
>>> # Save a model's state dict
>>> io.save(model)  # doctest: +SKIP

load(**load_options)

Load a PyTorch model by instantiating the model class and loading its state dict.

Creates a new instance of the specified model class using the provided arguments and kwargs, then loads the state dictionary from the file and sets the model to evaluation mode.

Parameters:

Name Type Description Default
**load_options Any

Additional options to pass to torch.load.

{}

Returns:

Type Description
Module

The instantiated PyTorch model with loaded weights in

Module

evaluation mode.

Raises:

Type Description
ValueError

If model_class is None.

save(model, **save_options)

Save a PyTorch model's state dictionary to the specified path.

Extracts the state dictionary from the model and saves it to the file. This approach provides safer loading compared to saving the entire model object.

Parameters:

Name Type Description Default
model Module

The PyTorch model whose state dict should be saved.

required
**save_options Any

Additional options to pass to torch.save.

{}

TorchObject dataclass

Bases: IO[Any]

IO class for reading and writing arbitrary PyTorch objects using pickle.

This class can load and save any PyTorch-compatible object including tensors and Python data structures. It uses PyTorch's native serialization which is based on Python's pickle protocol.

Example usage:

>>> from pathlib import Path
>>> import torch
>>> from ordeq_torch import TorchObject
>>>
>>> # Save and load a tensor
>>> tensor_path = Path("tensor.pt")
>>> io = TorchObject(path=tensor_path)
>>> tensor = torch.randn(3, 4)
>>> io.save(tensor)  # doctest: +SKIP
>>> loaded_tensor = io.load()  # doctest: +SKIP
>>>
>>> # Can also save other objects like lists of tensors
>>> data = [torch.randn(2, 2), {"key": torch.ones(5)}]
>>> io.save(data)  # doctest: +SKIP

load(**load_options)

Load a PyTorch object from the file specified by the path attribute.

Uses torch.load to deserialize the object. The loaded object can be a tensor, Python data structure or any other Pickle-compatible object.

Parameters:

Name Type Description Default
**load_options Any

Additional options to pass to torch.load.

{}

Returns:

Type Description
Any

The deserialized PyTorch object.

save(data, **save_options)

Save a PyTorch object to the file specified by the path attribute.

Serializes any Pickle-compatible object including tensors or Python data structures to a file using torch.save.

Parameters:

Name Type Description Default
data Any

The PyTorch object to be saved.

required
**save_options Any

Additional options to pass to torch.save.

{}