foscat.SynthHalfUNet2D#

SynthHalfUNet2D — fast texture synthesis via a decoder-only U-Net.

Architecture#

Input: white noise z ~ N(0,1), shape [N, 1, H, W]

Skip connections (frozen wavelet bank, no encoder):
At each scale j (j=0 finest, j=Jmax coarsest):

z_j = AvgPool2d^j(z) [N, 1, H/2^j, W/2^j] skip_j = [Re(ψ_l ★ z_j), Im(ψ_l ★ z_j)]_{l} [N, 2L, H/2^j, W/2^j]

where {ψ_l} are the L FOSCAT oriented complex wavelets (fixed, from FoCUS).

At the coarsest level Jmax an extra low-frequency channel is added:

z_avg = z_Jmax [N, 1, H/2^Jmax, W/2^Jmax]

so the initial input to the decoder is [skip_Jmax, z_avg] → 2L+1 channels.

Decoder (all parameters are learned):

x = InitBlock( cat[skip_Jmax, z_avg] ) [N, C[0], H/2^Jmax, W/2^Jmax] for j in Jmax-1 .. 0:

x = Upsample(x, ×2) x = DecBlock_j( cat[x, skip_j] ) [N, C[Jmax-j], H/2^j, W/2^j]

out = Conv1×1(x) [N, out_ch, H, W]

Each DecBlock is: Conv3×3 → BN → LeakyReLU → Conv3×3 → BN → LeakyReLU

Training loss:

z_i ~ N(0,1) (sampled fresh each epoch) x_i = network(z_i) loss = reduce_distance( mean_scat_cov({x_i}), target_scat_cov )

Classes#

ConvBlock2D

Two stacked Conv3×3 + InstanceNorm + LeakyReLU layers.

SynthHalfUNet2D

Decoder-only U-Net for fast scattering-covariance texture synthesis.

Functions#

train_synth_unet(→ SynthHalfUNet2D)

Train a SynthHalfUNet2D to reproduce the scattering covariance of a target.

generate_samples(→ torch.Tensor)

Draw n_samples textures from a trained SynthHalfUNet2D.

Module Contents#

class foscat.SynthHalfUNet2D.ConvBlock2D(in_ch: int, out_ch: int, negative_slope: float = 0.2)[source]#

Bases: torch.nn.Module

Two stacked Conv3×3 + InstanceNorm + LeakyReLU layers.

InstanceNorm is used instead of BatchNorm so that each sample in the batch is normalised independently. BatchNorm computes statistics across the batch dimension and therefore averages out the sample-to-sample variation introduced by the noise input z, causing mode collapse (all z produce the same output). InstanceNorm has no such effect.

block#
forward(x: torch.Tensor) torch.Tensor[source]#
class foscat.SynthHalfUNet2D.SynthHalfUNet2D(scat_op, Jmax: int, channel_list: list[int] | None = None, out_channels: int = 1)[source]#

Bases: torch.nn.Module

Decoder-only U-Net for fast scattering-covariance texture synthesis.

At inference time a single forward pass (milliseconds) replaces the slow gradient-descent synthesis (minutes). The network is trained by overfitting to a single target scattering covariance using train_synth_unet().

Parameters:
  • scat_op (FoCUS) – A FOSCAT FoCUS object initialised with use_2D=True. Its oriented wavelet kernels (ww_RealT[1], ww_ImagT[1]) are extracted and registered as frozen buffers (not trained).

  • Jmax (int) – Number of wavelet scales (decoder depth). The spatial resolution at the coarsest level is H/2^Jmax × W/2^Jmax.

  • channel_list (list[int] or None) – Feature-map channels at each decoder level, ordered coarsest → finest: channel_list[0] is used at level Jmax (coarsest), channel_list[-1] is used just before the output Conv1×1. Must have length Jmax + 1. Default: [min(32·2^(Jmax-j), 256) for j in 0..Jmax], e.g. [256, 128, 64, 32] for Jmax=3.

  • out_channels (int) – Number of output image channels (1 for single-channel textures). Generate N independent samples by setting the batch size of z to N.

Examples

>>> import torch
>>> from foscat.SynthHalfUNet2D import SynthHalfUNet2D, train_synth_unet, generate_samples
>>> import foscat.scat_cov2D as sc
>>>
>>> scat_op = sc.funct(NORIENT=4, KERNELSZ=3, use_2D=True)
>>> target = torch.tensor(my_image_2d)           # shape [H, W]
>>>
>>> # Train (overfits to the scattering covariance of target)
>>> model = train_synth_unet(target, scat_op, Jmax=3, n_epochs=2000)
>>>
>>> # Generate new samples instantly
>>> samples = generate_samples(model, n_samples=8, H=256, W=256)
>>> # samples: torch.Tensor [8, 1, 256, 256]
Jmax#
norient#
kernelsz#
channel_list = None#
init_block#
decoder_blocks#
output_conv#
forward(z: torch.Tensor) torch.Tensor[source]#

Generate synthesised images from white noise.

Parameters:

z (torch.Tensor, shape [N, 1, H, W]) – White noise input. N independent samples are produced in parallel.

Returns:

torch.Tensor, shape [N, out_channels, H, W]

foscat.SynthHalfUNet2D.train_synth_unet(target_image: torch.Tensor, scat_op, Jmax: int, n_samples: int = 4, channel_list: list[int] | None = None, out_channels: int = 1, lr: float = 0.001, n_epochs: int = 2000, eval_frequency: int = 100, norm: str = 'auto', Jmax_scat=None, edge: bool = False, iso_ang: bool = False, fft_ang: bool = False, fft_nharm: int = 1, fft_imaginary: bool = True, microcanonical: bool = True, sigma_synth: bool = True, micro_eps: float = 1e-06, device: str | None = None) SynthHalfUNet2D[source]#

Train a SynthHalfUNet2D to reproduce the scattering covariance of a target.

The network is over-fitted to a single target (no generalisation). Once trained, drawing a new z ~ N(0,1) and calling model(z) generates a new texture sample in a single forward pass.

Parameters:
  • target_image (torch.Tensor, shape [1, H, W] or [H, W]) – The reference texture whose scattering covariance we want to match.

  • scat_op (FoCUS) – Initialised FOSCAT operator (use_2D=True).

  • Jmax (int) – Number of upsampling levels in the U-Net decoder.

  • n_samples (int) – Noise batch size per training step. More samples → more stable gradient, more memory.

  • channel_list (list[int] or None) – Feature channels per decoder level, ordered coarsest → finest. Default: [min(32·2^(Jmax-j), 256) for j in 0..Jmax].

  • out_channels (int) – Output channels per image (1 for single-channel textures).

  • lr (float) – Initial Adam learning rate.

  • n_epochs (int) – Number of training epochs.

  • eval_frequency (int) – Print loss every this many epochs.

  • norm (str) – Normalisation passed to scat_op.eval (e.g. 'auto').

  • Jmax_scat (int or None) – Maximum wavelet scale for the scattering-covariance loss. None uses all scales available in scat_op. Independent of the U-Net depth Jmax.

  • edge (bool) – If True, pass edge=True to scat_op.eval to compute statistics on edge pixels as well. Must match how the target statistics are used in the rest of the pipeline.

  • iso_ang (bool) – If True, apply iso_mean() to the scattering covariance after eval, collapsing the orientation axes to a single isotropic mean. Cannot be combined with fft_ang.

  • fft_ang (bool) – If True, apply fft_ang() to the scattering covariance after eval, projecting the orientation axes onto the first fft_nharm Fourier harmonics. Cannot be combined with iso_ang.

  • fft_nharm (int) – Number of harmonics kept by fft_ang (beyond the DC term). Default 1. Ignored when fft_ang=False.

  • fft_imaginary (bool) – If True (default), keep both cosine and sine components in fft_ang, giving rotation-invariant amplitudes. Ignored when fft_ang=False.

  • microcanonical (bool) – If True (default), use the microcanonical loss _microcanonical_loss(): the batch mean of statistics must match the target, normalised by the batch standard deviation. Mode collapse is naturally penalised because σ_batch → 0 makes the loss diverge. If False, use the classical per-sample distance (each generated image must independently match the target) averaged over the batch. Requires n_samples >= 2.

  • sigma_synth (bool) – Only used when microcanonical=True. If True (default), the denominator is \(\sigma_{k,\text{batch}} \times |\Phi^*_k|\), combining the anti-collapse property (from \(\sigma_\text{batch}\)) with the per-coefficient scaling from the synthesis sigma (\(|\Phi^*_k|\)). If False, use only \(\sigma^2_{k,\text{batch}}\) (pure microcanonical).

  • micro_eps (float) – Variance floor for the microcanonical loss (prevents division by exactly zero at the very start of training). Default 1e-6.

  • device (str or None) – 'cuda', 'cpu', or None (auto-detect).

Returns:

SynthHalfUNet2D – The trained network (in eval mode).

Examples

>>> model = train_synth_unet(
...     target, scat_op, Jmax=4,
...     edge=True,          # same edge handling as scat_op.synthesis
...     iso_ang=True,       # isotropic loss (fewer statistics)
...     n_epochs=2000,
... )
>>> model = train_synth_unet(
...     target, scat_op, Jmax=4,
...     fft_ang=True,       # orientation-aware Fourier loss
...     fft_nharm=1,
...     fft_imaginary=True,
...     n_epochs=3000,
... )
foscat.SynthHalfUNet2D.generate_samples(model: SynthHalfUNet2D, n_samples: int, H: int, W: int, device: str | None = None, seed: int | None = None) torch.Tensor[source]#

Draw n_samples textures from a trained SynthHalfUNet2D.

Parameters:
  • model (SynthHalfUNet2D) – A trained network (output of train_synth_unet()).

  • n_samples (int) – Number of independent textures to generate.

  • H, W (int) – Spatial dimensions of the output.

  • device (str or None) – Target device. Defaults to the device of the model parameters.

  • seed (int or None) – Optional random seed for reproducibility.

Returns:

torch.Tensor, shape [n_samples, out_channels, H, W]