Source code for foscat.FoCUS

import os
import sys

import healpy as hp
import numpy as np
import foscat.HealSpline as HS
from scipy.interpolate import griddata
from foscat.SphereDownGeo import SphereDownGeo
from foscat.SphereUpGeo import SphereUpGeo
import torch

TMPFILE_VERSION = "V14_0"


[docs] class FoCUS: def __init__( self, NORIENT=4, LAMBDA=1.2, KERNELSZ=3, slope=1.0, all_type="float32", nstep_max=20, padding="SAME", gpupos=0, mask_thres=None, mask_norm=False, isMPI=False, TEMPLATE_PATH=None, BACKEND="torch", use_2D=False, use_1D=False, return_data=False, DODIV=False, use_median=False, InitWave=None, silent=True, mpi_size=1, mpi_rank=0 ): self.__version__ = "2026.06.1" # P00 coeff for normalization for scat_cov self.TMPFILE_VERSION = TMPFILE_VERSION self.P1_dic = None self.P2_dic = None self.isMPI = isMPI self.mask_thres = mask_thres self.mask_norm = mask_norm self.InitWave = InitWave self.mask_mask = None self.mpi_size = mpi_size self.mpi_rank = mpi_rank self.return_data = return_data self.silent = silent self.use_median = use_median self.kernel_smooth = {} self.padding_smooth = {} self.kernelR_conv = {} self.kernelI_conv = {} self.padding_conv = {} self.down = {} self.up = {} if not self.silent: print("================================================") print(" START FOSCAT CONFIGURATION") print("================================================") sys.stdout.flush() home_dir = os.environ["HOME"] if TEMPLATE_PATH is None: TEMPLATE_PATH=home_dir+"/.FOSCAT/data" self.TEMPLATE_PATH = TEMPLATE_PATH if not os.path.exists(self.TEMPLATE_PATH): if not self.silent: print( "The directory %s to store temporary information for FoCUS does not exist: Try to create it" % (self.TEMPLATE_PATH) ) try: os.system("mkdir -p %s" % (self.TEMPLATE_PATH)) if not self.silent: print("The directory %s is created") except: print("Impossible to create the directory %s" % (self.TEMPLATE_PATH)) return None self.number_of_loss = 0 self.history = np.zeros([10]) self.nlog = 0 self.padding = padding self.use_2D = use_2D self.use_1D = use_1D if isMPI: from mpi4py import MPI self.comm = MPI.COMM_WORLD if all_type == "float32": self.MPI_ALL_TYPE = MPI.FLOAT else: self.MPI_ALL_TYPE = MPI.DOUBLE else: self.MPI_ALL_TYPE = None self.all_type = all_type self.BACKEND = BACKEND if BACKEND == "torch": from foscat.BkTorch import BkTorch self.backend = BkTorch( all_type=all_type, mpi_rank=mpi_rank, gpupos=gpupos, silent=self.silent, ) elif BACKEND == "tensorflow": ''' from foscat.BkTensorflow import BkTensorflow self.backend = BkTensorflow( all_type=all_type, mpi_rank=mpi_rank, gpupos=gpupos, silent=self.silent, ) ''' raise 'TENSORFLOW BACKEND Not anymore maintained' else: ''' from foscat.BkNumpy import BkNumpy self.backend = BkNumpy( all_type=all_type, mpi_rank=mpi_rank, gpupos=gpupos, silent=self.silent, ) ''' raise 'NUMPY BACKEND Not anymore maintained' self.all_bk_type = self.backend.all_bk_type self.all_cbk_type = self.backend.all_cbk_type self.gpulist = self.backend.gpulist self.ngpu = self.backend.ngpu self.rank = mpi_rank self.gpupos = (gpupos + mpi_rank) % self.backend.ngpu if not self.silent: print("============================================================") print("== ==") print("== ==") print( "== RUN ON GPU Rank %d : %s ==" % (mpi_rank, self.gpulist[self.gpupos % self.ngpu]) ) print("== ==") print("== ==") print("============================================================") sys.stdout.flush() l_NORIENT = NORIENT if DODIV: l_NORIENT = NORIENT + 2 self.NORIENT = l_NORIENT self.LAMBDA = LAMBDA self.slope = slope self.R_off = (KERNELSZ - 1) // 2 if (self.R_off // 2) * 2 < self.R_off: self.R_off += 1 self.ww_Real = {} self.ww_Imag = {} self.ww_CNN_Transpose = {} self.ww_CNN = {} self.X_CNN = {} self.Y_CNN = {} self.Z_CNN = {} self.Idx_CNN = {} self.Idx_WCNN = {} self.filters_set = {} self.edge_masks = {} wwc = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type) wws = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type) x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape( KERNELSZ, KERNELSZ ) y = x.T if NORIENT == 1: xx = (3 / float(KERNELSZ)) * LAMBDA * x yy = (3 / float(KERNELSZ)) * LAMBDA * y if KERNELSZ == 5: # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2)) w_smooth = np.exp(-(xx**2 + yy**2)) tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp( -0.5 * (xx**2 + yy**2) ) else: w_smooth = np.exp(-0.5 * (xx**2 + yy**2)) tmp = np.exp(-2 * (xx**2 + yy**2)) - 0.25 * np.exp( -0.5 * (xx**2 + yy**2) ) wwc[0] = tmp.flatten() - tmp.mean() tmp = 0 * w_smooth wws[0] = tmp.flatten() sigma = np.sqrt((wwc[:, 0] ** 2).mean()) wwc[0] /= sigma wws[0] /= sigma w_smooth = w_smooth.flatten() else: for i in range(NORIENT): a = ( (NORIENT - 1 - i) / float(NORIENT) * np.pi ) # get the same angle number than scattering lib if KERNELSZ < 5: xx = ( (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a)) ) yy = ( (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a)) ) else: xx = (3 / 5) * LAMBDA * (x * np.cos(a) + y * np.sin(a)) yy = (3 / 5) * LAMBDA * (x * np.sin(a) - y * np.cos(a)) if KERNELSZ == 5: w_smooth = np.exp( -2 * ( (3.0 / float(KERNELSZ) * xx) ** 2 + (3.0 / float(KERNELSZ) * yy) ** 2 ) ) else: w_smooth = np.exp(-0.5 * (xx**2 + yy**2)) tmp1 = np.cos(yy * np.pi) * w_smooth tmp2 = np.sin(yy * np.pi) * w_smooth wwc[i] = tmp1.flatten() - tmp1.mean() wws[i] = tmp2.flatten() - tmp2.mean() # sigma = np.sqrt((wwc[:, i] ** 2).mean()) sigma = np.mean(w_smooth) wwc[i] /= sigma wws[i] /= sigma if DODIV and i == 0: r = xx**2 + yy**2 theta = np.arctan2(yy, xx) theta[KERNELSZ // 2, KERNELSZ // 2] = 0.0 tmp1 = r * np.cos(2 * theta) * w_smooth tmp2 = r * np.sin(2 * theta) * w_smooth wwc[NORIENT] = tmp1.flatten() - tmp1.mean() wws[NORIENT] = tmp2.flatten() - tmp2.mean() # sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean()) sigma = np.mean(w_smooth) wwc[NORIENT] /= sigma wws[NORIENT] /= sigma tmp1 = r * np.cos(2 * theta + np.pi) tmp2 = r * np.sin(2 * theta + np.pi) wwc[NORIENT + 1] = tmp1.flatten() - tmp1.mean() wws[NORIENT + 1] = tmp2.flatten() - tmp2.mean() # sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean()) sigma = np.mean(w_smooth) wwc[NORIENT + 1] /= sigma wws[NORIENT + 1] /= sigma w_smooth = w_smooth.flatten() if self.use_1D: KERNELSZ = 5 self.KERNELSZ = KERNELSZ self.Idx_Neighbours = {} self.w_smooth = {} if self.use_1D: self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type) self.ww_RealT = {} self.ww_ImagT = {} self.ww_SmoothT = {} if KERNELSZ == 5: xx = np.arange(5) - 2 w = np.exp(-0.25 * (xx) ** 2) c = w * np.cos((xx) * np.pi / 2) s = w * np.sin((xx) * np.pi / 2) w = w / np.sum(w) c = c - np.mean(c) s = s - np.mean(s) r = np.sum(np.sqrt(c * c + s * s)) c = c / r s = s / r self.ww_RealT[1] = self.backend.bk_cast( self.backend.bk_constant(np.array(c).reshape(xx.shape[0])) ) self.ww_ImagT[1] = self.backend.bk_cast( self.backend.bk_constant(np.array(s).reshape(xx.shape[0])) ) self.ww_SmoothT[1] = self.backend.bk_cast( self.backend.bk_constant(np.array(w).reshape(xx.shape[0])) ) if self.use_2D: self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type) self.ww_RealT = {} self.ww_ImagT = {} self.ww_SmoothT = {} self.ww_SmoothT[1] = self.backend.bk_constant( self.w_smooth.reshape(1, KERNELSZ, KERNELSZ) ) self.ww_RealT[1] = self.backend.bk_constant( self.backend.bk_reshape( wwc.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ] ) ) self.ww_ImagT[1] = self.backend.bk_constant( self.backend.bk_reshape( wws.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ] ) ) def doorientw(x): y = np.zeros( [KERNELSZ, KERNELSZ, NORIENT, NORIENT * NORIENT], dtype=self.all_type, ) for k in range(NORIENT): y[:, :, k, k * NORIENT : k * NORIENT + NORIENT] = x.reshape( KERNELSZ, KERNELSZ, NORIENT ) return y self.ww_RealT[NORIENT] = self.backend.bk_constant( doorientw(wwc.astype(self.all_type)) ) self.ww_ImagT[NORIENT] = self.backend.bk_constant( doorientw(wws.astype(self.all_type)) ) self.pix_interp_val = {} self.weight_interp_val = {} self.ring2nest = {} self.ampnorm = {} self.loss = {} self.dtype_dcode_map = { 0: np.int64, 1: np.int32, 2: np.float32, 3: np.float64, 4: np.complex64, 5: np.complex128 } self.dtype_code_map = { np.int64: 0, np.int32: 1, np.float32: 2, np.float64: 3, np.complex64: 4, np.complex128: 5 } # this is for the storage only
[docs] def get_dtype_code(self, dtype): for key, code in self.dtype_code_map.items(): if np.dtype(dtype) == np.dtype(key): return code raise ValueError(f"Unsupported data type: {dtype}")
[docs] def get_type(self): return self.all_type
[docs] def get_mpi_type(self): return self.MPI_ALL_TYPE
# ---------------------------------------------−--------- # -- COMPUTE 3X3 INDEX FOR HEALPIX WORK -- # ---------------------------------------------−---------
[docs] def conv_to_FoCUS(self, x, axis=0): if self.use_2D and isinstance(x, np.ndarray): return self.to_R(x, axis, chans=self.chans) return x
[docs] def diffang(self, a, b): return np.arctan2(np.sin(a) - np.sin(b), np.cos(a) - np.cos(b))
[docs] def corr_idx_wXX(self, x, y): idx = np.where(x == -1)[0] res = x res[idx] = y[idx] return res
# ---------------------------------------------−--------- # make the CNN working : index reporjection of the kernel on healpix
[docs] def calc_indices_convol(self, nside, kernel, rotation=None): to, po = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True) x, y, z = hp.pix2vec(nside, np.arange(12 * nside * nside), nest=True) idx = np.argsort((x - 1.0) ** 2 + y**2 + z**2)[0:kernel] x0, y0, z0 = hp.pix2vec(nside, idx[0], nest=True) t0, p0 = hp.pix2ang(nside, idx[0], nest=True) idx = np.argsort((x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2)[0:kernel] im = np.ones([12 * nside**2]) * -1 im[idx] = np.arange(len(idx)) xc, yc, zc = hp.pix2vec(nside, idx, nest=True) xc -= x0 yc -= y0 zc -= z0 vec = np.concatenate( [np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(z, -1)], 1 ) indices = np.zeros([12 * nside**2 * 250, 2], dtype="int") weights = np.zeros([12 * nside**2 * 250]) nn = 0 for k in range(12 * nside * nside): if k % (nside * nside) == nside * nside - 1: print( "Nside=%d KenelSZ=%d %.2f%%" % (nside, kernel, k / (12 * nside**2) * 100) ) if nside < 4: idx2 = np.arange(12 * nside**2) else: idx2 = hp.query_disc( nside, vec[k], np.pi / nside, inclusive=True, nest=True ) t2, p2 = hp.pix2ang(nside, idx2, nest=True) if rotation is None: rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0] else: rot = [po[k] / np.pi * 180.0, (t0 - to[k]) / np.pi * 180.0, rotation[k]] r = hp.Rotator(rot=rot) t2, p2 = r(t2, p2) ii, ww = hp.get_interp_weights(nside, t2, p2, nest=True) ii = im[ii] for l_rotation in range(4): iii = np.where(ii[l_rotation] != -1)[0] if len(iii) > 0: indices[nn : nn + len(iii), 1] = idx2[iii] indices[nn : nn + len(iii), 0] = k * kernel + ii[l_rotation, iii] weights[nn : nn + len(iii)] = ww[l_rotation, iii] nn += len(iii) indices = indices[0:nn] weights = weights[0:nn] if k % (nside * nside) == nside * nside - 1: print( "Nside=%d KenelSZ=%d Total Number of value=%d Ratio of the matrix %.2g%%" % ( nside, kernel, nn, 100 * nn / (kernel * 12 * nside**2 * 12 * nside**2), ) ) return indices, weights, xc, yc, zc
#====================================================================================== # The next two functions prepare the ability of FOSCAT to work with large indexed file #======================================================================================
[docs] def save_index(self, filepath, data, offset=0, count=None): """ Save an N-dimensional NumPy array with shape (N, ...) to binary file. A 12x int64 header is written, describing dtype and shape beyond axis 0. Header layout (12 x int64): [0] = dtype code (0=int64, 1=int32, 2=float32, 3=float64, 4=complex64, 5=complex128) [1] = number of extra dimensions (i.e., data.ndim - 1) [2:12] = shape[1:] padded with zeros Parameters: - filepath: target binary file path - data: NumPy array with shape (N, ...) - offset: number of items to skip on axis 0 - count: number of items to write on axis 0 (default: rest of the array) """ if filepath is None: raise ValueError("No filepath specified for writing.") data = np.asarray(data) if data.ndim < 1: raise ValueError("Data must have at least one dimension.") extra_dims = data.shape[1:] if len(extra_dims) > 10: raise ValueError(f"Too many dimensions: {data.ndim}. Max supported is 11 (1 + 10 extra).") dtype_code = self.get_dtype_code(data.dtype) itemsize = data.dtype.itemsize item_shape = data.shape[1:] item_count = np.prod(item_shape, dtype=np.int64) if item_shape else 1 if count is None: count = data.shape[0] header = np.zeros(12, dtype=np.int64) header[0] = dtype_code header[1] = len(extra_dims) header[2:2 + len(extra_dims)] = extra_dims mode = 'r+b' if os.path.exists(filepath) else 'w+b' with open(filepath, mode) as f: if os.path.getsize(filepath) == 0: f.write(header.tobytes()) byte_offset = 12 * 8 + offset * itemsize * item_count # header is 96 bytes f.seek(byte_offset) f.write(data[offset:offset + count].tobytes())
[docs] def read_index(self, filepath, offset=0, count=None): """ Load a NumPy array from a binary file with a 12x int64 header. Header layout: [0] = dtype code [1] = number of extra dimensions (D) [2:2+D] = shape[1:] of each sample (shape after axis 0) Parameters: - filepath: path to the binary file - offset: number of samples to skip on axis 0 - count: number of samples to read (default: all remaining) Returns: - data: NumPy array with shape (count, ...) and correct dtype """ if not os.path.exists(filepath): raise FileNotFoundError(f"File not found: {filepath}") with open(filepath, 'rb') as f: header_bytes = f.read(12 * 8) if len(header_bytes) != 96: raise ValueError("Invalid or missing header (expected 96 bytes).") header = np.frombuffer(header_bytes, dtype=np.int64) dtype_code = header[0] ndim_extra = header[1] if dtype_code not in self.dtype_dcode_map: raise ValueError(f"Unknown dtype code in header: {dtype_code}") dtype = self.dtype_dcode_map[dtype_code] shape1 = tuple(header[2:2 + ndim_extra]) itemsize = np.dtype(dtype).itemsize item_count = np.prod(shape1, dtype=np.int64) if shape1 else 1 bytes_per_sample = itemsize * item_count # Seek to data block f.seek(12 * 8 + offset * bytes_per_sample) # Determine number of items if count is None: remaining_bytes = os.path.getsize(filepath) - (12 * 8 + offset * bytes_per_sample) count = remaining_bytes // bytes_per_sample raw = f.read(count * bytes_per_sample) data = np.frombuffer(raw, dtype=dtype) if shape1: data = data.reshape((count,) + shape1) else: data = data.reshape((count,)) return data
# ---------------------------------------------−--------- # ---------------------------------------------−---------
[docs] def healpix_layer(self, im, ww, indices=None, weights=None): #ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT] #im [N_batch,N_i, NORIENT,N] nside=int(np.sqrt(im.shape[-1]//12)) if indices is None: if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN: self.init_index_cnn(nside,self.NORIENT) indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)] mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)] wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w] wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...] #win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w] #ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT] wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3]) wim = self.backend.bk_reduce_sum(wim,-1) return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
# ---------------------------------------------−--------- # ---------------------------------------------−---------
[docs] def get_rank(self): return self.rank
# ---------------------------------------------−---------
[docs] def get_size(self): return self.size
# ---------------------------------------------−---------
[docs] def barrier(self): if self.isMPI: self.comm.Barrier()
# ---------------------------------------------−---------
[docs] def toring(self, image, axis=0): lout = int(np.sqrt(image.shape[axis] // 12)) if lout not in self.ring2nest: self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2)) return image.numpy()[self.ring2nest[lout]]
# --------------------------------------------------------
[docs] def ud_grade(self, im, j, axis=0, cell_ids=None, nside=None): rim = im for k in range(j): # rim = self.smooth(rim, axis=axis) rim = self.ud_grade_2(rim, axis=axis, cell_ids=cell_ids, nside=nside) return rim
# --------------------------------------------------------
[docs] def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None,max_poll=False): if self.use_2D: ishape = list(im.shape) if len(ishape) < axis + 2: if not self.silent: print("Use of 2D scat with data that has less than 2D") return None, None npix = im.shape[-2] npiy = im.shape[-1] ndata = 1 for k in range(len(im.shape)-2): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, npix, npiy, 1] ) ''' tim = self.backend.bk_reshape( tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :], [ndata, npix // 2, 2, npiy // 2, 2, 1], ) #res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4 ''' if self.use_median: res = self.backend.downsample_median_2x2(tim) else: res = self.backend.downsample_mean_2x2(tim) if len(ishape) == 2: return ( self.backend.bk_reshape( res, [npix // 2, npiy // 2] ), None, ) else: return ( self.backend.bk_reshape( res, ishape[0:-2] + [npix // 2, npiy // 2], ), None, ) return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None elif self.use_1D: ishape = list(im.shape) npix = ishape[-1] ndata = 1 for k in range(len(ishape) - 1): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, npix // 2, 2] ) if self.use_median: res=self.backend.bk_reduce_median(tim,axis=-1) else: res=self.backend.bk_reduce_mean(tim,axis=-1) return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None else: shape = list(im.shape) if nside is None: l_nside=int(np.sqrt(shape[-1]//12)) else: l_nside=nside nbatch=1 for k in range(len(shape)-1): nbatch*=shape[k] if l_nside not in self.down: print('initialise down', l_nside) self.down[l_nside] = SphereDownGeo(nside_in=l_nside, dtype=self.all_bk_type,mode="smooth", in_cell_ids=cell_ids) res,out_cell=self.down[l_nside](self.backend.bk_reshape(im,[nbatch,1,shape[-1]])) return self.backend.bk_reshape(res,shape[:-1]+[out_cell.shape[0]]),out_cell ''' if self.use_median: if cell_ids is not None: sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='median') return sim, new_cell_ids return self.backend.bk_reduce_median( self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1 ),None elif max_poll: if cell_ids is not None: sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='max') return sim, new_cell_ids return self.backend.bk_reduce_max( self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1 ),None else: if cell_ids is not None: sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='mean') return sim, new_cell_ids return self.backend.bk_reduce_mean( self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1 ),None '''
# --------------------------------------------------------
[docs] def up_grade(self, im, nout, axis=-1, nouty=None, cell_ids=None, o_cell_ids=None, force_init_index=False, nside=None): ishape = list(im.shape) if self.use_2D: if len(ishape) < 2: if not self.silent: print("Use of 2D scat with data that has less than 2D") return None if nouty is None: nouty = nout if ishape[-2] == nout and ishape[-1] == nouty: return im npix = im.shape[-2] npiy = im.shape[-1] ndata = 1 for k in range(len(im.shape)-2): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, npix, npiy,1] ) res = self.backend.bk_resize_image(tim, [nout, nouty]) if len(ishape) == 2: return self.backend.bk_reshape(res, [nout, nouty]) else: return self.backend.bk_reshape( res, ishape[0:-2] + [nout, nouty] ) elif self.use_1D: if len(ishape) < axis + 1: if not self.silent: print("Use of 1D scat with data that has less than 1D") return None if ishape[axis] == nout: return im npix = im.shape[axis] odata = 1 ndata = 1 if len(ishape)>1: for k in range(len(ishape)-1): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, npix, odata] ) while tim.shape[1] != nout: res2 = self.backend.bk_expand_dims( self.backend.bk_concat( [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1 ), -2, ) res1 = self.backend.bk_expand_dims( self.backend.bk_concat( [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1 ), -2, ) tim = self.backend.bk_reshape( self.backend.bk_concat([res1, res2], -2), [ndata, tim.shape[1] * 2, odata], ) return self.backend.bk_reshape(tim, ishape[0:-1] + [nout]) else: if nside is None: lout = int(np.sqrt(im.shape[-1] // 12)) else: lout = nside ''' if (lout,nout) not in self.pix_interp_val or force_init_index: if not self.silent: print("compute lout nout", lout, nout) if cell_ids is None: o_cell_ids=np.arange(12 * nout**2, dtype="int") i_npix=12*lout**2 #level=int(np.log2(lout)) # nside=128 #sp = HS.heal_spline(level,gamma=2.0) th, ph = hp.pix2ang( nout, o_cell_ids, nest=True ) all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True) #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True) w=www.T p=all_idx.T w=w.flatten() p=p.flatten() indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int") indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4) indice[:, 0] = p self.pix_interp_val[(lout,nout)] = 1 self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor( self.backend.bk_constant(indice.T), self.backend.bk_constant(self.backend.bk_cast(w)), dense_shape=[i_npix,o_cell_ids.shape[0]], ) else: ratio=(nout//lout)**2 if o_cell_ids is None: o_cell_ids=np.tile(cell_ids,ratio)*ratio+np.repeat(np.arange(ratio),cell_ids.shape[0]) i_npix=cell_ids.shape[0] th, ph = hp.pix2ang( nout, self.backend.to_numpy(o_cell_ids), nest=True ) all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True) #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True) hidx,inv_idx = np.unique(all_idx, return_inverse=True) all_idx = inv_idx sorter = np.argsort(hidx) index=sorter[np.searchsorted(hidx, self.backend.to_numpy(cell_ids), sorter=sorter)] mask = -np.ones([hidx.shape[0]]) mask[index] = np.arange(index.shape[0],dtype='int') all_idx=mask[all_idx] www[all_idx==-1]=0.0 www/=np.sum(www,0)[None,:] all_idx[all_idx==-1]=0 w=www.T p=all_idx.T w=w.flatten() p=p.flatten() indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int") indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4) indice[:, 0] = p self.pix_interp_val[(lout,nout)] = 1 self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor( self.backend.bk_constant(indice.T), self.backend.bk_constant(self.backend.bk_cast(w)), dense_shape=[i_npix,o_cell_ids.shape[0]], ) del w del p ''' shape=list(im.shape) nbatch=1 for k in range(len(shape)-1): nbatch*=shape[k] im=self.backend.bk_reshape(im,[nbatch,1,shape[-1]]) while lout<nout: if lout not in self.up: if o_cell_ids is None: l_o_cell_ids=torch.tensor(np.arange(12*(lout**2),dtype='int'),device=im.device) else: l_o_cell_ids=o_cell_ids self.up[lout] = SphereUpGeo(nside_out=lout, dtype=self.all_bk_type, cell_ids_out=l_o_cell_ids, up_norm="col_l1") im, fine_ids = self.up[lout](self.backend.bk_cast(im)) lout*=2 if lout<nout and o_cell_ids is not None: o_cell_ids=torch.repeat(fine_ids,4)*4+ \ torch.tile(torch.tensor([0,1,2,3],device=fine_ids.device,dtype=fine_ids.dtype),fine_ids.shape[0]) return self.backend.bk_reshape(im,shape[:-1]+[im.shape[-1]]) ''' ndata = 1 for k in range(len(ishape)-1): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, ishape[-1]] ) if tim.dtype == self.all_cbk_type: rr = self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), self.weight_interp_val[(lout,nout)], ) ii = self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), self.weight_interp_val[(lout,nout)], ) imout = self.backend.bk_complex(rr, ii) else: imout = self.backend.bk_sparse_dense_matmul( tim, self.weight_interp_val[(lout,nout)], ) if len(ishape) == 1: return self.backend.bk_reshape(imout, [imout.shape[-1]]) else: return self.backend.bk_reshape( imout, ishape[0:-1]+[imout.shape[-1]] ) ''' return imout
# --------------------------------------------------------
[docs] def fill_1d(self, i_arr, nullval=0): arr = i_arr.copy() # Indices of non-zero elements non_zero_indices = np.where(arr != nullval)[0] # Indices of all elements all_indices = np.arange(len(arr)) # Linearly interpolate using np.interp # np.interp(x, xp, fp) : x sont les indices pour lesquels on veut obtenir des valeurs # xp are the indices of existing data, fp are the values of existing data interpolated_values = np.interp( all_indices, non_zero_indices, arr[non_zero_indices] ) # Update the original array arr[arr == nullval] = interpolated_values[arr == nullval] return arr
[docs] def fill_2d(self, i_arr, nullval=0): arr = i_arr.copy() # Create a coordinate grid matching the array indices x, y = np.indices(arr.shape) # Extract coordinates of non-zero points and their values non_zero_points = np.array((x[arr != nullval], y[arr != nullval])).T non_zero_values = arr[arr != nullval] # Extract coordinates of zero points zero_points = np.array((x[arr == nullval], y[arr == nullval])).T # Linear interpolation interpolated_values = griddata( non_zero_points, non_zero_values, zero_points, method="linear" ) # Replace zero values with interpolated values arr[arr == nullval] = interpolated_values return arr
[docs] def fill_healpy(self, i_map, nmax=10, nullval=hp.UNSEEN): map = 1 * i_map # Trouver les pixels nuls nside = hp.npix2nside(len(map)) null_indices = np.where(map == nullval)[0] itt = 0 while null_indices.shape[0] > 0 and itt < nmax: # Find theta, phi coordinates for zero pixels theta, phi = hp.pix2ang(nside, null_indices) # Interpoler les valeurs en utilisant les pixels voisins # get_interp_val can be used to obtain interpolated values # at positions given in theta and phi. i_idx = hp.get_all_neighbours(nside, theta, phi) i_w = (map[i_idx] != nullval) * (i_idx != -1) vv = np.sum(i_w, 0) interpolated_values = np.sum(i_w * map[i_idx], 0) # Replace zero values with interpolated values map[null_indices[vv > 0]] = interpolated_values[vv > 0] / vv[vv > 0] null_indices = np.where(map == nullval)[0] itt += 1 return map
# --------------------------------------------------------
[docs] def ud_grade_1d(self, im, nout, axis=0): npix = im.shape[axis] ishape = list(im.shape) odata = 1 for k in range(axis + 1, len(ishape)): odata = odata * ishape[k] ndata = 1 for k in range(axis): ndata = ndata * ishape[k] nscale = npix // nout if npix % nscale == 0: tim = self.backend.bk_reshape( self.backend.bk_cast(im), [ndata, npix // nscale, nscale, odata] ) else: im = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata]) tim = self.backend.bk_reshape( self.backend.bk_cast(im[:, 0 : nscale * (npix // nscale), :]), [ndata, npix // nscale, nscale, odata], ) res = self.backend.bk_reduce_mean(tim, 2) if axis == 0: if len(ishape) == 1: return self.backend.bk_reshape(res, [nout]) else: return self.backend.bk_reshape(res, [nout] + ishape[axis + 1 :]) else: if len(ishape) == axis + 1: return self.backend.bk_reshape(res, ishape[0:axis] + [nout]) else: return self.backend.bk_reshape( res, ishape[0:axis] + [nout] + ishape[axis + 1 :] ) return self.backend.bk_reshape(res, [nout])
# --------------------------------------------------------
[docs] def up_grade_2_1d(self, im, axis=0): npix = im.shape[axis] ishape = list(im.shape) odata = 1 for k in range(axis + 1, len(ishape)): odata = odata * ishape[k] ndata = 1 for k in range(axis): ndata = ndata * ishape[k] tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata]) res2 = self.backend.bk_expand_dims( self.backend.bk_concat( [(tim[:, 1:, :] + 3 * tim[:, :-1, :]) / 4, tim[:, -1:, :]], 1 ), -2, ) res1 = self.backend.bk_expand_dims( self.backend.bk_concat( [tim[:, 0:1, :], (tim[:, 1:, :] * 3 + tim[:, :-1, :]) / 4], 1 ), -2, ) res = self.backend.bk_concat([res1, res2], -2) if axis == 0: if len(ishape) == 1: return self.backend.bk_reshape(res, [npix * 2]) else: return self.backend.bk_reshape(res, [npix * 2] + ishape[axis + 1 :]) else: if len(ishape) == axis + 1: return self.backend.bk_reshape(res, ishape[0:axis] + [npix * 2]) else: return self.backend.bk_reshape( res, ishape[0:axis] + [npix * 2] + ishape[axis + 1 :] ) return self.backend.bk_reshape(res, [npix * 2])
# --------------------------------------------------------
[docs] def convol_1d(self, im, axis=0): xx = np.arange(5) - 2 w = np.exp(-0.17328679514 * (xx) ** 2) c = np.cos((xx) * np.pi / 2) s = np.sin((xx) * np.pi / 2) wr = np.array(w * c).reshape(xx.shape[0], 1, 1) wi = np.array(w * s).reshape(xx.shape[0], 1, 1) npix = im.shape[axis] ishape = list(im.shape) odata = 1 for k in range(axis + 1, len(ishape)): odata = odata * ishape[k] ndata = 1 for k in range(axis): ndata = ndata * ishape[k] if odata > 1: wr = np.repeat(wr, odata, 2) wi = np.repeat(wi, odata, 2) wr = self.backend.bk_cast(self.backend.bk_constant(wr)) wi = self.backend.bk_cast(self.backend.bk_constant(wi)) tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata]) if tim.dtype == self.all_cbk_type: rr1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wr) ii1 = self.backend.bk_conv1d(self.backend.bk_real(tim), wi) rr2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wr) ii2 = self.backend.bk_conv1d(self.backend.bk_imag(tim), wi) res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2) else: rr = self.backend.bk_conv1d(tim, wr) ii = self.backend.bk_conv1d(tim, wi) res = self.backend.bk_complex(rr, ii) if axis == 0: if len(ishape) == 1: return self.backend.bk_reshape(res, [npix]) else: return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :]) else: if len(ishape) == axis + 1: return self.backend.bk_reshape(res, ishape[0:axis] + [npix]) else: return self.backend.bk_reshape( res, ishape[0:axis] + [npix] + ishape[axis + 1 :] ) return self.backend.bk_reshape(res, [npix])
# --------------------------------------------------------
[docs] def smooth_1d(self, im, axis=0): xx = np.arange(5) - 2 w = np.exp(-0.17328679514 * (xx) ** 2) w = w / w.sum() w = np.array(w).reshape(xx.shape[0], 1, 1) npix = im.shape[axis] ishape = list(im.shape) odata = 1 for k in range(axis + 1, len(ishape)): odata = odata * ishape[k] ndata = 1 for k in range(axis): ndata = ndata * ishape[k] if odata > 1: w = np.repeat(w, odata, 2) w = self.backend.bk_cast(self.backend.bk_constant(w)) tim = self.backend.bk_reshape(self.backend.bk_cast(im), [ndata, npix, odata]) if tim.dtype == self.all_cbk_type: rr = self.backend.bk_conv1d(self.backend.bk_real(tim), w) ii = self.backend.bk_conv1d(self.backend.bk_real(tim), w) res = self.backend.bk_complex(rr, ii) else: res = self.backend.bk_conv1d(tim, w) if axis == 0: if len(ishape) == 1: return self.backend.bk_reshape(res, [npix]) else: return self.backend.bk_reshape(res, [npix] + ishape[axis + 1 :]) else: if len(ishape) == axis + 1: return self.backend.bk_reshape(res, ishape[0:axis] + [npix]) else: return self.backend.bk_reshape( res, ishape[0:axis] + [npix] + ishape[axis + 1 :] ) return self.backend.bk_reshape(res, [npix])
# --------------------------------------------------------
[docs] def up_grade_1d(self, im, nout, axis=0): lout = int(im.shape[axis]) nscale = int(np.log(nout // lout) / np.log(2)) res = self.backend.bk_cast(im) for k in range(nscale): res = self.up_grade_2_1d(res, axis=axis) return res
# ---------------------------------------------−---------
[docs] def init_index(self, nside, kernel=-1, cell_ids=None, spin=0): if kernel == -1: l_kernel = self.KERNELSZ else: l_kernel = kernel if cell_ids is not None: ncell = cell_ids.shape[0] else: ncell = 12 * nside * nside try: if self.use_2D: tmp = self.read_index("%s/W%d_%s_%d_IDX.fst" % (self.TEMPLATE_PATH, l_kernel**2,TMPFILE_VERSION, nside) ) else: if cell_ids is not None and spin==0: tmp = self.read_index( "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside, # if cell_ids computes the index ) ) else: ''' print('LOAD ',"%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside,spin # if cell_ids computes the index )) ''' tmp = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside,spin # if cell_ids computes the index ) ) except: if cell_ids is not None and spin!=0: self.init_index(nside, kernel=kernel, spin=spin) if not self.use_2D: if spin!=0: # keep the print here as spin!=0 can be long print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst' % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside,spin # if cell_ids computes the index ) ) try: tmp = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside ) ) except: ''' print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst' % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside ) ) ''' self.init_index(nside, kernel=kernel, spin=0) tmp = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, self.NORIENT, nside ) ) tmpw = self.read_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.fst"% ( self.TEMPLATE_PATH, self.TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, ) ) import foscat.HOrientedConvol as hs hconvol=hs.HOrientedConvol(nside,4*self.KERNELSZ+1,cell_ids=cell_ids) if cell_ids is None: l_cell_ids=np.arange(12*nside**2) else: l_cell_ids=cell_ids nvalid=4*self.KERNELSZ**2 if nvalid>12*nside**2: nvalid=12*nside**2 idxEB=hconvol.idx_nn[:,0:nvalid] tmpEB=np.zeros([self.NORIENT,4,l_cell_ids.shape[0],nvalid],dtype='complex') tmpS=np.zeros([4,l_cell_ids.shape[0],nvalid],dtype='float') idx={} nn=0 nn2=1 if nside<64: pp=10 else: pp=1 while nn2>0: idx2={} nn2=0 im=np.zeros([12*nside**2]) for n in range(l_cell_ids.shape[0]): if im[hconvol.idx_nn[n,0]]==0 and n not in idx: im[hconvol.idx_nn[n,:]]=1.0 idx[hconvol.idx_nn[n,0]]=1.0 idx2[hconvol.idx_nn[n,0]]=1.0 nn+=1 nn2+=1 im=np.zeros([12*nside**2]) for k in idx2: im[k]=1.0 r=self.convol(im) for k in range(self.NORIENT): ralm=hp.map2alm(hp.reorder(r[k].cpu().numpy().real,n2r=True))[None,:] ialm=hp.map2alm(hp.reorder(r[k].cpu().numpy().imag,n2r=True))[None,:] alm=np.concatenate([ralm,0*ralm,0*ralm],0) rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1) alm=np.concatenate([ialm,0*ialm,0*ialm],0) iqe,iue,iie=hp.alm2map_spin(alm,nside,spin,3*nside-1) alm=np.concatenate([0*ralm,ralm,0*ralm],0) rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1) alm=np.concatenate([0*ialm,ialm,0*ialm],0) iqb,iub,iib=hp.alm2map_spin(alm,nside,spin,3*nside-1) rqe=hp.reorder(rqe,r2n=True) rue=hp.reorder(rue,r2n=True) rqb=hp.reorder(rqb,r2n=True) rub=hp.reorder(rub,r2n=True) iqe=hp.reorder(iqe,r2n=True) iue=hp.reorder(iue,r2n=True) iqb=hp.reorder(iqb,r2n=True) iub=hp.reorder(iub,r2n=True) for l in idx2: tmpEB[k,0,l]=rqe[idxEB[l,:]]+1J*iqe[idxEB[l,:]] tmpEB[k,1,l]=rue[idxEB[l,:]]+1J*iue[idxEB[l,:]] tmpEB[k,2,l]=rqb[idxEB[l,:]]+1J*iqb[idxEB[l,:]] tmpEB[k,3,l]=rub[idxEB[l,:]]+1J*iub[idxEB[l,:]] r=self.smooth(im) ralm=hp.map2alm(hp.reorder(r.cpu().numpy(),n2r=True))[None,:] alm=np.concatenate([ralm,0*ralm,0*ralm],0) rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1) alm=np.concatenate([0*ralm,ralm,0*ralm],0) rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1) rqe=hp.reorder(rqe,r2n=True) rue=hp.reorder(rue,r2n=True) rqb=hp.reorder(rqb,r2n=True) rub=hp.reorder(rub,r2n=True) for l in idx2: tmpS[0,l,:]=rqe[idxEB[l,:]] tmpS[1,l,:]=rue[idxEB[l,:]] tmpS[2,l,:]=rqb[idxEB[l,:]] tmpS[3,l,:]=rub[idxEB[l,:]] if 100*nn/(l_cell_ids.shape[0])>pp: if nside<64: pp+=10 else: pp+=1 print('%.2f%% Done'%(100*nn/(l_cell_ids.shape[0]))) wav=tmpEB.flatten() wwav=tmpS.flatten() ndata=l_cell_ids.shape[0]*nvalid indice_1_1=np.tile(idxEB.flatten(),4*self.NORIENT) for k in range(self.NORIENT): indice_1_1[(4*k+1)*ndata:(4*k+2)*ndata]+=l_cell_ids.shape[0] indice_1_1[(4*k+3)*ndata:(4*k+4)*ndata]+=l_cell_ids.shape[0] indice_1_0=np.tile(np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4),self.NORIENT) for k in range(self.NORIENT): indice_1_0[(4*k+2)*ndata:(4*k+4)*ndata]+=self.NORIENT*l_cell_ids.shape[0] indice_1_0[(4*k)*ndata:(4*k+4)*ndata]+=k*l_cell_ids.shape[0] indice=np.concatenate([indice_1_1[:,None],indice_1_0[:,None]],1) indice_2_1=np.tile(idxEB.flatten(),4) indice_2_1[ndata:2*ndata]+=l_cell_ids.shape[0] indice_2_1[3*ndata:4*ndata]+=l_cell_ids.shape[0] indice_2_0=np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4) indice_2_0[2*ndata:]+=l_cell_ids.shape[0] indice2=np.concatenate([indice_2_1[:,None],indice_2_0[:,None]],1) self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"% (self.TEMPLATE_PATH, self.TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin ), indice.T ) self.save_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"% (self.TEMPLATE_PATH, self.TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), wav ) self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"% (self.TEMPLATE_PATH, self.TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin ), indice2.T ) self.save_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"% (self.TEMPLATE_PATH, self.TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), wwav ) else: if l_kernel == 5: pw = 0.75 pw2 = 0.5 threshold = 2e-5 elif l_kernel == 3: pw = 1.0 / np.sqrt(2) pw2 = 1.0 threshold = 1e-3 elif l_kernel == 7: pw = 0.5 pw2 = 0.25 threshold = 2e-5 import foscat.SphericalStencil as hs import torch if cell_ids is None: l_cell_ids=np.arange(12*nside**2) else: l_cell_ids=cell_ids if isinstance(l_cell_ids,torch.Tensor): l_cell_ids=self.backend.to_numpy(l_cell_ids) hconvol=hs.SphericalStencil(nside, l_kernel, cell_ids=l_cell_ids, n_gauges=self.NORIENT, gauge_type='cosmo') xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ) if nside>0: wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ) wwr-=wwr.mean() wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ) wwi-=wwi.mean() amp=np.sum(abs(wwr+1J*wwi)) else: #asymetric kernels wwr=(np.exp(-2*pw2*(xx**2+self.NORIENT*(xx.T)**2))).reshape(1,1,self.KERNELSZ*self.KERNELSZ) #wwr-=wwr.mean() #wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ) #wwi-=wwi.mean() wwi=0.0*wwr amp=self.NORIENT*np.sum(abs(wwr+1J*wwi)) wwr/=amp wwi/=amp wwr=hconvol.to_tensor(wwr) wwi=hconvol.to_tensor(wwi) wavr,indice,mshape=hconvol.make_matrix(wwr) wavi,indice,mshape=hconvol.make_matrix(wwi) wav=hconvol.to_numpy(wavr)+1J*hconvol.to_numpy(wavi) indice=hconvol.to_numpy(indice) hconvol=hs.SphericalStencil(nside, l_kernel, cell_ids=l_cell_ids, n_gauges=1, gauge_type='cosmo') ww=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))).reshape(1,1,self.KERNELSZ*self.KERNELSZ)) ww/=ww.sum() wwav,indice2,mshape=hconvol.make_matrix(ww) wwav=hconvol.to_numpy(wwav) indice2=hconvol.to_numpy(indice2) if cell_ids is None: if not self.silent: print( "Write %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin) ) self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), indice ) self.save_index( "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), wav, ) self.save_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), indice2, ) self.save_index( "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ), wwav, ) if self.use_2D: if l_kernel**2 == 9: if self.rank == 0: self.comp_idx_w9(nside) elif l_kernel**2 == 25: if self.rank == 0: self.comp_idx_w25(nside) else: if self.rank == 0: if not self.silent: print( "Only 3x3 and 5x5 kernel have been developped for Healpix and you ask for %dx%d" % (self.KERNELSZ, self.KERNELSZ) ) return None if cell_ids is None or spin!=0: self.barrier() if self.use_2D: tmp = self.read_index( "%s/W%d_%s_%d_IDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside, spin) ) else: tmp = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ) ) tmp2 = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ) ) wr = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ) ).real wi = self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ) ).imag ws = self.slope * self.read_index( "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside, spin, ) ) ''' if cell_ids is not None: idx_map=-np.ones([12*nside**2],dtype='int32') lcell_ids=cell_ids try: idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32') except: lcell_ids=self.to_numpy(cell_ids) idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32') lidx=np.where(idx_map[tmp[:,1]%(12*nside**2)]!=-1)[0] orientation=tmp[lidx,1]//(12*nside**2) orientation2=tmp[lidx,0]//(12*nside**2) tmp=tmp[lidx] wr=wr[lidx] wi=wi[lidx] tmp=idx_map[tmp%(12*nside**2)] lidx=np.where(tmp[:,0]==-1)[0] wr[lidx]=0.0 wi[lidx]=0.0 tmp[lidx,0]=0 tmp[:,1]+=orientation*lcell_ids.shape[0] tmp[:,0]+=orientation2*lcell_ids.shape[0] idx_map=-np.ones([12*nside**2],dtype='int32') idx_map[lcell_ids]=np.arange(cell_ids.shape[0],dtype='int32') lidx=np.where(idx_map[tmp2[:,1]%(12*nside**2)]!=-1)[0] i_id=tmp2[lidx,1]//(12*nside**2) i_id2=tmp2[lidx,0]//(12*nside**2) tmp2=tmp2[lidx] ws=ws[lidx] tmp2=idx_map[tmp2%(12*nside**2)] lidx=np.where(tmp2[:,0]==-1)[0] ws[lidx]=0.0 tmp2[lidx,0]=0 tmp2[:,1]+=i_id*lcell_ids.shape[0] tmp2[:,0]+=i_id2*lcell_ids.shape[0] #add normalisation ww=np.bincount(tmp2[:,1],weights=ws) ws/=ww[tmp2[:,1]] wh=np.bincount(tmp[:,1]) ww=np.bincount(tmp[:,1],weights=wr) wr-=(ww/wh)[tmp[:,1]] ww=np.bincount(tmp[:,1],weights=wi) wi-=(ww/wh)[tmp[:,1]] ww=np.bincount(tmp[:,1],weights=np.sqrt(wr*wr+wi*wi)) wr/=ww[tmp[:,1]] wi/=ww[tmp[:,1]] ''' else: tmp = indice tmp2 = indice2 wr = wav.real wi = wav.imag ws = self.slope * wwav if spin==0: wr = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp), self.backend.bk_constant(self.backend.bk_cast(wr)), dense_shape=[ncell, self.NORIENT * ncell], ) wi = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp), self.backend.bk_constant(self.backend.bk_cast(wi)), dense_shape=[ncell, self.NORIENT * ncell], ) ws = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp2), self.backend.bk_constant(self.backend.bk_cast(ws)), dense_shape=[ncell, ncell], ) else: wr = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp), self.backend.bk_constant(self.backend.bk_cast(wr)), dense_shape=[2*ncell, 2*self.NORIENT * ncell], ) wi = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp), self.backend.bk_constant(self.backend.bk_cast(wi)), dense_shape=[2*ncell, 2*self.NORIENT * ncell], ) ws = self.backend.bk_SparseTensor( self.backend.bk_constant(tmp2), self.backend.bk_constant(self.backend.bk_cast(ws)), dense_shape=[2*ncell, 2*ncell], ) if kernel == -1: self.Idx_Neighbours[nside] = tmp if self.use_2D: if kernel != -1: return tmp return wr, wi, ws,tmp
# ---------------------------------------------−---------
[docs] def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None): if kernel == -1: l_kernel = self.KERNELSZ else: l_kernel = kernel if cell_ids is not None: ncell = cell_ids.shape[0] else: ncell = 12 * nside * nside try: if cell_ids is not None: tmp = self.read_index( "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, NORIENT, nside, # if cell_ids computes the index ) ) else: tmp = self.read_index( "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel**2, NORIENT, nside, # if cell_ids computes the index ) ) except: pw = 8.0 pw2 = 1.0 threshold = 1e-3 if l_kernel == 5: pw = 8.0 pw2 = 0.5 threshold = 2e-4 elif l_kernel == 3: pw = 8.0 pw2 = 1.0 threshold = 1e-3 elif l_kernel == 7: pw = 8.0 pw2 = 0.25 threshold = 4e-5 n_weights = self.KERNELSZ*(self.KERNELSZ//2+1) if cell_ids is not None: if not isinstance(cell_ids, np.ndarray): cell_ids = self.backend.to_numpy(cell_ids) th, ph = hp.pix2ang(nside, cell_ids, nest=True) x, y, z = hp.pix2vec(nside, cell_ids, nest=True) t, p = hp.pix2ang(nside, cell_ids, nest=True) phi = [p[k] / np.pi * 180 for k in range(ncell)] thi = [t[k] / np.pi * 180 for k in range(ncell)] indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int") wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float") else: th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True) x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True) t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True) phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)] thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)] indice = np.zeros( [n_weights, NORIENT, 12 * nside * nside,4], dtype="int" ) wav = np.zeros( [n_weights, NORIENT, 12 * nside * nside,4], dtype="float" ) iv = 0 iv2 = 0 for iii in range(ncell): if cell_ids is None: if iii % (nside * nside) == nside * nside - 1: if not self.silent: print( "Pre-compute nside=%6d %.2f%%" % (nside, 100 * iii / (12 * nside * nside)) ) if cell_ids is not None: hidx = np.where( (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2 < (2 * np.pi / nside) ** 2 )[0] else: hidx = hp.query_disc( nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True, ) R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ") t2, p2 = R(th[hidx], ph[hidx]) vec2 = hp.ang2vec(t2, p2) x2 = vec2[:, 0] y2 = vec2[:, 1] z2 = vec2[:, 2] for l_rotation in range(NORIENT): angle = ( l_rotation / 4.0 * np.pi - phi[iii] / 180.0 * np.pi * (z[hidx] > 0) - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0) ) axes = y2 * np.cos(angle) - x2 * np.sin(angle) axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle) for k_weights in range(self.KERNELSZ//2+1): for l_weights in range(self.KERNELSZ): val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \ np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2)) idx = np.argsort(-val) idx = idx[0:4] nval = len(idx) val=val[idx] r = abs(val).sum() if r > 0: val = val / r indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx] wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val if not self.silent: print("Kernel Size ", iv / (NORIENT * 12 * nside * nside)) if cell_ids is None: if not self.silent: print( "Write FOSCAT_%s_W%d_%d_%d_PIDX.fst" % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside) ) self.save_index( "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside, ), indice, ) self.save_index( "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside, ), wav, ) if cell_ids is None: self.barrier() if self.use_2D: tmp = self.read_index( "%s/W%d_%s_%d_IDX.fst" % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside) ) else: tmp = self.read_index( "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside, ) ) wav = self.read_index( "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst" % ( self.TEMPLATE_PATH, TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside, ) ) else: tmp = indice self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav) return wav, tmp
# ---------------------------------------------−--------- # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
[docs] def swapaxes(self, x, axis1, axis2): shape = list(x.shape) if axis1 < 0: laxis1 = len(shape) + axis1 else: laxis1 = axis1 if axis2 < 0: laxis2 = len(shape) + axis2 else: laxis2 = axis2 naxes = len(shape) thelist = [i for i in range(naxes)] thelist[laxis1] = laxis2 thelist[laxis2] = laxis1 return self.backend.bk_transpose(x, thelist)
# ---------------------------------------------−--------- # Mean using mask x [n_b,....,Npix], mask[Nmask,Npix] to [n_b,Nmask,....] # if use_2D # Mean using mask x [n_b,....,N_1,N_2], mask[Nmask,N_1,N_2] to [n_b,Nmask,....]
[docs] def masked_mean(self, x, mask, rank=0, calc_var=False): # ========================================================================== # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] # in input mask=[Nmask,X[,Y]] # if self.use_2D : X[,Y]] = [X,Y] # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT # ========================================================================== shape = list(x.shape) if not self.use_2D and not self.use_1D: nside = int(np.sqrt(x.shape[-1] // 12)) l_mask = mask if self.mask_norm: sum_mask = self.backend.bk_reduce_sum( self.backend.bk_reshape( l_mask, [l_mask.shape[0], np.prod(np.array(l_mask.shape[1:]))] ), 1, ) if not self.use_2D: l_mask = ( 12 * nside * nside * l_mask / self.backend.bk_reshape( sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]] ) ) elif self.use_2D: l_mask = ( mask.shape[1] * mask.shape[2] * l_mask / self.backend.bk_reshape( sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]] ) ) else: l_mask = ( mask.shape[1] * l_mask / self.backend.bk_reshape( sum_mask, [l_mask.shape[0]] + [1 for i in l_mask.shape[1:]] ) ) if self.use_2D: if self.padding == "VALID": l_mask = l_mask[ :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, ] if shape[axis] != l_mask.shape[1]: l_mask = l_mask[ :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, ] ichannel = 1 for i in range(1, len(shape) - 2): ichannel *= shape[i] l_x = self.backend.bk_reshape( x, [shape[0], 1, ichannel, shape[-2], shape[-1]] ) if self.padding == "VALID": oshape = [k for k in shape] oshape[axis] = oshape[axis] - self.KERNELSZ + 1 oshape[axis + 1] = oshape[axis + 1] - self.KERNELSZ + 1 l_x = self.backend.bk_reshape( l_x[ :, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :, ], oshape, ) elif self.use_1D: if self.padding == "VALID": l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] if shape[axis] != l_mask.shape[1]: l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] ichannel = 1 for i in range(1, len(shape) - 1): ichannel *= shape[i] l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel,shape[-1]]) if self.padding == "VALID": oshape = [k for k in shape] oshape[axis] = oshape[axis] - self.KERNELSZ + 1 l_x = self.backend.bk_reshape( l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape ) else: ichannel = 1 if len(shape)>1: ichannel = shape[0] ochannel = 1 for i in range(1,len(shape)-1): ochannel *= shape[i] l_x = self.backend.bk_reshape(x, [ichannel,1,ochannel,shape[-1]]) # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,...,1,NORIENT[,NORIENT],X[,Y]] # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]] if self.use_2D: l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-3) else: l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-2) if l_x.dtype == self.all_cbk_type: l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask)) if self.use_2D: # if self.padding == "VALID": mtmp = l_mask vtmp = l_x # else: # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:] # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:] if self.use_median: res,res2 = self.backend.bk_masked_median_2d_weiszfeld(vtmp, mtmp) else: v1 = self.backend.bk_reduce_sum( self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1 ) v2 = self.backend.bk_reduce_sum( self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1 ) vh = self.backend.bk_reduce_sum( self.backend.bk_reduce_sum(mtmp, axis=-1), -1 ) res = v1 / vh res2= v2 / vh oshape = [x.shape[0]] + [mask.shape[0]] if len(x.shape) > 3: oshape = oshape + list(x.shape[1:-2]) else: oshape = oshape + [1] if calc_var: if self.use_median: vh = self.backend.bk_reduce_sum( self.backend.bk_reduce_sum(mtmp, axis=-1), -1 ) if self.backend.bk_is_complex(vtmp): res2 = self.backend.bk_sqrt( ( ( self.backend.bk_real(res2) - self.backend.bk_real(res) * self.backend.bk_real(res) ) + ( self.backend.bk_imag(res2) - self.backend.bk_imag(res) * self.backend.bk_imag(res) ) ) / self.backend.bk_real(vh) ) else: res2 = self.backend.bk_sqrt((res2 - res * res) / (vh)) res = self.backend.bk_reshape(res, oshape) res2 = self.backend.bk_reshape(res2, oshape) return res, res2 else: res = self.backend.bk_reshape(res, oshape) return res elif self.use_1D: mtmp = l_mask vtmp = l_x if self.use_median: res,res2 = self.backend.bk_masked_median(l_x, l_mask) else: v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1) v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1) vh = self.backend.bk_reduce_sum(l_mask, axis=-1) res = v1 / vh res2= v2 / vh oshape = [x.shape[0]] + [mask.shape[0]] if len(x.shape) > 1: oshape = oshape + list(x.shape[1:-1]) else: oshape = oshape + [1] if calc_var: if self.use_median: vh = self.backend.bk_reduce_sum(l_mask, axis=-1) if self.backend.bk_is_complex(vtmp): res2 = self.backend.bk_sqrt( ( ( self.backend.bk_real(res2) - self.backend.bk_real(res) * self.backend.bk_real(res) ) + ( self.backend.bk_imag(res2) - self.backend.bk_imag(res) * self.backend.bk_imag(res) ) ) / self.backend.bk_real(vh) ) else: res2 = self.backend.bk_sqrt((res2 - res * res) / (vh)) res = self.backend.bk_reshape(res, oshape) res2 = self.backend.bk_reshape(res2, oshape) return res, res2 else: res = self.backend.bk_reshape(res, oshape) return res else: if self.use_median: res,res2 = self.backend.bk_masked_median(l_x, l_mask) else: v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1) v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1) vh = self.backend.bk_reduce_sum(l_mask, axis=-1) res = v1 / vh res2= v2 / vh oshape = [] if len(shape) > 1: oshape = [x.shape[0]] else: oshape = [1] oshape = oshape + [mask.shape[0]] if len(shape) > 2: oshape = oshape + shape[1:-1] else: oshape = oshape + [1] if calc_var: if self.use_median: vh = self.backend.bk_reduce_sum(l_mask, axis=-1) if self.backend.bk_is_complex(l_x): res2 = self.backend.bk_sqrt( ( self.backend.bk_real(res2) - self.backend.bk_real(res) * self.backend.bk_real(res) + self.backend.bk_imag(res2) - self.backend.bk_imag(res) * self.backend.bk_imag(res) ) / self.backend.bk_real(vh) ) else: res2 = self.backend.bk_sqrt((res2 - res * res) / (vh)) res = self.backend.bk_reshape(res, oshape) res2 = self.backend.bk_reshape(res2, oshape) return res, res2 else: res = self.backend.bk_reshape(res, oshape) return res
# ---------------------------------------------−--------- # convert tensor x [....,a,b,....] to [....,a*b,....]
[docs] def reduce_dim(self, x, axis=0): shape = list(x.shape) if axis < 0: laxis = len(shape) + axis else: laxis = axis if laxis > 0: oshape = shape[0:laxis] oshape.append(shape[laxis] * shape[laxis + 1]) else: oshape = [shape[laxis] * shape[laxis + 1]] if laxis < len(shape) - 1: oshape.extend(shape[laxis + 2 :]) return self.backend.bk_reshape(x, oshape)
# ---------------------------------------------−---------
[docs] def conv2d(self, image, ww, axis=0): if len(ww.shape) == 2: norient = ww.shape[1] else: norient = ww.shape[2] shape = image.shape if axis > 0: o_shape = shape[0] for k in range(1, axis + 1): o_shape = o_shape * shape[k] else: o_shape = image.shape[0] if len(shape) > axis + 3: ishape = shape[axis + 3] for k in range(axis + 4, len(shape)): ishape = ishape * shape[k] oshape = [o_shape, shape[axis + 1], shape[axis + 2], ishape] # l_image=self.swapaxes(self.bk_reshape(image,oshape),-1,-3) l_image = self.backend.bk_reshape(image, oshape) l_ww = np.zeros([self.KERNELSZ, self.KERNELSZ, ishape, ishape * norient]) for k in range(ishape): l_ww[:, :, k, k * norient : (k + 1) * norient] = ww.reshape( self.KERNELSZ, self.KERNELSZ, norient ) if self.backend.bk_is_complex(l_image): r = self.backend.conv2d( self.backend.bk_real(l_image), l_ww, strides=[1, 1, 1, 1], padding=self.padding, ) i = self.backend.conv2d( self.backend.bk_imag(l_image), l_ww, strides=[1, 1, 1, 1], padding=self.padding, ) res = self.backend.bk_complex(r, i) else: res = self.backend.conv2d( l_image, l_ww, strides=[1, 1, 1, 1], padding=self.padding ) res = self.backend.bk_reshape( res, [o_shape, shape[axis + 1], shape[axis + 2], ishape, norient] ) else: oshape = [o_shape, shape[axis + 1], shape[axis + 2], 1] l_ww = self.backend.bk_reshape( ww, [self.KERNELSZ, self.KERNELSZ, 1, norient] ) tmp = self.backend.bk_reshape(image, oshape) if self.backend.bk_is_complex(tmp): r = self.backend.conv2d( self.backend.bk_real(tmp), l_ww, strides=[1, 1, 1, 1], padding=self.padding, ) i = self.backend.conv2d( self.backend.bk_imag(tmp), l_ww, strides=[1, 1, 1, 1], padding=self.padding, ) res = self.backend.bk_complex(r, i) else: res = self.backend.conv2d( tmp, l_ww, strides=[1, 1, 1, 1], padding=self.padding ) return self.backend.bk_reshape(res, shape + [norient])
[docs] def diff_data(self, x, y, is_complex=True, sigma=None): if sigma is None: if self.backend.bk_is_complex(x): r = self.backend.bk_square( self.backend.bk_real(x) - self.backend.bk_real(y) ) i = self.backend.bk_square( self.backend.bk_imag(x) - self.backend.bk_imag(y) ) return self.backend.bk_reduce_sum(r + i) else: r = self.backend.bk_square(x - y) return self.backend.bk_reduce_sum(r) else: if self.backend.bk_is_complex(x): r = self.backend.bk_square( (self.backend.bk_real(x) - self.backend.bk_real(y)) / sigma ) i = self.backend.bk_square( (self.backend.bk_imag(x) - self.backend.bk_imag(y)) / sigma ) return self.backend.bk_reduce_sum(r + i) else: r = self.backend.bk_square((x - y) / sigma) return self.backend.bk_reduce_sum(r)
# ---------------------------------------------−---------
[docs] def convol(self, in_image, axis=0, cell_ids=None, nside=None, spin=0): image = self.backend.bk_cast(in_image) if self.use_2D: ishape = list(in_image.shape) if len(ishape) < axis + 2: if not self.silent: print("Use of 2D scat with data that has less than 2D") return None npix = ishape[-2] npiy = ishape[-1] ndata = 1 for k in range(len(ishape) - 2): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(in_image), [ndata, npix, npiy] ) if self.backend.bk_is_complex(tim): rr1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_RealT[1]) ii1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_ImagT[1]) rr2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_RealT[1]) ii2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_ImagT[1]) res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2) else: rr = self.backend.conv2d(tim, self.ww_RealT[1]) ii = self.backend.conv2d(tim, self.ww_ImagT[1]) res = self.backend.bk_complex(rr, ii) return self.backend.bk_reshape( res, ishape[0:-2] + [self.NORIENT, npix, npiy] ) elif self.use_1D: ishape = list(in_image.shape) npix = ishape[-1] ndata = 1 for k in range(len(ishape) - 1): ndata = ndata * ishape[k] tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix]) if self.backend.bk_is_complex(tim): rr1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_RealT[1]) ii1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_ImagT[1]) rr2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_RealT[1]) ii2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_ImagT[1]) res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2) else: rr = self.backend.conv1d(tim, self.ww_RealT[1]) ii = self.backend.conv1d(tim, self.ww_ImagT[1]) res = self.backend.bk_complex(rr, ii) return self.backend.bk_reshape(res, ishape) else: ishape = list(image.shape) if nside is None: nside = int(np.sqrt(image.shape[-1] // 12)) if (spin,nside) not in self.Idx_Neighbours: if self.InitWave is None: wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin) else: wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin) self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp) self.ww_Real[(spin,nside)] = wr self.ww_Imag[(spin,nside)] = wi self.w_smooth[(spin,nside)] = ws l_ww_real = self.ww_Real[(spin,nside)] l_ww_imag = self.ww_Imag[(spin,nside)] # always convolve the last dimension ndata = 1 if len(ishape) > 1: for k in range(len(ishape) - 1): ndata = ndata * ishape[k] if spin>0: tim = self.backend.bk_reshape( self.backend.bk_cast(image), [ndata//2,2*ishape[-1]] ) else: tim = self.backend.bk_reshape( self.backend.bk_cast(image), [ndata, ishape[-1]] ) if tim.dtype == self.all_cbk_type: rr1 = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), l_ww_real, ), [ndata, self.NORIENT, ishape[-1]], ) ii1 = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), l_ww_imag, ), [ndata, self.NORIENT, ishape[-1]], ) rr2 = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul( self.backend.bk_imag(tim), l_ww_real, ), [ndata, self.NORIENT, ishape[-1]], ) ii2 = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul( self.backend.bk_imag(tim), l_ww_imag, ), [ndata, self.NORIENT, ishape[-1]], ) res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2) else: rr = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul(tim, l_ww_real), [ndata, self.NORIENT, ishape[-1]], ) ii = self.backend.bk_reshape( self.backend.bk_sparse_dense_matmul(tim, l_ww_imag), [ndata, self.NORIENT, ishape[-1]], ) res = self.backend.bk_complex(rr, ii) if spin==0: if len(ishape) > 1: return self.backend.bk_reshape( res, ishape[0:-1] + [self.NORIENT, ishape[-1]] ) else: return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]]) else: if len(ishape) > 2: return self.backend.bk_reshape( res, ishape[0:-2] + [2,self.NORIENT, ishape[-1]] ) else: return self.backend.bk_reshape(res, [2,self.NORIENT, ishape[-1]]) return res
# ---------------------------------------------−---------
[docs] def smooth(self, in_image, axis=0, cell_ids=None, nside=None, spin=0): image = self.backend.bk_cast(in_image) if self.use_2D: ishape = list(in_image.shape) if len(ishape) < axis + 2: if not self.silent: print("Use of 2D scat with data that has less than 2D") return None npix = ishape[-2] npiy = ishape[-1] odata = 1 if len(ishape) > 1: for k in range(len(ishape)-2): odata = odata * ishape[k] ndata = 1 for k in range(len(ishape)-2): ndata = ndata * ishape[k] tim = self.backend.bk_reshape( self.backend.bk_cast(in_image), [ndata, npix, npiy] ) if self.backend.bk_is_complex(tim): rr = self.backend.conv2d(self.backend.bk_real(tim), self.ww_SmoothT[1]) ii = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_SmoothT[1]) res = self.backend.bk_complex(rr, ii) else: res = self.backend.conv2d(tim, self.ww_SmoothT[1]) return self.backend.bk_reshape(res, ishape) elif self.use_1D: ishape = list(in_image.shape) npix = ishape[-1] ndata = 1 for k in range(len(ishape) - 1): ndata = ndata * ishape[k] tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix]) if self.backend.bk_is_complex(tim): rr = self.backend.conv1d(self.backend.bk_real(tim), self.ww_SmoothT[1]) ii = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_SmoothT[1]) res = self.backend.bk_complex(rr, ii) else: res = self.backend.conv1d(tim, self.ww_SmoothT[1]) return self.backend.bk_reshape(res, ishape) else: ishape = list(image.shape) if nside is None: nside = int(np.sqrt(image.shape[-1] // 12)) if (spin,nside) not in self.Idx_Neighbours: if self.InitWave is None: wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin) else: wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin) self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp) self.ww_Real[(spin,nside)] = wr self.ww_Imag[(spin,nside)] = wi self.w_smooth[(spin,nside)] = ws l_w_smooth = self.w_smooth[(spin,nside)] odata = 1 for k in range(0, len(ishape) - 1): odata = odata * ishape[k] tim = self.backend.bk_reshape(image, [odata, ishape[-1]]) if spin==0: if tim.dtype == self.all_cbk_type: rr = self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), l_w_smooth ) ri = self.backend.bk_sparse_dense_matmul( self.backend.bk_imag(tim), l_w_smooth ) res = self.backend.bk_complex(rr, ri) else: res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth) else: tim=self.backend.bk_reshape(tim,[odata//2,2*tim.shape[-1]]) if tim.dtype == self.all_cbk_type: rr = self.backend.bk_sparse_dense_matmul( self.backend.bk_real(tim), l_w_smooth ) ri = self.backend.bk_sparse_dense_matmul( self.backend.bk_imag(tim), l_w_smooth ) res = self.backend.bk_complex(rr, ri) else: res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth) if len(ishape) == 1: return self.backend.bk_reshape(res, [ishape[-1]]) else: return self.backend.bk_reshape(res, ishape[0:-1] + [ishape[-1]]) return res
# ---------------------------------------------−---------
[docs] def get_kernel_size(self): return self.KERNELSZ
# ---------------------------------------------−---------
[docs] def get_nb_orient(self): return self.NORIENT
# ---------------------------------------------−---------
[docs] def get_ww(self, nside=1): if self.use_2D: return ( self.ww_RealT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT), self.ww_ImagT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT), ) else: return (self.ww_Real[nside], self.ww_Imag[nside])
# ---------------------------------------------−---------
[docs] def plot_ww(self): c, s = self.get_ww() import matplotlib.pyplot as plt plt.figure(figsize=(16, 6)) npt = int(np.sqrt(c.shape[0])) for i in range(c.shape[1]): plt.subplot(2, c.shape[1], 1 + i) plt.imshow( c[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max() ) plt.subplot(2, c.shape[1], 1 + i + c.shape[1]) plt.imshow( s[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max() ) sys.stdout.flush() plt.show()