class documentation
class HAT(nn.Module): (source)
Constructor: HAT(img_size, patch_size, in_chans, embed_dim, ...)
- Hybrid Attention Transformer
- A PyTorch implementation of :
Activating More Pixels in Image Super-Resolution Transformer
. Some codes are based on SwinIR.
Parameters | |
img | Input image size. Default 64 |
patch | Patch size. Default: 1 |
in | Number of input image channels. Default: 3 |
embed | Patch embedding dimension. Default: 96 |
depths | Depth of each Swin Transformer layer. |
num | Number of attention heads in different layers. |
window | Window size. Default: 7 |
mlp | Ratio of mlp hidden dim to embedding dim. Default: 4 |
qkv | If True, add a learnable bias to query, key, value. Default: True |
qk | Override default qk scale of head_dim ** -0.5 if set. Default: None |
drop | Dropout rate. Default: 0 |
attn | Attention dropout rate. Default: 0 |
drop | Stochastic depth rate. Default: 0.1 |
norm | Normalization layer. Default: nn.LayerNorm. |
ape | If True, add absolute position embedding to the patch embedding. Default: False |
patch | If True, add normalization after patch embedding. Default: True |
use | Whether to use checkpointing to save memory. Default: False |
upscale | Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction |
img | Image range. 1. or 255. |
upsampler | The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None |
resi | The convolutional block before residual connection. '1conv'/'3conv' |
Method | __init__ |
Undocumented |
Method | calculate |
Undocumented |
Method | calculate |
Undocumented |
Method | calculate |
Undocumented |
Method | check |
Undocumented |
Method | forward |
Undocumented |
Method | forward |
Undocumented |
Class Variable | hyperparameters |
Undocumented |
Instance Variable | absolute |
Undocumented |
Instance Variable | ape |
Undocumented |
Instance Variable | conv |
Undocumented |
Instance Variable | conv |
Undocumented |
Instance Variable | conv |
Undocumented |
Instance Variable | conv |
Undocumented |
Instance Variable | embed |
Undocumented |
Instance Variable | img |
Undocumented |
Instance Variable | layers |
Undocumented |
Instance Variable | mean |
Undocumented |
Instance Variable | mlp |
Undocumented |
Instance Variable | norm |
Undocumented |
Instance Variable | num |
Undocumented |
Instance Variable | num |
Undocumented |
Instance Variable | overlap |
Undocumented |
Instance Variable | patch |
Undocumented |
Instance Variable | patch |
Undocumented |
Instance Variable | patch |
Undocumented |
Instance Variable | patches |
Undocumented |
Instance Variable | pos |
Undocumented |
Instance Variable | shift |
Undocumented |
Instance Variable | upsample |
Undocumented |
Instance Variable | upsampler |
Undocumented |
Instance Variable | upscale |
Undocumented |
Instance Variable | window |
Undocumented |
Method | _init |
Undocumented |
def __init__(self, *, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=( 6, 6, 6, 6), num_heads=( 6, 6, 6, 6), window_size=7, compress_ratio:
int|float
= 3, squeeze_factor: int|float
= 30, conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, upscale=1, img_range=1.0, upsampler: Literal[ 'pixelshuffle']
= 'pixelshuffle', resi_connection='1conv', num_feat=64):
(source)
¶
Undocumented