class documentation

Image restoration transformer with global, non-local, and local connections

Parameters
img_sizeInput image size. Default 64
in_channelsNumber of input image channels. Default: 3
out_channelsNumber of output image channels. Default: None
embed_dimPatch embedding dimension. Default: 96
upscaleUpscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_rangeImage range. 1. or 255.
upsamplerThe reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
depthsDepth of each Swin Transformer layer.
num_heads_windowNumber of window attention heads in different layers.
num_heads_stripeNumber of stripe attention heads in different layers.
window_sizeWindow size. Default: 8.
stripe_sizeStripe size. Default: [8, 8]
stripe_groupsNumber of stripe groups. Default: [None, None].
stripe_shiftwhether to shift the stripes. This is used as an ablation study.
mlp_ratioRatio of mlp hidden dim to embedding dim. Default: 4
qkv_biasIf True, add a learnable bias to query, key, value. Default: True
qkv_proj_typeQKV projection type. Default: linear. Choices: linear, separable_conv.
anchor_proj_typeAnchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging.
anchor_one_stageWhether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True.
anchor_window_down_factorThe downscale factor used to get the anchors.
out_proj_typeType of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d.
local_connectionWhether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used.
drop_rateDropout rate. Default: 0
attn_drop_rateAttention dropout rate. Default: 0
drop_path_rateStochastic depth rate. Default: 0.1
pretrained_window_sizepretrained window size. This is actually not used. Default: [0, 0].
pretrained_stripe_sizepretrained stripe size. This is actually not used. Default: [0, 0].
norm_layerNormalization layer. Default: nn.LayerNorm.
conv_typeThe convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear
init_method

initialization method of the weight parameters used to train large scale models.

Choices:
  • n, normal -- Swin V1 init method.
  • l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
  • r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
  • w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
  • t, trunc_normal_ -- nn.Linear, trunc_normal, nn.Conv2d, weight_rescale
fairscale_checkpointWhether to use fairscale checkpoint.
offload_to_cpuused by fairscale_checkpoint
euclidean_distuse Euclidean distance or inner product as the similarity metric. An ablation study.
Method __init__ Undocumented
Method check_image_size Undocumented
Method forward Undocumented
Method forward_features Undocumented
Method get_table_index_mask Undocumented
Method no_weight_decay Undocumented
Method no_weight_decay_keywords Undocumented
Method set_table_index_mask Two used cases: 1) At initialization: set the shared buffers. 2) During forward pass: get the new buffers if the resolution of the input changes
Class Variable hyperparameters Undocumented
Instance Variable anchor_window_down_factor Undocumented
Instance Variable conv_after_body Undocumented
Instance Variable conv_before_upsample Undocumented
Instance Variable conv_first Undocumented
Instance Variable conv_last Undocumented
Instance Variable embed_dim Undocumented
Instance Variable img_range Undocumented
Instance Variable in_channels Undocumented
Instance Variable input_resolution Undocumented
Instance Variable layers Undocumented
Instance Variable mean Undocumented
Instance Variable norm_end Undocumented
Instance Variable norm_start Undocumented
Instance Variable out_channels Undocumented
Instance Variable pad_size Undocumented
Instance Variable pos_drop Undocumented
Instance Variable pretrained_stripe_size Undocumented
Instance Variable pretrained_window_size Undocumented
Instance Variable shift_size Undocumented
Instance Variable stripe_groups Undocumented
Instance Variable stripe_size Undocumented
Instance Variable upsample Undocumented
Instance Variable upsampler Undocumented
Instance Variable upscale Undocumented
Instance Variable window_size Undocumented
Method _init_weights Undocumented
def __init__(self, *, img_size=64, in_channels: int = 3, out_channels: int = 3, embed_dim=96, upscale=1, img_range=1.0, upsampler='', depths: list[int] = [6, 6, 6, 6, 6, 6], num_heads_window: list[int] = [3, 3, 3, 3, 3, 3], num_heads_stripe: list[int] = [3, 3, 3, 3, 3, 3], window_size=8, stripe_size: list[int] = [8, 8], stripe_groups: list[(int|None)] = [None, None], stripe_shift=False, mlp_ratio=4.0, qkv_bias=True, qkv_proj_type='linear', anchor_proj_type='avgpool', anchor_one_stage=True, anchor_window_down_factor=1, out_proj_type: Literal['linear', 'conv2d'] = 'linear', local_connection=False, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, pretrained_window_size: list[int] = [0, 0], pretrained_stripe_size: list[int] = [0, 0], conv_type='1conv', init_method='n', fairscale_checkpoint=False, offload_to_cpu=False, euclidean_dist=False): (source)

Undocumented

def check_image_size(self, x): (source)

Undocumented

def forward(self, x): (source)

Undocumented

def forward_features(self, x): (source)

Undocumented

def get_table_index_mask(self, device, input_resolution: tuple[int, int]): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay(self): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay_keywords(self): (source)

Undocumented

def set_table_index_mask(self, x_size: tuple[int, int]): (source)

Two used cases: 1) At initialization: set the shared buffers. 2) During forward pass: get the new buffers if the resolution of the input changes

hyperparameters: dict = (source)

Undocumented

anchor_window_down_factor: int = (source)

Undocumented

conv_after_body = (source)

Undocumented

conv_before_upsample = (source)

Undocumented

conv_first = (source)

Undocumented

conv_last = (source)

Undocumented

embed_dim: int = (source)

Undocumented

img_range: float = (source)

Undocumented

in_channels: int = (source)

Undocumented

input_resolution = (source)

Undocumented

Undocumented

Undocumented

norm_end = (source)

Undocumented

norm_start = (source)

Undocumented

out_channels: int = (source)

Undocumented

pad_size = (source)

Undocumented

pos_drop = (source)

Undocumented

pretrained_stripe_size: list[int] = (source)

Undocumented

pretrained_window_size: list[int] = (source)

Undocumented

shift_size = (source)

Undocumented

stripe_groups: list[int] = (source)

Undocumented

stripe_size: list[int] = (source)

Undocumented

upsample = (source)

Undocumented

upsampler: str = (source)

Undocumented

upscale: int = (source)

Undocumented

window_size: int = (source)

Undocumented

def _init_weights(self, m): (source)

Undocumented