Fast Synthesis with a Decoder U-Net#
The classical synthesis workflow in FOSCAT uses L-BFGS-B gradient descent: starting from Gaussian noise, the optimizer iteratively adjusts pixel values until the scattering-covariance statistics of the synthesised image match the target. This produces excellent quality but can take minutes per image — and every new sample requires restarting the optimisation.
SynthHalfUNet2D offers a complementary approach: train a small neural network
once (a few thousand epochs), then generate as many new samples as needed with a
single forward pass (milliseconds).
Core idea#
Instead of searching for one good image, we train a generative network that maps white noise to texture. The network is over-fitted to a single target scattering covariance: after training, any Gaussian noise vector \(z\) fed to the network produces a new image whose statistics match the target.
The key architectural choice is how to make the network multi-scale without
an encoder. The solution: the skip connections that normally carry encoder
features are replaced by oriented wavelet responses of the input noise — the
same FOSCAT wavelets used by FoCUS, applied to progressively downsampled
versions of \(z\).
Architecture: decoder-only U-Net#
z ~ N(0,1) [N, 1, H, W]
│
├─ Wavelet bank (frozen FOSCAT kernels)
│ ├── ψ_l ★ z → skip[0] [N, 2L, H, W ] finest
│ ├── ψ_l ★ Pool(z) → skip[1] [N, 2L, H/2, W/2 ]
│ ├── ψ_l ★ Pool²(z) → skip[2] [N, 2L, H/4, W/4 ]
│ └── ψ_l ★ Pool^J(z) → skip[J] [N, 2L, H/2^J, W/2^J ] coarsest
│
│ Decoder (parameters learned during training)
│
├── InitBlock( cat[skip[J], Pool^J(z)] ) → [N, C[0], H/2^J, W/2^J]
│ ↑ 2L+1 input channels
├── Upsample ×2 → cat(skip[J-1]) → ConvBlock → [N, C[1], H/2^(J-1), W/2^(J-1)]
├── Upsample ×2 → cat(skip[J-2]) → ConvBlock → [N, C[2], H/2^(J-2), W/2^(J-2)]
│ ...
├── Upsample ×2 → cat(skip[0]) → ConvBlock → [N, C[J], H, W ]
│
└── Conv 1×1 → x_out [N, out_ch, H, W]
Wavelet skip connections. At each scale \(j\), the noise is downsampled \(2^j\) times and then convolved with all \(L\) oriented FOSCAT wavelets:
This gives \(2L\) channels per scale (real and imaginary parts of each of the \(L\) oriented wavelets). The real and imaginary parts together preserve the full complex wavelet response, allowing the decoder to reconstruct both the amplitude and phase of the angular modulation at each scale.
Why not use the modulus \(|\psi_l \star z_j|\) as skip? The modulus discards the phase of the wavelet response, limiting the diversity of generated samples. Keeping both components lets the network exploit the full information in \(z\).
Low-frequency channel. At the coarsest scale \(J\), the spatially averaged noise \(z_J = \operatorname{AvgPool}^J(z)\) is concatenated alongside the wavelet skips, providing a direct low-frequency anchor. The InitBlock input therefore has \(2L + 1\) channels.
ConvBlock. Each decoder block applies two rounds of:
No skip connections from an encoder; the wavelet responses of \(z\) play that role.
Training objective#
Microcanonical loss (default)#
The default training objective is the microcanonical loss, which constrains the distribution of statistics across the N generated images rather than each image individually.
Combined normalisation (sigma_synth=True, default):
The denominator is the product of two complementary normalisation terms:
Term |
Formula |
Role |
|---|---|---|
Batch std |
\(\sigma_{k,\text{batch}} = \sqrt{\tfrac{1}{N}\sum_i (\Phi_k(x_i) - \bar\Phi_k)^2}\) |
Anti-collapse: diverges if all images identical |
Synthesis sigma |
$ |
\Phi^*_k |
where \(\bar{\Phi}_k = \frac{1}{N}\sum_{i=1}^N \Phi_k(G_\theta(z_i))\).
Pure microcanonical (sigma_synth=False):
Why microcanonical? The analogy is to statistical physics:
The canonical approach (classical gradient-descent synthesis) forces every microstate (generated image) to individually satisfy the constraints.
The microcanonical approach only requires the ensemble average to match the target. Individual images may differ, as long as their collective statistics are correct.
Key property — built-in anti-collapse: if all N images become identical (mode collapse), \(\sigma_{k,\text{batch}} \to 0\) and the loss diverges. The gradient therefore simultaneously pushes:
\(\bar{\Phi}_k \to \Phi^*_k\) (match the target mean), and
\(\sigma_{k,\text{batch}}\) to remain nonzero (enforce sample diversity).
The equilibrium is the microcanonical ensemble of textures consistent with \(\Phi^*\): distinct images whose collective statistics reproduce the target.
Classical loss (microcanonical=False)#
Setting microcanonical=False reverts to the canonical per-sample loss:
This penalises each image independently; it converges reliably but is prone to mode collapse (all samples look the same), especially with expressive networks.
At each epoch, a fresh batch of \(N\) noise vectors is drawn, ensuring the network learns to map any noise to a valid texture. The optimizer is Adam with a cosine-annealing learning-rate schedule.
Parameters#
SynthHalfUNet2D#
Parameter |
Type |
Description |
|---|---|---|
|
|
FOSCAT operator initialised with |
|
|
Decoder depth (number of upsampling steps). The coarsest feature map has resolution \(H/2^{J} \times W/2^{J}\). Independent of the number of scales in |
|
|
Feature-map channels per level, ordered coarsest → finest, length |
|
|
Output channels per image. Use |
train_synth_unet#
Parameter |
Default |
Description |
|---|---|---|
|
— |
Reference texture, shape |
|
— |
FOSCAT 2D operator. |
|
— |
U-Net decoder depth (see above). |
|
|
Noise batch size per training step. Larger → more stable gradient, more memory. |
|
|
See above. |
|
|
Output image channels. |
|
|
Initial learning rate (Adam). |
|
|
Training epochs. |
|
|
Print loss every N epochs. |
|
|
Normalisation passed to |
|
|
Maximum scale for scattering-covariance computation. |
|
|
If |
|
|
Apply |
|
|
Apply |
|
|
Number of harmonics beyond DC kept by |
|
|
If |
|
|
Use the microcanonical loss (see Training objective). Constrains the ensemble mean of statistics to match the target, normalised by the batch std. Requires |
|
|
Only used when |
|
|
Floor on the denominator of the microcanonical loss. Prevents division by zero at the very start of training. |
|
auto |
|
generate_samples#
samples = generate_samples(model, n_samples=16, H=256, W=256, seed=42)
# → torch.Tensor [16, 1, 256, 256]
Parameter |
Description |
|---|---|
|
Trained |
|
Number of independent textures to generate. |
|
Spatial dimensions (must match the training image). |
|
Optional random seed for reproducibility. |
Channel layout#
channel_list is indexed from coarsest (index 0) to finest (last index),
one entry per decoder level:
Jmax = 3, image 256×256, channel_list = [256, 128, 64, 32]
Level Scale Resolution channel_list[k]
───── ───── ────────── ───────────────
0 j=3 32 × 32 256 ← InitBlock output
1 j=2 64 × 64 128
2 j=1 128 × 128 64
3 j=0 256 × 256 32 ← last decoder block, feeds Conv1×1
A uniform layout (same channels at all scales) is also valid and sometimes easier to tune:
channel_list = [64, 64, 64, 64] # 64 channels everywhere, Jmax=3
Two independent Jmax parameters#
A common source of confusion: there are two separate depth parameters.
Parameter |
Controls |
Who uses it |
|---|---|---|
U-Net |
How many upsampling levels the decoder has. Determines the coarsest spatial resolution of the feature maps. |
The network only. |
|
How many wavelet scales are used when computing the scattering covariance loss. |
|
They are independent: you can have a deep decoder (Jmax=5) with a
conservative scattering loss (Jmax_scat=3), or vice versa.
Setting Jmax_scat=None (default) tells FOSCAT to use all scales it was
initialised with — the safest choice.
Complete usage example#
import torch
import foscat.scat_cov2D as sc
from foscat.SynthHalfUNet2D import train_synth_unet, generate_samples
# 1. Create FOSCAT operator
scat_op = sc.funct(NORIENT=4, KERNELSZ=3, use_2D=True)
# 2. Load target image (numpy array → torch tensor)
import numpy as np
target = torch.tensor(np.load("my_texture.npy"), dtype=torch.float32) # [H, W]
# 3. Train the decoder U-Net
# - Jmax=4: 4 upsampling levels (coarsest map = H/16 × W/16 for H=256)
# - channel_list: 128 channels at coarsest, 8 at finest
# - edge=True: include edge pixels in the statistics (match synthesis convention)
# - fft_ang=True: orientation-aware Fourier loss (keeps DC = iso_mean + harmonics)
# - Jmax_scat=None: use all FOSCAT scales
model = train_synth_unet(
target,
scat_op,
Jmax=4,
channel_list=[128, 64, 32, 16, 8],
n_samples=8,
lr=1e-3,
n_epochs=3000,
eval_frequency=200,
edge=True,
fft_ang=True, # or iso_ang=True for isotropic loss
fft_nharm=1,
fft_imaginary=True,
)
# 4. Generate new samples instantly
samples = generate_samples(model, n_samples=16, H=256, W=256, seed=0)
# samples: torch.Tensor shape [16, 1, 256, 256]
# 5. Save / display
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i, 0].cpu().numpy(), cmap="gray")
ax.axis("off")
plt.suptitle("Synthesised textures (single forward pass each)")
plt.tight_layout()
plt.show()
# 6. Save the trained model for later
torch.save(model.state_dict(), "synth_unet_texture.pt")
# 7. Reload
model2 = SynthHalfUNet2D(scat_op, Jmax=4, channel_list=[128, 64, 32, 16, 8])
model2.load_state_dict(torch.load("synth_unet_texture.pt"))
model2.eval()
Comparison with gradient-descent synthesis#
|
|
|
|---|---|---|
Method |
L-BFGS-B gradient descent |
Neural network (Adam) |
Training cost |
None |
~minutes (once per texture) |
Inference cost per sample |
Minutes |
Milliseconds |
Quality |
Reference |
Comparable after sufficient training |
Multiple samples |
Restart from scratch |
Instant (single forward pass) |
Memory |
Low |
Depends on |
Geometry |
2D and HEALPix |
2D only (this module) |
Orientation reduction |
|
|
Diversity enforcement |
Manual (restart needed) |
Microcanonical loss (built-in) |
The two approaches are complementary: gradient descent is the gold standard for a single high-quality synthesis; the U-Net is preferable whenever many independent samples are needed quickly (Monte Carlo studies, uncertainty estimation, data augmentation).