foscat.SynthHalfUNet2D ====================== .. py:module:: foscat.SynthHalfUNet2D .. autoapi-nested-parse:: SynthHalfUNet2D — fast texture synthesis via a decoder-only U-Net. Architecture ------------ Input: white noise z ~ N(0,1), shape [N, 1, H, W] Skip connections (frozen wavelet bank, no encoder): At each scale j (j=0 finest, j=Jmax coarsest): z_j = AvgPool2d^j(z) [N, 1, H/2^j, W/2^j] skip_j = [Re(ψ_l ★ z_j), Im(ψ_l ★ z_j)]_{l} [N, 2L, H/2^j, W/2^j] where {ψ_l} are the L FOSCAT oriented complex wavelets (fixed, from FoCUS). At the coarsest level Jmax an extra low-frequency channel is added: z_avg = z_Jmax [N, 1, H/2^Jmax, W/2^Jmax] so the initial input to the decoder is [skip_Jmax, z_avg] → 2L+1 channels. Decoder (all parameters are learned): x = InitBlock( cat[skip_Jmax, z_avg] ) [N, C[0], H/2^Jmax, W/2^Jmax] for j in Jmax-1 .. 0: x = Upsample(x, ×2) x = DecBlock_j( cat[x, skip_j] ) [N, C[Jmax-j], H/2^j, W/2^j] out = Conv1×1(x) [N, out_ch, H, W] Each DecBlock is: Conv3×3 → BN → LeakyReLU → Conv3×3 → BN → LeakyReLU Training loss: z_i ~ N(0,1) (sampled fresh each epoch) x_i = network(z_i) loss = reduce_distance( mean_scat_cov({x_i}), target_scat_cov ) Classes ------- .. autoapisummary:: foscat.SynthHalfUNet2D.ConvBlock2D foscat.SynthHalfUNet2D.SynthHalfUNet2D Functions --------- .. autoapisummary:: foscat.SynthHalfUNet2D.train_synth_unet foscat.SynthHalfUNet2D.generate_samples Module Contents --------------- .. py:class:: ConvBlock2D(in_ch: int, out_ch: int, negative_slope: float = 0.2) Bases: :py:obj:`torch.nn.Module` Two stacked Conv3×3 + InstanceNorm + LeakyReLU layers. InstanceNorm is used instead of BatchNorm so that each sample in the batch is normalised independently. BatchNorm computes statistics across the batch dimension and therefore averages out the sample-to-sample variation introduced by the noise input z, causing mode collapse (all z produce the same output). InstanceNorm has no such effect. .. py:attribute:: block .. py:method:: forward(x: torch.Tensor) -> torch.Tensor .. py:class:: SynthHalfUNet2D(scat_op, Jmax: int, channel_list: list[int] | None = None, out_channels: int = 1) Bases: :py:obj:`torch.nn.Module` Decoder-only U-Net for fast scattering-covariance texture synthesis. At inference time a single forward pass (milliseconds) replaces the slow gradient-descent synthesis (minutes). The network is trained by overfitting to a single target scattering covariance using :func:`train_synth_unet`. :Parameters: * **scat_op** (:py:class:`FoCUS`) -- A FOSCAT FoCUS object initialised with ``use_2D=True``. Its oriented wavelet kernels (``ww_RealT[1]``, ``ww_ImagT[1]``) are extracted and registered as **frozen** buffers (not trained). * **Jmax** (:py:class:`int`) -- Number of wavelet scales (decoder depth). The spatial resolution at the coarsest level is H/2^Jmax × W/2^Jmax. * **channel_list** (:py:class:`list[int]` or :py:obj:`None`) -- Feature-map channels at each decoder level, ordered **coarsest → finest**: ``channel_list[0]`` is used at level Jmax (coarsest), ``channel_list[-1]`` is used just before the output Conv1×1. Must have length ``Jmax + 1``. Default: ``[min(32·2^(Jmax-j), 256) for j in 0..Jmax]``, e.g. ``[256, 128, 64, 32]`` for Jmax=3. * **out_channels** (:py:class:`int`) -- Number of output image channels (1 for single-channel textures). Generate N independent samples by setting the batch size of z to N. .. admonition:: Examples >>> import torch >>> from foscat.SynthHalfUNet2D import SynthHalfUNet2D, train_synth_unet, generate_samples >>> import foscat.scat_cov2D as sc >>> >>> scat_op = sc.funct(NORIENT=4, KERNELSZ=3, use_2D=True) >>> target = torch.tensor(my_image_2d) # shape [H, W] >>> >>> # Train (overfits to the scattering covariance of target) >>> model = train_synth_unet(target, scat_op, Jmax=3, n_epochs=2000) >>> >>> # Generate new samples instantly >>> samples = generate_samples(model, n_samples=8, H=256, W=256) >>> # samples: torch.Tensor [8, 1, 256, 256] .. py:attribute:: Jmax .. py:attribute:: norient .. py:attribute:: kernelsz .. py:attribute:: channel_list :value: None .. py:attribute:: init_block .. py:attribute:: decoder_blocks .. py:attribute:: output_conv .. py:method:: forward(z: torch.Tensor) -> torch.Tensor Generate synthesised images from white noise. :Parameters: **z** (:py:class:`torch.Tensor`, :py:class:`shape [N`, ``1``, :py:class:`H`, :py:class:`W]`) -- White noise input. N independent samples are produced in parallel. :returns: :py:class:`torch.Tensor`, :py:class:`shape [N`, :py:class:`out_channels`, :py:class:`H`, :py:class:`W]` .. py:function:: train_synth_unet(target_image: torch.Tensor, scat_op, Jmax: int, n_samples: int = 4, channel_list: list[int] | None = None, out_channels: int = 1, lr: float = 0.001, n_epochs: int = 2000, eval_frequency: int = 100, norm: str = 'auto', Jmax_scat=None, edge: bool = False, iso_ang: bool = False, fft_ang: bool = False, fft_nharm: int = 1, fft_imaginary: bool = True, microcanonical: bool = True, sigma_synth: bool = True, micro_eps: float = 1e-06, device: str | None = None) -> SynthHalfUNet2D Train a SynthHalfUNet2D to reproduce the scattering covariance of a target. The network is over-fitted to a single target (no generalisation). Once trained, drawing a new z ~ N(0,1) and calling ``model(z)`` generates a new texture sample in a single forward pass. :Parameters: * **target_image** (:py:class:`torch.Tensor`, :py:class:`shape [1`, :py:class:`H`, :py:class:`W]` or :py:class:`[H`, :py:class:`W]`) -- The reference texture whose scattering covariance we want to match. * **scat_op** (:py:class:`FoCUS`) -- Initialised FOSCAT operator (``use_2D=True``). * **Jmax** (:py:class:`int`) -- Number of upsampling levels in the U-Net decoder. * **n_samples** (:py:class:`int`) -- Noise batch size per training step. More samples → more stable gradient, more memory. * **channel_list** (:py:class:`list[int]` or :py:obj:`None`) -- Feature channels per decoder level, ordered **coarsest → finest**. Default: ``[min(32·2^(Jmax-j), 256) for j in 0..Jmax]``. * **out_channels** (:py:class:`int`) -- Output channels per image (1 for single-channel textures). * **lr** (:py:class:`float`) -- Initial Adam learning rate. * **n_epochs** (:py:class:`int`) -- Number of training epochs. * **eval_frequency** (:py:class:`int`) -- Print loss every this many epochs. * **norm** (:py:class:`str`) -- Normalisation passed to ``scat_op.eval`` (e.g. ``'auto'``). * **Jmax_scat** (:py:class:`int` or :py:obj:`None`) -- Maximum wavelet scale for the scattering-covariance loss. ``None`` uses all scales available in ``scat_op``. **Independent of the U-Net depth** ``Jmax``. * **edge** (:py:class:`bool`) -- If ``True``, pass ``edge=True`` to ``scat_op.eval`` to compute statistics on edge pixels as well. Must match how the target statistics are used in the rest of the pipeline. * **iso_ang** (:py:class:`bool`) -- If ``True``, apply :meth:`~foscat.scat_cov.scat_cov.iso_mean` to the scattering covariance after ``eval``, collapsing the orientation axes to a single isotropic mean. Cannot be combined with ``fft_ang``. * **fft_ang** (:py:class:`bool`) -- If ``True``, apply :meth:`~foscat.scat_cov.scat_cov.fft_ang` to the scattering covariance after ``eval``, projecting the orientation axes onto the first ``fft_nharm`` Fourier harmonics. Cannot be combined with ``iso_ang``. * **fft_nharm** (:py:class:`int`) -- Number of harmonics kept by ``fft_ang`` (beyond the DC term). Default 1. Ignored when ``fft_ang=False``. * **fft_imaginary** (:py:class:`bool`) -- If ``True`` (default), keep both cosine and sine components in ``fft_ang``, giving rotation-invariant amplitudes. Ignored when ``fft_ang=False``. * **microcanonical** (:py:class:`bool`) -- If ``True`` (default), use the **microcanonical loss** :func:`_microcanonical_loss`: the batch mean of statistics must match the target, normalised by the batch standard deviation. Mode collapse is naturally penalised because σ_batch → 0 makes the loss diverge. If ``False``, use the classical per-sample distance (each generated image must independently match the target) averaged over the batch. Requires ``n_samples >= 2``. * **sigma_synth** (:py:class:`bool`) -- Only used when ``microcanonical=True``. If ``True`` (default), the denominator is :math:`\sigma_{k,\text{batch}} \times |\Phi^*_k|`, combining the anti-collapse property (from :math:`\sigma_\text{batch}`) with the per-coefficient scaling from the synthesis sigma (:math:`|\Phi^*_k|`). If ``False``, use only :math:`\sigma^2_{k,\text{batch}}` (pure microcanonical). * **micro_eps** (:py:class:`float`) -- Variance floor for the microcanonical loss (prevents division by exactly zero at the very start of training). Default ``1e-6``. * **device** (:py:class:`str` or :py:obj:`None`) -- ``'cuda'``, ``'cpu'``, or None (auto-detect). :returns: :py:class:`SynthHalfUNet2D` -- The trained network (in eval mode). .. admonition:: Examples >>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... edge=True, # same edge handling as scat_op.synthesis ... iso_ang=True, # isotropic loss (fewer statistics) ... n_epochs=2000, ... ) >>> model = train_synth_unet( ... target, scat_op, Jmax=4, ... fft_ang=True, # orientation-aware Fourier loss ... fft_nharm=1, ... fft_imaginary=True, ... n_epochs=3000, ... ) .. py:function:: generate_samples(model: SynthHalfUNet2D, n_samples: int, H: int, W: int, device: str | None = None, seed: int | None = None) -> torch.Tensor Draw n_samples textures from a trained SynthHalfUNet2D. :Parameters: * **model** (:py:class:`SynthHalfUNet2D`) -- A trained network (output of :func:`train_synth_unet`). * **n_samples** (:py:class:`int`) -- Number of independent textures to generate. * **H, W** (:py:class:`int`) -- Spatial dimensions of the output. * **device** (:py:class:`str` or :py:obj:`None`) -- Target device. Defaults to the device of the model parameters. * **seed** (:py:class:`int` or :py:obj:`None`) -- Optional random seed for reproducibility. :returns: :py:class:`torch.Tensor`, :py:class:`shape [n_samples`, :py:class:`out_channels`, :py:class:`H`, :py:class:`W]`