Source code for foscat.planar_vit

# healpix_unet_torch.py
# (Planar Vision Transformer baseline for lat–lon grids)
from __future__ import annotations
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------
# Building blocks
# ---------------------------

class _MLP(nn.Module):
    """ViT MLP: Linear -> GELU -> Dropout -> Linear -> Dropout."""
    def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.1):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.fc2 = nn.Linear(hidden, dim)
        self.act = nn.GELU()
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.drop(self.act(self.fc1(x)))
        x = self.drop(self.fc2(x))
        return x


class _ViTBlock(nn.Module):
    """
    Transformer block (Pre-LN):
      x = x + Drop(MHA(LN(x)))
      x = x + Drop(MLP(LN(x)))
    """
    def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, drop: float = 0.1):
        super().__init__()
        assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp   = _MLP(dim, mlp_ratio, drop)
        self.drop_path = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Multi-head self-attention
        x = x + self.drop_path(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
        # Feed-forward
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


# ---------------------------
# Planar ViT (lat–lon images)
# ---------------------------

[docs] class PlanarViT(nn.Module): """ Vision Transformer for 2D lat–lon grids (planar baseline). Input : (B, C=T_in, H, W) Output: (B, out_ch, H, W) # dense per-pixel prediction Pipeline -------- 1) Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim 2) Optional CLS token (disabled by default for dense output) 3) Learned positional embeddings (or none) 4) Stack of Transformer blocks 5) Linear head per token, then nearest upsample back to (H, W) Notes ----- - Keep H, W divisible by `patch`. - For residual-of-persistence training (recommended for monthly SST): pred = x[:, -1:, ...] + model(x) and train the loss on `pred` vs target. """ def __init__( self, in_ch: int, # e.g., T_in months H: int, W: int, *, embed_dim: int = 384, depth: int = 8, num_heads: int = 12, mlp_ratio: float = 4.0, patch: int = 4, out_ch: int = 1, dropout: float = 0.1, cls_token: bool = False, # keep False for dense prediction pos_embed: str = "learned", # or "none" ): super().__init__() assert H % patch == 0 and W % patch == 0, "H and W must be divisible by patch" assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.H, self.W = H, W self.patch = patch self.embed_dim = embed_dim self.cls_token_enabled = bool(cls_token) self.use_pos_embed = (pos_embed == "learned") # 1) Patch embedding (Conv2d with stride=patch) → tokens self.patch_embed = nn.Conv2d(in_ch, embed_dim, kernel_size=patch, stride=patch) # 2) Token bookkeeping & positional embeddings Hp, Wp = H // patch, W // patch self.num_tokens = Hp * Wp + (1 if self.cls_token_enabled else 0) if self.cls_token_enabled: self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) else: self.cls_token = None if self.use_pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) else: self.pos_embed = None # 3) Transformer encoder self.blocks = nn.ModuleList([ _ViTBlock(embed_dim, num_heads, mlp_ratio=mlp_ratio, drop=dropout) for _ in range(depth) ]) # 4) Patch-wise head (token -> out_ch) self.head = nn.Linear(embed_dim, out_ch) # Store for unpatching self.Hp, self.Wp = Hp, Wp
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (B, C, H, W) with H,W fixed to construction-time H,W returns: (B, out_ch, H, W) """ B, C, H, W = x.shape if (H != self.H) or (W != self.W): raise ValueError(f"Input H,W must be ({self.H},{self.W}), got ({H},{W}).") # Patch embedding → (B, E, Hp, Wp) → (B, Np, E) z = self.patch_embed(x) # (B, E, Hp, Wp) z = z.flatten(2).transpose(1, 2) # (B, Np, E) # Optional CLS if self.cls_token_enabled: cls = self.cls_token.expand(B, -1, -1) # (B,1,E) z = torch.cat([cls, z], dim=1) # (B,1+Np,E) # Positional embedding if self.pos_embed is not None: z = z + self.pos_embed[:, :z.shape[1], :] # Transformer for blk in self.blocks: z = blk(z) # (B, N, E) # Drop CLS for dense output if self.cls_token_enabled: tokens = z[:, 1:, :] # (B, Np, E) else: tokens = z # Token head → (B, Np, out_ch) → (B, out_ch, Hp, Wp) → upsample to (H, W) y_tok = self.head(tokens).transpose(1, 2) # (B, out_ch, Np) y = y_tok.reshape(B, -1, self.Hp, self.Wp) # (B, out_ch, Hp, Wp) y = F.interpolate(y, scale_factor=self.patch, mode="nearest") return y
# --------------------------- # Utilities # ---------------------------
[docs] def count_parameters(model: nn.Module) -> tuple[int, int]: """Return (total_params, trainable_params).""" total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return total, trainable
# --------------------------- # Smoke test # --------------------------- if __name__ == "__main__": # Example: T_in=6, grid 128x256, predict 1 channel B, C, H, W = 2, 6, 128, 256 x = torch.randn(B, C, H, W) model = PlanarViT( in_ch=C, H=H, W=W, embed_dim=384, depth=8, num_heads=12, mlp_ratio=4.0, patch=4, out_ch=1, dropout=0.1, cls_token=False, pos_embed="learned" ) y = model(x) tot, trn = count_parameters(model) print("Output:", tuple(y.shape)) print("Params:", f"total={tot:,}", f"trainable={trn:,}")