Source code for foscat.SynthHalfUNet2D

"""
SynthHalfUNet2D — fast texture synthesis via a decoder-only U-Net.

Architecture
------------
Input: white noise  z ~ N(0,1),  shape [N, 1, H, W]

Skip connections (frozen wavelet bank, no encoder):
  At each scale j (j=0 finest, j=Jmax coarsest):
    z_j     = AvgPool2d^j(z)                          [N, 1,  H/2^j, W/2^j]
    skip_j  = [Re(ψ_l ★ z_j), Im(ψ_l ★ z_j)]_{l}   [N, 2L, H/2^j, W/2^j]

  where {ψ_l} are the L FOSCAT oriented complex wavelets (fixed, from FoCUS).

  At the coarsest level Jmax an extra low-frequency channel is added:
    z_avg = z_Jmax                                     [N, 1,  H/2^Jmax, W/2^Jmax]
  so the initial input to the decoder is [skip_Jmax, z_avg] → 2L+1 channels.

Decoder (all parameters are learned):
  x = InitBlock( cat[skip_Jmax, z_avg] )              [N, C[0], H/2^Jmax, W/2^Jmax]
  for j in Jmax-1 .. 0:
    x = Upsample(x, ×2)
    x = DecBlock_j( cat[x, skip_j] )                  [N, C[Jmax-j], H/2^j, W/2^j]
  out = Conv1×1(x)                                     [N, out_ch, H, W]

Each DecBlock is:  Conv3×3 → BN → LeakyReLU → Conv3×3 → BN → LeakyReLU

Training loss:
  z_i ~ N(0,1)  (sampled fresh each epoch)
  x_i = network(z_i)
  loss = reduce_distance( mean_scat_cov({x_i}), target_scat_cov )
"""

from __future__ import annotations

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------

[docs] class ConvBlock2D(nn.Module): """Two stacked Conv3×3 + InstanceNorm + LeakyReLU layers. InstanceNorm is used instead of BatchNorm so that each sample in the batch is normalised independently. BatchNorm computes statistics across the batch dimension and therefore averages out the sample-to-sample variation introduced by the noise input z, causing mode collapse (all z produce the same output). InstanceNorm has no such effect. """ def __init__(self, in_ch: int, out_ch: int, negative_slope: float = 0.2): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=True), nn.InstanceNorm2d(out_ch, affine=True), nn.LeakyReLU(negative_slope, inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=True), nn.InstanceNorm2d(out_ch, affine=True), nn.LeakyReLU(negative_slope, inplace=True), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x)
# --------------------------------------------------------------------------- # Main network # ---------------------------------------------------------------------------
[docs] class SynthHalfUNet2D(nn.Module): """Decoder-only U-Net for fast scattering-covariance texture synthesis. At inference time a single forward pass (milliseconds) replaces the slow gradient-descent synthesis (minutes). The network is trained by overfitting to a single target scattering covariance using :func:`train_synth_unet`. Parameters ---------- scat_op : FoCUS A FOSCAT FoCUS object initialised with ``use_2D=True``. Its oriented wavelet kernels (``ww_RealT[1]``, ``ww_ImagT[1]``) are extracted and registered as **frozen** buffers (not trained). Jmax : int Number of wavelet scales (decoder depth). The spatial resolution at the coarsest level is H/2^Jmax × W/2^Jmax. channel_list : list[int] or None Feature-map channels at each decoder level, ordered **coarsest → finest**: ``channel_list[0]`` is used at level Jmax (coarsest), ``channel_list[-1]`` is used just before the output Conv1×1. Must have length ``Jmax + 1``. Default: ``[min(32·2^(Jmax-j), 256) for j in 0..Jmax]``, e.g. ``[256, 128, 64, 32]`` for Jmax=3. out_channels : int Number of output image channels (1 for single-channel textures). Generate N independent samples by setting the batch size of z to N. Examples -------- >>> import torch >>> from foscat.SynthHalfUNet2D import SynthHalfUNet2D, train_synth_unet, generate_samples >>> import foscat.scat_cov2D as sc >>> >>> scat_op = sc.funct(NORIENT=4, KERNELSZ=3, use_2D=True) >>> target = torch.tensor(my_image_2d) # shape [H, W] >>> >>> # Train (overfits to the scattering covariance of target) >>> model = train_synth_unet(target, scat_op, Jmax=3, n_epochs=2000) >>> >>> # Generate new samples instantly >>> samples = generate_samples(model, n_samples=8, H=256, W=256) >>> # samples: torch.Tensor [8, 1, 256, 256] """ def __init__( self, scat_op, Jmax: int, channel_list: list[int] | None = None, out_channels: int = 1, ): super().__init__() self.Jmax = Jmax L = scat_op.NORIENT K = scat_op.KERNELSZ self.norient = L self.kernelsz = K pad = K // 2 # ------------------------------------------------------------------ # # Frozen wavelet bank — extracted from the FoCUS object # # Shape [L, K, K] → reshaped to [L, 1, K, K] for F.conv2d # # ------------------------------------------------------------------ # def _to_tensor(t): if isinstance(t, np.ndarray): return torch.tensor(t, dtype=torch.float32) return t.detach().float() wc = _to_tensor(scat_op.ww_RealT[1]).reshape(L, 1, K, K) ws = _to_tensor(scat_op.ww_ImagT[1]).reshape(L, 1, K, K) self.register_buffer("_wc", wc) # real part [L, 1, K, K] self.register_buffer("_ws", ws) # imag part [L, 1, K, K] # ------------------------------------------------------------------ # # Channel layout # # channel_list[0] = coarsest (Jmax), channel_list[-1] = finest (0) # # ------------------------------------------------------------------ # if channel_list is None: channel_list = [min(32 * 2 ** (Jmax - j), 256) for j in range(Jmax + 1)] assert len(channel_list) == Jmax + 1, ( f"channel_list must have Jmax+1={Jmax+1} entries, got {len(channel_list)}" ) self.channel_list = channel_list # Each skip contains: Re(ψ_l ★ z_j), Im(ψ_l ★ z_j), z_j # The extra z_j channel gives the decoder a direct, unfiltered path # from the noise to each decoder level, preventing mode collapse. skip_ch = 2 * L + 1 # real + imaginary (2L) + raw z (1) # ------------------------------------------------------------------ # # Decoder blocks # # ------------------------------------------------------------------ # # InitBlock: input = skip_Jmax (2L+1 ch, which already contains z_Jmax) self.init_block = ConvBlock2D(skip_ch, channel_list[0]) # One ConvBlock per upsampling step (Jmax steps, from Jmax-1 to 0) self.decoder_blocks = nn.ModuleList() for k in range(Jmax): in_ch = channel_list[k] + skip_ch # upsampled features + skip out_ch = channel_list[k + 1] self.decoder_blocks.append(ConvBlock2D(in_ch, out_ch)) # 1×1 projection to output channels self.output_conv = nn.Conv2d(channel_list[-1], out_channels, kernel_size=1) self._pad = pad # ------------------------------------------------------------------ # # Wavelet skip-connection bank # # ------------------------------------------------------------------ # def _wavelet_skips(self, z: torch.Tensor) -> list[torch.Tensor]: """Return skip tensors at every scale. Each skip contains the oriented wavelet responses of z at that scale **plus the raw (downsampled) noise z_j itself**. The extra z_j channel gives the decoder a direct, unfiltered path from the noise to each spatial resolution, which is the primary mechanism for producing diverse outputs. Returns ------- skips : list of tensors, len = Jmax+1 ``skips[j]`` has shape ``[N, 2L+1, H/2^j, W/2^j]`` Channels: ``[Re(ψ_0★z_j), …, Re(ψ_{L-1}★z_j),`` ``Im(ψ_0★z_j), …, Im(ψ_{L-1}★z_j),`` ``z_j]`` (j=0 finest, j=Jmax coarsest). """ pad = self._pad wc = self._wc # [L, 1, K, K] ws = self._ws skips = [] z_j = z # start at full resolution [N, 1, H, W] for j in range(self.Jmax + 1): if j > 0: z_j = F.avg_pool2d(z_j, kernel_size=2, stride=2) real = F.conv2d(z_j, wc, padding=pad) # [N, L, H/2^j, W/2^j] imag = F.conv2d(z_j, ws, padding=pad) # Concatenate wavelet responses + raw noise: [N, 2L+1, H/2^j, W/2^j] skips.append(torch.cat([real, imag, z_j], dim=1)) return skips # index 0 = finest, index Jmax = coarsest # ------------------------------------------------------------------ # # Forward # # ------------------------------------------------------------------ #
[docs] def forward(self, z: torch.Tensor) -> torch.Tensor: """Generate synthesised images from white noise. Parameters ---------- z : torch.Tensor, shape [N, 1, H, W] White noise input. N independent samples are produced in parallel. Returns ------- torch.Tensor, shape [N, out_channels, H, W] """ skips = self._wavelet_skips(z) # skips[0]=finest .. skips[Jmax]=coarsest # Each skip[j] already contains [Re, Im, z_j] — 2L+1 channels. # ---- Initial block at coarsest scale ---- # skip[Jmax] = [Re(ψ★z_J), Im(ψ★z_J), z_J] — 2L+1 channels x = self.init_block(skips[self.Jmax]) # ---- Decode from Jmax-1 down to 0 ---- for k, block in enumerate(self.decoder_blocks): x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) j = self.Jmax - 1 - k # scale of the skip to fuse x = block(torch.cat([x, skips[j]], dim=1)) return self.output_conv(x)
# --------------------------------------------------------------------------- # Microcanonical loss helper # --------------------------------------------------------------------------- def _microcanonical_loss( synth_sc, target_sc, sigma_synth: bool = True, eps: float = 1e-6, ) -> torch.Tensor: """Microcanonical scattering-covariance loss. Instead of penalising each generated image individually, this loss constrains the **distribution** of statistics across the N-sample batch. Two normalisation modes are supported: **With** ``sigma_synth=True`` (default) — combined normalisation: .. math:: \\mathcal{L} = \\sum_k \\frac{(\\bar{\\Phi}_k - \\Phi^*_k)^2} {\\sigma_{k,\\text{batch}} \\cdot |\\Phi^*_k| + \\varepsilon} The denominator is the product of two complementary terms: * :math:`\\sigma_{k,\\text{batch}}` — empirical standard deviation of coefficient *k* across the N generated images. Diverges when all images are identical (mode collapse → anti-collapse). * :math:`|\\Phi^*_k|` — absolute value of the target coefficient, i.e. the synthesis sigma used by the classical FOSCAT loss. Properly re-scales coefficients with very different magnitudes. **With** ``sigma_synth=False`` — pure microcanonical: .. math:: \\mathcal{L} = \\sum_k \\frac{(\\bar{\\Phi}_k - \\Phi^*_k)^2}{\\sigma^2_{k,\\text{batch}} + \\varepsilon} Parameters ---------- synth_sc : scat_cov Scattering covariance of the N generated images (batch dimension N). target_sc : scat_cov Scattering covariance of the target image (batch dimension 1). sigma_synth : bool If ``True`` (default), multiply the batch std by ``|Φ*|`` to also normalise by the natural scale of each coefficient (as in the classical synthesis loss). If ``False``, use only the batch variance. eps : float Floor added to the denominator to avoid division by zero. Returns ------- torch.Tensor (scalar) """ # Locate a non-None tensor to get device / dtype _ref = next( (getattr(synth_sc, a) for a in ("S2", "S3", "S4", "S1", "S3P", "S0") if getattr(synth_sc, a, None) is not None), None, ) device = _ref.device if _ref is not None else "cpu" _rdtype = _ref.real.dtype if torch.is_complex(_ref) else _ref.dtype loss = torch.zeros([], device=device, dtype=_rdtype) def _term(s_t, t_t): if s_t is None or t_t is None: return 0.0 # target may have batch=1; take mean over it to be safe t = t_t.mean(dim=0) # [...] — removes batch dimension if torch.is_complex(s_t): s_r, s_i = s_t.real, s_t.imag mean_r = s_r.mean(dim=0) mean_i = s_i.mean(dim=0) std_r = s_r.var(dim=0, unbiased=False).sqrt() std_i = s_i.var(dim=0, unbiased=False).sqrt() if sigma_synth: denom_r = (std_r * t.real.abs()).clamp(min=eps) denom_i = (std_i * t.imag.abs()).clamp(min=eps) else: denom_r = std_r.pow(2).clamp(min=eps) denom_i = std_i.pow(2).clamp(min=eps) return ( ((mean_r - t.real) ** 2 / denom_r).sum() + ((mean_i - t.imag) ** 2 / denom_i).sum() ) else: mean = s_t.mean(dim=0) std_b = s_t.var(dim=0, unbiased=False).sqrt() if sigma_synth: denom = (std_b * t.abs()).clamp(min=eps) else: denom = std_b.pow(2).clamp(min=eps) return ((mean - t) ** 2 / denom).sum() for attr in ("S0", "S1", "S2", "S3", "S3P", "S4"): loss = loss + _term(getattr(synth_sc, attr, None), getattr(target_sc, attr, None)) return loss # --------------------------------------------------------------------------- # Training loop # ---------------------------------------------------------------------------
[docs] def train_synth_unet( target_image: torch.Tensor, scat_op, Jmax: int, n_samples: int = 4, channel_list: list[int] | None = None, out_channels: int = 1, lr: float = 1e-3, n_epochs: int = 2000, eval_frequency: int = 100, norm: str = "auto", Jmax_scat=None, edge: bool = False, iso_ang: bool = False, fft_ang: bool = False, fft_nharm: int = 1, fft_imaginary: bool = True, microcanonical: bool = True, sigma_synth: bool = True, micro_eps: float = 1e-6, device: str | None = None, ) -> SynthHalfUNet2D: """Train a SynthHalfUNet2D to reproduce the scattering covariance of a target. The network is over-fitted to a single target (no generalisation). Once trained, drawing a new z ~ N(0,1) and calling ``model(z)`` generates a new texture sample in a single forward pass. Parameters ---------- target_image : torch.Tensor, shape [1, H, W] or [H, W] The reference texture whose scattering covariance we want to match. scat_op : FoCUS Initialised FOSCAT operator (``use_2D=True``). Jmax : int Number of upsampling levels in the U-Net decoder. n_samples : int Noise batch size per training step. More samples → more stable gradient, more memory. channel_list : list[int] or None Feature channels per decoder level, ordered **coarsest → finest**. Default: ``[min(32·2^(Jmax-j), 256) for j in 0..Jmax]``. out_channels : int Output channels per image (1 for single-channel textures). lr : float Initial Adam learning rate. n_epochs : int Number of training epochs. eval_frequency : int Print loss every this many epochs. norm : str Normalisation passed to ``scat_op.eval`` (e.g. ``'auto'``). Jmax_scat : int or None Maximum wavelet scale for the scattering-covariance loss. ``None`` uses all scales available in ``scat_op``. **Independent of the U-Net depth** ``Jmax``. edge : bool If ``True``, pass ``edge=True`` to ``scat_op.eval`` to compute statistics on edge pixels as well. Must match how the target statistics are used in the rest of the pipeline. iso_ang : bool If ``True``, apply :meth:`~foscat.scat_cov.scat_cov.iso_mean` to the scattering covariance after ``eval``, collapsing the orientation axes to a single isotropic mean. Cannot be combined with ``fft_ang``. fft_ang : bool If ``True``, apply :meth:`~foscat.scat_cov.scat_cov.fft_ang` to the scattering covariance after ``eval``, projecting the orientation axes onto the first ``fft_nharm`` Fourier harmonics. Cannot be combined with ``iso_ang``. fft_nharm : int Number of harmonics kept by ``fft_ang`` (beyond the DC term). Default 1. Ignored when ``fft_ang=False``. fft_imaginary : bool If ``True`` (default), keep both cosine and sine components in ``fft_ang``, giving rotation-invariant amplitudes. Ignored when ``fft_ang=False``. microcanonical : bool If ``True`` (default), use the **microcanonical loss** :func:`_microcanonical_loss`: the batch mean of statistics must match the target, normalised by the batch standard deviation. Mode collapse is naturally penalised because σ_batch → 0 makes the loss diverge. If ``False``, use the classical per-sample distance (each generated image must independently match the target) averaged over the batch. Requires ``n_samples >= 2``. sigma_synth : bool Only used when ``microcanonical=True``. If ``True`` (default), the denominator is :math:`\\sigma_{k,\\text{batch}} \\times |\\Phi^*_k|`, combining the anti-collapse property (from :math:`\\sigma_\\text{batch}`) with the per-coefficient scaling from the synthesis sigma (:math:`|\\Phi^*_k|`). If ``False``, use only :math:`\\sigma^2_{k,\\text{batch}}` (pure microcanonical). micro_eps : float Variance floor for the microcanonical loss (prevents division by exactly zero at the very start of training). Default ``1e-6``. device : str or None ``'cuda'``, ``'cpu'``, or None (auto-detect). Returns ------- SynthHalfUNet2D The trained network (in eval mode). Examples -------- >>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... edge=True, # same edge handling as scat_op.synthesis ... iso_ang=True, # isotropic loss (fewer statistics) ... n_epochs=2000, ... ) >>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... fft_ang=True, # orientation-aware Fourier loss ... fft_nharm=1, ... fft_imaginary=True, ... n_epochs=3000, ... ) """ if iso_ang and fft_ang: raise ValueError("iso_ang and fft_ang are mutually exclusive.") if microcanonical and n_samples < 2: raise ValueError( "microcanonical loss requires n_samples >= 2 " "(variance needs at least two data points)." ) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # ---- Helper: reduce statistics after eval -------------------------- # def _reduce(sc): """Apply iso_ang / fft_ang reduction (or none) to a scat_cov.""" if iso_ang: return sc.iso_mean() if fft_ang: return sc.fft_ang(nharm=fft_nharm, imaginary=fft_imaginary) return sc # ---- Prepare target image ----------------------------------------- # if target_image.dim() == 2: target_image = target_image.unsqueeze(0) # [1, H, W] H, W = target_image.shape[-2], target_image.shape[-1] target_image = target_image.to(device) # ---- Compute target scattering covariance (once) ------------------- # with torch.no_grad(): target_sc = _reduce( scat_op.eval(target_image, Jmax=Jmax_scat, norm=norm, edge=edge) ) # ---- Build model --------------------------------------------------- # model = SynthHalfUNet2D( scat_op, Jmax=Jmax, channel_list=channel_list, out_channels=out_channels, ).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) # ---- Training loop ------------------------------------------------- # model.train() for epoch in range(1, n_epochs + 1): optimizer.zero_grad() z = torch.randn(n_samples, 1, H, W, device=device) x_synth = model(z) # [N, out_ch, H, W] x_input = x_synth[:, 0, :, :] # [N, H, W] synth_sc = _reduce( scat_op.eval(x_input, Jmax=Jmax_scat, norm=norm, edge=edge) ) if microcanonical: # Microcanonical loss: (mean_k - target_k)^2 / (σ_batch × |Φ*| + ε) # σ_batch → 0 (mode collapse) → loss diverges (anti-collapse) # |Φ*| re-scales coefficients (synthesis sigma) loss = _microcanonical_loss( synth_sc, target_sc, sigma_synth=sigma_synth, eps=micro_eps, ) else: # Classical: each sample independently matches the target loss = scat_op.reduce_distance(synth_sc, target_sc) / n_samples loss.backward() optimizer.step() scheduler.step() if epoch % eval_frequency == 0 or epoch == 1: print(f"Epoch {epoch:5d}/{n_epochs} loss={loss.item():.6f}") model.eval() return model
# --------------------------------------------------------------------------- # Convenience: generate samples from a trained model # ---------------------------------------------------------------------------
[docs] def generate_samples( model: SynthHalfUNet2D, n_samples: int, H: int, W: int, device: str | None = None, seed: int | None = None, ) -> torch.Tensor: """Draw n_samples textures from a trained SynthHalfUNet2D. Parameters ---------- model : SynthHalfUNet2D A trained network (output of :func:`train_synth_unet`). n_samples : int Number of independent textures to generate. H, W : int Spatial dimensions of the output. device : str or None Target device. Defaults to the device of the model parameters. seed : int or None Optional random seed for reproducibility. Returns ------- torch.Tensor, shape [n_samples, out_channels, H, W] """ if device is None: device = next(model.parameters()).device if seed is not None: torch.manual_seed(seed) model.eval() with torch.no_grad(): z = torch.randn(n_samples, 1, H, W, device=device) return model(z)