import pickle
import sys
import healpy as hp
import numpy as np
import foscat as foscat
# import foscat.backend as bk
import foscat.FoCUS as FOC
# Check whether TensorFlow is imported and available
tf_defined = "tensorflow" in sys.modules
if tf_defined:
import tensorflow as tf
tf_function = (
tf.function
) # Facultatif : si vous voulez utiliser TensorFlow dans ce script
else:
[docs]
def tf_function(func):
return func
[docs]
def read(filename):
thescat = scat_cov(1, 1, 1, 1)
return thescat.read(filename)
testwarn = 0
[docs]
class scat_cov:
def __init__(self,
s0, s2, s3, s4,
s1=None,
s3p=None,
backend=None,
use_1D=False,
return_data=False
):
self.S0 = s0
self.S2 = s2
self.S3 = s3
self.S4 = s4
self.S1 = s1
self.S3P = s3p
self.backend = backend
self.idx1 = None
self.idx2 = None
self.use_1D = use_1D
if not return_data:
self.numel = self.backend.bk_len(s0)+ \
self.backend.bk_len(s1)+ \
self.backend.bk_len(s2)+ \
self.backend.bk_len(s3)+ \
self.backend.bk_len(s4)+ \
self.backend.bk_len(s3p)
[docs]
def numpy(self):
if self.BACKEND == "numpy":
return self
if self.S1 is None:
s1 = None
else:
s1 = self.backend.to_numpy(self.S1)
if self.S3P is None:
s3p = None
else:
s3p = self.backend.to_numpy(self.S3P)
return scat_cov(
self.backend.to_numpy(self.S0),
self.backend.to_numpy(self.S2),
self.backend.to_numpy(self.S3),
self.backend.to_numpy(self.S4),
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def constant(self):
if self.S1 is None:
s1 = None
else:
s1 = self.backend.constant(self.S1)
if self.S3P is None:
s3p = None
else:
s3p = self.backend.constant(self.S3P)
return scat_cov(
self.backend.constant(self.S0),
self.backend.constant(self.S2),
self.backend.constant(self.S3),
self.backend.constant(self.S4),
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def conv2complex(self, val):
if self.backend.bk_is_complex(val):
return val
else:
return self.backend.bk_complex(val, 0 * val)
return val
# ---------------------------------------------−---------
[docs]
def flatten(self):
if self.use_1D:
tmp = [
self.conv2complex(
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
)
]
if self.S1 is not None:
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S1,
[self.S1.shape[0], self.S1.shape[1] * self.S1.shape[2]],
)
)
]
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S2,
[self.S1.shape[0], self.S1.shape[1] * self.S1.shape[2]],
)
),
self.conv2complex(
self.backend.bk_reshape(
self.S3,
[self.S3.shape[0], self.S3.shape[1] * self.S3.shape[2]],
)
),
]
if self.S3P is not None:
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S3P,
[self.S3.shape[0], self.S3.shape[1] * self.S3.shape[2]],
)
)
]
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S4,
[self.S3.shape[0], self.S4.shape[1] * self.S4.shape[2]],
)
)
]
return self.backend.bk_concat(tmp, 1)
tmp = [
self.conv2complex(
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
)
]
if self.S1 is not None:
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S1,
[
self.S1.shape[0],
self.S1.shape[1] * self.S1.shape[2] * self.S1.shape[3],
],
)
)
]
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S2,
[
self.S1.shape[0],
self.S1.shape[1] * self.S1.shape[2] * self.S1.shape[3],
],
)
),
self.conv2complex(
self.backend.bk_reshape(
self.S3,
[
self.S3.shape[0],
self.S3.shape[1]
* self.S3.shape[2]
* self.S3.shape[3]
* self.S3.shape[4],
],
)
),
]
if self.S3P is not None:
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S3P,
[
self.S3.shape[0],
self.S3.shape[1]
* self.S3.shape[2]
* self.S3.shape[3]
* self.S3.shape[4],
],
)
)
]
tmp = tmp + [
self.conv2complex(
self.backend.bk_reshape(
self.S4,
[
self.S4.shape[0],
self.S4.shape[1]
* self.S4.shape[2]
* self.S4.shape[3]
* self.S4.shape[4]
* self.S4.shape[5],
],
)
)
]
return self.backend.bk_concat(tmp, 1)
# ---------------------------------------------−---------
[docs]
def flattenMask(self):
if isinstance(self.S2, np.ndarray):
if self.S1 is None:
if self.S3P is None:
tmp = np.concatenate(
[
self.S0[0].flatten(),
self.S2[0].flatten(),
self.S3[0].flatten(),
self.S4[0].flatten(),
],
0,
)
else:
tmp = np.concatenate(
[
self.S0[0].flatten(),
self.S2[0].flatten(),
self.S3[0].flatten(),
self.S3P[0].flatten(),
self.S4[0].flatten(),
],
0,
)
else:
if self.S3P is None:
tmp = np.concatenate(
[
self.S0[0].flatten(),
self.S1[0].flatten(),
self.S2[0].flatten(),
self.S3[0].flatten(),
self.S4[0].flatten(),
],
0,
)
else:
tmp = np.concatenate(
[
self.S0[0].flatten(),
self.S1[0].flatten(),
self.S2[0].flatten(),
self.S3[0].flatten(),
self.S3P[0].flatten(),
self.S4[0].flatten(),
],
0,
)
tmp = np.expand_dims(tmp, 0)
for k in range(1, self.S2.shape[0]):
if self.S1 is None:
if self.S3P is None:
ltmp = np.concatenate(
[
self.S0[k].flatten(),
self.S2[k].flatten(),
self.S3[k].flatten(),
self.S4[k].flatten(),
],
0,
)
else:
ltmp = np.concatenate(
[
self.S0[k].flatten(),
self.S2[k].flatten(),
self.S3[k].flatten(),
self.S3P[k].flatten(),
self.S4[k].flatten(),
],
0,
)
else:
if self.S3P is None:
ltmp = np.concatenate(
[
self.S0[k].flatten(),
self.S1[k].flatten(),
self.S2[k].flatten(),
self.S3[k].flatten(),
self.S4[k].flatten(),
],
0,
)
else:
ltmp = np.concatenate(
[
self.S0[k].flatten(),
self.S1[k].flatten(),
self.S2[k].flatten(),
self.S3[k].flatten(),
self.S3P[k].flatten(),
self.S4[k].flatten(),
],
0,
)
tmp = np.concatenate([tmp, np.expand_dims(ltmp, 0)], 0)
return tmp
else:
if self.S1 is None:
if self.S3P is None:
tmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[0]),
self.backend.bk_flattenR(self.S2[0]),
self.backend.bk_flattenR(self.S3[0]),
self.backend.bk_flattenR(self.S4[0]),
],
0,
)
else:
tmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[0]),
self.backend.bk_flattenR(self.S2[0]),
self.backend.bk_flattenR(self.S3[0]),
self.backend.bk_flattenR(self.S3P[0]),
self.backend.bk_flattenR(self.S4[0]),
],
0,
)
else:
if self.S3P is None:
tmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[0]),
self.backend.bk_flattenR(self.S1[0]),
self.backend.bk_flattenR(self.S2[0]),
self.backend.bk_flattenR(self.S3[0]),
self.backend.bk_flattenR(self.S4[0]),
],
0,
)
else:
tmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[0]),
self.backend.bk_flattenR(self.S1[0]),
self.backend.bk_flattenR(self.S2[0]),
self.backend.bk_flattenR(self.S3[0]),
self.backend.bk_flattenR(self.S3P[0]),
self.backend.bk_flattenR(self.S4[0]),
],
0,
)
tmp = self.backend.bk_expand_dims(tmp, 0)
for k in range(1, self.S2.shape[0]):
if self.S1 is None:
if self.S3P is None:
ltmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[k]),
self.backend.bk_flattenR(self.S2[k]),
self.backend.bk_flattenR(self.S3[k]),
self.backend.bk_flattenR(self.S4[k]),
],
0,
)
else:
ltmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[k]),
self.backend.bk_flattenR(self.S2[k]),
self.backend.bk_flattenR(self.S3[k]),
self.backend.bk_flattenR(self.S3P[k]),
self.backend.bk_flattenR(self.S4[k]),
],
0,
)
else:
if self.S3P is None:
ltmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[k]),
self.backend.bk_flattenR(self.S1[k]),
self.backend.bk_flattenR(self.S2[k]),
self.backend.bk_flattenR(self.S3[k]),
self.backend.bk_flattenR(self.S4[k]),
],
0,
)
else:
ltmp = self.backend.bk_concat(
[
self.backend.bk_flattenR(self.S0[k]),
self.backend.bk_flattenR(self.S1[k]),
self.backend.bk_flattenR(self.S2[k]),
self.backend.bk_flattenR(self.S3[k]),
self.backend.bk_flattenR(self.S3P[k]),
self.backend.bk_flattenR(self.S4[k]),
],
0,
)
tmp = self.backend.bk_concat(
[tmp, self.backend.bk_expand_dims(ltmp, 0)], 0
)
return tmp
[docs]
def get_S0(self):
return self.S0
[docs]
def get_S1(self):
return self.S1
[docs]
def get_S2(self):
return self.S2
[docs]
def reset_S2(self):
self.S2 = 0 * self.S2
[docs]
def get_S3(self):
return self.S3
[docs]
def get_S3P(self):
return self.S3P
[docs]
def get_S4(self):
return self.S4
[docs]
def get_j_idx(self):
shape = list(self.S2.shape)
if len(shape) == 3:
nscale = shape[2]
elif len(shape) == 4:
nscale = shape[2]
else:
nscale = shape[3]
n = nscale * (nscale + 1) // 2
j1 = np.zeros([n], dtype="int")
j2 = np.zeros([n], dtype="int")
n = 0
for i in range(nscale):
for j in range(i + 1):
j1[n] = j
j2[n] = i
n = n + 1
return j1, j2
[docs]
def get_js4_idx(self):
shape = list(self.S2.shape)
nscale = shape[2]
n = nscale * np.max([nscale - 1, 1]) * np.max([nscale - 2, 1])
j1 = np.zeros([n * 4], dtype="int")
j2 = np.zeros([n * 4], dtype="int")
j3 = np.zeros([n * 4], dtype="int")
n = 0
for i in range(nscale):
for j in range(i + 1):
for k in range(j + 1):
j1[n] = k
j2[n] = j
j3[n] = i
n = n + 1
return (j1[0:n], j2[0:n], j3[0:n])
def __add__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
if other.S1 is None:
s1 = None
else:
s1 = self.doadd(self.S1, other.S1)
else:
s1 = self.S1 + other
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
if other.S3P is None:
s3p = None
else:
s3p = self.doadd(self.S3P, other.S3P)
else:
s3p = self.S3P + other
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.doadd(self.S4, other.S4)
else:
s4 = self.S4 + other
if isinstance(other, scat_cov):
return scat_cov(
self.doadd(self.S0, other.S0),
self.doadd(self.S2, other.S2),
self.doadd(self.S3, other.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(self.S0 + other),
(self.S2 + other),
(self.S3 + other),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def relu(self):
if self.S1 is None:
s1 = None
else:
s1 = self.backend.bk_relu(self.S1)
if self.S3P is None:
s3p = None
else:
s3p = self.backend.bk_relu(self.s3p)
if self.S4 is None:
s4 = None
else:
s4 = self.backend.bk_relu(self.s4)
return scat_cov(
self.backend.bk_relu(self.S0),
self.backend.bk_relu(self.S2),
self.backend.bk_relu(self.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
def __radd__(self, other):
return self.__add__(other)
def __truediv__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
if other.S1 is None:
s1 = None
else:
s1 = self.dodiv(self.S1, other.S1)
else:
s1 = self.dodiv(self.S1, other)
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
if other.S3P is None:
s3p = None
else:
s3p = self.dodiv(self.S3P, other.S3P)
else:
s3p = self.dodiv(self.S3P, other)
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.dodiv(self.S4, other.S4)
else:
s4 = self.S4 / other
if isinstance(other, scat_cov):
return scat_cov(
self.dodiv(self.S0, other.S0),
self.dodiv(self.S2, other.S2),
self.dodiv(self.S3, other.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(self.S0 / other),
(self.S2 / other),
(self.S3 / other),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
def __rtruediv__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
s1 = self.dodiv(other.S1, self.S1)
else:
s1 = other / self.S1
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
s3p = self.dodiv(other.S3P, self.S3P)
else:
s3p = other / self.S3P
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.dodiv(other.S4, self.S4)
else:
s4 = other / self.S4
if isinstance(other, scat_cov):
return scat_cov(
self.dodiv(other.S0, self.S0),
self.dodiv(other.S2, self.S2),
self.dodiv(other.S3, self.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(other / self.S0),
(other / self.S2),
(other / self.S3),
(other / self.S4),
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
def __rsub__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
if other.S1 is None:
s1 = None
else:
s1 = self.domin(other.S1, self.S1)
else:
s1 = other - self.S1
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
if other.S3P is None:
s3p = None
else:
s3p = self.domin(other.S3P, self.S3P)
else:
s3p = other - self.S3P
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.domin(other.S4, self.S4)
else:
s4 = other - self.S4
if isinstance(other, scat_cov):
return scat_cov(
self.domin(other.S0, self.S0),
self.domin(other.S2, self.S2),
self.domin(other.S3, self.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(other - self.S0),
(other - self.S2),
(other - self.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
def __sub__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
if other.S1 is None:
s1 = None
else:
s1 = self.domin(self.S1, other.S1)
else:
s1 = self.S1 - other
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
if other.S3P is None:
s3p = None
else:
s3p = self.domin(self.S3P, other.S3P)
else:
s3p = self.S3P - other
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.domin(self.S4, other.S4)
else:
s4 = self.S4 - other
if isinstance(other, scat_cov):
return scat_cov(
self.domin(self.S0, other.S0),
self.domin(self.S2, other.S2),
self.domin(self.S3, other.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(self.S0 - other),
(self.S2 - other),
(self.S3 - other),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def domult(self, x, y):
try:
return x * y
except:
if x.dtype == y.dtype:
return x * y
if self.backend.bk_is_complex(x):
return self.backend.bk_complex(
self.backend.bk_real(x) * y, self.backend.bk_imag(x) * y
)
else:
return self.backend.bk_complex(
self.backend.bk_real(y) * x, self.backend.bk_imag(y) * x
)
[docs]
def dodiv(self, x, y):
try:
return x / y
except:
if x.dtype == y.dtype:
return x / y
if self.backend.bk_is_complex(x):
return self.backend.bk_complex(
self.backend.bk_real(x) / y, self.backend.bk_imag(x) / y
)
else:
return self.backend.bk_complex(
x / self.backend.bk_real(y), x / self.backend.bk_imag(y)
)
[docs]
def domin(self, x, y):
try:
return x - y
except:
if x.dtype == y.dtype:
return x - y
if self.backend.bk_is_complex(x):
return self.backend.bk_complex(
self.backend.bk_real(x) - y, self.backend.bk_imag(x) - y
)
else:
return self.backend.bk_complex(
x - self.backend.bk_real(y), x - self.backend.bk_imag(y)
)
[docs]
def doadd(self, x, y):
try:
return x + y
except:
if x.dtype == y.dtype:
return x + y
if self.backend.bk_is_complex(x):
return self.backend.bk_complex(
self.backend.bk_real(x) + y, self.backend.bk_imag(x) + y
)
else:
return self.backend.bk_complex(
x + self.backend.bk_real(y), x + self.backend.bk_imag(y)
)
def __mul__(self, other):
assert (
isinstance(other, float)
or isinstance(other, np.float32)
or isinstance(other, int)
or isinstance(other, bool)
or isinstance(other, scat_cov)
)
if self.S1 is None:
s1 = None
else:
if isinstance(other, scat_cov):
if other.S1 is None:
s1 = None
else:
s1 = self.domult(self.S1, other.S1)
else:
s1 = self.S1 * other
if self.S3P is None:
s3p = None
else:
if isinstance(other, scat_cov):
if other.S3P is None:
s3p = None
else:
s3p = self.domult(self.S3P, other.S3P)
else:
s3p = self.S3P * other
if self.S4 is None:
s4 = None
else:
if isinstance(other, scat_cov):
if other.S4 is None:
s4 = None
else:
s4 = self.domult(self.S4, other.S4)
else:
s4 = self.S4 * other
if isinstance(other, scat_cov):
return scat_cov(
self.domult(self.S0, other.S0),
self.domult(self.S2, other.S2),
self.domult(self.S3, other.S3),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
(self.S0 * other),
(self.S2 * other),
(self.S3 * other),
s4,
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
def __rmul__(self, other):
return self.__mul__(other)
# ---------------------------------------------−---------
[docs]
def interp(self, nscale, extend=True, constant=False):
if nscale + 2 > self.S2.shape[2]:
print(
"Can not *interp* %d with a statistic described over %d"
% (nscale, self.S2.shape[2])
)
return scat_cov(
self.S2,
self.S3,
self.S4,
s1=self.S1,
s3p=self.S3P,
backend=self.backend,
)
if self.S1 is not None:
if self.BACKEND == "numpy":
s1 = self.S1
else:
s1 = self.S1.numpy()
else:
s1 = self.S1
if self.BACKEND == "numpy":
s2 = self.S2
else:
s2 = self.S2.numpy()
for k in range(nscale):
if constant:
if self.S1 is not None:
s1[:, :, nscale - 1 - k, :] = s1[:, :, nscale - k, :]
s2[:, :, nscale - 1 - k, :] = s2[:, :, nscale - k, :]
else:
if self.S1 is not None:
s1[:, :, nscale - 1 - k, :] = np.exp(
2 * np.log(s1[:, :, nscale - k, :])
- np.log(s1[:, :, nscale + 1 - k, :])
)
s2[:, :, nscale - 1 - k, :] = np.exp(
2 * np.log(s2[:, :, nscale - k, :])
- np.log(s2[:, :, nscale + 1 - k, :])
)
j1, j2 = self.get_j_idx()
if self.S3P is not None:
if self.BACKEND == "numpy":
s3p = self.S3P
else:
s3p = self.S3P.numpy()
else:
s3p = self.S3P
if self.BACKEND == "numpy":
s3 = self.S3
else:
s3 = self.S3.numpy()
for k in range(nscale):
for l_orient in range(nscale - k):
i0 = np.where(
(j1 == nscale - 1 - k - l_orient) * (j2 == nscale - 1 - k)
)[0]
i1 = np.where((j1 == nscale - 1 - k - l_orient) * (j2 == nscale - k))[0]
i2 = np.where(
(j1 == nscale - 1 - k - l_orient) * (j2 == nscale + 1 - k)
)[0]
if constant:
s3p[:, :, i0] = s3p[:, :, i1]
s3[:, :, i0] = s3[:, :, i1]
else:
s3p[:, :, i0] = np.exp(
2 * np.log(s3p[:, :, i1]) - np.log(s3p[:, :, i2])
)
s3[:, :, i0] = np.exp(
2 * np.log(s3[:, :, i1]) - np.log(s3[:, :, i2])
)
if self.BACKEND == "numpy":
s4 = self.S4
else:
s4 = self.S4.numpy()
j1, j2, j3 = self.get_js4_idx()
for k in range(nscale):
for l_orient in range(nscale - k):
for m in range(nscale - k - l_orient):
i0 = np.where(
(j1 == nscale - 1 - k - l_orient - m)
* (j2 == nscale - 1 - k - l_orient)
* (j3 == nscale - 1 - k)
)[0]
i1 = np.where(
(j1 == nscale - 1 - k - l_orient - m)
* (j2 == nscale - 1 - k - l_orient)
* (j3 == nscale - k)
)[0]
i2 = np.where(
(j1 == nscale - 1 - k - l_orient - m)
* (j2 == nscale - 1 - k - l_orient)
* (j3 == nscale + 1 - k)
)[0]
if constant:
s4[:, :, i0] = s4[:, :, i1]
else:
s4[:, :, i0] = np.exp(
2 * np.log(s4[:, :, i1]) - np.log(s4[:, :, i2])
)
if s1 is not None:
s1 = self.backend.constant(s1)
if s3p is not None:
s3p = self.backend.constant(s3p)
return scat_cov(
self.S0,
self.backend.constant(s2),
self.backend.constant(s3),
self.backend.constant(s4),
s1=s1,
s3p=s3p,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def plot(self, name=None, hold=True, color="blue", lw=1, legend=True, norm=False):
import matplotlib.pyplot as plt
if name is None:
name = ""
j1, j2 = self.get_j_idx()
if hold:
plt.figure(figsize=(16, 8))
test = None
plt.subplot(2, 2, 2)
tmp = abs(self.get_np(self.S2))
ntmp = np.sqrt(tmp)
if len(tmp.shape) > 3:
for k in range(tmp.shape[3]):
for i1 in range(tmp.shape[0]):
for i2 in range(tmp.shape[1]):
if test is None:
test = 1
plt.plot(
tmp[i1, i2, :, k],
color=color,
label=r"%s $S_2$" % (name),
lw=lw,
)
else:
plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
else:
for i1 in range(tmp.shape[0]):
for i2 in range(tmp.shape[1]):
if test is None:
test = 1
plt.plot(
tmp[i1, i2, :],
color=color,
label=r"%s $S_2$" % (name),
lw=lw,
)
else:
plt.plot(tmp[i1, i2, :], color=color, lw=lw)
plt.yscale("log")
plt.ylabel("$S_2$")
plt.xlabel(r"$j_{1}$")
plt.legend(frameon=0)
if self.S1 is not None:
plt.subplot(2, 2, 1)
tmp = abs(self.get_np(self.S1))
test = None
if len(tmp.shape) > 3:
for k in range(tmp.shape[3]):
for i1 in range(tmp.shape[0]):
for i2 in range(tmp.shape[1]):
if test is None:
test = 1
if norm:
plt.plot(
tmp[
i1,
i2,
:,
k,
]
/ ntmp[i1, i2, :, k],
color=color,
label=r"%s norm. $S_1$" % (name),
lw=lw,
)
else:
plt.plot(
tmp[i1, i2, :, k],
color=color,
label=r"%s $S_1$" % (name),
lw=lw,
)
else:
if norm:
plt.plot(
tmp[i1, i2, :, k] / ntmp[i1, i2, :, k],
color=color,
lw=lw,
)
else:
plt.plot(tmp[i1, i2, :, k], color=color, lw=lw)
else:
for i1 in range(tmp.shape[0]):
for i2 in range(tmp.shape[1]):
if test is None:
test = 1
plt.plot(
tmp[i1, i2, :],
color=color,
label=r"%s $S_1$" % (name),
lw=lw,
)
else:
plt.plot(tmp[i1, i2, :], color=color, lw=lw)
plt.yscale("log")
plt.legend(frameon=0)
if norm:
plt.ylabel(r"$\frac{S_1}{\sqrt{S_2}}$")
else:
plt.ylabel("$S_1$")
plt.xlabel(r"$j_{1}$")
ax1 = plt.subplot(2, 2, 3)
ax2 = ax1.twiny()
n = 0
tmp = abs(self.get_np(self.S3))
if norm:
lname = r"%s norm. $S_{3}$" % (name)
ax1.set_ylabel(r"$\frac{S_3}{\sqrt{S_{2,j_1}S_{2,j_2}}}$")
else:
lname = r"%s $S_3$" % (name)
ax1.set_ylabel(r"$S_3$")
if self.S3P is not None:
tmp = abs(self.get_np(self.S3))
if norm:
lname = r"%s norm. $\tilde{S}_{3}$" % (name)
ax1.set_ylabel(r"$\frac{\tilde{S}_{3}}{\sqrt{S_{2,j_1}S_{2,j_2}}}$")
else:
lname = r"%s $\tilde{S}_{3}$" % (name)
ax1.set_ylabel(r"$\tilde{S}_{3}$")
test = None
tabx = []
tabnx = []
tab2x = []
tab2nx = []
if len(tmp.shape) > 4:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
for i3 in range(tmp.shape[3]):
for i4 in range(tmp.shape[4]):
dtmp = tmp[i0, i1, j1 == i2, i3, i4]
if norm:
dtmp = dtmp / (
ntmp[i0, i1, i2, i3]
* ntmp[i0, i1, j2[j1 == i2], i3]
)
if j2[j1 == i2].shape[0] == 1:
ax1.plot(
j2[j1 == i2] + n, dtmp, ".", color=color, lw=lw
)
else:
if legend and test is None:
ax1.plot(
j2[j1 == i2] + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(j2[j1 == i2] + n, dtmp, color=color, lw=lw)
tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
tabx = tabx + [k + n for k in j2[j1 == i2]]
tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(
(j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
)
n = n + j2[j1 == i2].shape[0] - 1
elif len(tmp.shape) == 3:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
dtmp = tmp[i0, i1, j1 == i2]
if norm:
dtmp = dtmp / (
ntmp[i0, i1, i2] * ntmp[i0, i1, j2[j1 == i2]]
)
if j2[j1 == i2].shape[0] == 1:
ax1.plot(j2[j1 == i2] + n, dtmp, ".", color=color, lw=lw)
else:
if legend and test is None:
ax1.plot(
j2[j1 == i2] + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(j2[j1 == i2] + n, dtmp, color=color, lw=lw)
tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
tabx = tabx + [k + n for k in j2[j1 == i2]]
tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(
(j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
)
n = n + j2[j1 == i2].shape[0] - 1
else:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
for i3 in range(tmp.shape[3]):
dtmp = tmp[i0, i1, j1 == i2, i3]
if norm:
dtmp = dtmp / (
ntmp[i0, i1, i2] * ntmp[i0, i1, j2[j1 == i2]]
)
if j2[j1 == i2].shape[0] == 1:
ax1.plot(
j2[j1 == i2] + n, dtmp, ".", color=color, lw=lw
)
else:
if legend and test is None:
ax1.plot(
j2[j1 == i2] + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(j2[j1 == i2] + n, dtmp, color=color, lw=lw)
tabnx = tabnx + [r"%d" % (k) for k in j2[j1 == i2]]
tabx = tabx + [k + n for k in j2[j1 == i2]]
tab2x = tab2x + [(j2[j1 == i2] + n).mean()]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(
(j2[j1 == i2] + n).max() + 0.5, ls=":", color="gray"
)
n = n + j2[j1 == i2].shape[0] - 1
plt.yscale("log")
ax1.set_xlim(0, n + 2)
ax1.set_xticks(tabx)
ax1.set_xticklabels(tabnx, fontsize=6)
ax1.set_xlabel(r"$j_{2}$", fontsize=6)
# Move twinned axis ticks and label from top to bottom
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
# Offset the twin axis below the host
ax2.spines["bottom"].set_position(("axes", -0.15))
# Turn on the frame for the twin axis, but then hide all
# but the bottom spine
ax2.set_frame_on(True)
ax2.patch.set_visible(False)
for sp in ax2.spines.values():
sp.set_visible(False)
ax2.spines["bottom"].set_visible(True)
ax2.set_xlim(0, n + 2)
ax2.set_xticks(tab2x)
ax2.set_xticklabels(tab2nx, fontsize=6)
ax2.set_xlabel(r"$j_{1}$", fontsize=6)
ax1.legend(frameon=0)
ax1 = plt.subplot(2, 2, 4)
j1, j2, j3 = self.get_js4_idx()
ax2 = ax1.twiny()
n = 1
tmp = abs(self.get_np(self.S4))
lname = r"%s $S_4$" % (name)
test = None
tabx = []
tabnx = []
tab2x = []
tab2nx = []
ntmp = ntmp
if len(tmp.shape) > 4:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
nprev = n
for i2b in range(j2[j1 == i2].max() + 1):
idx = np.where((j1 == i2) * (j2 == i2b))[0]
for i3 in range(tmp.shape[3]):
for i4 in range(tmp.shape[4]):
for i5 in range(tmp.shape[5]):
dtmp = tmp[i0, i1, idx, i3, i4, i5]
if norm:
dtmp = dtmp / (
ntmp[i0, i1, i2, i3]
* ntmp[i0, i1, i2b, i3]
)
if len(idx) == 1:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
".",
color=color,
lw=lw,
)
else:
if legend and test is None:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
color=color,
lw=lw,
)
tabnx = tabnx + [r"%d,%d" % (j2[k], j3[k]) for k in idx]
tabx = tabx + [k + n for k in range(len(idx))]
n = n + idx.shape[0]
tab2x = tab2x + [(n + nprev - 1) / 2]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(n - 0.5, ls=":", color="gray")
elif len(tmp.shape) == 3:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
nprev = n
for i2b in range(j2[j1 == i2].max() + 1):
idx = np.where((j1 == i2) * (j2 == i2b))[0]
dtmp = tmp[i0, i1, idx]
if norm:
dtmp = dtmp / (ntmp[i0, i1, i2] * ntmp[i0, i1, i2b])
if len(idx) == 1:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
".",
color=color,
lw=lw,
)
else:
if legend and test is None:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(
np.arange(len(idx)) + n, dtmp, color=color, lw=lw
)
tabnx = tabnx + [r"%d,%d" % (j2[k], j3[k]) for k in idx]
tabx = tabx + [k + n for k in range(len(idx))]
n = n + idx.shape[0]
tab2x = tab2x + [(n + nprev - 1) / 2]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(n - 0.5, ls=":", color="gray")
else:
for i0 in range(tmp.shape[0]):
for i1 in range(tmp.shape[1]):
for i2 in range(j1.max() + 1):
nprev = n
for i2b in range(j2[j1 == i2].max() + 1):
idx = np.where((j1 == i2) * (j2 == i2b))[0]
for i3 in range(tmp.shape[3]):
dtmp = tmp[i0, i1, idx, i3]
if norm:
dtmp = dtmp / (ntmp[i0, i1, i2] * ntmp[i0, i1, i2b])
if len(idx) == 1:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
".",
color=color,
lw=lw,
)
else:
if legend and test is None:
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
color=color,
label=lname,
lw=lw,
)
test = 1
ax1.plot(
np.arange(len(idx)) + n,
dtmp,
color=color,
lw=lw,
)
tabnx = tabnx + [r"%d,%d" % (j2[k], j3[k]) for k in idx]
tabx = tabx + [k + n for k in range(len(idx))]
n = n + idx.shape[0]
tab2x = tab2x + [(n + nprev - 1) / 2]
tab2nx = tab2nx + ["%d" % (i2)]
ax1.axvline(n - 0.5, ls=":", color="gray")
plt.yscale("log")
if norm:
ax1.set_ylabel(r"$\frac{S_4}{\sqrt{S_{2,j_1}S_{2,j_2}}}$")
else:
ax1.set_ylabel(r"$S_4$")
ax1.set_xticks(tabx)
ax1.set_xticklabels(tabnx, fontsize=6)
ax1.set_xlabel(r"$j_{2},j_{3}$", fontsize=6)
ax1.set_xlim(0, n)
# Move twinned axis ticks and label from top to bottom
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
# Offset the twin axis below the host
ax2.spines["bottom"].set_position(("axes", -0.15))
# Turn on the frame for the twin axis, but then hide all
# but the bottom spine
ax2.set_frame_on(True)
ax2.patch.set_visible(False)
for sp in ax2.spines.values():
sp.set_visible(False)
ax2.spines["bottom"].set_visible(True)
ax2.set_xlim(0, n)
ax2.set_xticks(tab2x)
ax2.set_xticklabels(tab2nx, fontsize=6)
ax2.set_xlabel(r"$j_{1}$", fontsize=6)
ax1.legend(frameon=0)
[docs]
def get_np(self, x):
if x is not None:
if isinstance(x, np.ndarray):
return x
else:
return self.backend.to_numpy(x)
else:
return None
[docs]
def save(self, filename):
outlist = [
self.get_np(self.S0),
self.get_np(self.S1),
self.get_np(self.S3P),
self.get_np(self.S3),
self.get_np(self.S4),
self.get_np(self.S2),
]
myout = open("%s.pkl" % (filename), "wb")
pickle.dump(outlist, myout)
myout.close()
[docs]
def read(self, filename):
outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
return scat_cov(
outlist[0],
outlist[5],
outlist[3],
outlist[4],
s1=outlist[1],
s3p=outlist[2],
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def std(self):
if self.S1 is not None: # Auto
return np.sqrt(
(
(abs(self.get_np(self.S0)).std()) ** 2
+ (abs(self.get_np(self.S1)).std()) ** 2
+ (abs(self.get_np(self.S3)).std()) ** 2
+ (abs(self.get_np(self.S4)).std()) ** 2
+ (abs(self.get_np(self.S2)).std()) ** 2
)
/ 4
)
else: # Cross
return np.sqrt(
(
(abs(self.get_np(self.S0)).std()) ** 2
+ (abs(self.get_np(self.S3)).std()) ** 2
+ (abs(self.get_np(self.S3P)).std()) ** 2
+ (abs(self.get_np(self.S4)).std()) ** 2
+ (abs(self.get_np(self.S2)).std()) ** 2
)
/ 4
)
[docs]
def mean(self):
if self.S1 is not None: # Auto
return (
abs(self.get_np(self.S0)).sum()
+ abs(self.get_np(self.S1)).sum()
+ abs(self.get_np(self.S3)).sum()
+ abs(self.get_np(self.S4)).sum()
+ abs(self.get_np(self.S2)).sum()
) / self.numel
else: # Cross
return (
abs(self.get_np(self.S0)).sum()
+ abs(self.get_np(self.S3)).sum()
+ abs(self.get_np(self.S3P)).sum()
+ abs(self.get_np(self.S4)).sum()
+ abs(self.get_np(self.S2)).sum()
) / self.numel
[docs]
def initdx(self, norient):
idx1 = np.zeros([norient * norient], dtype="int")
for i in range(norient):
idx1[i * norient : (i + 1) * norient] = (
np.arange(norient) + i
) % norient + i * norient
idx2 = np.zeros([norient * norient * norient], dtype="int")
for i in range(norient):
for j in range(norient):
idx2[
i * norient * norient
+ j * norient : i * norient * norient
+ (j + 1) * norient
] = (
((np.arange(norient) + i) % norient) * norient
+ (np.arange(norient) + i + j) % norient
+ np.arange(norient) * norient * norient
)
self.idx1 = self.backend.constant(idx1)
self.idx2 = self.backend.constant(idx2)
[docs]
def sqrt(self):
s1 = None
s3p = None
if self.S1 is not None:
s1 = self.backend.bk_sqrt(self.S1)
if self.S3P is not None:
s3p = self.backend.bk_sqrt(self.S3P)
s0 = self.backend.bk_sqrt(self.S0)
s2 = self.backend.bk_sqrt(self.S2)
s3 = self.backend.bk_sqrt(self.S3)
s4 = self.backend.bk_sqrt(self.S4)
return scat_cov(
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
)
[docs]
def L1(self):
s1 = None
s3p = None
if self.S1 is not None:
s1 = self.backend.bk_L1(self.S1)
if self.S3P is not None:
s3p = self.backend.bk_L1(self.S3P)
s0 = self.backend.bk_L1(self.S0)
s2 = self.backend.bk_L1(self.S2)
s3 = self.backend.bk_L1(self.S3)
s4 = self.backend.bk_L1(self.S4)
return scat_cov(
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
)
[docs]
def square_comp(self):
s1 = None
s3p = None
if self.S1 is not None:
s1 = self.backend.bk_square_comp(self.S1)
if self.S3P is not None:
s3p = self.backend.bk_square_comp(self.S3P)
s0 = self.backend.bk_square_comp(self.S0)
s2 = self.backend.bk_square_comp(self.S2)
s3 = self.backend.bk_square_comp(self.S3)
s4 = self.backend.bk_square_comp(self.S4)
return scat_cov(
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
)
[docs]
def iso_mean(self, repeat=False):
shape = list(self.S2.shape)
norient = shape[3]
S1 = self.S1
if self.S1 is not None:
S1 = self.backend.bk_reduce_mean(self.S1, 3)
if repeat:
S1 = self.backend.bk_reshape(
self.backend.bk_repeat(S1, norient, 2), self.S1.shape
)
S2 = self.backend.bk_reduce_mean(self.S2, 3)
if repeat:
S2 = self.backend.bk_reshape(
self.backend.bk_repeat(S2, norient, 2), self.S2.shape
)
S3 = self.S3
if norient not in self.backend._iso_orient:
self.backend.calc_iso_orient(norient)
shape = list(self.S3.shape)
if self.S3 is not None:
if self.backend.bk_is_complex(self.S3):
lmat = self.backend._iso_orient_C[norient]
lmat_T = self.backend._iso_orient_C_T[norient]
else:
lmat = self.backend._iso_orient[norient]
lmat_T = self.backend._iso_orient_T[norient]
S3 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S3, [shape[0] * shape[1] * shape[2], norient * norient]
),
lmat,
),
[shape[0], shape[1], shape[2], norient],
)
if repeat:
S3 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
S3, [shape[0] * shape[1] * shape[2], norient]
),
lmat_T,
),
[shape[0], shape[1], shape[2], norient, norient],
)
S3P = self.S3P
if self.S3P is not None:
if self.backend.bk_is_complex(self.S3P):
lmat = self.backend._iso_orient_C[norient]
lmat_T = self.backend._iso_orient_C_T[norient]
else:
lmat = self.backend._iso_orient[norient]
lmat_T = self.backend._iso_orient_T[norient]
S3P = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S3P, [shape[0] * shape[1] * shape[2], norient * norient]
),
lmat,
),
[shape[0], shape[1], shape[2], norient],
)
if repeat:
S3P = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
S3P, [shape[0] * shape[1] * shape[2], norient]
),
lmat_T,
),
[shape[0], shape[1], shape[2], norient, norient],
)
S4 = self.S4
if self.S4 is not None:
# S4 has shape [a, b, c, L1, L2, L3].
# The correct isotropic reduction averages over the global orientation l1
# while keeping the two relative angles (Δl12, Δl13) fixed:
# output[Δl12, Δl13] = (1/L) Σ_{l1} S4[l1, (l1+Δl12)%L, (l1+Δl13)%L]
# Result shape: [a, b, c, L, L] (two relative-angle indices).
if norient not in self.backend._iso_orient3:
self.backend.calc_iso_orient3(norient)
if self.backend.bk_is_complex(self.S4):
lmat3 = self.backend._iso_orient3_C[norient]
lmat3_T = self.backend._iso_orient3_C_T[norient]
else:
lmat3 = self.backend._iso_orient3[norient]
lmat3_T = self.backend._iso_orient3_T[norient]
shape = list(self.S4.shape)
# [a, b, c, L, L, L] → [a*b*c, L^3] matmul [L^3, L^2] → [a*b*c, L^2]
# → [a, b, c, L, L]
S4 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S4,
[shape[0] * shape[1] * shape[2], norient * norient * norient],
),
lmat3,
),
[shape[0], shape[1], shape[2], norient, norient],
)
if repeat:
# [a, b, c, L, L] → [a*b*c, L^2] matmul [L^2, L^3] → [a*b*c, L^3]
# → [a, b, c, L, L, L]
S4 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
S4,
[shape[0] * shape[1] * shape[2], norient * norient],
),
lmat3_T,
),
[shape[0], shape[1], shape[2], norient, norient, norient],
)
return scat_cov(
self.S0,
S2,
S3,
S4,
s1=S1,
s3p=S3P,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def fft_ang(self, nharm=1, imaginary=False):
"""Project orientation axes onto the first Fourier harmonics.
This is a softer alternative to :meth:`iso_mean`. Instead of
collapsing each orientation axis to a single mean value, it keeps
the first ``nharm`` harmonics of the discrete Fourier transform along
each orientation axis L. This preserves the *amplitude* of the angular
variation while reducing the number of descriptors from L to
``nout = 1 + 2*nharm`` (with ``imaginary=True``) per orientation axis.
.. rubric:: Why ``imaginary=True`` is the recommended mode
With ``imaginary=False`` the projection basis is
``{1, cos(2π·l/L), cos(4π·l/L), …}``. A field whose dominant
orientation sits at exactly 90° (i.e. at the zero-crossing of the
cosine) would give a near-zero first-harmonic coefficient even though
the field is strongly anisotropic.
With ``imaginary=True`` the basis is
``{1, cos(2π·l/L), sin(2π·l/L), cos(4π·l/L), sin(4π·l/L), …}``.
The amplitude of the first harmonic,
.. math::
A_1 = \\sqrt{c_1^2 + s_1^2}
is **rotation-invariant**: it is the same regardless of the absolute
orientation of the field. Use ``imaginary=True`` whenever you want a
description that does not depend on the image orientation.
.. rubric:: Reduction per statistic
============ ==================== ==========================
Statistic Input shape Output shape (nharm=1)
============ ==================== ==========================
S1, S2 ``[…, L]`` ``[…, nout]``
S3, S3P ``[…, L, L]`` ``[…, L, nout]``
S4 ``[…, L, L, L]`` ``[…, L, L, nout]``
============ ==================== ==========================
For S1/S2 the projection applies the Fourier basis directly on the
single orientation axis.
For S3/S3P and S4 the projection is **not** a tensor product.
Instead, the statistics are first reindexed by the relative-orientation
differences (exactly as in :meth:`iso_mean`), then the absolute
orientation axis :math:`l_1` is projected onto the Fourier basis:
.. math::
\\text{S3\\_out}[\\Delta l,\\, k] =
\\sum_{l_1} \\phi_k(l_1)\\, S3[l_1,\\,(l_1+\\Delta l)\\bmod L]
.. math::
\\text{S4\\_out}[\\Delta l_{12},\\, \\Delta l_{13},\\, k] =
\\sum_{l_1} \\phi_k(l_1)\\,
S4[l_1,\\,(l_1+\\Delta l_{12})\\bmod L,\\,(l_1+\\Delta l_{13})\\bmod L]
where :math:`\\phi_0(l)=1`, :math:`\\phi_1(l)=\\cos(2\\pi l/L)`,
:math:`\\phi_2(l)=\\sin(2\\pi l/L)`.
This preserves the relative-orientation axes (same as iso_mean) while
also capturing how strongly the statistics vary as the whole frame
rotates. The :math:`k=0` component is identical to the :meth:`iso_mean`
result.
Parameters
----------
nharm : int, optional
Number of harmonics to keep beyond the DC term. ``nharm=1``
(default) keeps the mean and the first angular harmonic.
imaginary : bool, optional
If ``False`` (default), keep only the cosine components
``{1, cos, cos(2·), …}`` — ``nout = 1 + nharm``.
If ``True``, keep both cosine and sine components
``{1, cos, sin, cos(2·), sin(2·), …}`` — ``nout = 1 + 2·nharm``.
**Recommended: ``True``** to obtain rotation-invariant amplitudes.
Returns
-------
scat_cov
A new statistics object with orientation axes compressed from L to
``nout``.
Examples
--------
>>> stat = scat_op.eval(image)
>>> stat_fft = stat.fft_ang(nharm=1, imaginary=True)
>>> # S2 shape: (..., L) -> (..., 3) [DC, cos, sin]
>>> # S3 shape: (..., L, L) -> (..., L, 3) [Δl axis kept, l1 compressed]
>>> # S4 shape: (..., L, L, L) -> (..., L, L, 3)
>>>
>>> # Rotation-invariant first-harmonic amplitude for S2:
>>> import numpy as np
>>> A1_S2 = np.sqrt(stat_fft.S2[..., 1]**2 + stat_fft.S2[..., 2]**2)
>>>
>>> # Angular modulation of S3 at a given relative angle Δl:
>>> A1_S3_delta0 = np.sqrt(stat_fft.S3[..., 0, 1]**2 + stat_fft.S3[..., 0, 2]**2)
"""
shape = list(self.S2.shape)
norient = shape[3]
if (norient, nharm) not in self.backend._fft_1_orient:
self.backend.calc_fft_orient(norient, nharm, imaginary)
nout = 1 + nharm
if imaginary:
nout = 1 + nharm * 2
S1 = self.S1
if self.S1 is not None:
if self.backend.bk_is_complex(self.S1):
lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
else:
lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
S1 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S1, [shape[0] * shape[1] * shape[2], norient]
),
lmat,
),
[shape[0], shape[1], shape[2], nout],
)
if self.backend.bk_is_complex(self.S2):
lmat = self.backend._fft_1_orient_C[(norient, nharm, imaginary)]
else:
lmat = self.backend._fft_1_orient[(norient, nharm, imaginary)]
S2 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S2, [shape[0] * shape[1] * shape[2], norient]
),
lmat,
),
[shape[0], shape[1], shape[2], nout],
)
# Ensure angular-FFT matrices exist for S3/S3P/S4
if (norient, nharm, imaginary) not in self.backend._fft_ang2_orient:
self.backend.calc_fft_ang_orient(norient, nharm, imaginary)
S3 = self.S3
if self.S3 is not None:
shape = list(self.S3.shape)
if self.backend.bk_is_complex(self.S3):
lmat = self.backend._fft_ang2_orient_C[(norient, nharm, imaginary)]
else:
lmat = self.backend._fft_ang2_orient[(norient, nharm, imaginary)]
# [B, j1, j2, L, L] -> [B, j1, j2, L*L] -> matmul -> [B, j1, j2, L*nout]
# -> reshape -> [B, j1, j2, L, nout]
S3 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S3, [shape[0] * shape[1] * shape[2], norient * norient]
),
lmat,
),
[shape[0], shape[1], shape[2], norient, nout],
)
S3P = self.S3P
if self.S3P is not None:
shape = list(self.S3P.shape)
if self.backend.bk_is_complex(self.S3P):
lmat = self.backend._fft_ang2_orient_C[(norient, nharm, imaginary)]
else:
lmat = self.backend._fft_ang2_orient[(norient, nharm, imaginary)]
S3P = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S3P, [shape[0] * shape[1] * shape[2], norient * norient]
),
lmat,
),
[shape[0], shape[1], shape[2], norient, nout],
)
S4 = self.S4
if self.S4 is not None:
shape = list(self.S4.shape)
if self.backend.bk_is_complex(self.S4):
lmat = self.backend._fft_ang3_orient_C[(norient, nharm, imaginary)]
else:
lmat = self.backend._fft_ang3_orient[(norient, nharm, imaginary)]
# [B, j1, j2, L, L, L] -> flat [B*j1*j2, L³] -> matmul -> [B*j1*j2, L²*nout]
# -> reshape -> [B, j1, j2, L, L, nout]
S4 = self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(
self.S4,
[shape[0] * shape[1] * shape[2], norient * norient * norient],
),
lmat,
),
[shape[0], shape[1], shape[2], norient, norient, nout],
)
return scat_cov(
self.S0,
S2,
S3,
S4,
s1=S1,
s3p=S3P,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def fft_ang_sigma(self, nharm=1, imaginary=False):
"""Propagate per-coefficient standard deviations through ``fft_ang``.
When the input object contains standard deviations ``sigma_j`` (e.g. the
variance output of :meth:`eval` with ``calc_var=True``), applying
:meth:`fft_ang` directly would be **wrong** because linear projection of
sigma gives near-zero values for harmonic components (cosine/sine sum to
zero over uniform sigma values).
The correct formula for a linear map :math:`y_k = \\sum_j A_{jk}\\,x_j`
with independent inputs of variance :math:`\\sigma_j^2` is:
.. math::
\\sigma_{y_k} = \\sqrt{\\sum_j A_{jk}^2 \\, \\sigma_j^2}
This method applies that formula element-wise using the squared projection
matrices, ensuring that the output sigma is always positive and physically
meaningful for use as the ``sigma`` argument of :meth:`reduce_distance`.
Parameters
----------
nharm : int, optional
Passed through to the underlying projection matrices (same as in
:meth:`fft_ang`).
imaginary : bool, optional
Passed through (same as in :meth:`fft_ang`).
Returns
-------
scat_cov
A new statistics object with the same shape as ``fft_ang`` output
but containing propagated standard deviations.
"""
shape = list(self.S2.shape)
norient = shape[3]
if (norient, nharm, imaginary) not in self.backend._fft_1_orient:
self.backend.calc_fft_orient(norient, nharm, imaginary)
nout = 1 + nharm
if imaginary:
nout = 1 + nharm * 2
bk = self.backend
# Helper: apply squared-matrix variance propagation for a 1-D orientation axis.
# input shape [..., L], output shape [..., nout]
def _prop1(x, lmat):
# lmat: [L, nout] (real)
lmat_sq = lmat * lmat
x_sq = x * x
# flatten all-but-last axis
flat_sq = bk.bk_reshape(x_sq, [-1, norient])
out_sq = bk.backend.matmul(flat_sq, lmat_sq) # [..., nout]
out = bk.bk_sqrt(bk.bk_abs(out_sq))
return bk.bk_reshape(out, list(x.shape[:-1]) + [nout])
# Helper for 2-D orientation axes [L*L] -> [L*nout]
def _prop2(x, lmat, batch_shape):
lmat_sq = lmat * lmat
x_sq = x * x
n = batch_shape[0] * batch_shape[1] * batch_shape[2]
flat_sq = bk.bk_reshape(x_sq, [n, norient * norient])
out_sq = bk.backend.matmul(flat_sq, lmat_sq)
out = bk.bk_sqrt(bk.bk_abs(out_sq))
return bk.bk_reshape(out, [batch_shape[0], batch_shape[1], batch_shape[2], norient, nout])
# Helper for 3-D orientation axes [L*L*L] -> [L*L*nout]
def _prop3(x, lmat, batch_shape):
lmat_sq = lmat * lmat
x_sq = x * x
n = batch_shape[0] * batch_shape[1] * batch_shape[2]
flat_sq = bk.bk_reshape(x_sq, [n, norient * norient * norient])
out_sq = bk.backend.matmul(flat_sq, lmat_sq)
out = bk.bk_sqrt(bk.bk_abs(out_sq))
return bk.bk_reshape(out, [batch_shape[0], batch_shape[1], batch_shape[2], norient, norient, nout])
# --- S1 ---
S1 = self.S1
if self.S1 is not None:
lmat = bk._fft_1_orient[(norient, nharm, imaginary)]
S1 = _prop1(self.S1, lmat)
# --- S2 ---
lmat = bk._fft_1_orient[(norient, nharm, imaginary)]
S2 = _prop1(self.S2, lmat)
if (norient, nharm, imaginary) not in bk._fft_ang2_orient:
bk.calc_fft_ang_orient(norient, nharm, imaginary)
# --- S3 ---
S3 = self.S3
if self.S3 is not None:
sh = list(self.S3.shape)
lmat = bk._fft_ang2_orient[(norient, nharm, imaginary)]
S3 = _prop2(self.S3, lmat, sh)
# --- S3P ---
S3P = self.S3P
if self.S3P is not None:
sh = list(self.S3P.shape)
lmat = bk._fft_ang2_orient[(norient, nharm, imaginary)]
S3P = _prop2(self.S3P, lmat, sh)
# --- S4 ---
S4 = self.S4
if self.S4 is not None:
sh = list(self.S4.shape)
lmat = bk._fft_ang3_orient[(norient, nharm, imaginary)]
S4 = _prop3(self.S4, lmat, sh)
return scat_cov(
self.S0,
S2,
S3,
S4,
s1=S1,
s3p=S3P,
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def iso_std(self, repeat=False):
val = (self - self.iso_mean(repeat=True)).square_comp()
return (val.iso_mean(repeat=repeat)).L1()
[docs]
def get_nscale(self):
return self.S2.shape[2]
[docs]
def get_norient(self):
return self.S2.shape[3]
[docs]
def add_data_from_log_slope(self, y, n, ds=3):
if len(y) < ds:
if len(y) == 1:
return np.repeat(y[0], n)
if len(y) == 2:
a = np.polyfit(np.arange(2), np.log(y[0:2]), 1)
else:
a = np.polyfit(np.arange(ds), np.log(y[0:ds]), 1)
return np.exp((np.arange(n) - 1 - n) * a[0] + a[1])
[docs]
def add_data_from_slope(self, y, n, ds=3):
if len(y) < ds:
if len(y) == 1:
return np.repeat(y[0], n)
if len(y) == 2:
a = np.polyfit(np.arange(2), y[0:2], 1)
else:
a = np.polyfit(np.arange(ds), y[0:ds], 1)
return (np.arange(n) - 1 - n) * a[0] + a[1]
[docs]
def up_grade(self, nscale, ds=3):
noff = nscale - self.S2.shape[2]
if noff == 0:
return scat_cov(
(self.S0),
(self.S2),
(self.S3),
(self.S4),
s1=self.S1,
s3p=self.S3P,
backend=self.backend,
use_1D=self.use_1D,
)
inscale = self.S2.shape[2]
s2 = np.zeros(
[self.S2.shape[0], self.S2.shape[1], nscale, self.S2.shape[3]],
dtype="complex",
)
if self.BACKEND == "numpy":
s2[:, :, noff:, :] = self.S2
else:
s2[:, :, noff:, :] = self.backend.to_numpy(self.S2)
for i in range(self.S2.shape[0]):
for j in range(self.S2.shape[1]):
for k in range(self.S2.shape[3]):
s2[i, j, 0:noff, k] = self.add_data_from_log_slope(
s2[i, j, noff:, k], noff, ds=ds
)
s1 = np.zeros([self.S1.shape[0], self.S1.shape[1], nscale, self.S1.shape[3]])
if self.BACKEND == "numpy":
s1[:, :, noff:, :] = self.S1
else:
s1[:, :, noff:, :] = self.backend.to_numpy(self.S1)
for i in range(self.S1.shape[0]):
for j in range(self.S1.shape[1]):
for k in range(self.S1.shape[3]):
s1[i, j, 0:noff, k] = self.add_data_from_log_slope(
s1[i, j, noff:, k], noff, ds=ds
)
nout = 0
for i in range(1, nscale):
nout = nout + i
s3 = np.zeros(
[
self.S3.shape[0],
self.S3.shape[1],
nout,
self.S3.shape[3],
self.S3.shape[4],
],
dtype="complex",
)
jo1 = np.zeros([nout])
jo2 = np.zeros([nout])
n = 0
for i in range(1, nscale):
jo1[n : n + i] = np.arange(i)
jo2[n : n + i] = i
n = n + i
j1 = np.zeros([self.S3.shape[2]])
j2 = np.zeros([self.S3.shape[2]])
n = 0
for i in range(1, self.S2.shape[2]):
j1[n : n + i] = np.arange(i)
j2[n : n + i] = i
n = n + i
for i in range(self.S3.shape[0]):
for j in range(self.S3.shape[1]):
for k in range(self.S3.shape[3]):
for l_orient in range(self.S3.shape[4]):
for ij in range(noff + 1, nscale):
idx = np.where(jo2 == ij)[0]
if self.BACKEND == "numpy":
s3[i, j, idx[noff:], k, l_orient] = self.S3[
i, j, j2 == ij - noff, k, l_orient
]
s3[i, j, idx[:noff], k, l_orient] = (
self.add_data_from_slope(
self.S3[i, j, j2 == ij - noff, k, l_orient],
noff,
ds=ds,
)
)
else:
s3[i, j, idx[noff:], k, l_orient] = (
self.backend.to_numpy(self.S3)[
i, j, j2 == ij - noff, k, l_orient
]
)
s3[i, j, idx[:noff], k, l_orient] = (
self.add_data_from_slope(
self.backend.to_numpy(self.S3)[
i, j, j2 == ij - noff, k, l_orient
],
noff,
ds=ds,
)
)
for ij in range(nscale):
idx = np.where(jo1 == ij)[0]
if idx.shape[0] > noff:
s3[i, j, idx[:noff], k, l_orient] = (
self.add_data_from_slope(
s3[i, j, idx[noff:], k, l_orient], noff, ds=ds
)
)
else:
s3[i, j, idx, k, l_orient] = np.mean(
s3[i, j, jo1 == ij - 1, k, l_orient]
)
nout = 0
for j3 in range(nscale):
for j2 in range(0, j3):
for j1 in range(0, j2):
nout = nout + 1
s4 = np.zeros(
[
self.S4.shape[0],
self.S4.shape[1],
nout,
self.S4.shape[3],
self.S4.shape[4],
self.S4.shape[5],
],
dtype="complex",
)
jo1 = np.zeros([nout])
jo2 = np.zeros([nout])
jo3 = np.zeros([nout])
nout = 0
for j3 in range(nscale):
for j2 in range(0, j3):
for j1 in range(0, j2):
jo1[nout] = j1
jo2[nout] = j2
jo3[nout] = j3
nout = nout + 1
ncross = self.S4.shape[2]
jj1 = np.zeros([ncross])
jj2 = np.zeros([ncross])
jj3 = np.zeros([ncross])
n = 0
for j3 in range(inscale):
for j2 in range(0, j3):
for j1 in range(0, j2):
jj1[n] = j1
jj2[n] = j2
jj3[n] = j3
n = n + 1
n = 0
for j3 in range(nscale):
for j2 in range(j3):
idx = np.where((jj3 == j3) * (jj2 == j2))[0]
if idx.shape[0] > 0:
idx2 = np.where((jo3 == j3 + noff) * (jo2 == j2 + noff))[0]
for i in range(self.S4.shape[0]):
for j in range(self.S4.shape[1]):
for k in range(self.S4.shape[3]):
for l_orient in range(self.S4.shape[4]):
for m in range(self.S4.shape[5]):
if self.BACKEND == "numpy":
s4[i, j, idx2[noff:], k, l_orient, m] = (
self.S4[i, j, idx, k, l_orient, m]
)
s4[i, j, idx2[:noff], k, l_orient, m] = (
self.add_data_from_log_slope(
self.S4[i, j, idx, k, l_orient, m],
noff,
ds=ds,
)
)
else:
s4[
i, j, idx2[noff:], k, l_orient, m
] = self.S4.numpy()[
i, j, idx, k, l_orient, m
]
s4[i, j, idx2[:noff], k, l_orient, m] = (
self.add_data_from_log_slope(
self.S4.numpy()[
i, j, idx, k, l_orient, m
],
noff,
ds=ds,
)
)
idx = np.where(abs(s4[0, 0, :, 0, 0, 0]) == 0)[0]
for iii in idx:
iii1 = np.where(
(jo1 == jo1[iii] + 1) * (jo2 == jo2[iii] + 1) * (jo3 == jo3[iii] + 1)
)[0]
iii2 = np.where(
(jo1 == jo1[iii] + 2) * (jo2 == jo2[iii] + 2) * (jo3 == jo3[iii] + 2)
)[0]
if iii2.shape[0] > 0:
for i in range(self.S4.shape[0]):
for j in range(self.S4.shape[1]):
for k in range(self.S4.shape[3]):
for l_orient in range(self.S4.shape[4]):
for m in range(self.S4.shape[5]):
s4[i, j, iii, k, l_orient, m] = (
self.add_data_from_slope(
s4[i, j, [iii1, iii2], k, l_orient, m],
1,
ds=2,
)[0]
)
idx = np.where(abs(s4[0, 0, :, 0, 0, 0]) == 0)[0]
for iii in idx:
iii1 = np.where(
(jo1 == jo1[iii]) * (jo2 == jo2[iii]) * (jo3 == jo3[iii] - 1)
)[0]
iii2 = np.where(
(jo1 == jo1[iii]) * (jo2 == jo2[iii]) * (jo3 == jo3[iii] - 2)
)[0]
if iii2.shape[0] > 0:
for i in range(self.S4.shape[0]):
for j in range(self.S4.shape[1]):
for k in range(self.S4.shape[3]):
for l_orient in range(self.S4.shape[4]):
for m in range(self.S4.shape[5]):
s4[i, j, iii, k, l_orient, m] = (
self.add_data_from_slope(
s4[i, j, [iii1, iii2], k, l_orient, m],
1,
ds=2,
)[0]
)
return scat_cov(
self.S0,
(s2),
(s3),
(s4),
s1=(s1),
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
class funct(FOC.FoCUS):
[docs]
def fill(self, im, nullval=hp.UNSEEN):
if self.use_2D:
return self.fill_2d(im, nullval=nullval)
if self.use_1D:
return self.fill_1d(im, nullval=nullval)
return self.fill_healpy(im, nullval=nullval)
[docs]
def moments(self, list_scat):
if isinstance(list_scat, foscat.scat_cov.scat_cov):
mS0 = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S0, 0), 0
)
mS2 = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S2, 0), 0
)
mS3 = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S3, 0), 0
)
mS4 = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S4, 0), 0
)
sS0 = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S0, 0), 0
)
sS2 = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S2, 0), 0
)
sS3 = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S3, 0), 0
)
sS4 = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S4, 0), 0
)
if list_scat.S1 is not None:
mS1 = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S1, 0), 0
)
sS1 = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S1, 0), 0
)
else:
mS1 = None
sS1 = None
if list_scat.S3P is not None:
mS3P = self.backend.bk_expand_dims(
self.backend.bk_reduce_mean(list_scat.S3P, 0), 0
)
sS3P = self.backend.bk_expand_dims(
self.backend.bk_reduce_std(list_scat.S3P, 0), 0
)
else:
mS3P = None
sS3P = None
else:
S0 = None
for k in list_scat:
tmp = list_scat[k]
if self.BACKEND == "numpy":
nS0 = np.expand_dims(tmp.S0, 0)
nS2 = np.expand_dims(tmp.S2, 0)
nS3 = np.expand_dims(tmp.S3, 0)
nS4 = np.expand_dims(tmp.S4, 0)
if tmp.S3P is not None:
nS3P = np.expand_dims(tmp.S3P, 0)
if tmp.S1 is not None:
nS1 = np.expand_dims(tmp.S1, 0)
else:
nS0 = np.expand_dims(self.backend.to_numpy(tmp.S0), 0)
nS2 = np.expand_dims(self.backend.to_numpy(tmp.S2), 0)
nS3 = np.expand_dims(self.backend.to_numpy(tmp.S3), 0)
nS4 = np.expand_dims(self.backend.to_numpy(tmp.S4), 0)
if tmp.S3P is not None:
nS3P = np.expand_dims(self.backend.to_numpy(tmp.S3P), 0)
if tmp.S1 is not None:
nS1 = np.expand_dims(self.backend.to_numpy(tmp.S1), 0)
if S0 is None:
S0 = nS0
S2 = nS2
S3 = nS3
S4 = nS4
if tmp.S3P is not None:
S3P = nS3P
if tmp.S1 is not None:
S1 = nS1
else:
S0 = np.concatenate([S0, nS0], 0)
S2 = np.concatenate([S2, nS2], 0)
S3 = np.concatenate([S3, nS3], 0)
S4 = np.concatenate([S4, nS4], 0)
if tmp.S3P is not None:
S3P = np.concatenate([S3P, nS3P], 0)
if tmp.S1 is not None:
S1 = np.concatenate([S1, nS1], 0)
sS0 = self.backend.bk_cast(np.std(S0, 0))
sS2 = self.backend.bk_cast(np.std(S2, 0))
sS3 = self.backend.bk_cast(np.std(S3, 0))
sS4 = self.backend.bk_cast(np.std(S4, 0))
mS0 = self.backend.bk_cast(np.mean(S0, 0))
mS2 = self.backend.bk_cast(np.mean(S2, 0))
mS3 = self.backend.bk_cast(np.mean(S3, 0))
mS4 = self.backend.bk_cast(np.mean(S4, 0))
if tmp.S3P is not None:
sS3P = self.backend.bk_cast(np.std(S3P, 0))
mS3P = self.backend.bk_cast(np.mean(S3P, 0))
else:
sS3P = None
mS3P = None
if tmp.S1 is not None:
sS1 = self.backend.bk_cast(np.std(S1, 0))
mS1 = self.backend.bk_cast(np.mean(S1, 0))
else:
sS1 = None
mS1 = None
return scat_cov(
mS0,
mS2,
mS3,
mS4,
s1=mS1,
s3p=mS3P,
backend=self.backend,
use_1D=self.use_1D,
), scat_cov(
sS0,
sS2,
sS3,
sS4,
s1=sS1,
s3p=sS3P,
backend=self.backend,
use_1D=self.use_1D,
)
# compute local direction to make the statistical analysis more efficient
[docs]
def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0,spin=0):
tmp = im
if image2 is not None:
tmpi2 = image2
if upscale:
l_nside = int(np.sqrt(tmp.shape[-1] // 12))
tmp = self.up_grade(tmp, l_nside * 2)
if image2 is not None:
tmpi2 = self.up_grade(tmpi2, l_nside * 2)
l_nside = int(np.sqrt(tmp.shape[-1] // 12))
nscale = int(np.log(l_nside) / np.log(2)+1)
cmat = {}
cmat2 = {}
# Loop over scales
for k in range(nscale):
if image2 is not None:
sim = self.backend.bk_real(
self.backend.bk_L1(
self.convol(tmp,spin=spin)
* self.backend.bk_conjugate(self.convol(tmpi2,spin=spin))
)
)
else:
sim = self.backend.bk_abs(self.convol(tmp,spin=spin))
# instead of difference between "opposite" channels use weighted average
# of cosine and sine contributions using all channels
if spin==0:
angles = self.backend.bk_cast(
(2 * np.pi * np.arange(self.NORIENT)
/ self.NORIENT).reshape(1,self.NORIENT,1)) # shape: (NORIENT,)
else:
angles = self.backend.bk_cast(
(2 * np.pi * np.arange(self.NORIENT)
/ self.NORIENT).reshape(1,1,self.NORIENT,1)) # shape: (NORIENT,)
# we use cosines and sines as weights for sim
weighted_cos = self.backend.bk_reduce_mean(
sim * self.backend.bk_cos(angles), axis=-2
)
weighted_sin = self.backend.bk_reduce_mean(
sim * self.backend.bk_sin(angles), axis=-2
)
# For simplicity, take first element of the batch
cc = weighted_cos[0]
ss = weighted_sin[0]
if smooth_scale > 0:
for m in range(smooth_scale):
if cc.shape[0] > 12:
cc, _ = self.ud_grade_2(cc)
ss, _ = self.ud_grade_2(ss)
if cc.shape[-1] != tmp.shape[-1]:
ll_nside = int(np.sqrt(tmp.shape[-1] // 12))
cc = self.up_grade(self.backend.bk_cast(cc), ll_nside)
ss = self.up_grade(self.backend.bk_cast(ss), ll_nside)
# compute local phase from weighted cos and sin (same as before)
if self.BACKEND == "numpy":
phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
else:
phase = np.fmod(
np.arctan2(self.backend.to_numpy(ss), self.backend.to_numpy(cc))
+ 2 * np.pi,
2 * np.pi,
)
# instead of linear interpolation cosine‐based interpolation
phase_scaled = self.NORIENT * phase / (2 * np.pi)
iph = np.floor(phase_scaled).astype("int") # lower bin index
delta = phase_scaled - iph # fractional part in [0,1)
# interpolation weights
w0 = np.cos(delta * np.pi / 2) ** 2
w1 = np.sin(delta * np.pi / 2) ** 2
# build rotation matrix
if spin==0:
mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[-1]])
else:
mat = np.zeros([2,self.NORIENT * self.NORIENT, sim.shape[-1]])
lidx = np.arange(sim.shape[-1])
for ell in range(self.NORIENT):
# Instead of simple linear weights, we use the cosine weights w0 and w1.
col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
if spin==0:
mat[col0, lidx] = w0
mat[col1, lidx] = w1
else:
mat[0,col0, lidx] = w0[0]
mat[0,col1, lidx] = w1[0]
mat[1,col0, lidx] = w0[1]
mat[1,col1, lidx] = w1[1]
cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
# do same modifications for mat2
if spin==0:
mat2 = np.zeros(
[k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
)
else:
mat2 = np.zeros(
[k + 1, 2, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
)
for k2 in range(k + 1):
tmp2 = self.backend.bk_expand_dims(sim,-2)
if spin==0:
sim2 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
self.backend.bk_cast(
mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[-1])
)
* tmp2,
[sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[-1]],
),
1,
)
else:
sim2 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
self.backend.bk_cast(
mat.reshape(1, 2, self.NORIENT, self.NORIENT, mat.shape[-1])
)
* tmp2,
[sim.shape[0], 2, self.NORIENT, self.NORIENT, mat.shape[-1]],
),
2,
)
sim2 = self.backend.bk_abs(self.convol(sim2))
angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
weighted_cos2 = self.backend.bk_reduce_mean(
sim2 * self.backend.bk_cos(angles), axis=-3
)
weighted_sin2 = self.backend.bk_reduce_mean(
sim2 * self.backend.bk_sin(angles), axis=-3
)
cc2 = weighted_cos2[0]
ss2 = weighted_sin2[0]
if smooth_scale > 0:
for m in range(smooth_scale):
if cc2.shape[1] > 12:
cc2, _ = self.ud_grade_2(cc2)
ss2, _ = self.ud_grade_2(ss2)
if cc2.shape[-1] != sim.shape[-1]:
ll_nside = int(np.sqrt(sim.shape[-1] // 12))
cc2 = self.up_grade(self.backend.bk_cast(cc2), ll_nside)
ss2 = self.up_grade(self.backend.bk_cast(ss2), ll_nside)
if self.BACKEND == "numpy":
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
else:
phase2 = np.fmod(
np.arctan2(
self.backend.to_numpy(ss2), self.backend.to_numpy(cc2)
)
+ 2 * np.pi,
2 * np.pi,
)
phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
iph2 = np.floor(phase2_scaled).astype("int")
delta2 = phase2_scaled - iph2
w0_2 = np.cos(delta2 * np.pi / 2) ** 2
w1_2 = np.sin(delta2 * np.pi / 2) ** 2
lidx = np.arange(sim.shape[-1])
if spin==0:
for m in range(self.NORIENT):
for ell in range(self.NORIENT):
col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
mat2[k2, col0, m, lidx] = w0_2[m, lidx]
mat2[k2, col1, m, lidx] = w1_2[m, lidx]
else:
for sidx in range(2):
for m in range(self.NORIENT):
for ell in range(self.NORIENT):
col0 = self.NORIENT * ((ell + iph2[sidx,m]) % self.NORIENT) + ell
col1 = self.NORIENT * ((ell + iph2[sidx,m] + 1) % self.NORIENT) + ell
mat2[k2, sidx, col0, m, lidx] = w0_2[sidx,m, lidx]
mat2[k2, sidx, col1, m, lidx] = w1_2[sidx,m, lidx]
cmat2[k] = self.backend.bk_cast(
mat2[0 : k + 1, None, ...].astype("complex64")
)
if k < nscale - 1:
tmp, _ = self.ud_grade_2(tmp)
if image2 is not None:
tmpi2, _ = self.ud_grade_2(tmpi)
return cmat, cmat2
[docs]
def stat_cfft_cell_ids(self, im, image2=None, upscale=False, smooth_scale=0, spin=0, cell_ids=None):
"""
Full-sky stat_cfft + per-scale slicing by cell_ids.
If cell_ids is None, identical to stat_cfft.
If provided, cell_ids must be NEST indices at the input resolution.
"""
import numpy as np
# --- 0) Cast inputs to backend tensors to avoid numpy/torch mixups
im_b = self.backend.bk_cast(im)
img2_b = None if image2 is None else self.backend.bk_cast(image2)
# --- 1) Compute the canonical full-sky orientation matrices (unchanged math)
cmat_full, cmat2_full = self.stat_cfft(
im_b, image2=img2_b, upscale=upscale, smooth_scale=smooth_scale, spin=spin
)
# Fast path: no slicing requested
if cell_ids is None:
return cmat_full, cmat2_full
# --- 2) Prepare the subset at the starting resolution
ids = np.asarray(cell_ids, dtype=np.int64)
if upscale:
# If caller asked for upscale=True, stat_cfft started at 2×nside.
# Expand parents to their 4 children in NEST so we slice the right resolution.
ids = (ids.reshape(-1, 1) * 4 + np.arange(4, dtype=np.int64)).reshape(-1)
# Helper: slice last axis, staying on device when possible
def _slice_last(x, idx_np):
try:
import torch
if isinstance(x, torch.Tensor):
return x.index_select(-1, torch.as_tensor(idx_np, device=x.device, dtype=torch.long))
except Exception:
pass
to_np = getattr(self.backend, "to_numpy", lambda t: t)
bkcast = self.backend.bk_cast
return bkcast(to_np(x)[..., idx_np])
# Helper: advance cell_ids to next coarser scale using FoCUS.ud_grade_2 mapping
def _coarsen_ids(curr_ids, npix_cur):
# Build a tiny dummy tensor with the proper last-dim so ud_grade_2 can compute new_cell_ids
dummy = self.backend.bk_cast(np.zeros((1, npix_cur), dtype=np.float32))
_, new_ids = self.ud_grade_2(dummy, cell_ids=curr_ids)
# new_ids can be backend tensor or numpy; normalize to numpy int64
to_np = getattr(self.backend, "to_numpy", lambda t: t)
new_ids = to_np(new_ids)
return np.asarray(new_ids, dtype=np.int64)
# --- 3) Per-scale slicing + id evolution consistent with FoCUS
cmat_s, cmat2_s = {}, {}
nscale = len(cmat_full)
for k in range(nscale):
npix_k = cmat_full[k].shape[-1] # full-sky length at this scale
cmat_s[k] = _slice_last(cmat_full[k], ids)
cmat2_s[k] = _slice_last(cmat2_full[k], ids)
# Prepare ids for next scale the same way FoCUS downscales maps
if k < nscale - 1:
ids = _coarsen_ids(ids, npix_k)
return cmat_s, cmat2_s
[docs]
def div_norm(self, complex_value, float_value):
return self.backend.bk_complex(
self.backend.bk_real(complex_value) / float_value,
self.backend.bk_imag(complex_value) / float_value,
)
[docs]
def eval(
self,
image1,
image2=None,
mask=None,
norm=None,
calc_var=False,
cmat=None,
cmat2=None,
Jmax=None,
out_nside=None,
edge=True,
nside=None,
cell_ids=None,
spin=0
):
"""
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
mean of modulus:
S1 = <|I * Psi_j3|>
Normalization : take the log
power spectrum:
S2 = <|I * Psi_j3|^2>
Normalization : take the log
orig. x modulus:
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
Normalization : divide by (S2_j2 * S2_j3)^0.5
modulus x modulus:
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
Normalization : divide by (S2_j1 * S2_j2)^0.5
Parameters
----------
image1: tensor
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
image2: tensor
Second image. If not None, we compute cross-scattering covariance coefficients.
mask:
norm: None or str
If None no normalization is applied, if 'auto' normalize by the reference S2,
if 'self' normalize by the current S2.
spin : Integer
If different from 0 compute spinned data (U,V to Divergence/Rotational spin==1) or (Q,U to E,B spin=2).
This implies that the input data is 2*12*nside^2.
Returns
-------
S1, S2, S3, S4 normalized
"""
return_data = self.return_data
# Check input consistency
if image2 is not None:
if list(image1.shape) != list(image2.shape):
print(
"The two input image should have the same size to eval Scattering Covariance"
)
return None
if mask is not None:
if self.use_2D:
if (
image1.shape[-2] != mask.shape[-2]
or image1.shape[-1] != mask.shape[-1]
):
print(
"The LAST 2 COLUMNs of the mask should have the same size ",
mask.shape,
"than the input image ",
image1.shape,
"to eval Scattering Covariance",
)
return None
else:
if image1.shape[-1] != mask.shape[-1]:
print(
"The LAST COLUMN of the mask should have the same size ",
mask.shape,
"than the input image ",
image1.shape,
"to eval Scattering Covariance",
)
return None
if self.use_2D and len(image1.shape) < 2:
print(
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
)
return None
### AUTO OR CROSS
cross = False
if image2 is not None:
cross = True
l_nside = 2**32 # not initialize if 1D or 2D
### PARAMETERS
# determine jmax and nside corresponding to the input map
im_shape = image1.shape
if self.use_2D:
if len(image1.shape) == 2:
nside = np.min([im_shape[0], im_shape[1]])
npix = im_shape[0] * im_shape[1] # Number of pixels
x1 = im_shape[0]
x2 = im_shape[1]
else:
nside = np.min([im_shape[1], im_shape[2]])
npix = im_shape[1] * im_shape[2] # Number of pixels
x1 = im_shape[1]
x2 = im_shape[2]
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
if J == 0:
print("Use of too small 2D domain does not work J_max=", J)
return None
elif self.use_1D:
if len(image1.shape) == 2:
npix = int(im_shape[1]) # Number of pixels
else:
npix = int(im_shape[0]) # Number of pixels
nside = int(npix)
J = int(np.log(nside) / np.log(2)) # Number of j scales
else:
npix=int(im_shape[-1])
if nside is None:
nside = int(np.sqrt(npix // 12))
J = int(np.log2(nside)+1) # Number of j scales
if cell_ids is not None:
J=np.min([J,int(np.log(cell_ids.shape[0]) / (2*np.log(2)))-1])
if (self.use_2D or self.use_1D) and self.KERNELSZ > 3:
J -= 1
if Jmax is None:
Jmax = J # Number of steps for the loop on scales
if Jmax > J:
print("==========\n\n")
print(
"The Jmax you requested is larger than the data size ", J,", which may cause problems while computing the scattering transform."
)
print("\n\n==========")
### LOCAL VARIABLES (IMAGES and MASK)
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D) or (len(image1.shape) == 2 and spin>0):
I1 = self.backend.bk_cast(
self.backend.bk_expand_dims(image1, 0)
) # Local image1 [Nbatch, Npix]
if cross:
I2 = self.backend.bk_cast(
self.backend.bk_expand_dims(image2, 0)
) # Local image2 [Nbatch, Npix]
else:
I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
if cross:
I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
if mask is None:
if self.use_2D:
vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
else:
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
else:
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
if self.KERNELSZ > 3 and not self.use_2D and cell_ids is None:
# if the kernel size is bigger than 3 increase the binning before smoothing
if self.use_2D:
vmask = self.up_grade(
self.backend.bk_cast(vmask), I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
)
I1 = self.up_grade(
I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
)
if cross:
I2 = self.up_grade(
I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2,axis=-2
)
elif self.use_1D:
vmask = self.up_grade(vmask, I1.shape[-1] * 2)
I1 = self.up_grade(I1, I1.shape[-1] * 2)
if cross:
I2 = self.up_grade(I2, I2.shape[-1] * 2)
nside = nside * 2
else:
I1 = self.up_grade(I1, nside * 2)
vmask = self.up_grade(vmask, nside * 2)
if cross:
I2 = self.up_grade(I2, nside * 2)
nside = nside * 2
Jmax = Jmax +1
# Normalize the masks because they have different pixel numbers
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
### INITIALIZATION
# Coefficients
if return_data:
S1 = {}
S2 = {}
S3 = {}
S3P = {}
S4 = {}
else:
S1 = []
S2 = []
S3 = []
S4 = []
S3P = []
VS1 = []
VS2 = []
VS3 = []
VS3P = []
VS4 = []
off_S2 = -2
off_S3 = -3
off_S4 = -4
if self.use_1D:
off_S2 = -1
off_S3 = -1
off_S4 = -1
# S2 for normalization
cond_init_P1_dic = (norm == "self") or (
(norm == "auto") and (self.P1_dic is None)
)
if norm is None:
pass
elif cond_init_P1_dic:
P1_dic = {}
if cross:
P2_dic = {}
elif (norm == "auto") and (self.P1_dic is not None):
P1_dic = self.P1_dic
if cross:
P2_dic = self.P2_dic
if return_data:
s0 = I1
if out_nside is not None:
s0 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
),
2,
)
else:
if not cross:
s0, l_vs0 = self.masked_mean(I1,
vmask,
calc_var=True)
else:
s0, l_vs0 = self.masked_mean(
self.backend.bk_L1(I1 * I2),
vmask,
calc_var=True)
vs0 = self.backend.bk_concat([l_vs0, l_vs0], -1)
s0 = self.backend.bk_concat([s0, l_vs0], -1)
if spin>0:
vs0=self.backend.bk_reshape(vs0,[vs0.shape[0],vs0.shape[1],2,vs0.shape[2]//2])
s0=self.backend.bk_reshape(s0,[s0.shape[0],s0.shape[1],2,s0.shape[2]//2])
#### COMPUTE S1, S2, S3 and S4
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
# a remettre comme avant
M1_dic = {}
M2_dic = {}
cell_ids_j3 = cell_ids
for j3 in range(Jmax):
if edge:
if self.mask_mask is None:
self.mask_mask = {}
if self.use_2D:
if (vmask.shape[1], vmask.shape[2]) not in self.mask_mask:
mask_mask = np.zeros([1, vmask.shape[1], vmask.shape[2]])
mask_mask[
0,
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
] = 1.0
self.mask_mask[(vmask.shape[1], vmask.shape[2])] = (
self.backend.bk_cast(mask_mask)
)
vmask = vmask * self.mask_mask[(vmask.shape[1], vmask.shape[2])]
# print(self.KERNELSZ//2,vmask,mask_mask)
if self.use_1D:
if (vmask.shape[1]) not in self.mask_mask:
mask_mask = np.zeros([1, vmask.shape[1]])
mask_mask[0, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] = 1.0
self.mask_mask[(vmask.shape[1])] = self.backend.bk_cast(
mask_mask
)
vmask = vmask * self.mask_mask[(vmask.shape[1])]
if return_data:
S3[j3] = None
S3P[j3] = None
if S4 is None:
S4 = {}
S4[j3] = None
####### S1 and S2
### Make the convolution I1 * Psi_j3
conv1 = self.convol(
I1, cell_ids=cell_ids_j3, nside=nside_j3,
spin=spin
) # [Nbatch, Norient3 , Npix_j3]
if cmat is not None:
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
if spin==0:
conv1 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat[j3] * tmp2,
[tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
),
1,
)
else:
conv1 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat[j3] * tmp2,
[tmp2.shape[0], 2,self.NORIENT, self.NORIENT, cmat[j3].shape[3]],
),
2,
)
### Take the module M1 = |I1 * Psi_j3|
M1_square = conv1 * self.backend.bk_conjugate(
conv1
) # [Nbatch, Norient3, Npix_j3]
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
# Store M1_j3 in a dictionary
M1_dic[j3] = M1
if not cross: # Auto
M1_square = self.backend.bk_real(M1_square)
### S2_auto = < M1^2 >_pix
# Apply the mask [Nmask, Npix_j3] and average over pixels
if return_data:
s2 = M1_square
else:
if calc_var:
s2, vs2 = self.masked_mean(
M1_square, vmask, rank=j3, calc_var=True
)
else:
s2 = self.masked_mean(M1_square, vmask, rank=j3)
if cond_init_P1_dic:
# We fill P1_dic with S2 for normalisation of S3 and S4
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
if return_data:
if S2 is None:
S2 = {}
if out_nside is not None and out_nside < nside_j3:
s2 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s2,
[
s2.shape[0],
s2.shape[2],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
],
),
2,
)
S2[j3] = s2
else:
if norm == "auto": # Normalize S2
s2 /= P1_dic[j3]
S2.append(
self.backend.bk_expand_dims(s2, off_S2)
) # Add a dimension for NS2
if calc_var:
VS2.append(
self.backend.bk_expand_dims(vs2, off_S2)
) # Add a dimension for NS2
#### S1_auto computation
### Image 1 : S1 = < M1 >_pix
# Apply the mask [Nmask, Npix_j3] and average over pixels
if return_data:
s1 = M1
else:
if calc_var:
s1, vs1 = self.masked_mean(
M1, vmask, rank=j3, calc_var=True
) # [Nbatch, Nmask, Norient3]
else:
s1 = self.masked_mean(
M1, vmask, rank=j3
) # [Nbatch, Nmask, Norient3]
if return_data:
if out_nside is not None and out_nside < nside_j3:
s1 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s1,
[
s1.shape[0],
s1.shape[2],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
],
),
2,
)
S1[j3] = s1
else:
### Normalize S1
if norm is not None:
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
S1.append(
self.backend.bk_expand_dims(s1, off_S2)
) # Add a dimension for NS1
if calc_var:
VS1.append(
self.backend.bk_expand_dims(vs1, off_S2)
) # Add a dimension for NS1
else: # Cross
### Make the convolution I2 * Psi_j3
conv2 = self.convol(
I2, cell_ids=cell_ids_j3, nside=nside_j3,
spin=spin
) # [Nbatch, Npix_j3, Norient3]
if cmat is not None:
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
if spin==0:
conv2 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat[j3] * tmp2,
[
tmp2.shape[0],
self.NORIENT,
self.NORIENT,
cmat[j3].shape[2],
],
),
1,
)
else:
conv2 = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat[j3] * tmp2,
[
tmp2.shape[0],
2,
self.NORIENT,
self.NORIENT,
cmat[j3].shape[3],
],
),
2,
)
### Take the module M2 = |I2 * Psi_j3|
M2_square = conv2 * self.backend.bk_conjugate(
conv2
) # [Nbatch, Npix_j3, Norient3]
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
# Store M2_j3 in a dictionary
M2_dic[j3] = M2
### S2_auto = < M2^2 >_pix
# Not returned, only for normalization
if cond_init_P1_dic:
# Apply the mask [Nmask, Npix_j3] and average over pixels
if return_data:
p1 = M1_square
p2 = M2_square
else:
if calc_var:
p1, vp1 = self.masked_mean(
M1_square, vmask, rank=j3, calc_var=True
) # [Nbatch, Nmask, Norient3]
p2, vp2 = self.masked_mean(
M2_square, vmask, rank=j3, calc_var=True
) # [Nbatch, Nmask, Norient3]
else:
p1 = self.masked_mean(
M1_square, vmask, rank=j3
) # [Nbatch, Nmask, Norient3]
p2 = self.masked_mean(
M2_square, vmask, rank=j3
) # [Nbatch, Nmask, Norient3]
# We fill P1_dic with S2 for normalisation of S3 and S4
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
s2 = conv1 * self.backend.bk_conjugate(conv2)
MX = self.backend.bk_L1(s2)
# Apply the mask [Nmask, Npix_j3] and average over pixels
if return_data:
s2 = s2
else:
if calc_var:
s2, vs2 = self.masked_mean(
s2, vmask, rank=j3, calc_var=True
)
else:
s2 = self.masked_mean(s2, vmask, rank=j3)
if return_data:
if out_nside is not None and out_nside < nside_j3:
s2 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s2,
[
s2.shape[0],
s2.shape[2],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
],
),
2,
)
S2[j3] = s2
else:
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
s2 = self.backend.bk_real(s2)
### Normalize S2_cross
if norm == "auto":
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
S2.append(
self.backend.bk_expand_dims(s2, off_S2)
) # Add a dimension for NS2
if calc_var:
VS2.append(
self.backend.bk_expand_dims(vs2, off_S2)
) # Add a dimension for NS2
#### S1_auto computation
### Image 1 : S1 = < M1 >_pix
# Apply the mask [Nmask, Npix_j3] and average over pixels
if return_data:
s1 = MX
else:
if calc_var:
s1, vs1 = self.masked_mean(
MX, vmask, rank=j3, calc_var=True
) # [Nbatch, Nmask, Norient3]
else:
s1 = self.masked_mean(
MX, vmask, rank=j3
) # [Nbatch, Nmask, Norient3]
if return_data:
if out_nside is not None and out_nside < nside_j3:
s1 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s1,
[
s1.shape[0],
s1.shape[2],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
],
),
2,
)
S1[j3] = s1
else:
### Normalize S1
if norm is not None:
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
S1.append(
self.backend.bk_expand_dims(s1, off_S2)
) # Add a dimension for NS1
if calc_var:
VS1.append(
self.backend.bk_expand_dims(vs1, off_S2)
) # Add a dimension for NS1
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
M1convPsi_dic = {}
if cross:
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
M2convPsi_dic = {}
###### S3
for j2 in range(0, j3 + 1): # j2 <= j3
if return_data:
if S4[j3] is None:
S4[j3] = {}
S4[j3][j2] = None
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
if not cross:
if calc_var:
s3, vs3 = self._compute_S3(
j2,
j3,
conv1,
vmask,
M1_dic,
M1convPsi_dic,
calc_var=True,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
) # [Nbatch, Nmask, Norient3, Norient2]
else:
s3 = self._compute_S3(
j2,
j3,
conv1,
vmask,
M1_dic,
M1convPsi_dic,
return_data=return_data,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
) # [Nbatch, Nmask, Norient3, Norient2]
if return_data:
if S3[j3] is None:
S3[j3] = {}
if out_nside is not None and out_nside < nside_j3:
s3 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s3,
[
s3.shape[0],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
s3.shape[2],
s3.shape[3],
],
),
2,
)
S3[j3][j2] = s3
else:
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
if norm is not None:
self.div_norm(
s3,
(
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
* self.backend.bk_expand_dims(P1_dic[j3], -1)
)
** 0.5,
) # [Nbatch, Nmask, Norient3, Norient2]
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
S3.append(
self.backend.bk_expand_dims(s3, off_S3)
) # Add a dimension for NS3
if calc_var:
VS3.append(
self.backend.bk_expand_dims(vs3, off_S3)
) # Add a dimension for NS3
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
else:
if calc_var:
s3, vs3 = self._compute_S3(
j2,
j3,
conv1,
vmask,
M2_dic,
M2convPsi_dic,
calc_var=True,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
)
s3p, vs3p = self._compute_S3(
j2,
j3,
conv2,
vmask,
M1_dic,
M1convPsi_dic,
calc_var=True,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
)
else:
s3p = self._compute_S3(
j2,
j3,
conv2,
vmask,
M1_dic,
M1convPsi_dic,
return_data=return_data,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
)
s3 = self._compute_S3(
j2,
j3,
conv1,
vmask,
M2_dic,
M2convPsi_dic,
return_data=return_data,
cmat2=cmat2,
cell_ids=cell_ids_j3,
nside_j2=nside_j3,
spin=spin,
)
if return_data:
if S3[j3] is None:
S3[j3] = {}
S3P[j3] = {}
if out_nside is not None and out_nside < nside_j3:
s3 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s3,
[
s3.shape[0],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
s3.shape[2],
s3.shape[3],
],
),
2,
)
s3p = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s3p,
[
s3.shape[0],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
s3.shape[2],
s3.shape[3],
],
),
2,
)
S3[j3][j2] = s3
S3P[j3][j2] = s3p
else:
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
if norm is not None:
self.div_norm(
s3,
(
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
* self.backend.bk_expand_dims(P1_dic[j3], -1)
)
** 0.5,
) # [Nbatch, Nmask, Norient3, Norient2]
self.div_norm(
s3p,
(
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
* self.backend.bk_expand_dims(P2_dic[j3], -1)
)
** 0.5,
) # [Nbatch, Nmask, Norient3, Norient2]
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
S3.append(
self.backend.bk_expand_dims(s3, off_S3)
) # Add a dimension for NS3
if calc_var:
VS3.append(
self.backend.bk_expand_dims(vs3, off_S3)
) # Add a dimension for NS3
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
S3P.append(
self.backend.bk_expand_dims(s3p, off_S3)
) # Add a dimension for NS3
if calc_var:
VS3P.append(
self.backend.bk_expand_dims(vs3p, off_S3)
) # Add a dimension for NS3
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
# s3.shape[2]*s3.shape[3]]))
##### S4
for j1 in range(0, j2 + 1): # j1 <= j2
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
if not cross:
if calc_var:
s4, vs4 = self._compute_S4(
j1,
j2,
vmask,
M1convPsi_dic,
M2convPsi_dic=None,
calc_var=True,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
else:
s4 = self._compute_S4(
j1,
j2,
vmask,
M1convPsi_dic,
M2convPsi_dic=None,
return_data=return_data,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
if return_data:
if S4[j3][j2] is None:
S4[j3][j2] = {}
if out_nside is not None and out_nside < nside_j3:
s4 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s4,
[
s4.shape[0],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
s4.shape[2],
s4.shape[3],
s4.shape[4],
],
),
2,
)
S4[j3][j2][j1] = s4
else:
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
if norm is not None:
self.div_norm(
s4,
(
self.backend.bk_expand_dims(
self.backend.bk_expand_dims(
P1_dic[j1], off_S2
),
off_S2,
)
* self.backend.bk_expand_dims(
self.backend.bk_expand_dims(
P1_dic[j2], off_S2
),
-1,
)
)
** 0.5,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
S4.append(
self.backend.bk_expand_dims(s4, off_S4)
) # Add a dimension for NS4
if calc_var:
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
VS4.append(
self.backend.bk_expand_dims(vs4, off_S4)
) # Add a dimension for NS4
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
else:
if calc_var:
s4, vs4 = self._compute_S4(
j1,
j2,
vmask,
M1convPsi_dic,
M2convPsi_dic=M2convPsi_dic,
calc_var=True,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
else:
s4 = self._compute_S4(
j1,
j2,
vmask,
M1convPsi_dic,
M2convPsi_dic=M2convPsi_dic,
return_data=return_data,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
if return_data:
if S4[j3][j2] is None:
S4[j3][j2] = {}
if out_nside is not None and out_nside < nside_j3:
s4 = self.backend.bk_reduce_mean(
self.backend.bk_reshape(
s4,
[
s4.shape[0],
12 * out_nside**2,
(nside_j3 // out_nside) ** 2,
s4.shape[2],
s4.shape[3],
s4.shape[4],
],
),
2,
)
S4[j3][j2][j1] = s4
else:
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
if norm is not None:
self.div_norm(
s4,
(
self.backend.bk_expand_dims(
self.backend.bk_expand_dims(
P1_dic[j1], off_S2
),
off_S2,
)
* self.backend.bk_expand_dims(
self.backend.bk_expand_dims(
P2_dic[j2], off_S2
),
-1,
)
)
** 0.5,
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
S4.append(
self.backend.bk_expand_dims(s4, off_S4)
) # Add a dimension for NS4
if calc_var:
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
VS4.append(
self.backend.bk_expand_dims(vs4, off_S4)
) # Add a dimension for NS4
###### Reshape for next iteration on j3
### Image I1,
# downscale the I1 [Nbatch, Npix_j3]
if j3 != Jmax - 1:
#I1 = self.smooth(I1, cell_ids=cell_ids_j3, nside=nside_j3)
I1, new_cell_ids_j3 = self.ud_grade_2(
I1, cell_ids=cell_ids_j3, nside=nside_j3
)
### Image I2
if cross:
#I2 = self.smooth(I2, cell_ids=cell_ids_j3, nside=nside_j3)
I2, new_cell_ids_j3 = self.ud_grade_2(
I2, cell_ids=cell_ids_j3, nside=nside_j3
)
### Modules
for j2 in range(0, j3 + 1): # j2 =< j3
### Dictionary M1_dic[j2]
#M1_smooth = self.smooth(
# M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
#) # [Nbatch, Npix_j3, Norient3]
M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
) # [Nbatch, Npix_j3, Norient3]
### Dictionary M2_dic[j2]
if cross:
#M2_smooth = self.smooth(
# M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
#) # [Nbatch, Npix_j3, Norient3]
M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
) # [Nbatch, Npix_j3, Norient3]
### Mask
vmask, new_cell_ids_j3 = self.ud_grade_2(
vmask, cell_ids=cell_ids_j3, nside=nside_j3
)
if self.mask_thres is not None:
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
### NSIDE_j3
nside_j3 = nside_j3 // 2
cell_ids_j3 = new_cell_ids_j3
### Store P1_dic and P2_dic in self
if (norm == "auto") and (self.P1_dic is None):
self.P1_dic = P1_dic
if cross:
self.P2_dic = P2_dic
if not return_data:
if not self.use_1D:
S1 = self.backend.bk_concat(S1, -2)
S2 = self.backend.bk_concat(S2, -2)
S3 = self.backend.bk_concat(S3, -3)
S4 = self.backend.bk_concat(S4, -4)
if cross:
S3P = self.backend.bk_concat(S3P, -3)
if calc_var:
VS1 = self.backend.bk_concat(VS1, -2)
VS2 = self.backend.bk_concat(VS2, -2)
VS3 = self.backend.bk_concat(VS3, -3)
VS4 = self.backend.bk_concat(VS4, -4)
if cross:
VS3P = self.backend.bk_concat(VS3P, -3)
else:
S1 = self.backend.bk_concat(S1, -1)
S2 = self.backend.bk_concat(S2, -1)
S3 = self.backend.bk_concat(S3, -1)
S4 = self.backend.bk_concat(S4, -1)
if cross:
S3P = self.backend.bk_concat(S3P, -1)
if calc_var:
VS1 = self.backend.bk_concat(VS1, -1)
VS2 = self.backend.bk_concat(VS2, -1)
VS3 = self.backend.bk_concat(VS3, -1)
VS4 = self.backend.bk_concat(VS4, -1)
if cross:
VS3P = self.backend.bk_concat(VS3P, -1)
if calc_var:
if not cross:
return scat_cov(
s0, S2, S3, S4, s1=S1, backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
), scat_cov(
vs0,
VS2,
VS3,
VS4,
s1=VS1,
backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
)
else:
return scat_cov(
s0,
S2,
S3,
S4,
s1=S1,
s3p=S3P,
backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
), scat_cov(
vs0,
VS2,
VS3,
VS4,
s1=VS1,
s3p=VS3P,
backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
)
else:
if not cross:
return scat_cov(
s0, S2, S3, S4,
s1=S1,
backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
)
else:
return scat_cov(
s0,
S2,
S3,
S4,
s1=S1,
s3p=S3P,
backend=self.backend,
use_1D=self.use_1D,
return_data=self.return_data
)
[docs]
def clean_norm(self):
self.P1_dic = None
self.P2_dic = None
return
def _compute_S3(
self,
j2,
j3,
conv,
vmask,
M_dic,
MconvPsi_dic,
calc_var=False,
return_data=False,
cmat2=None,
cell_ids=None,
nside_j2=None,
spin=0,
):
"""
Compute the S3 coefficients (auto or cross)
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
Parameters
----------
Returns
-------
cs3, ss3: real and imag parts of S3 coeff
"""
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
MconvPsi = self.convol(
M_dic[j2], cell_ids=cell_ids, nside=nside_j2
) # [Nbatch, Norient3, Norient2, Npix_j3]
if cmat2 is not None:
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
if spin==0:
MconvPsi = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat2[j3][j2] * tmp2,
[
tmp2.shape[0],
self.NORIENT,
self.NORIENT,
self.NORIENT,
cmat2[j3][j2].shape[3],
],
),
1,
)
else:
MconvPsi = self.backend.bk_reduce_sum(
self.backend.bk_reshape(
cmat2[j3][j2] * tmp2,
[
tmp2.shape[0],
2,
self.NORIENT,
self.NORIENT,
self.NORIENT,
cmat2[j3][j2].shape[4],
],
),
2,
)
# Store it so we can use it in S4 computation
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
# cconv, sconv are [Nbatch, Norient3, Npix_j3]
if self.use_1D:
s3 = conv * self.backend.bk_conjugate(MconvPsi)
elif self.use_2D:
s3 = self.backend.bk_expand_dims(conv, -4)* self.backend.bk_conjugate(
MconvPsi
) # [Nbatch, Norient3, Norient2, Npix_j3]
else:
s3 = self.backend.bk_expand_dims(conv, -3)* self.backend.bk_conjugate(
MconvPsi
) # [Nbatch, Norient3, Norient2, Npix_j3]
### Apply the mask [Nmask, Npix_j3] and sum over pixels
if return_data:
return s3
else:
if calc_var:
s3, vs3 = self.masked_mean(
s3, vmask, rank=j2, calc_var=True
) # [Nbatch, Nmask, Norient3, Norient2]
return s3, vs3
else:
s3 = self.masked_mean(
s3, vmask, rank=j2
) # [Nbatch, Nmask, Norient3, Norient2]
return s3
def _compute_S4(
self,
j1,
j2,
vmask,
M1convPsi_dic,
M2convPsi_dic=None,
calc_var=False,
return_data=False,
):
#### Simplify notations
M1 = M1convPsi_dic[j1] # [Nbatch, Norient3, Norient1, Npix_j3]
# Auto or Cross coefficients
if M2convPsi_dic is None: # Auto
M2 = M1convPsi_dic[j2] # [Nbatch, Norient3, Norient2, Npix_j3]
else: # Cross
M2 = M2convPsi_dic[j2]
### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
if self.use_1D:
s4 = M1 * self.backend.bk_conjugate(M2)
else:
s4 = self.backend.bk_expand_dims(M1, -4) * self.backend.bk_conjugate(
self.backend.bk_expand_dims(M2, -3)
) # [Nbatch, Norient3, Norient2, Norient1,Npix_j3]
### Apply the mask and sum over pixels
if return_data:
return s4
else:
if calc_var:
s4, vs4 = self.masked_mean(
s4, vmask, rank=j2, calc_var=True
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
return s4, vs4
else:
s4 = self.masked_mean(
s4, vmask, rank=j2
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
return s4
[docs]
def computer_filter(self, M, N, J, L):
"""
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
Done by Sihao Cheng and Rudy Morel.
"""
if N!=0:
filter = np.zeros([J, L, M, N], dtype="complex64")
slant = 4.0 / L
for j in range(J):
for ell in range(L):
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
sigma = 0.8 * 2**j
xi = 3.0 / 4.0 * np.pi / 2**j
R = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
np.float64,
)
R_inv = np.array(
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
np.float64,
)
D = np.array([[1, 0], [0, slant * slant]])
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
gab = np.zeros((M, N), np.complex128)
xx = np.empty((2, 2, M, N))
yy = np.empty((2, 2, M, N))
for ii, ex in enumerate([-1, 0]):
for jj, ey in enumerate([-1, 0]):
xx[ii, jj], yy[ii, jj] = np.mgrid[
ex * M : M + ex * M, ey * N : N + ey * N
]
arg = -(
curv[0, 0] * xx * xx
+ (curv[0, 1] + curv[1, 0]) * xx * yy
+ curv[1, 1] * yy * yy
)
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
gabi = np.exp(argi).sum((0, 1))
gab = np.exp(arg).sum((0, 1))
norm_factor = 2 * np.pi * sigma * sigma / slant
gab = gab / norm_factor
gabi = gabi / norm_factor
K = gabi.sum() / gab.sum()
# Apply the Gaussian
filter[j, ell] = np.fft.fft2(gabi - K * gab)
filter[j, ell, 0, 0] = 0.0
return self.backend.bk_cast(filter)
else:
filter = np.zeros([J, L, M], dtype="complex64")
#TODO
print('filter for 1D not yet available')
exit(0)
slant = 4.0 / L
for j in range(J):
for ell in range(L):
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
sigma = 0.8 * 2**j
xi = 3.0 / 4.0 * np.pi / 2**j
R = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
np.float64,
)
R_inv = np.array(
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
np.float64,
)
D = np.array([[1, 0], [0, slant * slant]])
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
gab = np.zeros((M), np.complex128)
xx = np.empty((M))
for ii, ex in enumerate([-1, 0]):
for jj, ey in enumerate([-1, 0]):
xx[ii, jj], yy[ii, jj] = np.mgrid[
ex * M : M + ex * M, ey * N : N + ey * N
]
arg = -(
curv[0, 0] * xx * xx
+ (curv[0, 1] + curv[1, 0]) * xx * yy
+ curv[1, 1] * yy * yy
)
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
gabi = np.exp(argi).sum((0, 1))
gab = np.exp(arg).sum((0, 1))
norm_factor = 2 * np.pi * sigma * sigma / slant
gab = gab / norm_factor
gabi = gabi / norm_factor
K = gabi.sum() / gab.sum()
# Apply the Gaussian
filter[j, ell] = np.fft.fft2(gabi - K * gab)
filter[j, ell, 0, 0] = 0.0
return self.backend.bk_cast(filter)
# ------------------------------------------------------------------------------------------
#
# utility functions
#
# ------------------------------------------------------------------------------------------
[docs]
def cut_high_k_off(self, data_f, dx, dy):
"""
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
Done by Sihao Cheng and Rudy Morel.
"""
if self.backend.BACKEND == "torch":
if_xodd = data_f.shape[-2] % 2 == 1
if_yodd = data_f.shape[-1] % 2 == 1
result = self.backend.backend.cat(
(
self.backend.backend.cat(
(
data_f[..., : dx + if_xodd, : dy + if_yodd],
data_f[..., -dx:, : dy + if_yodd],
),
-2,
),
self.backend.backend.cat(
(data_f[..., : dx + if_xodd, -dy:], data_f[..., -dx:, -dy:]), -2
),
),
-1,
)
return result
else:
# Check if the last two dimensions are odd
if_xodd = self.backend.backend.cast(
self.backend.backend.shape(data_f)[-2] % 2 == 1,
self.backend.backend.int32,
)
if_yodd = self.backend.backend.cast(
self.backend.backend.shape(data_f)[-1] % 2 == 1,
self.backend.backend.int32,
)
# Extract four regions
top_left = data_f[..., : dx + if_xodd, : dy + if_yodd]
top_right = data_f[..., -dx:, : dy + if_yodd]
bottom_left = data_f[..., : dx + if_xodd, -dy:]
bottom_right = data_f[..., -dx:, -dy:]
# Concatenate along the last two dimensions
top = self.backend.backend.concat([top_left, top_right], axis=-2)
bottom = self.backend.backend.concat([bottom_left, bottom_right], axis=-2)
result = self.backend.backend.concat([top, bottom], axis=-1)
return result
# ---------------------------------------------------------------------------
#
# utility functions for computing scattering coef and covariance
#
# ---------------------------------------------------------------------------
[docs]
def get_dxdy(self, j, M, N):
"""
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
Done by Sihao Cheng and Rudy Morel.
"""
dx = int(max(8, min(np.ceil(M / 2**j), M // 2)))
dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
return dx, dy
[docs]
def get_edge_masks(self, M, N, J, d0=1, in_mask=None, edge_dx=None, edge_dy=None):
"""
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
Done by Sihao Cheng and Rudy Morel.
"""
edge_masks = np.empty((J, M, N))
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
if in_mask is not None:
from scipy.ndimage import binary_erosion
if in_mask is not None:
if in_mask.shape[0] != M or in_mask.shape[0] != N:
l_mask = in_mask.reshape(
M, in_mask.shape[0] // M, N, in_mask.shape[1] // N
)
l_mask = (
np.sum(np.sum(l_mask, 1), 2)
* (M * N)
/ (in_mask.shape[0] * in_mask.shape[1])
)
else:
l_mask = in_mask
if edge_dx is None:
for j in range(J):
edge_dx = min(M // 4, 2**j * d0)
edge_dy = min(N // 4, 2**j * d0)
edge_masks[j] = (
(X >= edge_dx)
* (X < M - edge_dx)
* (Y >= edge_dy)
* (Y < N - edge_dy)
)
if in_mask is not None:
l_mask = binary_erosion(
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
)
edge_masks[j] *= l_mask
edge_masks = edge_masks[:, None, :, :]
edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
else:
edge_masks = (
(X >= edge_dx) * (X < M - edge_dx) * (Y >= edge_dy) * (Y < N - edge_dy)
)
if in_mask is not None:
l_mask = binary_erosion(
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
)
edge_masks *= l_mask
edge_masks = edge_masks / edge_masks.mean((-2, -1))
return self.backend.bk_cast(edge_masks)
# ---------------------------------------------------------------------------
#
# scattering cov
#
# ---------------------------------------------------------------------------
[docs]
def scattering_cov(
self,
data,
data2=None,
Jmax=None,
if_large_batch=False,
S4_criteria=None,
use_ref=False,
normalization="S2",
edge=False,
in_mask=None,
pseudo_coef=1,
get_variance=False,
ref_sigma=None,
iso_ang=False,
return_table=False,
fft_ang=False,
fft_nharm=1,
fft_imaginary=True,
):
"""
Calculates the scattering correlations for a batch of images, including:
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
Done by Sihao Cheng and Rudy Morel.
orig. x orig.:
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
orig. x modulus:
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
modulus x modulus:
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
C11 = C11_pre_norm / factor
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
modulus x modulus (auto):
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
Parameters
----------
data : numpy array or torch tensor
image set, with size [N_image, x-sidelength, y-sidelength]
if_large_batch : Bool (=False)
It is recommended to use "False" unless one meets a memory issue
C11_criteria : str or None (=None)
Only C11 coefficients that satisfy this criteria will be computed.
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
use_ref : Bool (=False)
When normalizing, whether or not to use the normalization factor
computed from a reference field. For just computing the statistics,
the default is False. However, for synthesis, set it to "True" will
stablize the optimization process.
normalization : str 'P00' or 'P11' (='P00')
Whether 'P00' or 'P11' is used as the normalization factor for C01
and C11.
remove_edge : Bool (=False)
If true, the edge region with a width of rougly the size of the largest
wavelet involved is excluded when taking the global average to obtain
the scattering coefficients.
Returns
-------
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
the power in each wavelet bands (the orig. x orig. term)
'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
defined in 'C11_criteria' are all set to np.nan and not computed.
'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
j1 <= j3 are set to np.nan and not computed.
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
'P11' averaged over l1 while keeping l2-l1 constant.
"""
if S4_criteria is None:
S4_criteria = "j2>=j1"
if not edge and in_mask is not None:
edge = True
if self.all_bk_type == "float32":
C_ONE = np.complex64(1.0)
else:
C_ONE = np.complex128(1.0)
# determine jmax and nside corresponding to the input map
im_shape = data.shape
if self.use_2D:
if len(data.shape) == 2:
nside = np.min([im_shape[0], im_shape[1]])
M, N = im_shape[0], im_shape[1]
N_image = 1
N_image2 = 1
else:
nside = np.min([im_shape[1], im_shape[2]])
M, N = im_shape[1], im_shape[2]
N_image = data.shape[0]
if data2 is not None:
N_image2 = data2.shape[0]
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
if Jmax is not None:
J = min(J, Jmax) # clamp: only compute scales up to Jmax
dim=(-2,-1)
elif self.use_1D:
if len(data.shape) == 2:
npix = int(im_shape[1]) # Number of pixels
M = im_shape[1]
N=0
N_image = 1
N_image2 = 1
else:
npix = int(im_shape[0]) # Number of pixels
N_image = data.shape[0]
M = im_shape[0]
N=0
if data2 is not None:
N_image2 = data2.shape[0]
nside = int(npix)
dim=(-1)
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
else:
if len(data.shape) == 2:
npix = int(im_shape[1]) # Number of pixels
N_image = 1
N_image2 = 1
else:
npix = int(im_shape[0]) # Number of pixels
N_image = data.shape[0]
if data2 is not None:
N_image2 = data2.shape[0]
if spin==0:
nside = int(np.sqrt(npix // 12))
else:
nside = int(np.sqrt(npix // 24))
J = int(np.log(nside) / np.log(2)) # Number of j scales
if Jmax is not None:
if Jmax > J:
print("==========\n\n")
print(
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
)
print("\n\n==========")
J = Jmax # Number of steps for the loop on scales
L = self.NORIENT
norm_factor_S3 = 1.0
if self.backend.BACKEND == "torch":
if (M, N, J, L) not in self.filters_set:
self.filters_set[(M, N, J, L)] = self.computer_filter(
M, N, J, L
) # self.computer_filter(M,N,J,L)
filters_set = self.filters_set[(M, N, J, L)]
# weight = self.weight
if use_ref:
if normalization == "S2":
ref_S2 = self.ref_scattering_cov_S2
else:
ref_P11 = self.ref_scattering_cov["P11"]
# convert numpy array input into self.backend.bk_ tensors
data = self.backend.bk_cast(data)
data_f = self.backend.bk_fftn(data, dim=dim)
if data2 is not None:
data2 = self.backend.bk_cast(data2)
data2_f = self.backend.bk_fftn(data2, dim=dim)
# initialize tensors for scattering coefficients
S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
S1 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
Ndata_S3 = J * (J + 1) // 2
Ndata_S4 = J * (J + 1) * (J + 2) // 6
J_S4 = {}
S3 = self.backend.bk_zeros((N_image, Ndata_S3, L, L), dtype=data_f.dtype)
if data2 is not None:
S3p = self.backend.bk_zeros(
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
)
S4_pre_norm = self.backend.bk_zeros(
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
)
S4 = self.backend.bk_zeros((N_image, Ndata_S4, L, L, L), dtype=data_f.dtype)
# variance
if get_variance:
S2_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
S1_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
S3_sigma = self.backend.bk_zeros(
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
)
if data2 is not None:
S3p_sigma = self.backend.bk_zeros(
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
)
S4_sigma = self.backend.bk_zeros(
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
)
if iso_ang:
S3_iso = self.backend.bk_zeros(
(N_image, Ndata_S3, L), dtype=data_f.dtype
)
S4_iso = self.backend.bk_zeros(
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
)
if get_variance:
S3_sigma_iso = self.backend.bk_zeros(
(N_image, Ndata_S3, L), dtype=data_f.dtype
)
S4_sigma_iso = self.backend.bk_zeros(
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
)
if data2 is not None:
S3p_iso = self.backend.bk_zeros(
(N_image, Ndata_S3, L), dtype=data_f.dtype
)
if get_variance:
S3p_sigma_iso = self.backend.bk_zeros(
(N_image, Ndata_S3, L), dtype=data_f.dtype
)
#
if edge:
if (M, N, J) not in self.edge_masks:
self.edge_masks[(M, N, J)] = self.get_edge_masks(
M, N, J, in_mask=in_mask
)
edge_mask = self.edge_masks[(M, N, J)]
else:
edge_mask = 1
# calculate scattering fields
if data2 is None:
if self.use_2D:
if len(data.shape) == 2:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
).abs()
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
).abs()
elif self.use_1D:
if len(data.shape) == 1:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
).abs()
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
).abs()
else:
print("todo")
S2 = (I1**2 * edge_mask).mean(dim)
S1 = (I1 * edge_mask).mean(dim)
if get_variance:
S2_sigma = (I1**2 * edge_mask).std(dim)
S1_sigma = (I1 * edge_mask).std(dim)
else:
if self.use_2D:
if len(data.shape) == 2:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[None, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[:, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
elif self.use_1D:
if len(data.shape) == 1:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=dim,
)
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :] * filters_set[None, :J, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
dim=dim,
)
else:
print("todo")
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
if get_variance:
S2_sigma = self.backend.bk_reduce_std(
(I1 * edge_mask), axis=dim
)
I1 = self.backend.bk_L1(I1)
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
if get_variance:
S1_sigma = self.backend.bk_reduce_std(
(I1 * edge_mask), axis=dim
)
I1_f = self.backend.bk_fftn(I1, dim=dim)
if pseudo_coef != 1:
I1 = I1**pseudo_coef
Ndata_S3 = 0
Ndata_S4 = 0
# calculate the covariance and correlations of the scattering fields
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
for j3 in range(0, J):
J_S4[j3] = Ndata_S4
dx3, dy3 = self.get_dxdy(j3, M, N)
I1_f_small = self.cut_high_k_off(
I1_f[:, : j3 + 1], dx3, dy3
) # Nimage, J, L, x, y
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
if data2 is not None:
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
if edge:
I1_small = self.backend.bk_ifftn(
I1_f_small, dim=dim, norm="ortho"
)
data_small = self.backend.bk_ifftn(
data_f_small, dim=dim, norm="ortho"
)
if data2 is not None:
data2_small = self.backend.bk_ifftn(
data2_f_small, dim=dim, norm="ortho"
)
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
_, M3, N3 = wavelet_f3.shape
wavelet_f3_squared = wavelet_f3**2
if edge is True:
if (M3, N3, J, j3) not in self.edge_masks:
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
self.edge_masks[(M3, N3, J, j3)] = self.get_edge_masks(
M3, N3, J, in_mask=in_mask, edge_dx=edge_dx, edge_dy=edge_dy
)
edge_mask = self.edge_masks[(M3, N3, J, j3)]
else:
edge_mask = 1
# a normalization change due to the cutoff of frequency space
fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
for j2 in range(0, j3 + 1):
I1_f2_wf3_small = I1_f_small[:, j2].view(
N_image, L, 1, M3, N3
) * wavelet_f3.view(1, 1, L, M3, N3)
I1_f2_wf3_2_small = I1_f_small[:, j2].view(
N_image, L, 1, M3, N3
) * wavelet_f3_squared.view(1, 1, L, M3, N3)
if edge:
I12_w3_small = self.backend.bk_ifftn(
I1_f2_wf3_small, dim=dim, norm="ortho"
)
I12_w3_2_small = self.backend.bk_ifftn(
I1_f2_wf3_2_small, dim=dim, norm="ortho"
)
if use_ref:
if normalization == "P11":
norm_factor_S3 = (
ref_S2[:, None, j3, :]
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
) ** 0.5
if normalization == "S2":
norm_factor_S3 = (
ref_S2[:, None, j3, :]
* ref_S2[:, j2, :, None] ** pseudo_coef
) ** 0.5
else:
if normalization == "P11":
# [N_image,l2,l3,x,y]
P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
dim
) * fft_factor
norm_factor_S3 = (
S2[:, None, j3, :] * P11_temp**pseudo_coef
) ** 0.5
if normalization == "S2":
norm_factor_S3 = (
S2[:, None, j3, :] * S2[:, j2, :, None] ** pseudo_coef
) ** 0.5
if not edge:
S3[:, Ndata_S3, :, :] = (
(
data_f_small.view(N_image, 1, 1, M3, N3)
* self.backend.bk_conjugate(I1_f2_wf3_small)
).mean(dim)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3_sigma[:, Ndata_S3, :, :] = (
(
data_f_small.view(N_image, 1, 1, M3, N3)
* self.backend.bk_conjugate(I1_f2_wf3_small)
).std(dim)
* fft_factor
/ norm_factor_S3
)
else:
S3[:, Ndata_S3, :, :] = (
(
data_small.view(N_image, 1, 1, M3, N3)
* self.backend.bk_conjugate(I12_w3_small)
* edge_mask[None, None, None, :, :]
).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
dim
)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3_sigma[:, Ndata_S3, :, :] = (
(
data_small.view(N_image, 1, 1, M3, N3)
* self.backend.bk_conjugate(I12_w3_small)
* edge_mask[None, None, None, :, :]
).std(dim)
* fft_factor
/ norm_factor_S3
)
if data2 is not None:
if not edge:
S3p[:, Ndata_S3, :, :] = (
(
data2_f_small.view(N_image2, 1, 1, M3, N3)
* self.backend.bk_conjugate(I1_f2_wf3_small)
).mean(dim)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3p_sigma[:, Ndata_S3, :, :] = (
(
data2_f_small.view(N_image2, 1, 1, M3, N3)
* self.backend.bk_conjugate(I1_f2_wf3_small)
).std(dim)
* fft_factor
/ norm_factor_S3
)
else:
S3p[:, Ndata_S3, :, :] = (
(
data2_small.view(N_image2, 1, 1, M3, N3)
* self.backend.bk_conjugate(I12_w3_small)
* edge_mask[None, None, None, :, :]
).mean(dim)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3p_sigma[:, Ndata_S3, :, :] = (
(
data2_small.view(N_image2, 1, 1, M3, N3)
* self.backend.bk_conjugate(I12_w3_small)
* edge_mask[None, None, None, :, :]
).std(dim)
* fft_factor
/ norm_factor_S3
)
Ndata_S3 += 1
if j2 <= j3:
beg_n = Ndata_S4
for j1 in range(0, j2 + 1):
if eval(S4_criteria):
if not edge:
if not if_large_batch:
# [N_image,l1,l2,l3,x,y]
S4_pre_norm[:, Ndata_S4, :, :, :] = (
I1_f_small[:, j1].view(
N_image, L, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I1_f2_wf3_2_small.view(
N_image, 1, L, L, M3, N3
)
)
).mean(dim) * fft_factor
if get_variance:
S4_sigma[:, Ndata_S4, :, :, :] = (
I1_f_small[:, j1].view(
N_image, L, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I1_f2_wf3_2_small.view(
N_image, 1, L, L, M3, N3
)
)
).std(dim) * fft_factor
else:
for l1 in range(L):
# [N_image,l2,l3,x,y]
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
I1_f_small[:, j1, l1].view(
N_image, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I1_f2_wf3_2_small.view(
N_image, L, L, M3, N3
)
)
).mean(dim) * fft_factor
if get_variance:
S4_sigma[:, Ndata_S4, l1, :, :] = (
I1_f_small[:, j1, l1].view(
N_image, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I1_f2_wf3_2_small.view(
N_image, L, L, M3, N3
)
)
).std(dim) * fft_factor
else:
if not if_large_batch:
# [N_image,l1,l2,l3,x,y]
S4_pre_norm[:, Ndata_S4, :, :, :] = (
I1_small[:, j1].view(
N_image, L, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I12_w3_2_small.view(
N_image, 1, L, L, M3, N3
)
)
* edge_mask[None, None, None, None, :, :]
).mean(dim) * fft_factor
if get_variance:
S4_sigma[:, Ndata_S4, :, :, :] = (
I1_small[:, j1].view(
N_image, L, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I12_w3_2_small.view(
N_image, 1, L, L, M3, N3
)
)
* edge_mask[
None, None, None, None, :, :
]
).std(dim) * fft_factor
else:
for l1 in range(L):
# [N_image,l2,l3,x,y]
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
I1_small[:, j1].view(
N_image, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I12_w3_2_small.view(
N_image, L, L, M3, N3
)
)
* edge_mask[
None, None, None, None, :, :
]
).mean(dim) * fft_factor
if get_variance:
S4_sigma[:, Ndata_S4, l1, :, :] = (
I1_small[:, j1].view(
N_image, 1, 1, M3, N3
)
* self.backend.bk_conjugate(
I12_w3_2_small.view(
N_image, L, L, M3, N3
)
)
* edge_mask[
None, None, None, None, :, :
]
).std(dim) * fft_factor
Ndata_S4 += 1
if normalization == "S2":
if use_ref:
P = (
ref_S2[:, j3 : j3 + 1, :, None, None]
* ref_S2[:, j2 : j2 + 1, None, :, None]
) ** (0.5 * pseudo_coef)
else:
P = (
S2[:, j3 : j3 + 1, :, None, None]
* S2[:, j2 : j2 + 1, None, :, None]
) ** (0.5 * pseudo_coef)
S4[:, beg_n:Ndata_S4, :, :, :] = (
S4_pre_norm[:, beg_n:Ndata_S4, :, :, :].clone() / P
)
if get_variance:
S4_sigma[:, beg_n:Ndata_S4, :, :, :] = (
S4_sigma[:, beg_n:Ndata_S4, :, :, :] / P
)
else:
S4 = S4_pre_norm
# average over l1 to obtain simple isotropic statistics
if iso_ang:
S2_iso = S2.mean(-1)
S1_iso = S1.mean(-1)
for l1 in range(L):
for l2 in range(L):
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
if data2 is not None:
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
for l3 in range(L):
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[
..., l1, l2, l3
]
S3_iso /= L
S4_iso /= L
if data2 is not None:
S3p_iso /= L
if get_variance:
S2_sigma_iso = S2_sigma.mean(-1)
S1_sigma_iso = S1_sigma.mean(-1)
for l1 in range(L):
for l2 in range(L):
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
if data2 is not None:
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[
..., l1, l2
]
for l3 in range(L):
S4_sigma_iso[
..., (l2 - l1) % L, (l3 - l1) % L
] += S4_sigma[..., l1, l2, l3]
S3_sigma_iso /= L
S4_sigma_iso /= L
if data2 is not None:
S3p_sigma_iso /= L
# ---- fft_ang: project orientation axes onto Fourier harmonics ----
# Applied before building ref_sigma / for_synthesis so that all
# tensors (and their sigma counterparts) share the same nout shape.
# The raw S2 is preserved in S2_raw_for_norm because
# self.ref_scattering_cov_S2 must remain un-projected (it is used
# as a normalization divisor inside the scattering loops).
if fft_ang:
S2_raw_for_norm = S2
nout = 1 + fft_nharm * (2 if fft_imaginary else 1)
if (L, fft_nharm) not in self.backend._fft_1_orient:
self.backend.calc_fft_orient(L, fft_nharm, fft_imaginary)
if (L, fft_nharm, fft_imaginary) not in self.backend._fft_ang2_orient:
self.backend.calc_fft_ang_orient(L, fft_nharm, fft_imaginary)
m1 = self.backend._fft_1_orient[(L, fft_nharm, fft_imaginary)]
m1c = self.backend._fft_1_orient_C[(L, fft_nharm, fft_imaginary)]
m2 = self.backend._fft_ang2_orient[(L, fft_nharm, fft_imaginary)]
m2c = self.backend._fft_ang2_orient_C[(L, fft_nharm, fft_imaginary)]
m3 = self.backend._fft_ang3_orient[(L, fft_nharm, fft_imaginary)]
m3c = self.backend._fft_ang3_orient_C[(L, fft_nharm, fft_imaginary)]
bkc = self.backend.bk_is_complex
def _pa(x):
"""S1/S2 [N, J, L] → [N, J, nout]"""
n, j_ = x.shape[0], x.shape[1]
m = m1c if bkc(x) else m1
return self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(x, [n * j_, L]), m
), [n, j_, nout]
)
def _pb(x):
"""S3/S3p [N, Nd, L, L] → [N, Nd, L, nout]"""
n, nd = x.shape[0], x.shape[1]
m = m2c if bkc(x) else m2
return self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(x, [n * nd, L * L]), m
), [n, nd, L, nout]
)
def _pc(x):
"""S4 [N, Nd, L, L, L] → [N, Nd, L, L, nout]"""
n, nd = x.shape[0], x.shape[1]
m = m3c if bkc(x) else m3
return self.backend.bk_reshape(
self.backend.backend.matmul(
self.backend.bk_reshape(x, [n * nd, L * L * L]), m
), [n, nd, L, L, nout]
)
# S1/S2 are positive power spectra: log first, then project.
# This ensures the projected values represent log-space harmonics,
# which are always well-defined and avoids NaN from log(negative).
# S3/S4 are complex covariances: project directly (no log).
S2 = _pa(S2)
S1 = _pa(S1)
S3 = _pb(S3)
S4 = _pc(S4)
if get_variance:
S2_sigma = _pa(S2_sigma)
S1_sigma = _pa(S1_sigma)
S3_sigma = _pb(S3_sigma)
S4_sigma = _pc(S4_sigma)
if data2 is not None:
S3p = _pb(S3p)
if get_variance:
S3p_sigma = _pb(S3p_sigma)
# ref_sigma passed in is already fft_ang'd — do not re-project.
mean_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
if data2 is None:
mean_data[:, 0] = data.mean(dim)
std_data[:, 0] = data.std(dim)
else:
mean_data[:, 0] = (data2 * data).mean(dim)
std_data[:, 0] = (data2 * data).std(dim)
if get_variance:
ref_sigma = {}
if iso_ang:
ref_sigma["std_data"] = std_data
ref_sigma["S1_sigma"] = S1_sigma_iso
ref_sigma["S2_sigma"] = S2_sigma_iso
ref_sigma["S3_sigma"] = S3_sigma_iso
if data2 is not None:
ref_sigma["S3p_sigma"] = S3p_sigma_iso
ref_sigma["S4_sigma"] = S4_sigma_iso
else:
ref_sigma["std_data"] = std_data
ref_sigma["S1_sigma"] = S1_sigma
ref_sigma["S2_sigma"] = S2_sigma
ref_sigma["S3_sigma"] = S3_sigma
if data2 is not None:
ref_sigma["S3p_sigma"] = S3p_sigma
ref_sigma["S4_sigma"] = S4_sigma
if data2 is None:
if iso_ang:
if ref_sigma is not None:
if return_table:
return (S1_iso / ref_sigma["S1_sigma"]), \
(S2_iso / ref_sigma["S2_sigma"]) , \
(S3_iso / ref_sigma["S3_sigma"]) , \
(S4_iso / ref_sigma["S4_sigma"])
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
(S2_iso / ref_sigma["S2_sigma"])
.reshape((N_image, -1))
.log(),
(S1_iso / ref_sigma["S1_sigma"])
.reshape((N_image, -1))
.log(),
(S3_iso / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.real,
(S3_iso / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.imag,
(S4_iso / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.real,
(S4_iso / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.imag,
),
dim=-1,
)
else:
if return_table:
return S1_iso,S2_iso,S3_iso,S4_iso
for_synthesis = self.backend.backend.cat(
(
mean_data / std_data,
std_data,
S2_iso.reshape((N_image, -1)).log(),
S1_iso.reshape((N_image, -1)).log(),
S3_iso.reshape((N_image, -1)).real,
S3_iso.reshape((N_image, -1)).imag,
S4_iso.reshape((N_image, -1)).real,
S4_iso.reshape((N_image, -1)).imag,
),
dim=-1,
)
else:
if ref_sigma is not None:
if return_table:
return (S1 / ref_sigma["S1_sigma"]), \
(S2 / ref_sigma["S2_sigma"]), \
(S3 / ref_sigma["S3_sigma"]), \
(S4 / ref_sigma["S4_sigma"])
if fft_ang:
# S2/S1 are _pa(log(.)); ref_sigma sigmas are _pa(log(sigma)).
# Normalized log-projection = _pa(log S2 - log S2_sigma)
# = _pa(log S2) - _pa(log S2_sigma) = S2 - S2_sigma.
s2_n = (S2 - ref_sigma["S2_sigma"]).reshape((N_image, -1))
s1_n = (S1 - ref_sigma["S1_sigma"]).reshape((N_image, -1))
else:
s2_n = (S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)).log()
s1_n = (S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)).log()
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
s2_n,
s1_n,
(S3 / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.real,
(S3 / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.imag,
(S4 / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.real,
(S4 / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.imag,
),
dim=-1,
)
else:
if return_table:
return S1,S2,S3,S4
if fft_ang:
# S2/S1 already = _pa(log(S2_orig)): use directly.
s2_n = S2.reshape((N_image, -1))
s1_n = S1.reshape((N_image, -1))
else:
s2_n = S2.reshape((N_image, -1)).log()
s1_n = S1.reshape((N_image, -1)).log()
for_synthesis = self.backend.backend.cat(
(
mean_data / std_data,
std_data,
s2_n,
s1_n,
S3.reshape((N_image, -1)).real,
S3.reshape((N_image, -1)).imag,
S4.reshape((N_image, -1)).real,
S4.reshape((N_image, -1)).imag,
),
dim=-1,
)
else:
if iso_ang:
if ref_sigma is not None:
if return_table:
return (S1_iso / ref_sigma["S1_sigma"]), \
(S2_iso / ref_sigma["S2_sigma"]), \
(S3_iso / ref_sigma["S3_sigma"]), \
(S4_iso / ref_sigma["S4_sigma"])
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
(S2_iso / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
(S1_iso / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
(S3_iso / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.real,
(S3_iso / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.imag,
(S3p_iso / ref_sigma["S3p_sigma"])
.reshape((N_image, -1))
.real,
(S3p_iso / ref_sigma["S3p_sigma"])
.reshape((N_image, -1))
.imag,
(S4_iso / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.real,
(S4_iso / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.imag,
),
dim=-1,
)
else:
if return_table:
return S1_iso,S2_iso,S3_iso,S4_iso
for_synthesis = self.backend.backend.cat(
(
mean_data / std_data,
std_data,
S2_iso.reshape((N_image, -1)),
S1_iso.reshape((N_image, -1)),
S3_iso.reshape((N_image, -1)).real,
S3_iso.reshape((N_image, -1)).imag,
S3p_iso.reshape((N_image, -1)).real,
S3p_iso.reshape((N_image, -1)).imag,
S4_iso.reshape((N_image, -1)).real,
S4_iso.reshape((N_image, -1)).imag,
),
dim=-1,
)
else:
if ref_sigma is not None:
if return_table:
return (S1 / ref_sigma["S1_sigma"]), \
(S2 / ref_sigma["S2_sigma"]), \
(S3 / ref_sigma["S3_sigma"]), \
(S4 / ref_sigma["S4_sigma"])
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
(S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
(S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
(S3 / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.real,
(S3 / ref_sigma["S3_sigma"])
.reshape((N_image, -1))
.imag,
(S3p / ref_sigma["S3p_sigma"])
.reshape((N_image, -1))
.real,
(S3p / ref_sigma["S3p_sigma"])
.reshape((N_image, -1))
.imag,
(S4 / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.real,
(S4 / ref_sigma["S4_sigma"])
.reshape((N_image, -1))
.imag,
),
dim=-1,
)
else:
if return_table:
return S1,S2,S3,S4
for_synthesis = self.backend.backend.cat(
(
mean_data / std_data,
std_data,
S2.reshape((N_image, -1)),
S1.reshape((N_image, -1)),
S3.reshape((N_image, -1)).real,
S3.reshape((N_image, -1)).imag,
S3p.reshape((N_image, -1)).real,
S3p.reshape((N_image, -1)).imag,
S4.reshape((N_image, -1)).real,
S4.reshape((N_image, -1)).imag,
),
dim=-1,
)
if not use_ref:
# Store the original (un-projected, un-logged) S2 so that
# subsequent use_ref=True calls can normalise raw S3/S4 values
# by true wavelet power before projecting.
self.ref_scattering_cov_S2 = (
S2_raw_for_norm if fft_ang else S2
)
if get_variance:
return for_synthesis, ref_sigma
return for_synthesis
if (M, N, J, L) not in self.filters_set:
self.filters_set[(M, N, J, L)] = self.computer_filter(
M, N, J, L
) # self.computer_filter(M,N,J,L)
filters_set = self.filters_set[(M, N, J, L)]
# weight = self.weight
if use_ref:
if normalization == "S2":
ref_S2 = self.ref_scattering_cov_S2
else:
ref_P11 = self.ref_scattering_cov["P11"]
# convert numpy array input into self.backend.bk_ tensors
data = self.backend.bk_cast(data)
data_f = self.backend.bk_fftn(data, dim=dim)
if data2 is not None:
data2 = self.backend.bk_cast(data2)
data2_f = self.backend.bk_fftn(data2, dim=dim)
# initialize tensors for scattering coefficients
Ndata_S3 = J * (J + 1) // 2
Ndata_S4 = J * (J + 1) * (J + 2) // 6
J_S4 = {}
S3 = []
if data2 is not None:
S3p = []
S4_pre_norm = []
S4 = []
# variance
if get_variance:
S3_sigma = []
if data2 is not None:
S3p_sigma = []
S4_sigma = []
if iso_ang:
S3_iso = []
if data2 is not None:
S3p_iso = []
S4_iso = []
if get_variance:
S3_sigma_iso = []
if data2 is not None:
S3p_sigma_iso = []
S4_sigma_iso = []
#
if edge:
if (M, N, J) not in self.edge_masks:
self.edge_masks[(M, N, J)] = self.get_edge_masks(
M, N, J, in_mask=in_mask
)
edge_mask = self.edge_masks[(M, N, J)]
else:
edge_mask = 1
# calculate scattering fields
if data2 is None:
if self.use_2D:
if len(data.shape) == 2:
I1 = self.backend.bk_abs(
self.backend.bk_ifftn(
data_f[None, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
)
else:
I1 = self.backend.bk_abs(
self.backend.bk_ifftn(
data_f[:, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
)
elif self.use_1D:
if len(data.shape) == 1:
I1 = self.backend.bk_abs(
self.backend.bk_ifftn(
data_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
)
)
else:
I1 = self.backend.bk_abs(
self.backend.bk_ifftn(
data_f[:, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
)
)
else:
print("todo")
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=dim)
S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=dim)
if get_variance:
S2_sigma = self.backend.bk_reduce_std(
(I1**2 * edge_mask), axis=dim
)
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
I1_f = self.backend.bk_fftn(I1, dim=dim)
else:
if self.use_2D:
if len(data.shape) == 2:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[None, None, None, :, :]
* filters_set[None, :J, :, :, :],
dim=dim,
)
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
dim=dim,
)
I2 = self.backend.bk_ifftn(
data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
dim=dim,
)
elif self.use_1D:
if len(data.shape) == 1:
I1 = self.backend.bk_ifftn(
data_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
)
I2 = self.backend.bk_ifftn(
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
)
else:
I1 = self.backend.bk_ifftn(
data_f[:, None, None, :] * filters_set[None, :J, :, :], dim=(-1)
)
I2 = self.backend.bk_ifftn(
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
dim=(-1),
)
else:
print("todo")
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
if get_variance:
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
I1 = self.backend.bk_L1(I1)
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
if get_variance:
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
I1_f = self.backend.bk_fftn(I1, dim=dim)
if pseudo_coef != 1:
I1 = I1**pseudo_coef
Ndata_S3 = 0
Ndata_S4 = 0
# calculate the covariance and correlations of the scattering fields
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
for j3 in range(0, J):
J_S4[j3] = Ndata_S4
dx3, dy3 = self.get_dxdy(j3, M, N)
I1_f_small = self.cut_high_k_off(
I1_f[:, : j3 + 1], dx3, dy3
) # Nimage, J, L, x, y
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
if data2 is not None:
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
if edge:
I1_small = self.backend.bk_ifftn(I1_f_small, dim=dim, norm="ortho")
data_small = self.backend.bk_ifftn(
data_f_small, dim=dim, norm="ortho"
)
if data2 is not None:
data2_small = self.backend.bk_ifftn(
data2_f_small, dim=dim, norm="ortho"
)
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
_, M3, N3 = wavelet_f3.shape
wavelet_f3_squared = wavelet_f3**2
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
# a normalization change due to the cutoff of frequency space
if self.all_bk_type == "float32":
fft_factor = np.complex64(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
else:
fft_factor = np.complex128(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
for j2 in range(0, j3 + 1):
# I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
# I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
I1_f2_wf3_small = self.backend.bk_reshape(
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
) * self.backend.bk_reshape(wavelet_f3, [1, 1, 1, L, M3, N3])
I1_f2_wf3_2_small = self.backend.bk_reshape(
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
if edge:
I12_w3_small = self.backend.bk_ifftn(
I1_f2_wf3_small, dim=dim, norm="ortho"
)
I12_w3_2_small = self.backend.bk_ifftn(
I1_f2_wf3_2_small, dim=dim, norm="ortho"
)
if use_ref:
if normalization == "P11":
norm_factor_S3 = (
ref_S2[:, None, j3, :]
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
) ** 0.5
norm_factor_S3 = self.backend.bk_complex(
norm_factor_S3, 0 * norm_factor_S3
)
elif normalization == "S2":
norm_factor_S3 = (
ref_S2[:, None, j3, :]
* ref_S2[:, j2, :, None] ** pseudo_coef
) ** 0.5
norm_factor_S3 = self.backend.bk_complex(
norm_factor_S3, 0 * norm_factor_S3
)
else:
norm_factor_S3 = C_ONE
else:
if normalization == "P11":
# [N_image,l2,l3,x,y]
P11_temp = (
self.backend.bk_reduce_mean(
(I1_f2_wf3_small.abs() ** 2), axis=dim
)
* fft_factor
)
norm_factor_S3 = (
S2[:, None, j3, :] * P11_temp**pseudo_coef
) ** 0.5
norm_factor_S3 = self.backend.bk_complex(
norm_factor_S3, 0 * norm_factor_S3
)
elif normalization == "S2":
norm_factor_S3 = (
S2[:, None, j3, None, :]
* S2[:, None, j2, :, None] ** pseudo_coef
) ** 0.5
norm_factor_S3 = self.backend.bk_complex(
norm_factor_S3, 0 * norm_factor_S3
)
else:
norm_factor_S3 = C_ONE
if not edge:
S3.append(
self.backend.bk_reduce_mean(
self.backend.bk_reshape(
data_f_small, [N_image, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I1_f2_wf3_small),
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3_sigma.append(
self.backend.bk_reduce_std(
self.backend.bk_reshape(
data_f_small, [N_image, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I1_f2_wf3_small),
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
else:
S3.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
data_small, [N_image, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I12_w3_small)
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3_sigma.apend(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
data_small, [N_image, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I12_w3_small)
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if data2 is not None:
if not edge:
S3p.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I1_f2_wf3_small)
),
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3p_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I1_f2_wf3_small)
),
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
else:
S3p.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
data2_small, [N_image2, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I12_w3_small)
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if get_variance:
S3p_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
data2_small, [N_image2, 1, 1, 1, M3, N3]
)
* self.backend.bk_conjugate(I12_w3_small)
)[
...,
edge_dx : M3 - edge_dx,
edge_dy : N3 - edge_dy,
],
axis=dim,
)
* fft_factor
/ norm_factor_S3
)
if j2 <= j3:
if normalization == "S2":
if use_ref:
P = 1 / (
(
ref_S2[:, j3 : j3 + 1, :, None, None]
* ref_S2[:, j2 : j2 + 1, None, :, None]
)
** (0.5 * pseudo_coef)
)
else:
P = 1 / (
(
S2[:, j3 : j3 + 1, :, None, None]
* S2[:, j2 : j2 + 1, None, :, None]
)
** (0.5 * pseudo_coef)
)
P = self.backend.bk_complex(P, 0.0 * P)
else:
P = C_ONE
for j1 in range(0, j2 + 1):
if not edge:
if not if_large_batch:
# [N_image,l1,l2,l3,x,y]
S4.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
I1_f_small[:, j1],
[N_image, 1, L, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I1_f2_wf3_2_small,
[N_image, 1, 1, L, L, M3, N3],
)
)
),
axis=dim,
)
* fft_factor
* P
)
if get_variance:
S4_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
I1_f_small[:, j1],
[N_image, 1, L, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I1_f2_wf3_2_small,
[N_image, 1, 1, L, L, M3, N3],
)
)
),
axis=dim,
)
* fft_factor
* P
)
else:
for l1 in range(L):
# [N_image,l2,l3,x,y]
S4.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
I1_f_small[:, j1, l1],
[N_image, 1, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I1_f2_wf3_2_small,
[N_image, 1, L, L, M3, N3],
)
)
),
axis=dim,
)
* fft_factor
* P
)
if get_variance:
S4_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
I1_f_small[:, j1, l1],
[N_image, 1, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I1_f2_wf3_2_small,
[N_image, 1, L, L, M3, N3],
)
)
),
axis=dim,
)
* fft_factor
* P
)
else:
if not if_large_batch:
# [N_image,l1,l2,l3,x,y]
S4.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
I1_small[:, j1],
[N_image, 1, L, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I12_w3_2_small,
[N_image, 1, 1, L, L, M3, N3],
)
)
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
axis=dim,
)
* fft_factor
* P
)
if get_variance:
S4_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
I1_small[:, j1],
[N_image, 1, L, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I12_w3_2_small,
[N_image, 1, 1, L, L, M3, N3],
)
)
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
axis=dim,
)
* fft_factor
* P
)
else:
for l1 in range(L):
# [N_image,l2,l3,x,y]
S4.append(
self.backend.bk_reduce_mean(
(
self.backend.bk_reshape(
I1_small[:, j1],
[N_image, 1, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I12_w3_2_small,
[N_image, 1, L, L, M3, N3],
)
)
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
axis=dim,
)
* fft_factor
* P
)
if get_variance:
S4_sigma.append(
self.backend.bk_reduce_std(
(
self.backend.bk_reshape(
I1_small[:, j1],
[N_image, 1, 1, 1, M3, N3],
)
* self.backend.bk_conjugate(
self.backend.bk_reshape(
I12_w3_2_small,
[N_image, 1, L, L, M3, N3],
)
)
)[
...,
edge_dx:-edge_dx,
edge_dy:-edge_dy,
],
axis=dim,
)
* fft_factor
* P
)
S3 = self.backend.bk_concat(S3, axis=1)
S4 = self.backend.bk_concat(S4, axis=1)
if get_variance:
S3_sigma = self.backend.bk_concat(S3_sigma, axis=1)
S4_sigma = self.backend.bk_concat(S4_sigma, axis=1)
if data2 is not None:
S3p = self.backend.bk_concat(S3p, axis=1)
if get_variance:
S3p_sigma = self.backend.bk_concat(S3p_sigma, axis=1)
# average over l1 to obtain simple isotropic statistics
if iso_ang:
S2_iso = self.backend.bk_reduce_mean(S2, axis=(-1))
S1_iso = self.backend.bk_reduce_mean(S1, axis=(-1))
for l1 in range(L):
for l2 in range(L):
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
if data2 is not None:
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
for l3 in range(L):
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[..., l1, l2, l3]
S3_iso /= L
S4_iso /= L
if data2 is not None:
S3p_iso /= L
if get_variance:
S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma, axis=(-1))
S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma, axis=(-1))
for l1 in range(L):
for l2 in range(L):
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
if data2 is not None:
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[..., l1, l2]
for l3 in range(L):
S4_sigma_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4_sigma[
..., l1, l2, l3
]
S3_sigma_iso /= L
S4_sigma_iso /= L
if data2 is not None:
S3p_sigma_iso /= L
if data2 is None:
mean_data = self.backend.bk_reshape(
self.backend.bk_reduce_mean(data, axis=dim), [N_image, 1]
)
std_data = self.backend.bk_reshape(
self.backend.bk_reduce_std(data, axis=dim), [N_image, 1]
)
else:
mean_data = self.backend.bk_reshape(
self.backend.bk_reduce_mean(data * data2, axis=dim), [N_image, 1]
)
std_data = self.backend.bk_reshape(
self.backend.bk_reduce_std(data * data2, axis=dim), [N_image, 1]
)
if get_variance:
ref_sigma = {}
if iso_ang:
ref_sigma["std_data"] = std_data
ref_sigma["S1_sigma"] = S1_sigma_iso
ref_sigma["S2_sigma"] = S2_sigma_iso
ref_sigma["S3_sigma"] = S3_sigma_iso
ref_sigma["S4_sigma"] = S4_sigma_iso
if data2 is not None:
ref_sigma["S3p_sigma"] = S3p_sigma_iso
else:
ref_sigma["std_data"] = std_data
ref_sigma["S1_sigma"] = S1_sigma
ref_sigma["S2_sigma"] = S2_sigma
ref_sigma["S3_sigma"] = S3_sigma
ref_sigma["S4_sigma"] = S4_sigma
if data2 is not None:
ref_sigma["S3p_sigma"] = S3_sigma
if data2 is None:
if iso_ang:
if ref_sigma is not None:
for_synthesis = self.backend.bk_concat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
self.backend.bk_reshape(
self.backend.bk_log(S2_iso / ref_sigma["S2_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_log(S1_iso / ref_sigma["S1_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
[N_image, -1],
),
),
axis=-1,
)
else:
for_synthesis = self.backend.bk_concat(
(
mean_data / std_data,
std_data,
self.backend.bk_reshape(
self.backend.bk_log(S2_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_log(S1_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S4_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S4_iso), [N_image, -1]
),
),
axis=-1,
)
else:
if ref_sigma is not None:
for_synthesis = self.backend.bk_concat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
self.backend.bk_reshape(
self.backend.bk_log(S2 / ref_sigma["S2_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_log(S1 / ref_sigma["S1_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
[N_image, -1],
),
),
axis=-1,
)
else:
for_synthesis = self.backend.bk_concat(
(
mean_data / std_data,
std_data,
self.backend.bk_reshape(
self.backend.bk_log(S2), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_log(S1), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S4), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S4), [N_image, -1]
),
),
axis=-1,
)
else:
if iso_ang:
if ref_sigma is not None:
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
self.backend.bk_reshape(
self.backend.bk_real(S2_iso / ref_sigma["S2_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S1_iso / ref_sigma["S1_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3p_iso / ref_sigma["S3p_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3p_iso / ref_sigma["S3p_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
[N_image, -1],
),
),
axis=-1,
)
else:
for_synthesis = self.backend.backend.cat(
(
mean_data / std_data,
std_data,
self.backend.bk_reshape(
self.backend.bk_real(S2_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S1_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3p_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3p_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S4_iso), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S4_iso), [N_image, -1]
),
),
axis=-1,
)
else:
if ref_sigma is not None:
for_synthesis = self.backend.backend.cat(
(
mean_data / ref_sigma["std_data"],
std_data / ref_sigma["std_data"],
self.backend.bk_reshape(
self.backend.bk_real(S2 / ref_sigma["S2_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S1 / ref_sigma["S1_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S3p / ref_sigma["S3p_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S3p / ref_sigma["S3p_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
[N_image, -1],
),
self.backend.bk_reshape(
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
[N_image, -1],
),
),
axis=-1,
)
else:
for_synthesis = self.backend.bk_concat(
(
mean_data / std_data,
std_data,
self.backend.bk_reshape(
self.backend.bk_real(S2), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S1), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S3p), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S3p), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_real(S4), [N_image, -1]
),
self.backend.bk_reshape(
self.backend.bk_imag(S4), [N_image, -1]
),
),
axis=-1,
)
if not use_ref:
# Store the original (un-projected, un-logged) S2 so that
# subsequent use_ref=True calls can normalise raw S3/S4 values
# by true wavelet power before projecting.
self.ref_scattering_cov_S2 = (
S2_raw_for_norm if fft_ang else S2
)
if get_variance:
return for_synthesis, ref_sigma
return for_synthesis
[docs]
def purge_edge_mask(self):
list_edge = []
for k in self.edge_masks:
list_edge.append(k)
for k in list_edge:
del self.edge_masks[k]
self.edge_masks = {}
[docs]
def to_gaussian(self, x, in_mask=None):
from scipy.interpolate import interp1d
from scipy.stats import norm
if in_mask is not None:
m_idx = np.where(in_mask.flatten() > 0)[0]
idx = np.argsort(x.flatten()[m_idx])
p = norm.ppf((np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0])
im_target = x.flatten()
im_target[m_idx[idx]] = p
self.f_gaussian = interp1d(
im_target[m_idx[idx]], x.flatten()[m_idx[idx]], kind="cubic"
)
self.val_min = im_target[m_idx][idx[0]]
self.val_max = im_target[m_idx][idx[-1]]
else:
idx = np.argsort(x.flatten())
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
im_target = x.flatten()
im_target[idx] = norm.ppf(p)
# Interpolation cubique
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
self.val_min = im_target[idx[0]]
self.val_max = im_target[idx[-1]]
return im_target.reshape(x.shape)
[docs]
def from_gaussian(self, x):
x = self.backend.bk_clip_by_value(x,
self.val_min+1E-7*(self.val_max-self.val_min),
self.val_max-1E-7*(self.val_max-self.val_min))
return self.f_gaussian(self.backend.to_numpy(x))
[docs]
def square(self, x):
if isinstance(x, scat_cov):
if x.S1 is None:
return scat_cov(
self.backend.bk_square(self.backend.bk_abs(x.S0)),
self.backend.bk_square(self.backend.bk_abs(x.S2)),
self.backend.bk_square(self.backend.bk_abs(x.S3)),
self.backend.bk_square(self.backend.bk_abs(x.S4)),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
self.backend.bk_square(self.backend.bk_abs(x.S0)),
self.backend.bk_square(self.backend.bk_abs(x.S2)),
self.backend.bk_square(self.backend.bk_abs(x.S3)),
self.backend.bk_square(self.backend.bk_abs(x.S4)),
s1=self.backend.bk_square(self.backend.bk_abs(x.S1)),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return self.backend.bk_abs(self.backend.bk_square(x))
[docs]
def sqrt(self, x):
if isinstance(x, scat_cov):
if x.S1 is None:
return scat_cov(
self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S3)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S4)),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
self.backend.bk_sqrt(self.backend.bk_abs(x.S0)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S2)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S3)),
self.backend.bk_sqrt(self.backend.bk_abs(x.S4)),
s1=self.backend.bk_sqrt(self.backend.bk_abs(x.S1)),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return self.backend.bk_abs(self.backend.bk_sqrt(x))
[docs]
def reduce_mean(self, x):
if isinstance(x, scat_cov):
result = (
self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3))
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
)
N = (
self.backend.bk_size(x.S0)
+ self.backend.bk_size(x.S2)
+ self.backend.bk_size(x.S3)
+ self.backend.bk_size(x.S4)
)
if x.S1 is not None:
result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
N = N + self.backend.bk_size(x.S1)
if x.S3P is not None:
result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
N = N + self.backend.bk_size(x.S3P)
return result / self.backend.bk_cast(N)
else:
return self.backend.bk_reduce_mean(x, axis=0)
[docs]
def reduce_mean_batch(self, x):
if isinstance(x, scat_cov):
sS0 = self.backend.bk_reduce_mean(x.S0, axis=0)
sS2 = self.backend.bk_reduce_mean(x.S2, axis=0)
sS3 = self.backend.bk_reduce_mean(x.S3, axis=0)
sS4 = self.backend.bk_reduce_mean(x.S4, axis=0)
sS1 = None
sS3P = None
if x.S1 is not None:
sS1 = self.backend.bk_reduce_mean(x.S1, axis=0)
if x.S3P is not None:
sS3P = self.backend.bk_reduce_mean(x.S3P, axis=0)
result = scat_cov(
sS0,
sS2,
sS3,
sS4,
s1=sS1,
s3p=sS3P,
backend=self.backend,
use_1D=self.use_1D,
)
return result
else:
return self.backend.bk_reduce_mean(x, axis=0)
[docs]
def reduce_sum_batch(self, x):
if isinstance(x, scat_cov):
sS0 = self.backend.bk_reduce_sum(x.S0, axis=0)
sS2 = self.backend.bk_reduce_sum(x.S2, axis=0)
sS3 = self.backend.bk_reduce_sum(x.S3, axis=0)
sS4 = self.backend.bk_reduce_sum(x.S4, axis=0)
sS1 = None
sS3P = None
if x.S1 is not None:
sS1 = self.backend.bk_reduce_sum(x.S1, axis=0)
if x.S3P is not None:
sS3P = self.backend.bk_reduce_sum(x.S3P, axis=0)
result = scat_cov(
sS0,
sS2,
sS3,
sS4,
s1=sS1,
s3p=sS3P,
backend=self.backend,
use_1D=self.use_1D,
)
return result
else:
return self.backend.bk_reduce_mean(x, axis=0)
[docs]
def reduce_distance(self, x, y, sigma=None):
if isinstance(x, scat_cov):
if sigma is None:
result = self.diff_data(y.S0, x.S0, is_complex=False)
if x.S1 is not None:
result += self.diff_data(y.S1, x.S1)
if x.S3P is not None:
result += self.diff_data(y.S3P, x.S3P)
result += self.diff_data(y.S2, x.S2)
result += self.diff_data(y.S3, x.S3)
result += self.diff_data(y.S4, x.S4)
else:
result = self.diff_data(y.S0, x.S0, is_complex=False, sigma=sigma.S0)
if x.S1 is not None:
result += self.diff_data(y.S1, x.S1, sigma=sigma.S1)
if x.S3P is not None:
result += self.diff_data(y.S3P, x.S3P, sigma=sigma.S3P)
result += self.diff_data(y.S2, x.S2, sigma=sigma.S2)
result += self.diff_data(y.S3, x.S3, sigma=sigma.S3)
result += self.diff_data(y.S4, x.S4, sigma=sigma.S4)
nval = (
self.backend.bk_size(x.S0)
+ self.backend.bk_size(x.S2)
+ self.backend.bk_size(x.S3)
+ self.backend.bk_size(x.S4)
)
if x.S1 is not None:
nval += self.backend.bk_size(x.S1)
if x.S3P is not None:
nval += self.backend.bk_size(x.S3P)
result /= self.backend.bk_cast(nval)
return result
else:
if sigma is None:
tmp = self.diff_data(x,y)
else:
tmp = self.diff_data(x,y,sigma=sigma)
# do abs in case of complex values
return tmp/x.shape[0]
[docs]
def reduce_sum(self, x):
if isinstance(x, scat_cov):
if x.S1 is None:
result = (
self.backend.bk_reduce_sum(x.S0)
+ self.backend.bk_reduce_sum(x.S2)
+ self.backend.bk_reduce_sum(x.S3)
+ self.backend.bk_reduce_sum(x.S4)
)
else:
result = (
self.backend.bk_reduce_sum(x.S0)
+ self.backend.bk_reduce_sum(x.S2)
+ self.backend.bk_reduce_sum(x.S1)
+ self.backend.bk_reduce_sum(x.S3)
+ self.backend.bk_reduce_sum(x.S4)
)
else:
return self.backend.bk_reduce_sum(x)
return result
[docs]
def ldiff(self, sig, x):
if x.S1 is None:
if x.S3P is not None:
return scat_cov(
x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
x.domult(sig.S3, x.S3) * x.domult(sig.S3, x.S3),
x.domult(sig.S4, x.S4) * x.domult(sig.S4, x.S4),
S3P=x.domult(sig.S3P, x.S3P) * x.domult(sig.S3P, x.S3P),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
x.domult(sig.S3, x.S3) * x.domult(sig.S3, x.S3),
x.domult(sig.S4, x.S4) * x.domult(sig.S4, x.S4),
backend=self.backend,
use_1D=self.use_1D,
)
else:
if x.S3P is None:
return scat_cov(
x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
x.domult(sig.S3, x.S3) * x.domult(sig.S3, x.S3),
x.domult(sig.S4, x.S4) * x.domult(sig.S4, x.S4),
S1=x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
S3P=x.domult(sig.S3P, x.S3P) * x.domult(sig.S3P, x.S3P),
backend=self.backend,
use_1D=self.use_1D,
)
else:
return scat_cov(
x.domult(sig.S0, x.S0) * x.domult(sig.S0, x.S0),
x.domult(sig.S2, x.S2) * x.domult(sig.S2, x.S2),
x.domult(sig.S3, x.S3) * x.domult(sig.S3, x.S3),
x.domult(sig.S4, x.S4) * x.domult(sig.S4, x.S4),
S1=x.domult(sig.S1, x.S1) * x.domult(sig.S1, x.S1),
backend=self.backend,
use_1D=self.use_1D,
)
[docs]
def log(self, x):
if isinstance(x, scat_cov):
if x.S1 is None:
result = (
self.backend.bk_log(x.S0)
+ self.backend.bk_log(x.S2)
+ self.backend.bk_log(x.S3)
+ self.backend.bk_log(x.S4)
)
else:
result = (
self.backend.bk_log(x.S0)
+ self.backend.bk_log(x.S2)
+ self.backend.bk_log(x.S1)
+ self.backend.bk_log(x.S3)
+ self.backend.bk_log(x.S4)
)
else:
return self.backend.bk_log(x)
return result
[docs]
@tf_function
def eval_comp_fast(
self,
image1,
image2=None,
mask=None,
norm=None,
cmat=None,
cmat2=None,
):
res = self.eval(image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2)
return res.S0, res.S2, res.S1, res.S3, res.S4, res.S3P
[docs]
def eval_fast(
self,
image1,
image2=None,
mask=None,
norm=None,
cmat=None,
cmat2=None,
):
s0, s2, s1, s3, s4, s3p = self.eval_comp_fast(
image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
)
return scat_cov(
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
)
[docs]
def calc_matrix_orientation(self,noise_map,image2=None):
# Circular shift via permutation matrix
def circ_shift_matrix(N,k):
return np.roll(np.eye(N), shift=-k, axis=1)
Norient = self.NORIENT
im=self.convol(noise_map)
if image2 is None:
mm=np.mean(abs(self.backend.to_numpy(im)),0)
else:
im2=self.convol(self.backend.bk_cast(image2))
mm=np.mean(self.backend.to_numpy(
self.backend.bk_L1(im*self.backend.bk_conjugate(im2))).real,0)
Norient=mm.shape[0]
xx=np.cos(np.arange(Norient)/Norient*2*np.pi)
yy=np.sin(np.arange(Norient)/Norient*2*np.pi)
a=np.sum(mm*xx[:,None],0)
b=np.sum(mm*yy[:,None],0)
o=np.fmod(Norient*np.arctan2(-b,a)/(2*np.pi)+Norient,Norient)
xx=np.arange(Norient)
alpha = o[None,:]-xx[:,None]
beta = np.fmod(1+o[None,:]-xx[:,None],Norient)
alpha=(1-alpha)*(alpha<1)*(alpha>0)+beta*(beta<1)*(beta>0)
m=np.zeros([Norient,Norient,mm.shape[1]])
for k in range(Norient):
m[k,:,:]=np.roll(alpha,k,0)
#m=np.mean(m,0)
return self.backend.bk_cast(m)
[docs]
def synthesis(
self,
image_target,
reference=None,
nstep=4,
seed=1234,
Jmax=None,
edge=False,
to_gaussian=False,
use_variance=True,
synthesised_N=1,
input_image=None,
grd_mask=None,
in_mask=None,
iso_ang=False,
fft_ang=False,
fft_nharm=1,
fft_imaginary=True,
EVAL_FREQUENCY=100,
NUM_EPOCHS=300,
scat_cov_method='eval',
n_up=0,
):
"""Synthesise a new field whose scattering-covariance statistics match
those of ``image_target``.
This is the main high-level entry point for texture synthesis, denoising,
inpainting, and component separation with FOSCAT. Internally it runs
L-BFGS-B (``scipy.optimize.fmin_l_bfgs_b``) on the scattering-covariance
loss, with an optional coarse-to-fine multi-resolution schedule controlled
by ``nstep``.
The optimisation minimises:
.. math::
u^* = \\arg\\min_u \\mathcal{L}(u), \\qquad
\\mathcal{L}(u) = \\sum_k
\\frac{(\\Phi(u)_k - \\Phi(d)_k)^2}{\\sigma_k^2}
where :math:`\\Phi` is the scattering-covariance operator, :math:`d` is
``image_target``, and :math:`\\sigma_k^2` is the per-coefficient variance
of :math:`\\Phi(d)` (used only when ``use_variance=True``).
Parameters
----------
image_target : array-like
Reference field whose statistics are to be reproduced.
- **HEALPix:** shape ``(npix,)`` or ``(B, npix)`` with
``npix = 12 * nside**2``.
- **2D:** shape ``(H, W)`` or ``(B, H, W)``.
- **1D:** shape ``(N,)`` or ``(B, N)``.
Normalise the input before calling (zero mean, unit variance is
standard)::
xnorm = (image - image.mean()) / image.std()
reference : array-like, optional
Second reference field for **cross-covariance** synthesis. When
provided the loss matches the cross-statistics
:math:`\\Phi_\\times(u, d_2)` against the target
:math:`\\Phi_\\times(d_1, d_2)`, where ``d1 = image_target`` and
``d2 = reference``. Both fields must have the same shape.
Useful for component separation (e.g. CMB × dust template).
nstep : int, optional
Number of resolution levels in the coarse-to-fine schedule.
Default is ``4``.
The algorithm downsamples ``image_target`` by successive factors of 2
to build a resolution pyramid, then optimises from coarsest to finest,
using each result as the warm start for the next level. This
dramatically improves convergence for large maps.
+---------+----------------------------------------------+
| nstep | Resolutions visited (2D, 256 × 256 target) |
+=========+==============================================+
| 1 | 256 × 256 only |
+---------+----------------------------------------------+
| 2 | 128 × 128 → 256 × 256 |
+---------+----------------------------------------------+
| 3 | 64 × 64 → 128 × 128 → 256 × 256 |
+---------+----------------------------------------------+
| 4 | 32 → 64 → 128 → 256 |
+---------+----------------------------------------------+
Capped automatically when ``nstep > jmax - 1`` (map too small for
that many downsampling steps).
seed : int, optional
Random seed for the Gaussian white-noise initial condition used at
the coarsest resolution level. Default is ``1234``. Has no effect
when ``input_image`` is provided. Change to generate independent
realisations::
results = [scat_op.synthesis(xnorm, seed=s) for s in range(8)]
Jmax : int or None, optional
Maximum wavelet scale index included in the loss. ``None`` (default)
uses all scales available for the map size. Decremented by 1 at each
coarser resolution level during the multi-resolution schedule.
Reduce to constrain only small-scale statistics::
result = scat_op.synthesis(xnorm, Jmax=4) # ignore large scales
edge : bool, optional
Enable edge-aware boundary handling for non-periodic maps.
Default is ``False`` (periodic/full-sphere boundaries). Set to
``True`` for rectangular images or partial-sky patches.
Activated automatically when ``in_mask`` is provided.
to_gaussian : bool, optional
Gaussianise ``image_target`` before computing target statistics, then
invert at the end. Default is ``False``. Useful for highly
non-Gaussian fields (log-normal distributions, etc.) to decouple
the histogram constraint from the scattering statistics.
use_variance : bool, optional
Weight each loss term by the inverse variance of the corresponding
coefficient in ``image_target``. Default is ``True``.
- ``True``: loss is scale-invariant — large-scale and small-scale
coefficients contribute equally regardless of amplitude.
- ``False``: uniform weighting, dominated by the largest-amplitude
statistics (usually coarse scales).
synthesised_N : int, optional
Number of independent synthetic maps to produce simultaneously.
Default is ``1``. All ``synthesised_N`` maps share the same loss,
which amortises the per-iteration GPU cost. The output gains a
leading batch dimension::
result = scat_op.synthesis(xnorm, synthesised_N=4)
# result.shape == (4, H, W) for 2D
input_image : array-like or None, optional
Warm-start initial condition. When ``None`` (default), starts from
Gaussian white noise. When provided, ``input_image`` is downsampled
to each resolution level and used as the initial guess at the
coarsest level. Useful for iterative refinement or for injecting
a prior::
result_v1 = scat_op.synthesis(xnorm, NUM_EPOCHS=100)
result_v2 = scat_op.synthesis(xnorm, NUM_EPOCHS=500,
input_image=result_v1)
grd_mask : array-like or None, optional
Binary mask selecting which pixels the optimiser is allowed to
update. Pixels where ``grd_mask = 0`` are frozen to their values
in ``image_target``; only pixels where ``grd_mask = 1`` move.
Same shape as ``image_target``. Downsampled at each resolution
level. Useful for **inpainting** (reconstruct missing regions
while keeping observed pixels fixed)::
grd_mask = np.zeros_like(xnorm)
grd_mask[missing_pixels] = 1.0
result = scat_op.synthesis(xnorm, grd_mask=grd_mask)
in_mask : array-like or None, optional
Binary mask marking **invalid pixels in the input data** that should
not contribute to the reference statistics. Same shape as
``image_target``. Pixels with ``in_mask = 0`` are excluded from
the scattering-covariance computation at every scale level (the mask
is downsampled along with the map). Setting ``in_mask`` also
enables ``edge=True`` internally. Useful for partial-sky CMB maps
or images with missing regions::
mask = np.ones_like(xnorm)
mask[bad_pixels] = 0.0
result = scat_op.synthesis(xnorm, in_mask=mask)
iso_ang : bool, optional
Use isotropically averaged statistics in the loss. Default is
``False``.
- ``False``: the full oriented S1–S4 statistics are used. The
synthesised field can reproduce anisotropic structures (filaments,
oriented textures).
- ``True``: statistics are collapsed to their rotationally invariant
content via :meth:`iso_mean` before computing the loss. Reduces
the number of constraints and accelerates convergence for
statistically isotropic fields (e.g. CMB).
Angular-reduction summary with ``iso_ang=True``:
============ ==================== =======================
Statistic Input shape Output shape
============ ==================== =======================
S1, S2 ``[…, L]`` ``[…]`` (mean over L)
S3, S3P ``[…, L, L]`` ``[…, L]`` (by Δl = l₂−l₁)
S4 ``[…, L, L, L]`` ``[…, L, L]``
============ ==================== =======================
.. note::
``iso_ang`` and ``fft_ang`` should not be used together.
``iso_ang`` is the harder reduction (mean only);
``fft_ang`` is softer and keeps angular variation.
fft_ang : bool, optional
Use Fourier-compressed angular statistics in the loss. Default is
``False``. Softer alternative to ``iso_ang``: instead of collapsing
each orientation axis to a single mean, keeps the first ``fft_nharm``
Fourier harmonics along each axis.
With ``fft_nharm=1, fft_imaginary=True`` (defaults), each orientation
axis L is projected to 3 coefficients:
- index 0 — DC (mean, same as ``iso_ang``);
- index 1 — cosine of first harmonic;
- index 2 — sine of first harmonic.
The amplitude :math:`A_1 = \\sqrt{c_1^2 + s_1^2}` is
**rotation-invariant**: it does not depend on the absolute
orientation of the image.
Shape reduction:
============ ==================== ==========================
Statistic Input shape Output shape (nharm=1)
============ ==================== ==========================
S1, S2 ``[…, L]`` ``[…, 3]``
S3, S3P ``[…, L, L]`` ``[…, 3, 3]``
S4 ``[…, L, L, L]`` ``[…, 3, 3, 3]``
============ ==================== ==========================
For S3/S4 the projection is the tensor product of independent 1-D
Fourier projections on each orientation axis::
result = scat_op.synthesis(xnorm, fft_ang=True, NUM_EPOCHS=300)
fft_nharm : int, optional
Number of Fourier harmonics to keep beyond the DC term when
``fft_ang=True``. Default is ``1``. The number of output
coefficients per orientation axis is ``1 + 2*fft_nharm`` (with
``fft_imaginary=True``) or ``1 + fft_nharm`` (with
``fft_imaginary=False``).
fft_imaginary : bool, optional
Whether to include both cosine **and** sine components when
``fft_ang=True``. Default is ``True``.
- ``True`` *(recommended)*: output per axis = ``1 + 2*fft_nharm``
coefficients ``[DC, cos₁, sin₁, cos₂, sin₂, …]``.
The harmonic amplitude :math:`\\sqrt{c_k^2 + s_k^2}` is
independent of the image orientation.
- ``False``: output per axis = ``1 + fft_nharm`` coefficients
``[DC, cos₁, cos₂, …]``. A field oriented at 90° (zero of
cosine) gives a near-zero first-harmonic coefficient even if
strongly anisotropic — use only if orientation is fixed.
EVAL_FREQUENCY : int, optional
Print the current loss every N L-BFGS-B iterations. Default is
``100``.
NUM_EPOCHS : int, optional
Maximum number of L-BFGS-B iterations **per resolution level**.
Default is ``300``. The optimiser may stop earlier when its
convergence criterion is met. Total wall-clock time scales as
``nstep × NUM_EPOCHS``.
scat_cov_method : str, optional
Internal method for computing scattering covariances. Default is
``'eval'`` (recommended), which uses :meth:`eval` with
``norm='auto'`` and caches the normalisation after the first call.
Any other value falls back to the legacy ``scattering_cov()`` path
(2D only, kept for backward compatibility).
n_up : int, optional
Number of **extra upsampling steps** beyond the target size, keeping
the same ``Jmax``. Default is ``0`` (no extra steps).
With ``n_up=1``, after completing synthesis at N × N, the algorithm
continues at 2N × 2N using the same scattering statistics as the
N × N target. The result is a field that locally matches the
statistics of the original target, embedded in a larger canvas.
Only available for 2D maps.
+--------+-----------------------------+
| n_up | Output size (N × N target) |
+========+=============================+
| 0 | N × N (standard) |
+--------+-----------------------------+
| 1 | 2N × 2N |
+--------+-----------------------------+
| 2 | 4N × 4N |
+--------+-----------------------------+
The ``Jmax`` used during n_up steps is pinned to the value
effective for the N × N target so that wavelet filters and the norm
cache remain consistent::
result = scat_op.synthesis(xnorm_256, nstep=3, n_up=1)
# result.shape == (512, 512)
Returns
-------
numpy.ndarray
Synthesised field.
- Shape equals ``image_target.shape`` when ``synthesised_N=1`` and
``n_up=0``.
- When ``synthesised_N > 1``, a leading dimension of size
``synthesised_N`` is added.
- When ``n_up > 0`` (2D only), spatial dimensions are multiplied by
``2**n_up``.
Examples
--------
**Minimal 2D synthesis**::
import foscat.scat_cov2D as sc
import numpy as np
scat_op = sc.funct(NORIENT=4)
xnorm = (image - image.mean()) / image.std()
result = scat_op.synthesis(xnorm, seed=10, nstep=3, NUM_EPOCHS=300)
**Batch synthesis with mask**::
mask = np.ones_like(xnorm)
mask[invalid] = 0.0
results = scat_op.synthesis(
xnorm,
in_mask=mask,
synthesised_N=4,
nstep=3,
iso_ang=True,
NUM_EPOCHS=500,
)
# results.shape == (4, H, W)
**Inpainting — reconstruct missing pixels, freeze observed ones**::
grd_mask = np.zeros_like(xnorm)
grd_mask[hole_pixels] = 1.0
result = scat_op.synthesis(
xnorm, grd_mask=grd_mask, nstep=3, edge=True, NUM_EPOCHS=500
)
**Upsampled synthesis (n_up)**::
result = scat_op.synthesis(xnorm_256, nstep=3, n_up=1, NUM_EPOCHS=300)
# result.shape == (512, 512)
**Cross-covariance — component separation**::
result = scat_op.synthesis(
cmb_estimate,
reference=dust_template,
nstep=3,
iso_ang=True,
NUM_EPOCHS=300,
)
**Isotropic Gaussianised synthesis**::
result = scat_op.synthesis(
image_target,
to_gaussian=True,
iso_ang=True,
nstep=4,
NUM_EPOCHS=500,
)
Notes
-----
The optimiser is ``scipy.optimize.fmin_l_bfgs_b``, a quasi-Newton
method that uses a low-memory approximation of the inverse Hessian.
It converges in tens to a few hundred iterations for typical
scattering-covariance losses, far fewer than first-order methods.
For large maps (nside ≥ 256 for HEALPix, H/W ≥ 256 for 2D) always
use ``nstep ≥ 3``. Starting directly at full resolution wastes
compute and converges poorly.
"""
import time
import foscat.Synthesis as synthe
# fft_ang is now supported natively inside scattering_cov (the S tensors
# are projected onto Fourier harmonics before for_synthesis is built).
# No fallback to 'eval' is needed.
l_edge = edge
if in_mask is not None:
l_edge = True
if edge:
self.purge_edge_mask()
def The_loss_ref_image(u, scat_operator, args):
input_image = args[0]
mask = args[1]
loss = 1e-3 * scat_operator.backend.bk_reduce_mean(
scat_operator.backend.bk_square(mask * (input_image - u))
)
return loss
def The_loss(u, scat_operator, args):
ref = args[0]
sref = args[1]
use_v = args[2]
ljmax = args[3]
# compute scattering covariance of the current synthetised map called u
if use_v:
learn = scat_operator.reduce_mean_batch(
scat_operator.scattering_cov(
u,
edge=l_edge,
Jmax=ljmax,
ref_sigma=sref,
use_ref=True,
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
)
else:
learn = scat_operator.reduce_mean_batch(
scat_operator.scattering_cov(
u,
edge=l_edge,
Jmax=ljmax,
use_ref=True,
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
)
# make the difference withe the reference coordinates
loss = scat_operator.backend.bk_reduce_mean(
scat_operator.backend.bk_square(learn - ref)
)
return loss
def The_lossH(u, scat_operator, args):
ref = args[0]
sref = args[1]
use_v = args[2]
ljmax = args[3]
learn = scat_operator.eval(
u,
Jmax=ljmax,
norm='auto',
)
if iso_ang:
learn = learn.iso_mean()
if fft_ang:
learn = learn.fft_ang(nharm=fft_nharm, imaginary=fft_imaginary)
if synthesised_N>1:
learn = scat_operator.reduce_mean_batch(learn)
# compute scattering covariance of the current synthetised map called u
if use_v:
loss = scat_operator.reduce_distance(learn,ref,sigma=sref)
else:
loss = scat_operator.reduce_distance(learn,ref)
return loss
def The_lossX(u, scat_operator, args):
ref = args[0]
sref = args[1]
use_v = args[2]
im2 = args[3]
ljmax = args[4]
# compute scattering covariance of the current synthetised map called u
if use_v:
learn = scat_operator.reduce_mean_batch(
scat_operator.scattering_cov(
u,
data2=im2,
edge=l_edge,
Jmax=ljmax,
ref_sigma=sref,
use_ref=True,
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
)
else:
learn = scat_operator.reduce_mean_batch(
scat_operator.scattering_cov(
u,
data2=im2,
edge=l_edge,
Jmax=ljmax,
use_ref=True,
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
)
# make the difference withe the reference coordinates
loss = scat_operator.backend.bk_reduce_mean(
scat_operator.backend.bk_square(learn - ref)
)
return loss
if to_gaussian:
# Change the data histogram to gaussian distribution
im_target = self.to_gaussian(image_target, in_mask=in_mask)
else:
im_target = image_target
axis = len(im_target.shape) - 1
if self.use_2D:
axis -= 1
if axis == 0:
im_target = self.backend.bk_expand_dims(im_target, 0)
# compute the number of possible steps
if self.use_2D:
jmax = int(
np.min([np.log(im_target.shape[1]), np.log(im_target.shape[2])])
/ np.log(2)
)
elif self.use_1D:
jmax = int(np.log(im_target.shape[1]) / np.log(2))
else:
jmax = int((np.log(im_target.shape[1] // 12) / np.log(2)) / 2)
nside = 2**jmax
if nstep > jmax - 1:
nstep = jmax - 1
t1 = time.time()
tmp = {}
l_grd_mask = {}
l_in_mask = {}
l_input_image = {}
l_ref = {}
l_jmax = {}
tmp[nstep - 1] = self.backend.bk_cast(im_target)
l_jmax[nstep - 1] = Jmax
if reference is not None:
l_ref[nstep - 1] = self.backend.bk_cast(reference)
else:
l_ref[nstep - 1] = None
if grd_mask is not None:
l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
else:
l_grd_mask[nstep - 1] = None
if in_mask is not None:
l_in_mask[nstep - 1] = in_mask
else:
l_in_mask[nstep - 1] = None
if input_image is not None:
l_input_image[nstep - 1] = input_image
for ell in range(nstep - 2, -1, -1):
tmp[ell], _ = self.ud_grade_2(tmp[ell + 1], axis=1)
if grd_mask is not None:
l_grd_mask[ell], _ = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
else:
l_grd_mask[ell] = None
if in_mask is not None:
l_in_mask[ell], _ = self.ud_grade_2(l_in_mask[ell + 1])
l_in_mask[ell] = self.backend.to_numpy(l_in_mask[ell])
else:
l_in_mask[ell] = None
if input_image is not None:
l_input_image[ell], _ = self.ud_grade_2(l_input_image[ell + 1], axis=1)
if reference is not None:
l_ref[ell], _ = self.ud_grade_2(l_ref[ell + 1], axis=1)
else:
l_ref[ell] = None
if l_jmax[ell + 1] is None:
l_jmax[ell] = None
else:
l_jmax[ell] = l_jmax[ell + 1] - 1
if not self.use_2D and not self.use_1D:
l_nside = nside // (2 ** (nstep - 1))
for k in range(nstep):
if k == 0:
if input_image is None:
np.random.seed(seed)
if self.use_2D:
imap = self.backend.bk_cast(
np.random.randn(
synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
)
)
else:
imap = self.backend.bk_cast(
np.random.randn(synthesised_N, tmp[k].shape[1])
)
else:
if self.use_2D:
imap = self.backend.bk_reshape(
self.backend.bk_tile(
self.backend.bk_cast(l_input_image[k].flatten()),
synthesised_N,
),
[synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
)
else:
imap = self.backend.bk_reshape(
self.backend.bk_tile(
self.backend.bk_cast(l_input_image[k].flatten()),
synthesised_N,
),
[synthesised_N, tmp[k].shape[1]],
)
else:
# Increase the resolution between each step
if self.use_2D:
imap = self.up_grade(
self.backend.bk_cast(omap),
imap.shape[1] * 2,
axis=-2,
nouty=imap.shape[2] * 2
)
elif self.use_1D:
imap = self.up_grade(self.backend.bk_cast(omap), imap.shape[1] * 2)
else:
imap = self.up_grade(self.backend.bk_cast(omap), l_nside)
if grd_mask is not None:
imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
if self.use_2D and scat_cov_method!='eval':
# compute the coefficients for the target image
if use_variance:
ref, sref = self.scattering_cov(
tmp[k],
data2=l_ref[k],
get_variance=True,
edge=l_edge,
Jmax=l_jmax[k],
in_mask=l_in_mask[k],
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
else:
ref = self.scattering_cov(
tmp[k],
data2=l_ref[k],
in_mask=l_in_mask[k],
edge=l_edge,
Jmax=l_jmax[k],
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
sref = ref
else:
self.clean_norm()
ref = self.eval(
tmp[k],
image2=l_ref[k],
mask=l_in_mask[k],
Jmax=l_jmax[k],
norm='auto'
)
# compute the coefficients for the target image
if use_variance:
ref, sref = self.eval(
tmp[k],
image2=l_ref[k],
mask=l_in_mask[k],
Jmax=l_jmax[k],
calc_var=True,
norm='auto'
)
else:
ref = self.eval(
tmp[k],
image2=l_ref[k],
mask=l_in_mask[k],
Jmax=l_jmax[k],
norm='auto'
)
sref = ref
if iso_ang:
ref = ref.iso_mean()
sref = sref.iso_mean()
if fft_ang:
ref = ref.fft_ang(nharm=fft_nharm, imaginary=fft_imaginary)
# sref contains per-coefficient standard deviations: propagate with A²
# (not A) to avoid near-zero sigma for harmonic components
sref = sref.fft_ang_sigma(nharm=fft_nharm, imaginary=fft_imaginary)
# compute the mean of the population does nothing if only one map is given
ref = self.reduce_mean_batch(ref)
if l_in_mask[k] is not None:
self.purge_edge_mask()
if l_ref[k] is None:
if self.use_2D and scat_cov_method!='eval':
# define a loss to minimize
loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
else:
loss = synthe.Loss(The_lossH, self, ref, sref, use_variance, l_jmax[k])
else:
# define a loss to minimize
if self.use_2D and scat_cov_method!='eval':
loss = synthe.Loss(
The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
)
else:
loss = synthe.Loss(
The_lossXH, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
)
if input_image is not None:
# define a loss to minimize
loss_input = synthe.Loss(
The_loss_ref_image,
self,
self.backend.bk_cast(l_input_image[k]),
self.backend.bk_cast(l_in_mask[k]),
)
sy = synthe.Synthesis([loss]) # ,loss_input])
else:
sy = synthe.Synthesis([loss])
# initialize the synthesised map
if self.use_2D:
print("Synthesis scale [ %d x %d ]" % (imap.shape[1], imap.shape[2]))
elif self.use_1D:
print("Synthesis scale [ %d ]" % (imap.shape[1]))
else:
print("Synthesis scale nside=%d" % (l_nside))
l_nside *= 2
# do the minimization
omap = sy.run(
imap,
EVAL_FREQUENCY=EVAL_FREQUENCY,
NUM_EPOCHS=NUM_EPOCHS,
grd_mask=l_grd_mask[k],
)
# When nstep=0 the main loop is skipped entirely. Initialise omap with
# noise (or the provided input_image) at target resolution so that the
# n_up loop can start directly at the upsampled size.
if nstep == 0 and n_up > 0 and self.use_2D:
np.random.seed(seed)
_tgt = tmp[nstep - 1]
if input_image is None:
omap = self.backend.bk_cast(
np.random.randn(
synthesised_N, _tgt.shape[1], _tgt.shape[2]
)
)
else:
omap = self.backend.bk_reshape(
self.backend.bk_tile(
self.backend.bk_cast(
self.backend.bk_cast(input_image).flatten()
),
synthesised_N,
),
[synthesised_N, _tgt.shape[1], _tgt.shape[2]],
)
if n_up > 0 and self.use_2D:
# Extra upsampling steps: synthesise at 2^n_up times the target size
# while keeping the SAME Jmax as the original target (same wavelet scales,
# same P1_dic keys). When Jmax=None, eval derives it from the map size,
# so on a 2N×2N map it would add one extra scale not present in P1_dic
# → KeyError. We therefore pin Jmax to the value that was effective for
# the NxN target and pass it explicitly to The_lossH.
target_shape = tmp[nstep - 1]
target_side = np.min([target_shape.shape[1], target_shape.shape[2]])
n_up_jmax = int(np.log(target_side - self.KERNELSZ) / np.log(2))
if self.KERNELSZ > 3:
n_up_jmax -= 1
# Respect an explicit user-supplied Jmax if it is more restrictive
if l_jmax[nstep - 1] is not None:
n_up_jmax = min(n_up_jmax, l_jmax[nstep - 1])
if self.use_2D and scat_cov_method != 'eval':
# Recompute the reference on the ORIGINAL TARGET with exactly n_up_jmax.
# The existing `ref` was built with Jmax=l_jmax[nstep-1] (often None),
# which may derive a different internal J than what scattering_cov produces
# for the upsampled image with Jmax=n_up_jmax. Building ref_up with the
# same Jmax guarantees identical tensor shapes for (ref_up, learn).
if use_variance:
ref_up, sref_up = self.scattering_cov(
tmp[nstep - 1],
data2=l_ref[nstep - 1],
get_variance=True,
edge=l_edge,
Jmax=n_up_jmax,
in_mask=l_in_mask[nstep - 1],
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
else:
ref_up = self.scattering_cov(
tmp[nstep - 1],
data2=l_ref[nstep - 1],
in_mask=l_in_mask[nstep - 1],
edge=l_edge,
Jmax=n_up_jmax,
iso_ang=iso_ang,
fft_ang=fft_ang,
fft_nharm=fft_nharm,
fft_imaginary=fft_imaginary,
)
sref_up = ref_up
ref_up = self.reduce_mean_batch(ref_up)
else:
if nstep > 0:
# ref was computed by the main synthesis loop above
ref_up, sref_up = ref, sref
else:
# nstep=0: main loop was skipped — build ref now via eval
self.clean_norm()
if use_variance:
ref_up, sref_up = self.eval(
tmp[nstep - 1],
image2=l_ref[nstep - 1],
mask=l_in_mask[nstep - 1],
Jmax=n_up_jmax,
norm='auto',
calc_var=True,
)
else:
ref_up = self.eval(
tmp[nstep - 1],
image2=l_ref[nstep - 1],
mask=l_in_mask[nstep - 1],
Jmax=n_up_jmax,
norm='auto',
)
sref_up = ref_up
if iso_ang:
ref_up = ref_up.iso_mean()
sref_up = sref_up.iso_mean()
if fft_ang:
ref_up = ref_up.fft_ang(nharm=fft_nharm, imaginary=fft_imaginary)
sref_up = sref_up.fft_ang_sigma(
nharm=fft_nharm, imaginary=fft_imaginary
)
ref_up = self.reduce_mean_batch(ref_up)
for up in range(n_up):
# Upsample current result by factor 2 in each spatial dimension
imap = self.up_grade(
self.backend.bk_cast(omap),
omap.shape[1] * 2,
axis=-2,
nouty=omap.shape[2] * 2,
)
if self.use_2D and scat_cov_method != 'eval':
loss_up = synthe.Loss(
The_loss, self, ref_up, sref_up, use_variance, n_up_jmax
)
else:
loss_up = synthe.Loss(
The_lossH, self, ref_up, sref_up, use_variance, n_up_jmax
)
sy_up = synthe.Synthesis([loss_up])
print(
"Synthesis scale [ %d x %d ] (n_up step %d/%d)"
% (imap.shape[1], imap.shape[2], up + 1, n_up)
)
omap = sy_up.run(
imap,
EVAL_FREQUENCY=EVAL_FREQUENCY,
NUM_EPOCHS=NUM_EPOCHS,
)
if not self.use_2D:
self.clean_norm()
t2 = time.time()
print("Total computation %.2fs" % (t2 - t1))
if to_gaussian:
omap = self.from_gaussian(omap)
if axis == 0 and synthesised_N == 1:
return omap[0]
else:
return omap
[docs]
def to_numpy(self, x):
return self.backend.to_numpy(x)