Component Separation#
Component separation uses scattering-covariance statistics as differentiable morphological priors to separate mixed physical fields. Because the statistics encode multi-scale, cross-orientation correlations, they capture morphological shape beyond what a simple power spectrum can describe — making them powerful priors for separating physically distinct components.
This workflow corresponds to Remove_CMB.ipynb in demo-foscat-pangeo-eosc.
The problem#
Given an observed mixture:
find component maps \(\hat{s}_1, \ldots, \hat{s}_n\) such that:
\(\hat{s}_1 + \cdots + \hat{s}_n \approx d\) (mixture consistency)
\(\Phi(\hat{s}_i) \approx \Phi_i^\text{ref}\) (each component matches its expected morphology)
where \(\Phi_i^\text{ref}\) are reference statistics for each component (from a training set, a physical model, or a prior observation).
This is an underdetermined inverse problem: FOSCAT regularises it with scattering statistics rather than with smoothness or sparsity priors.
General workflow#
Observed mixture d = s₁ + s₂ (known)
│
▼ optimise over (ŝ₁, ŝ₂)
┌──────────────────────────────────────────────────────────┐
│ L_mix = ‖ŝ₁ + ŝ₂ − d‖² (mixture fidelity) │
│ L_stat₁ = ‖Φ(ŝ₁) − Φ(s₁_ref)‖² (morphology prior) │
│ L_stat₂ = ‖Φ(ŝ₂) − Φ(s₂_ref)‖² (morphology prior) │
└──────────────────────────────────────────────────────────┘
│
▼
Synthesis([L_mix, L_stat₁, L_stat₂]).run(x0)
│
▼
ŝ₁*, ŝ₂* — separated components
Example: CMB-like background removal#
import numpy as np
import foscat.scat_cov as sc
from foscat.Synthesis import Loss, Synthesis
nside = 64
npix = 12 * nside**2
# --- input data ---
mixture = np.load("observed_mixture.npy") # d = foreground + CMB
cmb_ref = np.load("cmb_simulations.npy") # ensemble of CMB realisations
fg_ref = np.load("foreground_ref.npy") # reference foreground map
# --- operator ---
scat_op = sc.funct(KERNELSZ=5, NORIENT=4, nstep_max=4, all_type='float64')
# compute reference statistics (average over CMB ensemble if available)
cmb_stat = scat_op.eval(cmb_ref)
fg_stat = scat_op.eval(fg_ref)
# --- optimise over (ŝ_fg, ŝ_cmb) jointly ---
# pack both fields into a single vector: x = [ŝ_fg | ŝ_cmb]
def mixture_loss(x, scat_op, args):
d = args[0]
s_fg = x[:npix]
s_cmb = x[npix:]
residual = (s_fg + s_cmb - d) ** 2
return scat_op.backend.bk_mean(residual)
def fg_stat_loss(x, scat_op, args):
ref = args[0]
s_fg = x[:npix]
stat = scat_op.eval(s_fg)
return stat.reduce_mean_batch((stat - ref) ** 2)
def cmb_stat_loss(x, scat_op, args):
ref = args[0]
s_cmb = x[npix:]
stat = scat_op.eval(s_cmb)
return stat.reduce_mean_batch((stat - ref) ** 2)
loss_mix = Loss(mixture_loss, scat_op, mixture)
loss_fg = Loss(fg_stat_loss, scat_op, fg_stat)
loss_cmb = Loss(cmb_stat_loss, scat_op, cmb_stat)
# initialise: foreground = mixture, CMB = zero
x0 = np.concatenate([mixture, np.zeros(npix)])
solver = Synthesis([loss_mix, loss_fg, loss_cmb], eta=0.01)
result = solver.run(x0, NUM_EPOCHS=500, EVAL_FREQUENCY=20)
fg_hat = result[:npix]
cmb_hat = result[npix:]
Weighting losses#
The three losses above have different scales. If one dominates, weight them explicitly by scaling your loss function return values:
def cmb_stat_loss(x, scat_op, args):
ref, weight = args
s_cmb = x[npix:]
stat = scat_op.eval(s_cmb)
return weight * stat.reduce_mean_batch((stat - ref) ** 2)
loss_cmb = Loss(cmb_stat_loss, scat_op, cmb_stat, 0.5) # weight = 0.5
A good heuristic: at the first iteration, print each loss component separately and set weights so they are all \(O(1)\).
Gradient masking for partial-sky separation#
When the observation covers only part of the sky, freeze pixels outside the survey footprint and only update the internal region:
grd_mask = survey_mask.astype(np.float64)
grd_mask_full = np.tile(grd_mask, 2) # one mask per component
result = solver.run(x0, NUM_EPOCHS=500, grd_mask=grd_mask_full)
Practical notes#
Reference quality. The quality of separation depends critically on the reference statistics \(\Phi_i^\text{ref}\). Averaging over an ensemble of simulations gives better priors than a single reference map.
Amplitude ambiguity. Scattering statistics are not amplitude-preserving by default. If the amplitude ratio between components matters, add a power-spectrum constraint (e.g. via
healpy.anafast) as an additional loss term.Number of iterations. Component separation typically needs more iterations than simple synthesis (500–1000 epochs). The loss may decrease slowly if the two components are morphologically similar.
Initialisation. Initialising with the observed mixture for the first component and zeros for the second often converges faster than random noise.