Source code for foscat.unet_2_d_from_healpix_params

from __future__ import annotations
from typing import List, Optional, Literal, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext

[docs] class PlanarUNet(nn.Module): """ U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet. Key compat points with HealpixUNet: - Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task, out_channels, final_activation, device, down_type, dtype, head_reduce. - Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv. - Downsampling by factor 2 at each level; upsampling mirrors back. - Head produces `out_channels` with optional BN and final activation. Differences vs sphere version: - Operates on regular 2D images of size (3*in_nside, 4*in_nside). - Standard Conv2d instead of custom spherical stencil. - No gauges (G=1 implicit) and no cell_ids. Shapes ------ Input : (B, C_in, 3*in_nside, 4*in_nside) Output : (B, C_out, 3*in_nside, 4*in_nside) Constraints ----------- `in_nside` must be divisible by 2**depth, where depth == len(chanlist). """ def __init__( self, *, in_nside: int, n_chan_in: int, chanlist: List[int], KERNELSZ: int = 3, task: Literal['regression', 'segmentation'] = 'regression', out_channels: int = 1, final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None, device: Optional[torch.device | str] = None, down_type: Optional[Literal['mean','max']] = 'max', dtype: Literal['float32','float64'] = 'float32', head_reduce: Literal['mean','learned'] = 'mean', # kept for API symmetry ) -> None: super().__init__() if len(chanlist) == 0: raise ValueError("chanlist must be non-empty (depth >= 1)") self.in_nside = int(in_nside) self.n_chan_in = int(n_chan_in) self.chanlist = list(map(int, chanlist)) self.KERNELSZ = int(KERNELSZ) self.task = task self.out_channels = int(out_channels) self.down_type = down_type self.dtype = torch.float32 if dtype == 'float32' else torch.float64 self.head_reduce = head_reduce # default final activation consistent with HealpixUNet if final_activation is None: if task == 'regression': self.final_activation = 'none' else: self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax' else: self.final_activation = final_activation # Resolve device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) depth = len(self.chanlist) # geometry H0, W0 = 3 * self.in_nside, 4 * self.in_nside # ensure divisibility by 2**depth if (self.in_nside % (2 ** depth)) != 0: raise ValueError( f"in_nside={self.in_nside} must be divisible by 2**depth where depth={depth}" ) padding = self.KERNELSZ // 2 # --- Encoder --- enc_layers = [] inC = self.n_chan_in self.skips_channels: List[int] = [] for outC in self.chanlist: block = nn.Sequential( nn.Conv2d(inC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False), _norm_2d(outC, kind="group"), nn.ReLU(inplace=True), nn.Conv2d(outC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False), _norm_2d(outC, kind="group"), nn.ReLU(inplace=True), ) enc_layers.append(block) inC = outC self.skips_channels.append(outC) self.encoder = nn.ModuleList(enc_layers) # Pools if self.down_type == 'max': self.pool = nn.MaxPool2d(kernel_size=2, stride=2) else: self.pool = nn.AvgPool2d(kernel_size=2, stride=2) # --- Decoder --- dec_layers = [] upconvs = [] for l in reversed(range(depth)): skipC = self.skips_channels[l] upC = self.skips_channels[l + 1] if (l + 1) < depth else self.skips_channels[l] inC_dec = upC + skipC outC_dec = skipC upconvs.append( nn.ConvTranspose2d(upC, upC, kernel_size=2, stride=2) ) dec_layers.append( nn.Sequential( nn.Conv2d(inC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False), _norm_2d(outC_dec, kind="group"), nn.ReLU(inplace=True), nn.Conv2d(outC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False), _norm_2d(outC_dec, kind="group"), nn.ReLU(inplace=True), ) ) self.upconvs = nn.ModuleList(upconvs) self.decoder = nn.ModuleList(dec_layers) # --- Head --- head_inC = self.chanlist[0] self.head_conv = nn.Conv2d(head_inC, self.out_channels, kernel_size=self.KERNELSZ, padding=padding) self.head_bn = _norm_2d(self.out_channels, kind="group") if self.task == 'segmentation' else None # optional learned mixer kept for API compatibility (no gauges here) self.head_mixer = nn.Identity() self.to(self.device, dtype=self.dtype)
[docs] def to_tensor(self,x): return torch.tensor(x,device=self.device)
[docs] def to_numpy(self,x): if isinstance(x,np.ndarray): return x return x.cpu().numpy()
# -------------------------- forward --------------------------
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside""" if x.dim() != 4: raise ValueError("Input must be (B, C, H, W)") if x.shape[1] != self.n_chan_in: raise ValueError(f"Expected {self.n_chan_in} input channels, got {x.shape[1]}") x = x.to(self.device, dtype=self.dtype) skips = [] z = x for l, block in enumerate(self.encoder): z = block(z) skips.append(z) if l < len(self.encoder) - 1: z = self.pool(z) # Decoder for d, l in enumerate(reversed(range(len(self.chanlist)))): if l < len(self.chanlist) - 1: z = self.upconvs[d](z) # pad if odd due to pooling/upsampling asymmetry (shouldn't happen given divisibility) sh = skips[l].shape if z.shape[-2:] != sh[-2:]: z = _pad_to_match(z, sh[-2], sh[-1]) z = torch.cat([skips[l], z], dim=1) z = self.decoder[d](z) y = self.head_conv(z) if self.task == 'segmentation' and self.head_bn is not None: y = self.head_bn(y) if self.final_activation == 'sigmoid': y = torch.sigmoid(y) elif self.final_activation == 'softmax': y = torch.softmax(y, dim=1) return y
@torch.no_grad() def predict(self, x: torch.Tensor, batch_size: int = 8) -> torch.Tensor: self.eval() outs = [] for i in range(0, x.shape[0], batch_size): xb = x[i:i+batch_size] outs.append(self.forward(xb)) return torch.cat(outs, dim=0)
[docs] @torch.no_grad() def predict( self, x: torch.Tensor, batch_size: int = 8, *, amp: bool = False, out_device: Optional[str] = 'cpu', out_dtype: Literal['float32','float16'] = 'float32', show_pbar: bool = False, ) -> torch.Tensor: """Memory-safe prediction. - Streams mini-batches avec torch.inference_mode() + AMP optionnel. - Moves each output batch to `out_device` (CPU by default) to free VRAM. - Checks and clarifies shape errors. """ self.eval() # --- input checks & normalisation --- x = x if torch.is_tensor(x) else torch.as_tensor(x) if x.ndim != 4: raise ValueError(f"predict expects (N,C,H,W), got {tuple(getattr(x,'shape',()))}") if x.shape[1] != self.n_chan_in: raise ValueError(f"predict expected {self.n_chan_in} channels, got {x.shape[1]}") n = int(x.shape[0]) if n == 0: H, W = int(x.shape[-2]), int(x.shape[-1]) return torch.empty((0, self.out_channels, H, W), device=out_device or self.device) # --- preparation --- dtype_map = {'float32': torch.float32, 'float16': torch.float16} out_dtype_t = dtype_map[out_dtype] use_cuda = (self.device.type == 'cuda') if use_cuda: torch.backends.cudnn.benchmark = True from math import ceil nb = ceil(n / batch_size) rng = range(0, n, batch_size) if show_pbar: try: from tqdm import tqdm # type: ignore rng = tqdm(rng, total=nb, desc='predict') except Exception: pass # --- batch-by-batch inference --- out_list: List[torch.Tensor] = [] with torch.inference_mode(): ctx = (torch.cuda.amp.autocast() if (amp and use_cuda) else nullcontext()) for i in rng: xb = x[i:i+batch_size].to(self.device, dtype=self.dtype, non_blocking=True) with ctx: yb = self.forward(xb) # Move output to the target device (CPU by default) yb = yb.to(out_device, dtype=out_dtype_t) if out_device is not None else yb.to(dtype=out_dtype_t) out_list.append(yb) del xb, yb if use_cuda: torch.cuda.empty_cache() if not out_list: raise RuntimeError(f"predict produced no outputs; check input shape {tuple(x.shape)} and batch_size={batch_size}") return torch.cat(out_list, dim=0)
# ----------------------------- # Helpers # ----------------------------- def _norm_2d(C: int, kind: str = "group", **kwargs) -> nn.Module: if kind == "group": num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1) while C % num_groups != 0 and num_groups > 1: num_groups //= 2 return nn.GroupNorm(num_groups=num_groups, num_channels=C) elif kind == "instance": return nn.InstanceNorm2d(C, affine=True, track_running_stats=False) elif kind == "batch": return nn.BatchNorm2d(C) else: raise ValueError(f"Unknown norm kind: {kind}") def _pad_to_match(x: torch.Tensor, H: int, W: int) -> torch.Tensor: """Pad x (B,C,h,w) with zeros on right/bottom to reach (H,W).""" _, _, h, w = x.shape ph = max(0, H - h) pw = max(0, W - w) if ph == 0 and pw == 0: return x return F.pad(x, (0, pw, 0, ph), mode='constant', value=0) # ----------------------------- # Training utilities (mirror of Healpix fit) # ----------------------------- from typing import Union import numpy as np from torch.utils.data import DataLoader, TensorDataset
[docs] def fit( model: nn.Module, x_train: Union[torch.Tensor, np.ndarray], y_train: Union[torch.Tensor, np.ndarray], *, n_epoch: int = 10, view_epoch: int = 10, batch_size: int = 16, lr: float = 1e-3, weight_decay: float = 0.0, clip_grad_norm: Optional[float] = None, verbose: bool = True, optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM', ) -> dict: """Training loop mirroring `healpix_unet_torch.fit`, adapted for 2D images. - Fixed inputs: tensors/ndarrays of the same size (B, C, H, W) with H=3*nside, W=4*nside - Perte: MSE (regression) / BCE(BCEWithLogits si final_activation='none') / CrossEntropy (multiclasses) - Optimiseur: ADAM ou LBFGS avec closure - Logs: renvoie {"loss": history} """ device = next(model.parameters()).device model.to(device) # ---- DataLoader x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device) y_is_class = (getattr(model, 'task', 'regression') != 'regression' and getattr(model, 'out_channels', 1) > 1) y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim == x_t.ndim - 1) else torch.float32 y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device) ds = TensorDataset(x_t, y_t) loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False) # ---- Loss if getattr(model, 'task', 'regression') == 'regression': criterion = nn.MSELoss(reduction='mean') seg_multiclass = False else: if getattr(model, 'out_channels', 1) == 1: criterion = nn.BCEWithLogitsLoss() if getattr(model, 'final_activation', 'none') == 'none' else nn.BCELoss() seg_multiclass = False else: criterion = nn.CrossEntropyLoss() seg_multiclass = True # ---- Optim if optimizer.upper() == 'ADAM': optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) outer, inner = n_epoch, 1 elif optimizer.upper() == 'LBFGS': optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=50, line_search_fn='strong_wolfe') outer, inner = max(1, n_epoch // 20), 20 else: raise ValueError("optimizer must be 'ADAM' or 'LBFGS'") # ---- Train history: List[float] = [] model.train() for epoch in range(outer): for _ in range(inner): epoch_loss, n_samples = 0.0, 0 for xb, yb in loader: xb = xb.to(device, dtype=torch.float32, non_blocking=True) yb = yb.to(device, non_blocking=True) if isinstance(optim, torch.optim.LBFGS): def closure(): optim.zero_grad(set_to_none=True) preds = model(xb) if seg_multiclass: loss = criterion(preds, yb) else: loss = criterion(preds, yb) loss.backward() return loss loss_val = float(optim.step(closure).item()) else: optim.zero_grad(set_to_none=True) preds = model(xb) if seg_multiclass: loss = criterion(preds, yb) else: loss = criterion(preds, yb) loss.backward() if clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) optim.step() loss_val = float(loss.item()) epoch_loss += loss_val * xb.shape[0] n_samples += xb.shape[0] epoch_loss /= max(1, n_samples) history.append(epoch_loss) if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)): print(f"[epoch {len(history)}] loss={epoch_loss:.6f}") return {"loss": history}
# ----------------------------- # Minimal smoke test # ----------------------------- if __name__ == "__main__": torch.manual_seed(0) nside = 32 chanlist = [16, 32, 64] net = PlanarUNet( in_nside=nside, n_chan_in=3, chanlist=chanlist, KERNELSZ=3, task='regression', out_channels=1, ) x = torch.randn(2, 3, 3*nside, 4*nside) y = net(x) print('input:', x.shape, 'output:', y.shape)