class documentation

Class for automatically loading a pth file into any architecture

Method __init__ Undocumented
Method load_from_file Load a model from the given file path.
Method load_from_state_dict Load a model from the given state dict.
Method load_state_dict_from_file Load the state dict of a model from the given file path.
Instance Variable device Undocumented
Instance Variable registry The architecture registry to use for loading models.
Method _load_pth Undocumented
Method _load_safetensors Undocumented
Method _load_torchscript Undocumented
def __init__(self, device: str|torch.device|None = None, registry: ArchRegistry = MAIN_REGISTRY): (source)

Undocumented

def load_from_file(self, path: str|Path) -> ModelDescriptor: (source)

Load a model from the given file path.

Throws a ValueError if the file extension is not supported. Throws an UnsupportedModelError if the model architecture is not supported.

def load_from_state_dict(self, state_dict: StateDict) -> ModelDescriptor: (source)

Load a model from the given state dict.

Throws an UnsupportedModelError if the model architecture is not supported.

def load_state_dict_from_file(self, path: str|Path) -> StateDict: (source)

Load the state dict of a model from the given file path.

State dicts are typically only useful to pass them into the load function of a specific architecture.

Throws a ValueError if the file extension is not supported.

device: torch.device = (source)

Undocumented

registry: ArchRegistry = (source)

The architecture registry to use for loading models.

Note: Unless initialized with a custom registry, this is the global main registry (MAIN_REGISTRY). Modifying this registry will affect all ModelLoader instances without a custom registry.

def _load_pth(self, path: str|Path) -> StateDict: (source)

Undocumented

def _load_safetensors(self, path: str|Path) -> StateDict: (source)

Undocumented

def _load_torchscript(self, path: str|Path) -> StateDict: (source)

Undocumented