class documentation

ATD

A PyTorch impl of : Transcending the Limit of Local Window: Advanced Super-Resolution Transformer with Adaptive Token Dictionary.

Parameters
img_sizeInput image size. Default 64
patch_sizePatch size. Default: 1
in_chansNumber of input image channels. Default: 3
embed_dimPatch embedding dimension. Default: 96
depthsDepth of each Swin Transformer layer.
num_headsNumber of attention heads in different layers.
window_sizeWindow size. Default: 7
mlp_ratioRatio of mlp hidden dim to embedding dim. Default: 2
qkv_biasIf True, add a learnable bias to query, key, value. Default: True
norm_layerNormalization layer. Default: nn.LayerNorm.
apeIf True, add absolute position embedding to the patch embedding. Default: False
patch_normIf True, add normalization after patch embedding. Default: True
use_checkpointWhether to use checkpointing to save memory. Default: False
upscaleUpscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_rangeImage range. 1. or 255.
upsamplerThe reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connectionThe convolutional block before residual connection. '1conv'/'3conv'
Method __init__ Undocumented
Method calculate_mask Undocumented
Method calculate_rpi_sa Undocumented
Method forward Undocumented
Method forward_features Undocumented
Method no_weight_decay Undocumented
Method no_weight_decay_keywords Undocumented
Class Variable hyperparameters Undocumented
Instance Variable absolute_pos_embed Undocumented
Instance Variable ape Undocumented
Instance Variable conv_after_body Undocumented
Instance Variable conv_before_upsample Undocumented
Instance Variable conv_first Undocumented
Instance Variable conv_last Undocumented
Instance Variable embed_dim Undocumented
Instance Variable img_range Undocumented
Instance Variable layers Undocumented
Instance Variable mean Undocumented
Instance Variable mlp_ratio Undocumented
Instance Variable no_norm Undocumented
Instance Variable norm Undocumented
Instance Variable num_features Undocumented
Instance Variable num_layers Undocumented
Instance Variable patch_embed Undocumented
Instance Variable patch_norm Undocumented
Instance Variable patch_unembed Undocumented
Instance Variable patches_resolution Undocumented
Instance Variable upsample Undocumented
Instance Variable upsampler Undocumented
Instance Variable upscale Undocumented
Instance Variable window_size Undocumented
Property is_norm Undocumented
Method _init_weights Undocumented
def __init__(self, *, img_size=64, patch_size=1, in_chans=3, embed_dim=90, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=8, category_size=256, num_tokens=64, reducted_dim=4, convffn_kernel_size=5, mlp_ratio=2.0, qkv_bias=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, upscale=1, img_range=1.0, upsampler='', resi_connection='1conv', norm=True): (source)

Undocumented

def calculate_mask(self, x_size): (source)

Undocumented

def calculate_rpi_sa(self): (source)

Undocumented

def forward(self, x): (source)

Undocumented

def forward_features(self, x, params): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay(self): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay_keywords(self): (source)

Undocumented

hyperparameters: dict = (source)

Undocumented

absolute_pos_embed = (source)

Undocumented

ape: bool = (source)

Undocumented

conv_after_body = (source)

Undocumented

conv_before_upsample = (source)

Undocumented

conv_first = (source)

Undocumented

conv_last = (source)

Undocumented

embed_dim: int = (source)

Undocumented

img_range = (source)

Undocumented

Undocumented

Undocumented

mlp_ratio: float = (source)

Undocumented

no_norm: torch.Tensor|None = (source)

Undocumented

Undocumented

num_features = (source)

Undocumented

num_layers = (source)

Undocumented

patch_embed = (source)

Undocumented

patch_norm: bool = (source)

Undocumented

patch_unembed = (source)

Undocumented

patches_resolution = (source)

Undocumented

upsample = (source)

Undocumented

upsampler = (source)

Undocumented

Undocumented

window_size: int = (source)

Undocumented

@property
is_norm = (source)

Undocumented

def _init_weights(self, m): (source)

Undocumented