Source code for foscat.heal_NN

import pickle

import numpy as np

import foscat.scat_cov as sc

[docs] class CNN: def __init__( self, nparam=1, KERNELSZ=3, NORIENT=4, chanlist=[], in_nside=1, n_chan_in=1, SEED=1234, add_undersample_data=False, all_type='float32', filename=None, scat_operator=None, BACKEND='tensorflow' ): if filename is not None: outlist = pickle.load(open("%s.pkl" % (filename), "rb")) self.scat_operator = sc.funct(KERNELSZ=outlist[3], NORIENT= outlist[9], all_type=outlist[7]) self.KERNELSZ = self.scat_operator.KERNELSZ self.all_type = self.scat_operator.all_type self.npar = outlist[2] self.nscale = outlist[5] self.chanlist = outlist[0] self.in_nside = outlist[4] self.nbatch = outlist[1] self.n_chan_in = outlist[8] self.NORIENT = outlist[9] self.x = self.scat_operator.backend.bk_cast(outlist[6]) self.out_nside = self.in_nside // (2**(self.nscale+1)) else: self.add_undersample_data=add_undersample_data self.nscale = len(chanlist)-1 self.npar = nparam self.n_chan_in = n_chan_in if scat_operator is None: self.scat_operator = sc.funct( KERNELSZ=KERNELSZ, NORIENT=NORIENT, all_type=all_type) else: self.scat_operator = scat_operator self.chanlist = chanlist self.KERNELSZ = self.scat_operator.KERNELSZ self.NORIENT = self.scat_operator.NORIENT self.all_type = self.scat_operator.all_type self.in_nside = in_nside self.out_nside = self.in_nside // (2**(self.nscale+1)) self.backend = self.scat_operator.backend np.random.seed(SEED) self.x = self.scat_operator.backend.bk_cast( np.random.rand(self.get_number_of_weights()) ) self.mpi_size = self.scat_operator.mpi_size self.mpi_rank = self.scat_operator.mpi_rank self.BACKEND = BACKEND self.gpupos = self.scat_operator.gpupos self.ngpu = self.scat_operator.ngpu self.gpulist = self.scat_operator.gpulist
[docs] def save(self, filename): outlist = [ self.chanlist, self.nbatch, self.npar, self.KERNELSZ, self.in_nside, self.nscale, self.get_weights().numpy(), self.all_type, self.n_chan_in, self.NORIENT, ] myout = open("%s.pkl" % (filename), "wb") pickle.dump(outlist, myout) myout.close()
[docs] def get_number_of_weights(self): totnchan = 0 for i in range(self.nscale): totnchan = totnchan + (self.chanlist[i]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[i + 1] return ( self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT + self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT )
[docs] def set_weights(self, x): self.x = x
[docs] def get_weights(self): return self.x
[docs] def init_wave(self): w0=np.zeros([self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0], self.NORIENT]) if self.KERNELSZ==3: w0[:,0]=-0.2 w0[:,1]=-0.5 w0[:,2]=-0.2 w0[:,3]=0.2 w0[:,4]=0.5 w0[:,5]=0.2 if self.KERNELSZ==5: w0[:,0]=-0.1 w0[:,1]=-0.2 w0[:,2]=-0.5 w0[:,3]=-0.2 w0[:,4]=-0.1 w0[:,10]=0.1 w0[:,11]=0.2 w0[:,12]=0.5 w0[:,13]=0.2 w0[:,14]=0.1 a=2*np.sqrt(6/(12 * self.out_nside**2 * self.chanlist[self.nscale]*self.NORIENT*self.npar)) x=(np.random.rand(self.get_number_of_weights())-0.5)*a w0=w0.flatten() x[0:w0.shape[0]]=w0 nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT for k in range(self.nscale): ww = np.zeros([self.chanlist[k], self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1], self.NORIENT]) if self.KERNELSZ==3: ww[:,:,0]=-0.2 ww[:,:,1]=-0.5 ww[:,:,2]=-0.2 ww[:,:,3]=0.2 ww[:,:,4]=0.5 ww[:,:,5]=0.2 if self.KERNELSZ==5: ww[:,:,0]=-0.1 ww[:,:,1]=-0.2 ww[:,:,2]=-0.5 ww[:,:,3]=-0.2 ww[:,:,4]=-0.1 ww[:,:,10]=0.1 ww[:,:,11]=0.2 ww[:,:,12]=0.5 ww[:,:,13]=0.2 ww[:,:,14]=0.1 x[nn : nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT*self.NORIENT * self.chanlist[k] * self.chanlist[k + 1] ]=ww.flatten() nn = nn + (self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT*self.NORIENT * self.chanlist[k] * self.chanlist[k + 1]) self.x = self.scat_operator.backend.bk_cast(x)
[docs] def calc_matrix_first_layer(self,noise_map): # Circular shift via permutation matrix def circ_shift_matrix(N,k): return np.roll(np.eye(N), shift=-k, axis=1) im=self.scat_operator.convol(noise_map) mm=np.mean(abs(im.cpu().numpy()),0) Norient=mm.shape[1] xx=np.cos(np.arange(Norient)/Norient*2*np.pi) yy=np.sin(np.arange(Norient)/Norient*2*np.pi) a=np.sum(mm*xx[None,:,None],1) b=np.sum(mm*yy[None,:,None],1) o=np.fmod(Norient*np.arctan2(-b,a)/(2*np.pi)+Norient,Norient) xx=np.arange(Norient) alpha = o[:,None,:]-xx[None,:,None] beta = np.fmod(1+o[:,None,:]-xx[None,:,None],Norient) alpha=(1-alpha)*(alpha<1)*(alpha>0)+beta*(beta<1)*(beta>0) m=np.zeros([mm.shape[0],4,4,mm.shape[2]]) for k in range(4): m[:,k,:,:]=np.roll(alpha,k,1) m=np.mean(m,0) return self.scat_operator.backend.bk_cast(m[None,None,...])
[docs] def eval(self, in_im, indices=None, weights=None, out_map=False, first_layer_rot=None, activation='relu'): x = self.x ww = self.backend.bk_reshape( x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT], [self.n_chan_in, 1 , self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0], self.NORIENT], ) nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]*self.NORIENT im = self.scat_operator.healpix_layer(in_im[:,:,None,:], ww) if first_layer_rot is not None: im = self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,1,im.shape[3]]) im = self.backend.bk_reduce_sum(im*first_layer_rot,2) if activation=='relu': im = self.backend.bk_relu(im) elif activation=='abs': im = self.backend.bk_abs(im) im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4) if self.add_undersample_data: l_im=self.backend.bk_repeat(self.backend.bk_reshape(in_im,[in_im.shape[0],in_im.shape[1],1,in_im.shape[2]]),self.NORIENT,2) l_im=self.backend.bk_reduce_sum( self.backend.bk_reshape(l_im,[in_im.shape[0],in_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4) im=self.backend.bk_concat([im,l_im],1) if out_map: return im for k in range(self.nscale): ww = self.scat_operator.backend.bk_reshape( x[ nn : nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT*self.NORIENT * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[k + 1] ], [self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in, self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1], self.NORIENT], ) nn = ( nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT*self.NORIENT * (self.chanlist[k]+int(self.add_undersample_data)* self.n_chan_in) * self.chanlist[k + 1] ) if indices is None: im = self.scat_operator.healpix_layer(im, ww) else: im = self.scat_operator.healpix_layer( im, ww, indices=indices[k], weights=weights[k] ) if activation=='relu': im = self.backend.bk_relu(im) elif activation=='abs': im = self.backend.bk_abs(im) im = self.backend.bk_reduce_sum(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],self.NORIENT,im.shape[3]//4,4]),4) if self.add_undersample_data: l_im=self.backend.bk_reduce_sum( self.backend.bk_reshape(l_im,[l_im.shape[0],l_im.shape[1],self.NORIENT,l_im.shape[3]//4,4]), 4) im=self.backend.bk_concat([im,l_im],1) ww = self.scat_operator.backend.bk_reshape( x[ nn : nn + self.npar * 12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT ], [12 * self.out_nside**2 * (self.chanlist[self.nscale]+int(self.add_undersample_data)* self.n_chan_in)*self.NORIENT, self.npar], ) im = self.scat_operator.backend.bk_matmul( self.scat_operator.backend.bk_reshape( im, [im.shape[0], im.shape[1] * im.shape[2] * im.shape[3]] ), ww, ) #im = self.scat_operator.backend.bk_reshape(im, [self.npar]) #im = self.scat_operator.backend.bk_relu(im) return im
[docs] class GCNN: def __init__( self, nparam=1, KERNELSZ=3, NORIENT=4, chanlist=[], in_nside=1, out_chan=1, SEED=1234, all_type='float32', filename=None, scat_operator=None, BACKEND='tensorflow' ): if filename is not None: outlist = pickle.load(open("%s.pkl" % (filename), "rb")) self.scat_operator = sc.funct(KERNELSZ=outlist[3],NORIENT=outlist[8], all_type=outlist[7]) self.KERNELSZ = self.scat_operator.KERNELSZ self.all_type = self.scat_operator.all_type self.npar = outlist[2] self.nscale = outlist[5] self.chanlist = outlist[0] self.in_nside = outlist[4] self.nbatch = outlist[1] self.NORIENT = outlist[8] self.out_chan = outlist[9] self.x = self.scat_operator.backend.bk_cast(outlist[6]) self.out_nside = self.in_nside // (2**self.nscale) else: self.nscale = len(chanlist)-1 self.npar = nparam if scat_operator is None: self.scat_operator = sc.funct( KERNELSZ=KERNELSZ, NORIENT=NORIENT, all_type=all_type) else: self.scat_operator = scat_operator self.chanlist = chanlist self.KERNELSZ = self.scat_operator.KERNELSZ self.NORIENT = self.scat_operator.NORIENT self.all_type = self.scat_operator.all_type self.in_nside = in_nside self.out_nside = self.in_nside * (2**self.nscale) self.out_chan = out_chan self.backend = self.scat_operator.backend np.random.seed(SEED) self.x = self.scat_operator.backend.bk_cast( np.random.rand(self.get_number_of_weights()) ) self.mpi_size = self.scat_operator.mpi_size self.mpi_rank = self.scat_operator.mpi_rank self.BACKEND = BACKEND self.gpupos = self.scat_operator.gpupos self.ngpu = self.scat_operator.ngpu self.gpulist = self.scat_operator.gpulist
[docs] def save(self, filename): outlist = [ self.chanlist, self.nbatch, self.npar, self.KERNELSZ, self.in_nside, self.nscale, self.get_weights().numpy(), self.all_type, self.NORIENT, self.out_chan ] myout = open("%s.pkl" % (filename), "wb") pickle.dump(outlist, myout) myout.close()
[docs] def get_number_of_weights(self): totnchan = 0 for i in range(self.nscale): totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1] return ( self.npar * 12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT + totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT*self.NORIENT + self.chanlist[-1]*self.out_chan*self.NORIENT )
[docs] def set_weights(self, x): self.x = x
[docs] def get_weights(self): return self.x
[docs] def eval(self, im, indices=None, weights=None): x = self.x ww = self.backend.bk_reshape( x[0:self.npar * 12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT], [self.npar,12 * self.in_nside**2 * self.chanlist[0]*self.NORIENT], ) im = self.scat_operator.backend.bk_matmul(im,ww) im = self.backend.bk_reshape(im,[im.shape[0],self.chanlist[0],self.NORIENT,12 * self.in_nside**2]) nn = self.npar * 12 * self.in_nside**2 * self.chanlist[0] for k in range(self.nscale): im = self.scat_operator.backend.bk_relu(im) im = self.backend.bk_reshape( self.scat_operator.backend.bk_repeat(im,4,axis=-1), [im.shape[0],im.shape[1],self.NORIENT,im.shape[3]*4]) ww = self.scat_operator.backend.bk_reshape( x[ nn : nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT *self.NORIENT * self.chanlist[k] * self.chanlist[k + 1] ], [self.chanlist[k] , self.NORIENT, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1],self.NORIENT], ) nn = ( nn + self.KERNELSZ * (self.KERNELSZ//2+1) * self.NORIENT *self.NORIENT * self.chanlist[k] * self.chanlist[k + 1] ) if indices is None: im = self.scat_operator.healpix_layer(im, ww) else: im = self.scat_operator.healpix_layer( im, ww, indices=indices[k], weights=weights[k] ) ww = self.scat_operator.backend.bk_reshape( x[ nn : nn + self.chanlist[-1]*self.NORIENT * self.out_chan ], [1,self.chanlist[-1],self.NORIENT, self.out_chan, 1], ) im = self.backend.bk_reduce_mean(im[:,:,:,None]*ww,[1,2]) #im = self.scat_operator.backend.bk_relu(im) return im