foscat.SynthHalfUNet2D#
SynthHalfUNet2D — fast texture synthesis via a decoder-only U-Net.
Architecture#
Input: white noise z ~ N(0,1), shape [N, 1, H, W]
- Skip connections (frozen wavelet bank, no encoder):
- At each scale j (j=0 finest, j=Jmax coarsest):
z_j = AvgPool2d^j(z) [N, 1, H/2^j, W/2^j] skip_j = [Re(ψ_l ★ z_j), Im(ψ_l ★ z_j)]_{l} [N, 2L, H/2^j, W/2^j]
where {ψ_l} are the L FOSCAT oriented complex wavelets (fixed, from FoCUS).
- At the coarsest level Jmax an extra low-frequency channel is added:
z_avg = z_Jmax [N, 1, H/2^Jmax, W/2^Jmax]
so the initial input to the decoder is [skip_Jmax, z_avg] → 2L+1 channels.
- Decoder (all parameters are learned):
x = InitBlock( cat[skip_Jmax, z_avg] ) [N, C[0], H/2^Jmax, W/2^Jmax] for j in Jmax-1 .. 0:
x = Upsample(x, ×2) x = DecBlock_j( cat[x, skip_j] ) [N, C[Jmax-j], H/2^j, W/2^j]
out = Conv1×1(x) [N, out_ch, H, W]
Each DecBlock is: Conv3×3 → BN → LeakyReLU → Conv3×3 → BN → LeakyReLU
- Training loss:
z_i ~ N(0,1) (sampled fresh each epoch) x_i = network(z_i) loss = reduce_distance( mean_scat_cov({x_i}), target_scat_cov )
Classes#
Two stacked Conv3×3 + InstanceNorm + LeakyReLU layers. |
|
Decoder-only U-Net for fast scattering-covariance texture synthesis. |
Functions#
|
Train a SynthHalfUNet2D to reproduce the scattering covariance of a target. |
|
Draw n_samples textures from a trained SynthHalfUNet2D. |
Module Contents#
- class foscat.SynthHalfUNet2D.ConvBlock2D(in_ch: int, out_ch: int, negative_slope: float = 0.2)[source]#
Bases:
torch.nn.ModuleTwo stacked Conv3×3 + InstanceNorm + LeakyReLU layers.
InstanceNorm is used instead of BatchNorm so that each sample in the batch is normalised independently. BatchNorm computes statistics across the batch dimension and therefore averages out the sample-to-sample variation introduced by the noise input z, causing mode collapse (all z produce the same output). InstanceNorm has no such effect.
- block#
- forward(x: torch.Tensor) torch.Tensor[source]#
- class foscat.SynthHalfUNet2D.SynthHalfUNet2D(scat_op, Jmax: int, channel_list: list[int] | None = None, out_channels: int = 1)[source]#
Bases:
torch.nn.ModuleDecoder-only U-Net for fast scattering-covariance texture synthesis.
At inference time a single forward pass (milliseconds) replaces the slow gradient-descent synthesis (minutes). The network is trained by overfitting to a single target scattering covariance using
train_synth_unet().- Parameters:
scat_op (
FoCUS) – A FOSCAT FoCUS object initialised withuse_2D=True. Its oriented wavelet kernels (ww_RealT[1],ww_ImagT[1]) are extracted and registered as frozen buffers (not trained).Jmax (
int) – Number of wavelet scales (decoder depth). The spatial resolution at the coarsest level is H/2^Jmax × W/2^Jmax.channel_list (
list[int]orNone) – Feature-map channels at each decoder level, ordered coarsest → finest:channel_list[0]is used at level Jmax (coarsest),channel_list[-1]is used just before the output Conv1×1. Must have lengthJmax + 1. Default:[min(32·2^(Jmax-j), 256) for j in 0..Jmax], e.g.[256, 128, 64, 32]for Jmax=3.out_channels (
int) – Number of output image channels (1 for single-channel textures). Generate N independent samples by setting the batch size of z to N.
Examples
>>> import torch >>> from foscat.SynthHalfUNet2D import SynthHalfUNet2D, train_synth_unet, generate_samples >>> import foscat.scat_cov2D as sc >>> >>> scat_op = sc.funct(NORIENT=4, KERNELSZ=3, use_2D=True) >>> target = torch.tensor(my_image_2d) # shape [H, W] >>> >>> # Train (overfits to the scattering covariance of target) >>> model = train_synth_unet(target, scat_op, Jmax=3, n_epochs=2000) >>> >>> # Generate new samples instantly >>> samples = generate_samples(model, n_samples=8, H=256, W=256) >>> # samples: torch.Tensor [8, 1, 256, 256]
- Jmax#
- norient#
- kernelsz#
- channel_list = None#
- init_block#
- decoder_blocks#
- output_conv#
- forward(z: torch.Tensor) torch.Tensor[source]#
Generate synthesised images from white noise.
- Parameters:
z (
torch.Tensor,shape [N,1,H,W]) – White noise input. N independent samples are produced in parallel.- Returns:
torch.Tensor,shape [N,out_channels,H,W]
- foscat.SynthHalfUNet2D.train_synth_unet(target_image: torch.Tensor, scat_op, Jmax: int, n_samples: int = 4, channel_list: list[int] | None = None, out_channels: int = 1, lr: float = 0.001, n_epochs: int = 2000, eval_frequency: int = 100, norm: str = 'auto', Jmax_scat=None, edge: bool = False, iso_ang: bool = False, fft_ang: bool = False, fft_nharm: int = 1, fft_imaginary: bool = True, microcanonical: bool = True, sigma_synth: bool = True, micro_eps: float = 1e-06, device: str | None = None) SynthHalfUNet2D[source]#
Train a SynthHalfUNet2D to reproduce the scattering covariance of a target.
The network is over-fitted to a single target (no generalisation). Once trained, drawing a new z ~ N(0,1) and calling
model(z)generates a new texture sample in a single forward pass.- Parameters:
target_image (
torch.Tensor,shape [1,H,W]or[H,W]) – The reference texture whose scattering covariance we want to match.scat_op (
FoCUS) – Initialised FOSCAT operator (use_2D=True).Jmax (
int) – Number of upsampling levels in the U-Net decoder.n_samples (
int) – Noise batch size per training step. More samples → more stable gradient, more memory.channel_list (
list[int]orNone) – Feature channels per decoder level, ordered coarsest → finest. Default:[min(32·2^(Jmax-j), 256) for j in 0..Jmax].out_channels (
int) – Output channels per image (1 for single-channel textures).lr (
float) – Initial Adam learning rate.n_epochs (
int) – Number of training epochs.eval_frequency (
int) – Print loss every this many epochs.norm (
str) – Normalisation passed toscat_op.eval(e.g.'auto').Jmax_scat (
intorNone) – Maximum wavelet scale for the scattering-covariance loss.Noneuses all scales available inscat_op. Independent of the U-Net depthJmax.edge (
bool) – IfTrue, passedge=Truetoscat_op.evalto compute statistics on edge pixels as well. Must match how the target statistics are used in the rest of the pipeline.iso_ang (
bool) – IfTrue, applyiso_mean()to the scattering covariance aftereval, collapsing the orientation axes to a single isotropic mean. Cannot be combined withfft_ang.fft_ang (
bool) – IfTrue, applyfft_ang()to the scattering covariance aftereval, projecting the orientation axes onto the firstfft_nharmFourier harmonics. Cannot be combined withiso_ang.fft_nharm (
int) – Number of harmonics kept byfft_ang(beyond the DC term). Default 1. Ignored whenfft_ang=False.fft_imaginary (
bool) – IfTrue(default), keep both cosine and sine components infft_ang, giving rotation-invariant amplitudes. Ignored whenfft_ang=False.microcanonical (
bool) – IfTrue(default), use the microcanonical loss_microcanonical_loss(): the batch mean of statistics must match the target, normalised by the batch standard deviation. Mode collapse is naturally penalised because σ_batch → 0 makes the loss diverge. IfFalse, use the classical per-sample distance (each generated image must independently match the target) averaged over the batch. Requiresn_samples >= 2.sigma_synth (
bool) – Only used whenmicrocanonical=True. IfTrue(default), the denominator is \(\sigma_{k,\text{batch}} \times |\Phi^*_k|\), combining the anti-collapse property (from \(\sigma_\text{batch}\)) with the per-coefficient scaling from the synthesis sigma (\(|\Phi^*_k|\)). IfFalse, use only \(\sigma^2_{k,\text{batch}}\) (pure microcanonical).micro_eps (
float) – Variance floor for the microcanonical loss (prevents division by exactly zero at the very start of training). Default1e-6.device (
strorNone) –'cuda','cpu', or None (auto-detect).
- Returns:
SynthHalfUNet2D– The trained network (in eval mode).
Examples
>>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... edge=True, # same edge handling as scat_op.synthesis ... iso_ang=True, # isotropic loss (fewer statistics) ... n_epochs=2000, ... )
>>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... fft_ang=True, # orientation-aware Fourier loss ... fft_nharm=1, ... fft_imaginary=True, ... n_epochs=3000, ... )
- foscat.SynthHalfUNet2D.generate_samples(model: SynthHalfUNet2D, n_samples: int, H: int, W: int, device: str | None = None, seed: int | None = None) torch.Tensor[source]#
Draw n_samples textures from a trained SynthHalfUNet2D.
- Parameters:
model (
SynthHalfUNet2D) – A trained network (output oftrain_synth_unet()).n_samples (
int) – Number of independent textures to generate.H, W (
int) – Spatial dimensions of the output.device (
strorNone) – Target device. Defaults to the device of the model parameters.seed (
intorNone) – Optional random seed for reproducibility.
- Returns:
torch.Tensor,shape [n_samples,out_channels,H,W]