foscat.planar_vit#
Attributes#
Classes#
Vision Transformer for 2D lat–lon grids (planar baseline). |
Functions#
|
Return (total_params, trainable_params). |
Module Contents#
- class foscat.planar_vit.PlanarViT(in_ch: int, H: int, W: int, *, embed_dim: int = 384, depth: int = 8, num_heads: int = 12, mlp_ratio: float = 4.0, patch: int = 4, out_ch: int = 1, dropout: float = 0.1, cls_token: bool = False, pos_embed: str = 'learned')[source]#
Bases:
torch.nn.ModuleVision Transformer for 2D lat–lon grids (planar baseline).
Input : (B, C=T_in, H, W) Output: (B, out_ch, H, W) # dense per-pixel prediction
Pipeline#
Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim
Optional CLS token (disabled by default for dense output)
Learned positional embeddings (or none)
Stack of Transformer blocks
Linear head per token, then nearest upsample back to (H, W)
Notes
Keep H, W divisible by patch.
- For residual-of-persistence training (recommended for monthly SST):
pred = x[:, -1:, …] + model(x)
and train the loss on pred vs target.
- patch = 4#
- embed_dim = 384#
- cls_token_enabled = False#
- use_pos_embed#
- patch_embed#
- num_tokens#
- blocks#
- head#
- forward(x: torch.Tensor) torch.Tensor[source]#
x: (B, C, H, W) with H,W fixed to construction-time H,W returns: (B, out_ch, H, W)
- foscat.planar_vit.count_parameters(model: torch.nn.Module) tuple[int, int][source]#
Return (total_params, trainable_params).
- foscat.planar_vit.x#