import pickle
import numpy as np
import foscat.scat_cov as sc
[docs]
class CNN:
def __init__(
self,
nparam=1,
KERNELSZ=3,
NORIENT=4,
chanlist=[],
in_nside=1,
n_chan_in=1,
SEED=1234,
filename=None,
):
if filename is not None:
outlist = pickle.load(open("%s.pkl" % (filename), "rb"))
self.scat_operator = sc.funct(KERNELSZ=outlist[3], all_type=outlist[7])
self.KERNELSZ = self.scat_operator.KERNELSZ
self.all_type = self.scat_operator.all_type
self.npar = outlist[2]
self.nscale = outlist[5]
self.chanlist = outlist[0]
self.in_nside = outlist[4]
self.nbatch = outlist[1]
self.n_chan_in = outlist[8]
self.NORIENT = outlist[9]
self.x = self.scat_operator.backend.bk_cast(outlist[6])
self.out_nside = self.in_nside // (2**self.nscale)
else:
self.nscale = len(chanlist)-1
self.npar = nparam
self.n_chan_in = n_chan_in
self.scat_operator = scat_operator
if self.scat_operator is None:
self.scat_operator = sc.funct(
KERNELSZ=KERNELSZ,
NORIENT=NORIENT)
self.chanlist = chanlist
self.KERNELSZ = self.scat_operator.KERNELSZ
self.NORIENT = self.scat_operator.NORIENT
self.all_type = self.scat_operator.all_type
self.in_nside = in_nside
self.out_nside = self.in_nside // (2**self.nscale)
self.backend = self.scat_operator.backend
np.random.seed(SEED)
self.x = self.scat_operator.backend.bk_cast(
np.random.rand(self.get_number_of_weights())
/ (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT)
)
[docs]
def save(self, filename):
outlist = [
self.chanlist,
self.nbatch,
self.npar,
self.KERNELSZ,
self.in_nside,
self.nscale,
self.get_weights().numpy(),
self.all_type,
self.n_chan_in,
self.NORIENT,
]
myout = open("%s.pkl" % (filename), "wb")
pickle.dump(outlist, myout)
myout.close()
[docs]
def get_number_of_weights(self):
totnchan = 0
for i in range(self.nscale):
totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
return (
self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
+ totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)
+ self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
)
[docs]
def set_weights(self, x):
self.x = x
[docs]
def get_weights(self):
return self.x
[docs]
def eval(self, im, indices=None, weights=None):
x = self.x
ww = self.backend.bk_reshape(
x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]],
[self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0]],
)
nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
im = self.scat_operator.healpix_layer(im, ww)
im = self.backend.bk_relu(im)
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
for k in range(self.nscale):
ww = self.scat_operator.backend.bk_reshape(
x[
nn : nn
+ self.KERNELSZ
* (self.KERNELSZ//2+1)
* self.chanlist[k]
* self.chanlist[k + 1]
],
[self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]],
)
nn = (
nn
+ self.KERNELSZ
* (self.KERNELSZ//2)
* self.chanlist[k]
* self.chanlist[k + 1]
)
if indices is None:
im = self.scat_operator.healpix_layer(im, ww)
else:
im = self.scat_operator.healpix_layer(
im, ww, indices=indices[k], weights=weights[k]
)
im = self.scat_operator.backend.bk_relu(im)
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
ww = self.scat_operator.backend.bk_reshape(
x[
nn : nn
+ self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
],
[12 * self.out_nside**2 * self.chanlist[self.nscale], self.npar],
)
im = self.scat_operator.backend.bk_matmul(
self.scat_operator.backend.bk_reshape(
im, [im.shape[0], im.shape[1] * im.shape[2]]
),
ww,
)
#im = self.scat_operator.backend.bk_reshape(im, [self.npar])
im = self.scat_operator.backend.bk_relu(im)
return im