class documentation

Cross-Refinement Adaptive Fusion Transformer
Some codes are based on SwinIR.

Parameters
in_chansNumber of input image channels. Default: 3
embed_dimPatch embedding dimension. Default: 96
depthsDepth of each Swin Transformer layer.
num_headsNumber of attention heads in different layers.
mlp_ratioRatio of mlp hidden dim to embedding dim. Default: 2
qkv_biasIf True, add a learnable bias to query, key, value. Default: True
qk_scaleOverride default qk scale of head_dim ** -0.5 if set. Default: None
norm_layerNormalization layer. Default: nn.LayerNorm.
upscaleUpscale factor. 2/3/4/
img_rangeImage range. 1. or 255.
resi_connectionThe convolutional block before residual connection. '1conv'/'3conv'
Method __init__ Undocumented
Method calculate_rpi_v_sa Undocumented
Method forward Undocumented
Method forward_features Undocumented
Method no_weight_decay Undocumented
Method no_weight_decay_keywords Undocumented
Class Variable hyperparameters Undocumented
Instance Variable conv_after_body Undocumented
Instance Variable conv_first Undocumented
Instance Variable embed_dim Undocumented
Instance Variable h Undocumented
Instance Variable img_range Undocumented
Instance Variable layers Undocumented
Instance Variable mean Undocumented
Instance Variable mlp_ratio Undocumented
Instance Variable norm Undocumented
Instance Variable num_feat Undocumented
Instance Variable num_features Undocumented
Instance Variable num_layers Undocumented
Instance Variable num_out_ch Undocumented
Instance Variable split_size Undocumented
Instance Variable upsample Undocumented
Instance Variable upscale Undocumented
Instance Variable w Undocumented
Instance Variable window_size Undocumented
Method _init_weights Undocumented
def __init__(self, *, in_chans=3, window_size=16, embed_dim=48, depths=[2, 2, 2, 2], num_heads=[6, 6, 6, 6], split_size_0=4, split_size_1=16, mlp_ratio=2.0, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm, upscale=4, img_range=1.0, resi_connection='1conv'): (source)

Undocumented

def calculate_rpi_v_sa(self): (source)

Undocumented

def forward(self, x): (source)

Undocumented

def forward_features(self, x): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay(self): (source)

Undocumented

@torch.jit.ignore
def no_weight_decay_keywords(self): (source)

Undocumented

hyperparameters: dict = (source)

Undocumented

conv_after_body = (source)

Undocumented

conv_first = (source)

Undocumented

embed_dim: int = (source)

Undocumented

Undocumented

img_range = (source)

Undocumented

Undocumented

Undocumented

mlp_ratio: float = (source)

Undocumented

Undocumented

num_feat = (source)

Undocumented

num_features = (source)

Undocumented

num_layers = (source)

Undocumented

num_out_ch = (source)

Undocumented

split_size = (source)

Undocumented

upsample = (source)

Undocumented

Undocumented

Undocumented

window_size = (source)

Undocumented

def _init_weights(self, m): (source)

Undocumented