Source code for foscat.GCNN

import pickle

import numpy as np

import foscat.scat_cov as sc


[docs] class GCNN: def __init__( self, nparam=1, KERNELSZ=3, NORIENT=4, chanlist=[], in_nside=1, SEED=1234, filename=None, ): if filename is not None: outlist = pickle.load(open("%s.pkl" % (filename), "rb")) self.scat_operator = sc.funct(KERNELSZ=outlist[3], all_type=outlist[7]) self.KERNELSZ = self.scat_operator.KERNELSZ self.all_type = self.scat_operator.all_type self.npar = outlist[2] self.nscale = outlist[5] self.chanlist = outlist[0] self.in_nside = outlist[4] self.nbatch = outlist[1] self.NORIENT = outlist[8] self.x = self.scat_operator.backend.bk_cast(outlist[6]) self.out_nside = self.in_nside // (2**self.nscale) else: self.nscale = len(chanlist)-1 self.npar = nparam self.n_chan_in = n_chan_in self.scat_operator = scat_operator if self.scat_operator is None: self.scat_operator = sc.funct( KERNELSZ=KERNELSZ, NORIENT=NORIENT) self.chanlist = chanlist self.KERNELSZ = self.scat_operator.KERNELSZ self.NORIENT = self.scat_operator.NORIENT self.all_type = self.scat_operator.all_type self.in_nside = in_nside self.out_nside = self.in_nside * (2**self.nscale) self.backend = self.scat_operator.backend np.random.seed(SEED) self.x = self.scat_operator.backend.bk_cast( np.random.rand(self.get_number_of_weights()) / (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT) )
[docs] def save(self, filename): outlist = [ self.chanlist, self.nbatch, self.npar, self.KERNELSZ, self.in_nside, self.nscale, self.get_weights().numpy(), self.all_type, self.NORIENT, ] myout = open("%s.pkl" % (filename), "wb") pickle.dump(outlist, myout) myout.close()
[docs] def get_number_of_weights(self): totnchan = 0 for i in range(self.nscale): totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1] return ( self.npar * 12 * self.in_nside**2 * self.chanlist[0] + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1) + self.KERNELSZ * (self.KERNELSZ//2+1) * self.chanlist[nscale] )
[docs] def set_weights(self, x): self.x = x
[docs] def get_weights(self): return self.x
[docs] def eval(self, im, indices=None, weights=None): x = self.x ww = self.backend.bk_reshape( x[0:self.npar * 12 * self.in_nside**2 * self.chanlist[0]], [self.npar,12 * self.in_nside**2 * self.chanlist[0]], ) im = self.scat_operator.backend.bk_matmul(im,ww) im = self.backend.bk_reshape(im,[im.shape[0],self.chanlist[0],12 * self.in_nside**2]) nn = self.npar * 12 * self.in_nside**2 * self.chanlist[0] for k in range(self.nscale): ww = self.scat_operator.backend.bk_reshape( x[ nn : nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.chanlist[k] * self.chanlist[k + 1] ], [self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]], ) nn = ( nn + self.KERNELSZ * (self.KERNELSZ//2) * self.chanlist[k] * self.chanlist[k + 1] ) if indices is None: im = self.scat_operator.healpix_layer(im, ww) else: im = self.scat_operator.healpix_layer( im, ww, indices=indices[k], weights=weights[k] ) im = self.scat_operator.backend.bk_relu(im) im = self.backend.bk_reshape(self.scat_operator.backend.bk_repeat(im,4),[im.shape[0],im.shape[1],im.shape[2]*4]) #im = self.scat_operator.backend.bk_reshape(im, [self.npar]) #im = self.scat_operator.backend.bk_relu(im) return im