class documentation

The base class of all model descriptors.

This is mostly intended for instanceof checks in user code. Use ModelDescriptor for type hints instead.

Method __init__ Undocumented
Method bfloat16 Moves the parameters and buffers of the underlying module to bfloat16 precision.
Method cpu Moves the parameters and buffers of the underlying module to the CPU.
Method cuda Moves the parameters and buffers of the underlying module to the GPU.
Method eval Sets the underlying module in evaluation mode.
Method float Moves the parameters and buffers of the underlying module to single precision (fp32).
Method half Moves the parameters and buffers of the underlying module to half precision (fp16).
Method to Moves and casts the parameters and buffers of the underlying module to the given device and data type.
Method train Sets the underlying module in training mode.
Instance Variable input_channels The number of input image channels of the model. E.g. 3 for RGB, 1 for grayscale.
Instance Variable output_channels The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale.
Instance Variable scale The output scale of super resolution models. E.g. 4x, 2x, 1x.
Instance Variable size_requirements Size requirements for the input image. E.g. minimum size.
Instance Variable supports_bfloat16 Whether the model supports bfloat16 precision.
Instance Variable supports_half Whether the model supports half precision (fp16).
Instance Variable tags A list of tags for the model, usually describing the size or model parameters. E.g. "64nf" or "large".
Instance Variable tiling Whether the model supports tiling.
Property architecture The architecture of the model.
Property device The device of the underlying module.
Property dtype The data type of the underlying module.
Property model The model itself: a torch.nn.Module with weights loaded in.
Property purpose The purpose of this model.
Instance Variable _architecture Undocumented
Instance Variable _model Undocumented
def __init__(self, model: T, state_dict: StateDict, architecture: Architecture[T], tags: list[str], supports_half: bool, supports_bfloat16: bool, scale: int, input_channels: int, output_channels: int, size_requirements: SizeRequirements|None = None, tiling: ModelTiling = ModelTiling.SUPPORTED): (source)
def bfloat16(self) -> Self: (source)

Moves the parameters and buffers of the underlying module to bfloat16 precision.

Same as self.to(torch.bfloat16).

def cpu(self) -> Self: (source)

Moves the parameters and buffers of the underlying module to the CPU.

Same as self.to(torch.device("cpu")).

def cuda(self, device: int|None = None) -> Self: (source)

Moves the parameters and buffers of the underlying module to the GPU.

Same as self.to(torch.device("cuda")).

def eval(self) -> Self: (source)

Sets the underlying module in evaluation mode.

Same as self.train(False).

def float(self) -> Self: (source)

Moves the parameters and buffers of the underlying module to single precision (fp32).

Same as self.to(torch.float).

def half(self) -> Self: (source)

Moves the parameters and buffers of the underlying module to half precision (fp16).

Same as self.to(torch.half).

@overload
def to(self, device: torch.device|str|None = None, dtype: torch.dtype|None = None) -> Self:
@overload
def to(self, dtype: torch.dtype) -> Self:
(source)

Moves and casts the parameters and buffers of the underlying module to the given device and data type.

For more information, see https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to.

Use device to get the current device and dtype to get the current data type of the model.

Throws UnsupportedDtypeError if the model does not support the given data type. If you want to force a dtype cast, use .model.to(dtype) instead.

def train(self, mode: bool = True) -> Self: (source)

Sets the underlying module in training mode.

Same as self.model.train(mode).

input_channels: int = (source)

The number of input image channels of the model. E.g. 3 for RGB, 1 for grayscale.

output_channels: int = (source)

The number of output image channels of the model. E.g. 3 for RGB, 1 for grayscale.

scale: int = (source)

The output scale of super resolution models. E.g. 4x, 2x, 1x.

Models that are not super resolution models (e.g. denoisers) have a scale of 1.

size_requirements: SizeRequirements = (source)

Size requirements for the input image. E.g. minimum size.

Requirements are specific to individual models and may be different for models of the same architecture.

Users of spandrel's call API can largely ignore size requirements, because the call API will automatically pad the input image to satisfy the requirements. Size requirements might still be useful for user code that tiles images by allowing it to pick an optimal tile size to avoid padding.

supports_bfloat16: bool = (source)

Whether the model supports bfloat16 precision.

supports_half: bool = (source)

Whether the model supports half precision (fp16).

tags: list[str] = (source)

A list of tags for the model, usually describing the size or model parameters. E.g. "64nf" or "large".

Tags are specific to the architecture of the model. Some architectures may not have any tags.

Whether the model supports tiling.

Technically, all models support tiling. This is simply a recommendation on how to best use the model.

@property
architecture: Architecture[T] = (source)

The architecture of the model.

@property
device: torch.device = (source)

The device of the underlying module.

Use to to move the model to a different device.

@property
dtype: torch.dtype = (source)

The data type of the underlying module.

Use to to cast the model to a different data type.

@property
model: T = (source)

The model itself: a torch.nn.Module with weights loaded in.

The specific subclass of torch.nn.Module depends on the model architecture.

@property
@abstractmethod
purpose: Purpose = (source)

The purpose of this model.

_architecture: Architecture[T] = (source)

Undocumented

_model: T = (source)

Undocumented