Source code for foscat.Synthesis

import os
import sys
import time
from threading import Event, Thread

import numpy as np
import scipy.optimize as opt


[docs] class Loss: def __init__( self, function, scat_operator, *param, name="", batch=None, batch_data=None, batch_update=None, info_callback=False, ): self.loss_function = function self.scat_operator = scat_operator self.to_numpy = scat_operator.backend.to_numpy self.args = param self.name = name self.batch = batch self.batch_data = batch_data self.batch_update = batch_update self.info = info_callback self.id_loss = 0
[docs] def eval(self, x, batch, return_all=False): if self.batch is None: if self.info: return self.loss_function( x, self.scat_operator, self.args, return_all=return_all ) else: return self.loss_function(x, self.scat_operator, self.args) else: if self.info: return self.loss_function( x, batch, self.scat_operator, self.args, return_all=return_all ) else: return self.loss_function(x, batch, self.scat_operator, self.args)
[docs] def set_id_loss(self, id_loss): self.id_loss = id_loss
[docs] def get_id_loss(self, id_loss): return self.id_loss
[docs] class Synthesis: def __init__( self, loss_list, eta=0.03, beta1=0.9, beta2=0.999, epsilon=1e-7, decay_rate=0.999, ): self.loss_class = loss_list self.number_of_loss = len(loss_list) for k in range(self.number_of_loss): self.loss_class[k].set_id_loss(k) self.__iteration__ = 1234 self.nlog = 0 self.m_dw, self.v_dw = 0.0, 0.0 self.beta1 = beta1 self.beta2 = beta2 self.pbeta1 = beta1 self.pbeta2 = beta2 self.epsilon = epsilon self.eta = eta self.history = np.zeros([10]) self.curr_gpu = 0 self.event = Event() self.operation = loss_list[0].scat_operator self.to_numpy = self.operation.backend.to_numpy self.mpi_size = self.operation.mpi_size self.mpi_rank = self.operation.mpi_rank self.KEEP_TRACK = None self.MAXNUMLOSS = len(loss_list) if self.operation.BACKEND == "tensorflow": import foscat.loss_backend_tens as fbk self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank) if self.operation.BACKEND == "torch": import foscat.loss_backend_torch as fbk self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank) if self.operation.BACKEND == "numpy": print( "Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend" ) return None # ---------------------------------------------−---------
[docs] def get_gpu(self, event, delay): isnvidia = os.system("which nvidia-smi &> /dev/null") while 1: if event.is_set(): break time.sleep(delay) if isnvidia == 0: try: os.system( "nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt" ) except: print("No nvidia GPU: Impossible to trace") self.nogpu = 1
[docs] def stop_synthesis(self): # stop thread that catch GPU information self.event.set() try: self.gpu_thrd.join() except: print("No thread to stop, everything is ok") sys.stdout.flush()
# ---------------------------------------------−---------
[docs] def getgpumem(self): try: return np.loadtxt("smi_tmp.txt") except: return np.zeros([1, 3])
# ---------------------------------------------−---------
[docs] def info_back(self, x): self.nlog = self.nlog + 1 self.itt2 = 0 if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0: end = time.time() cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean()) for k in self.ltot[self.ltot != -1]: cur_loss = cur_loss + "%10.3g " % (k) cur_loss = cur_loss + ")" mess = "" if self.SHOWGPU: info_gpu = self.getgpumem() for k in range(info_gpu.shape[0]): mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % ( k, info_gpu[k, 0], info_gpu[k, 1], info_gpu[k, 2], ) print( "%sItt %6d L=%s %.3fs %s" % (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess) ) sys.stdout.flush() if self.KEEP_TRACK is not None: print(self.last_info) sys.stdout.flush() self.start = time.time() self.itt = self.itt + 1
# ---------------------------------------------−---------
[docs] def calc_grad(self, in_x): g_tot = None l_tot = 0.0 if self.do_all_noise and self.totalsz > self.batchsz: nstep = self.totalsz // self.batchsz else: nstep = 1 x = self.operation.backend.bk_reshape( self.operation.backend.bk_cast(in_x), self.oshape ) if self.idx_grd is not None: x=x[self.idx_grd] self.l_log[ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS ] = -1.0 for istep in range(nstep): for k in range(self.number_of_loss): if self.loss_class[k].batch is None: l_batch = None else: l_batch = self.loss_class[k].batch( self.loss_class[k].batch_data, istep ) if self.KEEP_TRACK is not None: l_loss, g, linfo = self.bk.loss( x, l_batch, self.loss_class[k], self.KEEP_TRACK ) self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True) else: l_loss, g = self.bk.loss( x, l_batch, self.loss_class[k], self.KEEP_TRACK ) if g_tot is None: g_tot = g else: g_tot = g_tot + g l_tot = l_tot + self.to_numpy(l_loss) if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1: self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = ( self.to_numpy(l_loss) / nstep ) else: self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = ( self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] + self.to_numpy(l_loss) / nstep ) grd_mask = self.grd_mask if grd_mask is not None: g_tot = self.operation.backend.to_numpy(g_tot * grd_mask) else: g_tot = self.operation.backend.to_numpy(g_tot) g_tot[np.isnan(g_tot)] = 0.0 if self.idx_grd is not None: lg_tot=np.zeros(in_x.shape) lg_tot[self.idx_grd]=g_tot g_tot=lg_tot self.imin = self.imin + self.batchsz if self.mpi_size == 1: self.ltot = self.l_log else: local_log = (self.l_log).astype("float64") self.ltot = np.zeros(self.l_log.shape, dtype="float64") self.comm.Allreduce( (local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE) ) if self.mpi_size == 1: grad = g_tot else: if self.operation.backend.bk_is_complex(g_tot): grad = np.zeros(self.oshape, dtype=g_tot.dtype) self.comm.Allreduce((g_tot), (grad)) else: grad = np.zeros(self.oshape, dtype="float64") self.comm.Allreduce( (g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE) ) if self.nlog == self.history.shape[0]: new_log = np.zeros([self.history.shape[0] * 2]) new_log[0 : self.nlog] = self.history self.history = new_log l_tot = self.ltot[self.ltot != -1].mean() self.history[self.nlog] = l_tot g_tot = grad.flatten() if self.operation.backend.bk_is_complex(g_tot): return l_tot.astype("float64"), g_tot return l_tot.astype("float64"), g_tot.astype("float64")
# ---------------------------------------------−---------
[docs] def xtractmap(self, x, axis): x = self.operation.backend.bk_reshape(x, self.oshape) return x
# ---------------------------------------------−---------
[docs] def run( self, in_x, NUM_EPOCHS=100, DECAY_RATE=0.95, EVAL_FREQUENCY=100, DEVAL_STAT_FREQUENCY=1000, NUM_STEP_BIAS=1, LEARNING_RATE=0.03, EPSILON=1e-7, KEEP_TRACK=None, grd_mask=None, SHOWGPU=False, MESSAGE="", factr=10.0, batchsz=1, totalsz=1, do_lbfgs=True, idx_grd=None, axis=0, ): self.KEEP_TRACK = KEEP_TRACK self.track = {} self.ntrack = 0 self.eta = LEARNING_RATE self.epsilon = EPSILON self.decay_rate = DECAY_RATE self.nlog = 0 self.itt2 = 0 self.batchsz = batchsz self.totalsz = totalsz self.grd_mask = grd_mask self.idx_grd = idx_grd self.EVAL_FREQUENCY = EVAL_FREQUENCY self.MESSAGE = MESSAGE self.SHOWGPU = SHOWGPU self.axis = axis self.in_x_nshape = in_x.shape[0] self.seed = 1234 np.random.seed(self.mpi_rank * 7 + 1234) x = in_x self.curr_gpu = self.curr_gpu + self.mpi_rank if self.mpi_size > 1: from mpi4py import MPI comm = MPI.COMM_WORLD self.comm = comm self.MPI = MPI if self.mpi_rank == 0: print("Work with MPI") sys.stdout.flush() if self.mpi_rank == 0 and SHOWGPU: # start thread that catch GPU information try: self.gpu_thrd = Thread( target=self.get_gpu, args=( self.event, 1, ), ) self.gpu_thrd.start() except: print("Error: unable to start thread for GPU survey") # start = time.time() if self.mpi_size > 1: num_loss = np.zeros([1], dtype="int32") total_num_loss = np.zeros([1], dtype="int32") num_loss[0] = self.number_of_loss comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT)) total_num_loss = total_num_loss[0] else: total_num_loss = self.number_of_loss if self.mpi_rank == 0: print("Total number of loss ", total_num_loss) sys.stdout.flush() l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32") l_log[ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS ] = -1.0 self.ltot = l_log.copy() self.l_log = l_log self.imin = 0 self.start = time.time() self.itt = 0 self.oshape = list(x.shape) if not isinstance(x, np.ndarray): x = self.to_numpy(x) x = x.flatten() self.do_all_noise = False self.do_all_noise = True self.noise_idx = None # for k in range(self.number_of_loss): # if self.loss_class[k].batch is not None: # l_batch = self.loss_class[k].batch( # self.loss_class[k].batch_data, 0, init=True # ) l_tot, g_tot = self.calc_grad(x) self.info_back(x) maxitt = NUM_EPOCHS # start_x = x.copy() for iteration in range(NUM_STEP_BIAS): x, loss, i = opt.fmin_l_bfgs_b( self.calc_grad, x.astype("float64"), callback=self.info_back, pgtol=1e-32, factr=factr, maxiter=maxitt, ) print("Final Loss ", loss) # update bias input data if iteration < NUM_STEP_BIAS - 1: # if self.mpi_rank==0: # print('%s Hessian restart'%(self.MESSAGE)) omap = self.xtractmap(x, axis) for k in range(self.number_of_loss): if self.loss_class[k].batch_update is not None: self.loss_class[k].batch_update( self.loss_class[k].batch_data, omap ) # if self.loss_class[k].batch is not None: # l_batch = self.loss_class[k].batch( # self.loss_class[k].batch_data, 0, init=True # ) # x=start_x.copy() if self.mpi_rank == 0 and SHOWGPU: self.stop_synthesis() if self.KEEP_TRACK is not None: self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False) x = self.xtractmap(x, axis) return x
[docs] def get_history(self): return self.history[0 : self.nlog]