foscat.planar_vit#

Attributes#

x

Classes#

PlanarViT

Vision Transformer for 2D lat–lon grids (planar baseline).

Functions#

count_parameters(→ tuple[int, int])

Return (total_params, trainable_params).

Module Contents#

class foscat.planar_vit.PlanarViT(in_ch: int, H: int, W: int, *, embed_dim: int = 384, depth: int = 8, num_heads: int = 12, mlp_ratio: float = 4.0, patch: int = 4, out_ch: int = 1, dropout: float = 0.1, cls_token: bool = False, pos_embed: str = 'learned')[source]#

Bases: torch.nn.Module

Vision Transformer for 2D lat–lon grids (planar baseline).

Input : (B, C=T_in, H, W) Output: (B, out_ch, H, W) # dense per-pixel prediction

Pipeline#

  1. Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim

  2. Optional CLS token (disabled by default for dense output)

  3. Learned positional embeddings (or none)

  4. Stack of Transformer blocks

  5. Linear head per token, then nearest upsample back to (H, W)

Notes

  • Keep H, W divisible by patch.

  • For residual-of-persistence training (recommended for monthly SST):

    pred = x[:, -1:, …] + model(x)

    and train the loss on pred vs target.

patch = 4#
embed_dim = 384#
cls_token_enabled = False#
use_pos_embed#
patch_embed#
num_tokens#
blocks#
head#
forward(x: torch.Tensor) torch.Tensor[source]#

x: (B, C, H, W) with H,W fixed to construction-time H,W returns: (B, out_ch, H, W)

foscat.planar_vit.count_parameters(model: torch.nn.Module) tuple[int, int][source]#

Return (total_params, trainable_params).

foscat.planar_vit.x#