foscat.planar_vit ================= .. py:module:: foscat.planar_vit Attributes ---------- .. autoapisummary:: foscat.planar_vit.x Classes ------- .. autoapisummary:: foscat.planar_vit.PlanarViT Functions --------- .. autoapisummary:: foscat.planar_vit.count_parameters Module Contents --------------- .. py:class:: PlanarViT(in_ch: int, H: int, W: int, *, embed_dim: int = 384, depth: int = 8, num_heads: int = 12, mlp_ratio: float = 4.0, patch: int = 4, out_ch: int = 1, dropout: float = 0.1, cls_token: bool = False, pos_embed: str = 'learned') Bases: :py:obj:`torch.nn.Module` Vision Transformer for 2D lat–lon grids (planar baseline). Input : (B, C=T_in, H, W) Output: (B, out_ch, H, W) # dense per-pixel prediction Pipeline -------- 1) Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim 2) Optional CLS token (disabled by default for dense output) 3) Learned positional embeddings (or none) 4) Stack of Transformer blocks 5) Linear head per token, then nearest upsample back to (H, W) .. admonition:: Notes - Keep H, W divisible by `patch`. - For residual-of-persistence training (recommended for monthly SST): pred = x[:, -1:, ...] + model(x) and train the loss on `pred` vs target. .. py:attribute:: patch :value: 4 .. py:attribute:: embed_dim :value: 384 .. py:attribute:: cls_token_enabled :value: False .. py:attribute:: use_pos_embed .. py:attribute:: patch_embed .. py:attribute:: num_tokens .. py:attribute:: blocks .. py:attribute:: head .. py:method:: forward(x: torch.Tensor) -> torch.Tensor x: (B, C, H, W) with H,W fixed to construction-time H,W returns: (B, out_ch, H, W) .. py:function:: count_parameters(model: torch.nn.Module) -> tuple[int, int] Return (total_params, trainable_params). .. py:data:: x