From a89794ae3ba5d6309b9816c8bb0778b04f618d12 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Thu, 12 Mar 2020 14:04:53 -0500 Subject: [PATCH 01/13] also write csv --- scan_precision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scan_precision.py b/scan_precision.py index a97f6e0..f9d6220 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -24,6 +24,7 @@ def plotScan(x,outs,name,odir,xtitle="n bits"): for metric in ['ssd','corr','emd']: plotHist(x, outs[metric], outs[metric+'_err'], name+"_"+metric, odir,xtitle=xtitle,ytitle=metric) + outs.to_csv(odir+"/"+name+".csv") return def BitScan(options, args): From ce9a5901833ec1e27fe3aa1bbd5c4ee1ee29986f Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Thu, 12 Mar 2020 14:05:10 -0500 Subject: [PATCH 02/13] do things the right way --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 7465810..3aae8c7 100644 --- a/train.py +++ b/train.py @@ -106,8 +106,9 @@ def predict(x,autoencoder,encoder,reshape=True): ### cross correlation of input/output def cross_corr(x,y): cov = np.cov(x.flatten(),y.flatten()) - std = np.sqrt(np.diag(cov)+1e-10) - corr = cov / np.multiply.outer(std, std) + std = np.sqrt(np.diag(cov)) + stdsqr = np.multiply.outer(std, std) + corr = np.divide(cov, stdsqr, out=np.zeros_like(cov), where=(stdsqr!=0)) return corr[0,1] def ssd(x,y): From 7bc842391d9af4539c09d58663b977fb1cfae6ad Mon Sep 17 00:00:00 2001 From: Ben Hawks Date: Thu, 12 Mar 2020 19:06:04 -0500 Subject: [PATCH 03/13] Minor updates for more accurate model name generation, now closes plots after saved to stop matplotlib from complaining there are to many plots open (and other wonkyness), added code to scan precision of other values --- scan_precision.py | 50 +++++++++++++++++++++++++++++++++++++++++------ train.py | 31 ++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/scan_precision.py b/scan_precision.py index a97f6e0..9a5d425 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -17,6 +17,7 @@ def plotHist(x,y,ye, name, odir,xtitle, ytitle): plt.xlabel(xtitle) plt.legend(['others 16,6'], loc='upper right') plt.savefig(odir+"/"+name+".png") + plt.close() return def plotScan(x,outs,name,odir,xtitle="n bits"): @@ -27,7 +28,9 @@ def plotScan(x,outs,name,odir,xtitle="n bits"): return def BitScan(options, args): - + og_odir = options.odir + ''' + options.odir = og_odir+"/input" # test inputs bits = [i+3 for i in range(6)] bits = [i+3 for i in range(2)] @@ -36,17 +39,52 @@ def BitScan(options, args): plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits") exit(0) - + ''' + ''' # test weights + options.odir = og_odir+"/weight" bits = [i+1 for i in range(8)] updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits] outputs = [trainCNN(options,args,u) for u in updates] - plotScan(bits,outputs,"test_weight_bits",xtitle="total input bits") + plotScan([2*b+1 for b in bits],outputs,"test_weight_bits",og_odir,xtitle="total weight bits") + ''' - emd, emde = zip(*[trainCNN(options,args,u) for u in updates]) - plotScan(bits,emd,emde,"test_weight_bits") + #exit(0) - exit(0) + # test accumulator + options.odir = og_odir+"/accum" + bits = [i+1 for i in range(8)] + updates = [{'nBits_accum': {'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan([2*b+1 for b in bits],outputs,"test_accum_bits",og_odir,xtitle="total accumulator layer bits") + + #exit(0) + + # test encoder layer bits + options.odir = og_odir+"/encod" + bits = [i+1 for i in range(8)] + updates = [{'nBits_encod': {'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan([2*b+1 for b in bits],outputs,"test_encoded_bits",og_odir,xtitle="total encoder layer bits") + + #exit(0) + + #test dense alone (conv constant) + options.odir = og_odir+"/dense" + bits = [i+1 for i in range(8)] + updates = [{'nBits_dense': {'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan([2*b+1 for b in bits],outputs,"test_dense_bits",og_odir,xtitle="total dense layer bits") + + + #exit(0) + + # test conv alone (dense constant) + options.odir = og_odir+"/conv" + bits = [i+1 for i in range(8)] + updates = [{'nBits_conv': {'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan([2*b+1 for b in bits],outputs,"test_conv_bits",og_odir,xtitle="total convolutional layer bits") diff --git a/train.py b/train.py index 7465810..028d7c7 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,7 @@ def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): plt.legend(['Train', 'Test'], loc='upper right') plt.savefig("history_%s.png"%name) #plt.show() - + plt.close() save_models(autoencoder,name) return history @@ -184,6 +184,7 @@ def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): plt.tight_layout() plt.savefig("%s_examples.png"%name) + plt.close(fig) def visMetric(input_Q,decoded_Q,maxQ,name, skipPlot=False): def plothist(y,xlabel,name): @@ -198,6 +199,7 @@ def plothist(y,xlabel,name): plt.ylabel('Entry') plt.title('%s on validation set'%xlabel) plt.savefig("hist_%s.png"%name) + plt.close() cross_corr_arr = np.array([cross_corr(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))] ) ssd_arr = np.array([ssd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) @@ -215,6 +217,7 @@ def plothist(y,xlabel,name): plt.legend(loc='upper right') plt.xlabel('Charge fraction') plt.savefig("hist_Qfr_%s.png"%name) + plt.close() input_Q_abs = np.array([input_Q[i] * maxQ[i] for i in range(0,len(input_Q))]) decoded_Q_abs = np.array([decoded_Q[i]*maxQ[i] for i in range(0,len(decoded_Q))]) @@ -225,6 +228,7 @@ def plothist(y,xlabel,name): plt.legend(loc='upper right') plt.xlabel('Charge') plt.savefig("hist_Qabs_%s.png"%name) + plt.close() nonzeroQs = np.count_nonzero(input_Q_abs.reshape(len(input_Q_abs),48),axis=1) occbins = [0,5,10,20,48] @@ -241,14 +245,26 @@ def plothist(y,xlabel,name): plt.tight_layout() #plt.show() plt.savefig('corr_vs_occ_%s.png'%name) + plt.close() return cross_corr_arr,ssd_arr,emd_arr - -def GetBitsString(In, Accum, Weight): + +def GetBitsString(In, Accum, Weight, Encoded, Dense=False, Conv=False): s="" s += "Input{}b{}i".format(In['total'],In['integer']) s += "_Accum{}b{}i".format(Accum['total'],Accum['integer']) - s += "_Weight{}b{}i".format(Weight['total'],Weight['integer']) + if Dense: + s += "_Dense{}b{}i".format(Dense['total'], Dense['integer']) + if Conv: + s += "_Conv{}b{}i".format(Conv['total'], Conv['integer']) + else: + s += "_Conv{}b{}i".format(Weight['total'], Weight['integer']) + elif Conv: + s += "_Dense{}b{}i".format(Weight['total'], Weight['integer']) + s += "_Conv{}b{}i".format(Conv['total'], Conv['integer']) + else: + s += "_Weight{}b{}i".format(Weight['total'],Weight['integer']) + s += "_Encod{}b{}i".format(Encoded['total'], Encoded['integer']) return s def trainCNN(options, args, pam_updates=None): @@ -438,9 +454,10 @@ def trainCNN(options, args, pam_updates=None): for model in models: model_name = model['name'] if options.quantize: - bit_str = GetBitsString(m['pams']['nBits_input'], - m['pams']['nBits_accum'], - m['pams']['nBits_weight']) + bit_str = GetBitsString(model['pams']['nBits_input'], model['pams']['nBits_accum'], + model['pams']['nBits_weight'], model['pams']['nBits_encod'], + (model['pams']['nBits_dense'] if 'nBits_dense' in model['pams'] else False), + (model['pams']['nBits_conv'] if 'nBits_conv' in model['pams'] else False)) model_name += "_" + bit_str if not os.path.exists(model_name): os.mkdir(model_name) os.chdir(model_name) From 01ceda84fa383a0447cd535d43d1cd63f226da5c Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Fri, 13 Mar 2020 17:13:51 -0500 Subject: [PATCH 04/13] partially updated parameter updates and metrics --- scan_precision.py | 35 +++++++++++---------- train.py | 78 +++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/scan_precision.py b/scan_precision.py index f9d6220..b098429 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -29,23 +29,26 @@ def plotScan(x,outs,name,odir,xtitle="n bits"): def BitScan(options, args): - # test inputs - bits = [i+3 for i in range(6)] - bits = [i+3 for i in range(2)] - updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits] - outputs = [trainCNN(options,args,u) for u in updates] - plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits") - - exit(0) - - # test weights - bits = [i+1 for i in range(8)] - updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits] - outputs = [trainCNN(options,args,u) for u in updates] - plotScan(bits,outputs,"test_weight_bits",xtitle="total input bits") + if False: + # test inputs + bits = [i+3 for i in range(6)] + updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits") + + if False: + # test weights + bits = [i+1 for i in range(8)] + updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_weight_bits",options.odir,xtitle="total weight bits") - emd, emde = zip(*[trainCNN(options,args,u) for u in updates]) - plotScan(bits,emd,emde,"test_weight_bits") + if True: + # test encoded bits + bits = [4,6,8,10,12,16] + updates = [{'nBits_encod':{'total': b, 'integer': b/2},'encoded_dim':int(64/b)} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_encod_bits",options.odir,xtitle="bits per encoded node") exit(0) diff --git a/train.py b/train.py index 3aae8c7..8399b33 100644 --- a/train.py +++ b/train.py @@ -47,7 +47,7 @@ def split(shaped_data, validation_frac=0.2): print('training shape',train_input.shape) print('validation shape',val_input.shape) - return val_input,train_input + return val_input,train_input,val_indices def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): @@ -67,7 +67,7 @@ def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): plt.xlabel('Epoch') plt.legend(['Train', 'Test'], loc='upper right') plt.savefig("history_%s.png"%name) - #plt.show() + plt.close() save_models(autoencoder,name) @@ -147,6 +147,38 @@ def emd(_x, _y, threshold=-1): return ot.emd2(x, y, hexMetric) +def d_weighted_mean(x, y): + x = 1./x.sum()*x.flatten() + y = 1./y.sum()*y.flatten() + dx = hexCoords[:,0].dot(x-y) + dy = hexCoords[:,1].dot(x-y) + return np.sqrt(dx*dx+dy*dy) + +def make_supercells(_x): + shape = _x.shape + x = _x.copy().flatten() + mask = np.array([ + [ 0, 1, 4, 5], #indices for 1 supercell + [ 2, 3, 6, 7], + [ 8, 9, 12, 13], + [10, 11, 14, 15], + [16, 17, 20, 21], + [18, 19, 22, 23], + [24, 25, 28, 29], + [26, 27, 30, 31], + [24, 25, 28, 29], + [26, 27, 30, 31], + [32, 33, 36, 37], + [34, 35, 38, 39]]) + for sc in mask: + # set max cell to sum + ii = np.argmax( x[sc] ) + mysum = np.sum( x[sc] ) + x[sc]=0 + x[ii]=mysum + return x + + def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): if index.size==0: Nevents=8 @@ -185,6 +217,7 @@ def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): plt.tight_layout() plt.savefig("%s_examples.png"%name) + plt.close() def visMetric(input_Q,decoded_Q,maxQ,name, skipPlot=False): def plothist(y,xlabel,name): @@ -199,7 +232,9 @@ def plothist(y,xlabel,name): plt.ylabel('Entry') plt.title('%s on validation set'%xlabel) plt.savefig("hist_%s.png"%name) - + plt.close() + + metrics = [cross_corr,ssd,emd,d_weighted_mean] cross_corr_arr = np.array([cross_corr(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))] ) ssd_arr = np.array([ssd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) emd_arr = np.array([emd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) @@ -216,7 +251,8 @@ def plothist(y,xlabel,name): plt.legend(loc='upper right') plt.xlabel('Charge fraction') plt.savefig("hist_Qfr_%s.png"%name) - + plt.close() + input_Q_abs = np.array([input_Q[i] * maxQ[i] for i in range(0,len(input_Q))]) decoded_Q_abs = np.array([decoded_Q[i]*maxQ[i] for i in range(0,len(decoded_Q))]) @@ -226,6 +262,7 @@ def plothist(y,xlabel,name): plt.legend(loc='upper right') plt.xlabel('Charge') plt.savefig("hist_Qabs_%s.png"%name) + plt.close() nonzeroQs = np.count_nonzero(input_Q_abs.reshape(len(input_Q_abs),48),axis=1) occbins = [0,5,10,20,48] @@ -242,7 +279,8 @@ def plothist(y,xlabel,name): plt.tight_layout() #plt.show() plt.savefig('corr_vs_occ_%s.png'%name) - + plt.close() + return cross_corr_arr,ssd_arr,emd_arr def GetBitsString(In, Accum, Weight): @@ -452,9 +490,12 @@ def trainCNN(options, args, pam_updates=None): m = denseCNN(weights_f=model['ws']) m.setpams(model['pams']) m.init() - shaped_data = m.prepInput(normdata) - val_input, train_input = split(shaped_data) - m_autoCNN , m_autoCNNen = m.get_models() + shaped_data = m.prepInput(normdata) + val_input, train_input, val_ind = split(shaped_data) + m_autoCNN , m_autoCNNen = m.get_models() + maxdata.reshape(shaped_data.shape) + val_max = maxdata[val_ind] + if model['ws']=='': history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: @@ -472,7 +513,26 @@ def trainCNN(options, args, pam_updates=None): np.savetxt("verify_decoded.csv",cnn_deQ[0:N_csv].reshape(N_csv,48), delimiter=",",fmt='%.12f') index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - corr_arr, ssd_arr, emd_arr = visMetric(input_Q,cnn_deQ,maxdata,name=model_name, skipPlot=options.skipPlot) + + # metrics to compute on the validation dataset + metrics = {'xcorr':cross_corr, + 'ssd':ssd, + 'emd':emd, + 'wgtd_mean':d_weighted_mean, + } + #super-cell variants + sc_metrics = {'sc_'+n : (lambda x,y : f[n](make_supercells(x),make_supercells(y))) for n in metrics} + # threshold variants + for pct in [47,69]: + thr_metrics = {'thr{}_'.format(pct)+n : (lambda x,y : f[n](threshold(x,val_max,pct),threshold(y,val_max,pct))) for n in metrics} + all_metrics = metrics.update(sc_metrics) + metric_arrays={} + for name in all_metrics: + func = all_metrics[name] + metric_arrays[name] = np.array([func(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))]) + + + corr_arr, ssd_arr, emd_arr = visMetric(input_Q,cnn_deQ,val_max,name=model_name, skipPlot=options.skipPlot) if not options.skipPlot: hi_corr_index = (np.where(corr_arr>0.9))[0] From 678162e36c595e52c66ad6fe0089a81c8eb63e36 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Fri, 13 Mar 2020 19:54:24 -0500 Subject: [PATCH 05/13] changing around metrics again --- train.py | 63 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/train.py b/train.py index 8399b33..2ce03aa 100644 --- a/train.py +++ b/train.py @@ -36,18 +36,18 @@ def split(shaped_data, validation_frac=0.2): N = round(len(shaped_data)*validation_frac) #randomly select 25% entries - index = np.random.choice(shaped_data.shape[0], N, replace=False) + val_index = np.random.choice(shaped_data.shape[0], N, replace=False) #select the indices of the other 75% full_index = np.array(range(0,len(shaped_data))) - train_index = np.logical_not(np.in1d(full_index,index)) + train_index = np.logical_not(np.in1d(full_index,val_index)) - val_input = shaped_data[index] + val_input = shaped_data[val_index] train_input = shaped_data[train_index] print('training shape',train_input.shape) print('validation shape',val_input.shape) - return val_input,train_input,val_indices + return val_input,train_input,val_index def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): @@ -135,8 +135,8 @@ def ssd(x,y): def emd(_x, _y, threshold=-1): x = np.array(_x, dtype=np.float64) y = np.array(_y, dtype=np.float64) - x = 1./x.sum()*x.flatten() - y = 1./y.sum()*y.flatten() + x = (1./x.sum() if x.sum() else 1.)*x.flatten() + y = (1./y.sum() if y.sum() else 1.)*y.flatten() if threshold > 0: # only keep entries above 2%, e.g. @@ -148,8 +148,8 @@ def emd(_x, _y, threshold=-1): return ot.emd2(x, y, hexMetric) def d_weighted_mean(x, y): - x = 1./x.sum()*x.flatten() - y = 1./y.sum()*y.flatten() + x = (1./x.sum() if x.sum() else 1.)*x.flatten() + y = (1./y.sum() if y.sum() else 1.)*y.flatten() dx = hexCoords[:,0].dot(x-y) dy = hexCoords[:,1].dot(x-y) return np.sqrt(dx*dx+dy*dy) @@ -178,6 +178,12 @@ def make_supercells(_x): x[ii]=mysum return x +def threshold(_x, norm, cut): + x = _x.copy() + # reshape to allow broadcasting to all cells + norm_shape = norm.reshape((norm.shape[0],)+(1,)*(x.ndim-1)) + x = np.where(x*norm_shape>=cut,x,0) + return x def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): if index.size==0: @@ -493,7 +499,6 @@ def trainCNN(options, args, pam_updates=None): shaped_data = m.prepInput(normdata) val_input, train_input, val_ind = split(shaped_data) m_autoCNN , m_autoCNNen = m.get_models() - maxdata.reshape(shaped_data.shape) val_max = maxdata[val_ind] if model['ws']=='': @@ -513,27 +518,35 @@ def trainCNN(options, args, pam_updates=None): np.savetxt("verify_decoded.csv",cnn_deQ[0:N_csv].reshape(N_csv,48), delimiter=",",fmt='%.12f') index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - + + stc_Q = make_supercells(input_Q) + thr_nom = threshold(input_Q,val_max,47) # 1.35 transverse MIPs + thr_hi = threshold(input_Q,val_max,69) # 2.0 transverse MIPs + # metrics to compute on the validation dataset - metrics = {'xcorr':cross_corr, - 'ssd':ssd, - 'emd':emd, - 'wgtd_mean':d_weighted_mean, + metrics = {'xcorr' :cross_corr, + 'ssd' :ssd, + 'emd' :emd, + 'diff_mean':d_weighted_mean, } - #super-cell variants - sc_metrics = {'sc_'+n : (lambda x,y : f[n](make_supercells(x),make_supercells(y))) for n in metrics} - # threshold variants - for pct in [47,69]: - thr_metrics = {'thr{}_'.format(pct)+n : (lambda x,y : f[n](threshold(x,val_max,pct),threshold(y,val_max,pct))) for n in metrics} - all_metrics = metrics.update(sc_metrics) - metric_arrays={} - for name in all_metrics: - func = all_metrics[name] - metric_arrays[name] = np.array([func(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))]) + + # # super trigger cell variants + # sc_metrics = {'sc_'+n : (lambda x,y,_=None : (metrics[n])(make_supercells(x),make_supercells(y))) for n in metrics} + # # threshold variants + # thr_metrics={} + # for pct in [47,69]: #1.35 and 2.0 transverse MIPs + # thr_metrics.update({'thr{}_'.format(pct)+n : (lambda x,y,maxVal : (metrics[n])(threshold(x,maxVal,pct),threshold(y,maxVal,pct))) for n in metrics}) + # all_metrics = metrics.copy() + # all_metrics.update(sc_metrics) + # all_metrics.update(thr_metrics) + # #metric_arrays={name : np.array([(all_metrics[name])(input_Q[i],cnn_deQ[i]) for i in range(0,len(input_Q))]) for name in all_metrics} + + metric_arrays={name : np.array([(all_metrics[name])(input_Q[i],cnn_deQ[i],val_max[i]) for i in range(0,len(input_Q))]) for name in all_metrics} + corr_arr, ssd_arr, emd_arr = visMetric(input_Q,cnn_deQ,val_max,name=model_name, skipPlot=options.skipPlot) - + if not options.skipPlot: hi_corr_index = (np.where(corr_arr>0.9))[0] low_corr_index = (np.where(corr_arr<0.2))[0] From f12645e7643cdfe207d62a6815f8e2edda93e5b0 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Sat, 14 Mar 2020 17:07:14 -0500 Subject: [PATCH 06/13] wrap up metrics and algs --- scan_precision.py | 13 +--- train.py | 185 +++++++++++++++++----------------------------- utils.py | 42 +++++++++++ 3 files changed, 113 insertions(+), 127 deletions(-) create mode 100644 utils.py diff --git a/scan_precision.py b/scan_precision.py index b098429..652ab00 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -8,21 +8,12 @@ import json from train import trainCNN - -def plotHist(x,y,ye, name, odir,xtitle, ytitle): - plt.figure() - plt.errorbar(x,y,ye) - plt.title('') - plt.ylabel(ytitle) - plt.xlabel(xtitle) - plt.legend(['others 16,6'], loc='upper right') - plt.savefig(odir+"/"+name+".png") - return +from util import plotGraphErr def plotScan(x,outs,name,odir,xtitle="n bits"): outs = pd.concat(outs) for metric in ['ssd','corr','emd']: - plotHist(x, outs[metric], outs[metric+'_err'], name+"_"+metric, + plotGraphErr(x, outs[metric], outs[metric+'_err'], name+"_"+metric, odir,xtitle=xtitle,ytitle=metric) outs.to_csv(odir+"/"+name+".csv") return diff --git a/train.py b/train.py index 2ce03aa..2e8a93d 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,8 @@ matplotlib.use('Agg') import matplotlib.pyplot as plt +from util import plotHist,plotHistErr + import numba import json @@ -89,19 +91,6 @@ def save_models(autoencoder, name): encoder.save_weights('%s.hdf5'%("encoder_"+name)) decoder.save_weights('%s.hdf5'%("decoder_"+name)) return - - -def predict(x,autoencoder,encoder,reshape=True): - decoded_Q = autoencoder.predict(x) - encoded_Q = encoder.predict(x) - - #need reshape for CNN layers - if reshape : - decoded_Q = np.reshape(decoded_Q,(len(decoded_Q),12,4)) - encoded_shape = encoded_Q.shape - encoded_Q = np.reshape(encoded_Q,(len(encoded_Q),encoded_shape[3],encoded_shape[1])) - - return decoded_Q, encoded_Q ### cross correlation of input/output def cross_corr(x,y): @@ -177,7 +166,6 @@ def make_supercells(_x): x[sc]=0 x[ii]=mysum return x - def threshold(_x, norm, cut): x = _x.copy() # reshape to allow broadcasting to all cells @@ -185,19 +173,13 @@ def threshold(_x, norm, cut): x = np.where(x*norm_shape>=cut,x,0) return x -def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): - if index.size==0: - Nevents=8 - #randomly pick Nevents if index is not specified - index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - else: - Nevents = len(index) +def visDisplays(input_Q,decoded_Q,encoded_Q=None,index,name='model_X'): + Nevents = len(index) inputImg = input_Q[index] - encodedImg = encoded_Q[index] outputImg = decoded_Q[index] - fig, axs = plt.subplots(3, Nevents, figsize=(16, 10)) + fig, axs = plt.subplots(2+(encoded_Q!=None), Nevents, figsize=(16, 10)) for i in range(0,Nevents): if i==0: @@ -212,44 +194,24 @@ def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): else: axs[1,i].set(xlabel='cell_x',title='CNN Ouput_%i'%i) c1=axs[1,i].imshow(outputImg[i]) - - for i in range(0,Nevents): - if i==0: - axs[2,i].set(xlabel='latent dim',ylabel='depth',title='Encoded_%i'%i) - else: - axs[2,i].set(xlabel='latent dim',title='Encoded_%i'%i) - c1=axs[2,i].imshow(encodedImg[i]) - plt.colorbar(c1,ax=axs[2,i]) + + if encoded_Q: + encodedImg = encoded_Q[index] + for i in range(0,Nevents): + if i==0: + axs[2,i].set(xlabel='latent dim',ylabel='depth',title='Encoded_%i'%i) + else: + axs[2,i].set(xlabel='latent dim',title='Encoded_%i'%i) + c1=axs[2,i].imshow(encodedImg[i]) + plt.colorbar(c1,ax=axs[2,i]) plt.tight_layout() plt.savefig("%s_examples.png"%name) plt.close() -def visMetric(input_Q,decoded_Q,maxQ,name, skipPlot=False): - def plothist(y,xlabel,name): - plt.figure(figsize=(6,4)) - plt.hist(y,50) - mu = np.mean(y) - std = np.std(y) - ax = plt.axes() - plt.text(0.1, 0.9, name,transform=ax.transAxes) - plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) - plt.xlabel(xlabel) - plt.ylabel('Entry') - plt.title('%s on validation set'%xlabel) - plt.savefig("hist_%s.png"%name) - plt.close() +def visMetric(input_Q,decoded_Q,metric,name,odir,skipPlot=False): - metrics = [cross_corr,ssd,emd,d_weighted_mean] - cross_corr_arr = np.array([cross_corr(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))] ) - ssd_arr = np.array([ssd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) - emd_arr = np.array([emd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) - - if skipPlot: return cross_corr_arr,ssd_arr,emd_arr - - plothist(cross_corr_arr,'cross correlation',name+"_corr") - plothist(ssd_arr,'sum squared difference',name+"_ssd") - plothist(emd_arr,'earth movers distance',name+"_emd") + plotHist(vals,name,options.odir,xtitle=longMetric[mname]) plt.figure(figsize=(6,4)) plt.hist([input_Q.flatten(),decoded_Q.flatten()],20,label=['input','output']) @@ -262,14 +224,6 @@ def plothist(y,xlabel,name): input_Q_abs = np.array([input_Q[i] * maxQ[i] for i in range(0,len(input_Q))]) decoded_Q_abs = np.array([decoded_Q[i]*maxQ[i] for i in range(0,len(decoded_Q))]) - plt.figure(figsize=(6,4)) - plt.hist([input_Q_abs.flatten(),decoded_Q_abs.flatten()],20,label=['input','output']) - plt.yscale('log') - plt.legend(loc='upper right') - plt.xlabel('Charge') - plt.savefig("hist_Qabs_%s.png"%name) - plt.close() - nonzeroQs = np.count_nonzero(input_Q_abs.reshape(len(input_Q_abs),48),axis=1) occbins = [0,5,10,20,48] fig, axes = plt.subplots(1,len(occbins)-1, figsize=(16, 4)) @@ -505,10 +459,7 @@ def trainCNN(options, args, pam_updates=None): history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: save_models(m_autoCNN,model_name) - - Nevents = 8 - N_verify = 50 - + input_Q,cnn_deQ ,cnn_enQ = m.predict(val_input) ## csv files for RTL verification @@ -517,63 +468,65 @@ def trainCNN(options, args, pam_updates=None): np.savetxt("verify_output.csv",cnn_enQ[0:N_csv].reshape(N_csv,m.pams['encoded_dim']), delimiter=",",fmt='%.12f') np.savetxt("verify_decoded.csv",cnn_deQ[0:N_csv].reshape(N_csv,48), delimiter=",",fmt='%.12f') - index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - stc_Q = make_supercells(input_Q) - thr_nom = threshold(input_Q,val_max,47) # 1.35 transverse MIPs - thr_hi = threshold(input_Q,val_max,69) # 2.0 transverse MIPs - + thr_lo_Q = threshold(input_Q,val_max,47) # 1.35 transverse MIPs + thr_hi_Q = threshold(input_Q,val_max,69) # 2.0 transverse MIPs + occupancy = np.count_nonzero(input_Q.reshape(len(input_Q),48),axis=1) + + # compression algorithms, autoencoder and more traditional benchmarks + alg_outs = {'ae' : cnn_deQ, + 'stc': stc_Q, + 'thr_lo': thr_lo_Q, + 'thr_hi': thr_hi_Q, + } # metrics to compute on the validation dataset - metrics = {'xcorr' :cross_corr, - 'ssd' :ssd, - 'emd' :emd, - 'diff_mean':d_weighted_mean, + metrics = {'cross_corr' :cross_corr, + 'SSD' :ssd, + 'EMD' :emd, + 'dMean':d_weighted_mean, + } + longMetric = {'cross_corr' :'cross correlation', + 'SSD' :'sum of squared differences', + 'EMD' :'earth movers distance', + 'dMean':'difference in energy-weighted mean', } - + # to generate event displays + Nevents = 8 + index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - # # super trigger cell variants - # sc_metrics = {'sc_'+n : (lambda x,y,_=None : (metrics[n])(make_supercells(x),make_supercells(y))) for n in metrics} - # # threshold variants - # thr_metrics={} - # for pct in [47,69]: #1.35 and 2.0 transverse MIPs - # thr_metrics.update({'thr{}_'.format(pct)+n : (lambda x,y,maxVal : (metrics[n])(threshold(x,maxVal,pct),threshold(y,maxVal,pct))) for n in metrics}) - # all_metrics = metrics.copy() - # all_metrics.update(sc_metrics) - # all_metrics.update(thr_metrics) - # #metric_arrays={name : np.array([(all_metrics[name])(input_Q[i],cnn_deQ[i]) for i in range(0,len(input_Q))]) for name in all_metrics} - - metric_arrays={name : np.array([(all_metrics[name])(input_Q[i],cnn_deQ[i],val_max[i]) for i in range(0,len(input_Q))]) for name in all_metrics} - - corr_arr, ssd_arr, emd_arr = visMetric(input_Q,cnn_deQ,val_max,name=model_name, skipPlot=options.skipPlot) - - if not options.skipPlot: - hi_corr_index = (np.where(corr_arr>0.9))[0] - low_corr_index = (np.where(corr_arr<0.2))[0] - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name) - if len(hi_corr_index)>0: - index = np.random.choice(hi_corr_index, min(Nevents,len(hi_corr_index)), replace=False) - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name+"_corr0.9") + # compute metrics for each alg + for algname, alg_out in alg_outs.items(): + # charge fraction comparison + plotHist([input_Q.flatten(),alg_out.flatten()],"hist_chargeFrac_"+algname,options.odir,xtitle="charge fraction",ytitle="Cells") + # event displays + if(not options.skipPlot): visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None), index, name=algname) + for mname, metric in metrics.items(): + name = mname+"_"+algname + vals = [metric(input_Q[i],alg_out[i]) for i in range(0,len(input_Q))] + model[name] = np.round(np.mean(vals),3) + model[name+'_err'] = np.round(np.std(vals),3) + if(not options.skipPlot): + plotHist(vals,"hist_"+name,options.odir,xtitle=longMetric[mname]) + sort = np.sort(vals) + hi_index = (np.where(vals>vals.quantile(0.9)))[0] + lo_index = (np.where(vals0: + hi_index = np.random.choice(hi_index, min(Nevents,len(hi_index)), replace=False) + visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None),hi_ index, name=algname) + if len(lo_index)>0: + lo_index = np.random.choice(lo_index, min(Nevents,len(lo_index)), replace=False) + visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None),lo_ index, name=algname) - if len(low_corr_index)>0: - index = np.random.choice(low_corr_index,min(Nevents,len(low_corr_index)), replace=False) - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name+"_corr0.2") - - model['corr'] = np.round(np.mean(corr_arr),3) - model['ssd'] = np.round(np.mean(ssd_arr),3) - model['emd'] = np.round(np.mean(emd_arr),3) - model['corr_err'] = np.round(np.std(corr_arr),3) - model['ssd_err'] = np.round(np.std(ssd_arr),3) - model['emd_err'] = np.round(np.std(emd_arr),3) - summary = summary.append( {'name':model_name, - 'corr':model['corr'], - 'ssd':model['ssd'], - 'emd':model['emd'], - 'corr_err':model['corr_err'], - 'ssd_err':model['ssd_err'], - 'emd_err':model['emd_err'], + # 'corr':model['corr'], + # 'ssd':model['ssd'], + # 'emd':model['emd'], + # 'corr_err':model['corr_err'], + # 'ssd_err':model['ssd_err'], + # 'emd_err':model['emd_err'], 'en_pams' : m_autoCNNen.count_params(), 'tot_pams': m_autoCNN.count_params(),}, ignore_index=True) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..811c04d --- /dev/null +++ b/utils.py @@ -0,0 +1,42 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import json + +from train import trainCNN + +def plotGraph(x, y, name, odir, xtitle, ytitle, leg=None): + plt.figure() + plt.plot(x,y) + plt.title('') + plt.ylabel(ytitle) + plt.xlabel(xtitle) + if leg: plt.legend(leg, loc='upper right') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + +def plotGraphErr(x, y, ye, name, odir, xtitle, ytitle, leg=None): + plt.figure() + plt.errorbar(x,y,ye) + plt.title('') + plt.ylabel(ytitle) + plt.xlabel(xtitle) + if leg: plt.legend(leg, loc='upper right') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + +def plotHist(vals,name,odir,xtitle="",ytitle="",nbins=40): + plt.figure() + plt.hist(vals,nbins) + mu = np.mean(vals) + std = np.std(vals) + ax = plt.axes() + plt.text(0.1, 0.9, name,transform=ax.transAxes) + plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) + plt.xlabel(xtitle) + plt.ylabel(ytitle if ytitle else 'Entries') + plt.savefig(odir+"/"+name+".png") + plt.close() + return From dddf7b3373d7e157c44a6f8be275065a28785fe2 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Mon, 16 Mar 2020 22:19:28 -0500 Subject: [PATCH 07/13] running w more metrics --- train.py | 167 +++++++++++++++++++++++++++++++++++-------------------- utils.py | 4 +- 2 files changed, 108 insertions(+), 63 deletions(-) diff --git a/train.py b/train.py index 2e8a93d..385d2a4 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ matplotlib.use('Agg') import matplotlib.pyplot as plt -from util import plotHist,plotHistErr +##from utils import plotHist import numba import json @@ -34,6 +34,26 @@ def normalize(data,rescaleInputToMax=False): data[i] = 1.*data[i]/data[i].sum() return data,np.array(norm) +def plotHist(vals,name,odir='.',xtitle="",ytitle="",nbins=40, + stats=True, logy=False, leg=None): + plt.figure(figsize=(6,4)) + if leg: + plt.hist(vals,nbins,label=leg) + else: + plt.hist(vals,nbins) + ax = plt.axes() + plt.text(0.1, 0.9, name,transform=ax.transAxes) + if stats: + mu = np.mean(vals) + std = np.std(vals) + plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) + plt.xlabel(xtitle) + plt.ylabel(ytitle if ytitle else 'Entries') + if logy: plt.yscale('log') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + def split(shaped_data, validation_frac=0.2): N = round(len(shaped_data)*validation_frac) @@ -101,6 +121,7 @@ def cross_corr(x,y): return corr[0,1] def ssd(x,y): + if (np.sum(x)==0 or np.sum(y)==0): return 1. ssd=np.sum(((x-y)**2).flatten()) ssd = ssd/(np.sum(x**2)*np.sum(y**2))**0.5 return ssd @@ -121,7 +142,9 @@ def ssd(x,y): [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) hexMetric = ot.dist(hexCoords, hexCoords, 'euclidean') +MAXDIST = 16.08806614 def emd(_x, _y, threshold=-1): + if (np.sum(_x)==0 or np.sum(_y)==0): return MAXDIST x = np.array(_x, dtype=np.float64) y = np.array(_y, dtype=np.float64) x = (1./x.sum() if x.sum() else 1.)*x.flatten() @@ -137,15 +160,16 @@ def emd(_x, _y, threshold=-1): return ot.emd2(x, y, hexMetric) def d_weighted_mean(x, y): + if (np.sum(x)==0 or np.sum(y)==0): return MAXDIST/2. x = (1./x.sum() if x.sum() else 1.)*x.flatten() y = (1./y.sum() if y.sum() else 1.)*y.flatten() dx = hexCoords[:,0].dot(x-y) dy = hexCoords[:,1].dot(x-y) return np.sqrt(dx*dx+dy*dy) -def make_supercells(_x): - shape = _x.shape - x = _x.copy().flatten() +def make_supercells(inQ, shareQ=False): + outQ = inQ.copy() + inshape = inQ[0].shape mask = np.array([ [ 0, 1, 4, 5], #indices for 1 supercell [ 2, 3, 6, 7], @@ -155,17 +179,26 @@ def make_supercells(_x): [18, 19, 22, 23], [24, 25, 28, 29], [26, 27, 30, 31], - [24, 25, 28, 29], - [26, 27, 30, 31], [32, 33, 36, 37], - [34, 35, 38, 39]]) - for sc in mask: - # set max cell to sum - ii = np.argmax( x[sc] ) - mysum = np.sum( x[sc] ) - x[sc]=0 - x[ii]=mysum - return x + [34, 35, 38, 39], + [40, 41, 44, 45], + [43, 43, 46, 47]]) + for i in range(len(inQ)): + inFlat = inQ[i].flatten() + outFlat = outQ[i].flatten() + for sc in mask: + # set max cell to sum + if shareQ: + mysum = np.sum( inFlat[sc] ) + outFlat[sc]=mysum/4. + else: + ii = np.argmax( inFlat[sc] ) + mysum = np.sum( inFlat[sc] ) + outFlat[sc]=0 + outFlat[sc[ii]]=mysum + outQ[i] = outFlat.reshape(inshape) + return outQ + def threshold(_x, norm, cut): x = _x.copy() # reshape to allow broadcasting to all cells @@ -173,13 +206,14 @@ def threshold(_x, norm, cut): x = np.where(x*norm_shape>=cut,x,0) return x -def visDisplays(input_Q,decoded_Q,encoded_Q=None,index,name='model_X'): +def visDisplays(index,input_Q,decoded_Q,encoded_Q=np.array([]),name='model_X'): Nevents = len(index) inputImg = input_Q[index] outputImg = decoded_Q[index] - - fig, axs = plt.subplots(2+(encoded_Q!=None), Nevents, figsize=(16, 10)) + + nrows = 3 if len(encoded_Q) else 2 + fig, axs = plt.subplots(nrows, Nevents, figsize=(16, 10)) for i in range(0,Nevents): if i==0: @@ -195,7 +229,7 @@ def visDisplays(input_Q,decoded_Q,encoded_Q=None,index,name='model_X'): axs[1,i].set(xlabel='cell_x',title='CNN Ouput_%i'%i) c1=axs[1,i].imshow(outputImg[i]) - if encoded_Q: + if len(encoded_Q): encodedImg = encoded_Q[index] for i in range(0,Nevents): if i==0: @@ -427,9 +461,26 @@ def trainCNN(options, args, pam_updates=None): m['pams'].update(pam_updates) print ('updated parameters for model',m['name']) - summary = pd.DataFrame(columns=['name','en_pams','tot_pams', - 'corr','ssd','emd', - 'corr_err','ssd_err','emd_err',]) + # compression algorithms, autoencoder and more traditional benchmarks + algnames = ['ae','stc1','stc2','thr_lo','thr_hi'] + # metrics to compute on the validation dataset + metrics = {'cross_corr' :cross_corr, + 'SSD' :ssd, + 'EMD' :emd, + 'dMean':d_weighted_mean, + 'zero_frac':(lambda x,y: np.all(y==0)),} + longMetric = {'cross_corr' :'cross correlation', + 'SSD' :'sum of squared differences', + 'EMD' :'earth movers distance', + 'dMean':'difference in energy-weighted mean', + 'zero_frac':'zero fraction',} + summary_entries=['name','en_pams','tot_pams'] + for algname in algnames: + for mname in metrics: + name = mname+"_"+algname + summary_entries.append(mname+"_"+algname) + summary_entries.append(mname+"_"+algname+"_err") + summary = pd.DataFrame(columns=summary_entries) orig_dir = os.getcwd() if not os.path.exists(options.odir): os.mkdir(options.odir) @@ -456,9 +507,15 @@ def trainCNN(options, args, pam_updates=None): val_max = maxdata[val_ind] if model['ws']=='': + if options.quickTrain: train_input = train_input[:5000] history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: save_models(m_autoCNN,model_name) + + summary_dict = { + 'name':model_name, + 'en_pams' : m_autoCNNen.count_params(), + 'tot_pams': m_autoCNN.count_params(),} input_Q,cnn_deQ ,cnn_enQ = m.predict(val_input) @@ -468,28 +525,17 @@ def trainCNN(options, args, pam_updates=None): np.savetxt("verify_output.csv",cnn_enQ[0:N_csv].reshape(N_csv,m.pams['encoded_dim']), delimiter=",",fmt='%.12f') np.savetxt("verify_decoded.csv",cnn_deQ[0:N_csv].reshape(N_csv,48), delimiter=",",fmt='%.12f') - stc_Q = make_supercells(input_Q) + stc1_Q = make_supercells(input_Q) + stc2_Q = make_supercells(input_Q,shareQ=True) thr_lo_Q = threshold(input_Q,val_max,47) # 1.35 transverse MIPs thr_hi_Q = threshold(input_Q,val_max,69) # 2.0 transverse MIPs occupancy = np.count_nonzero(input_Q.reshape(len(input_Q),48),axis=1) - - # compression algorithms, autoencoder and more traditional benchmarks alg_outs = {'ae' : cnn_deQ, - 'stc': stc_Q, + 'stc1': stc1_Q, + 'stc2': stc2_Q, 'thr_lo': thr_lo_Q, 'thr_hi': thr_hi_Q, } - # metrics to compute on the validation dataset - metrics = {'cross_corr' :cross_corr, - 'SSD' :ssd, - 'EMD' :emd, - 'dMean':d_weighted_mean, - } - longMetric = {'cross_corr' :'cross correlation', - 'SSD' :'sum of squared differences', - 'EMD' :'earth movers distance', - 'dMean':'difference in energy-weighted mean', - } # to generate event displays Nevents = 8 @@ -498,38 +544,38 @@ def trainCNN(options, args, pam_updates=None): # compute metrics for each alg for algname, alg_out in alg_outs.items(): # charge fraction comparison - plotHist([input_Q.flatten(),alg_out.flatten()],"hist_chargeFrac_"+algname,options.odir,xtitle="charge fraction",ytitle="Cells") + if(not options.skipPlot): plotHist([input_Q.flatten(),alg_out.flatten()], + algname+"_fracQ",xtitle="charge fraction",ytitle="Cells", + stats=False,logy=True,leg=['input','output']) + input_Q_abs = np.array([input_Q[i]*val_max[i] for i in range(0,len(input_Q))]) + alg_out_abs = np.array([alg_out[i]*val_max[i] for i in range(0,len(alg_out))]) + if(not options.skipPlot): plotHist([input_Q_abs.flatten(),alg_out_abs.flatten()], + algname+"_absQ",xtitle="absolute charge",ytitle="Cells", + stats=False,logy=True,leg=['input','output']) # event displays - if(not options.skipPlot): visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None), index, name=algname) + if(not options.skipPlot): visDisplays(index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) for mname, metric in metrics.items(): name = mname+"_"+algname - vals = [metric(input_Q[i],alg_out[i]) for i in range(0,len(input_Q))] - model[name] = np.round(np.mean(vals),3) - model[name+'_err'] = np.round(np.std(vals),3) - if(not options.skipPlot): - plotHist(vals,"hist_"+name,options.odir,xtitle=longMetric[mname]) - sort = np.sort(vals) - hi_index = (np.where(vals>vals.quantile(0.9)))[0] - lo_index = (np.where(valsnp.quantile(vals,0.9)))[0] + lo_index = (np.where(vals0: hi_index = np.random.choice(hi_index, min(Nevents,len(hi_index)), replace=False) - visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None),hi_ index, name=algname) + visDisplays(hi_index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) if len(lo_index)>0: lo_index = np.random.choice(lo_index, min(Nevents,len(lo_index)), replace=False) - visDisplays(input_Q, alg_out, (cnn_enQ if algname=='ae' else None),lo_ index, name=algname) + visDisplays(lo_index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) - summary = summary.append( - {'name':model_name, - # 'corr':model['corr'], - # 'ssd':model['ssd'], - # 'emd':model['emd'], - # 'corr_err':model['corr_err'], - # 'ssd_err':model['ssd_err'], - # 'emd_err':model['emd_err'], - 'en_pams' : m_autoCNNen.count_params(), - 'tot_pams': m_autoCNN.count_params(),}, - ignore_index=True) + print('summary_dict',summary_dict) + summary = summary.append(summary_dict, ignore_index=True) with open(model_name+"_pams.json",'w') as f: f.write(json.dumps(m.get_pams(),indent=4)) @@ -548,6 +594,7 @@ def trainCNN(options, args, pam_updates=None): parser.add_option("--dryRun", action='store_true', default = False,dest="dryRun", help="dryRun") parser.add_option("--epochs", type='int', default = 100, dest="epochs", help="n epoch to train") parser.add_option("--skipPlot", action='store_true', default = False,dest="skipPlot", help="skip the plotting step") + parser.add_option("--quickTrain", action='store_true', default = False,dest="quickTrain", help="train w only 5k events for testing purposes") parser.add_option("--nCSV", type='int', default = 50, dest="nCSV", help="n of validation events to write to csv") parser.add_option("--rescaleInputToMax", action='store_true', default = False,dest="rescaleInputToMax", help="recale the input images so the maximum deposit is 1. Else normalize") (options, args) = parser.parse_args() diff --git a/utils.py b/utils.py index 811c04d..f7f7e85 100644 --- a/utils.py +++ b/utils.py @@ -1,9 +1,7 @@ +import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt -import json - -from train import trainCNN def plotGraph(x, y, name, odir, xtitle, ytitle, leg=None): plt.figure() From 4dedf0a189e63ce7a48f1cd981c461761faf1a61 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Tue, 17 Mar 2020 13:49:54 -0500 Subject: [PATCH 08/13] multi-file reading --- train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 385d2a4..742aa56 100644 --- a/train.py +++ b/train.py @@ -302,8 +302,17 @@ def trainCNN(options, args, pam_updates=None): # from tensorflow.keras import backend # backend.set_image_data_format('channels_first') - - data = pd.read_csv(options.inputFile,dtype=np.float64) ## big 300k file + if os.path.isdir(options.inputFile): + df_arr = [] + for infile in os.listdir(options.inputFile): + infile = os.path.join(options.inputFile,infile) + df_arr.append(pd.read_csv(infile, dtype=np.float64, header=0)) + data = pd.concat(df_arr) + print(data.shape) + data.describe() + else: + data = pd.read_csv(options.inputFile,dtype=np.float64) + #data = pd.read_csv(options.inputFile,dtype=np.float64) ## big 300k file normdata,maxdata = normalize(data.values.copy(),rescaleInputToMax=options.rescaleInputToMax) arrange8x8 = np.array([ From d0bcff6a0d10d1e2db86b0a6eb33a2242155e278 Mon Sep 17 00:00:00 2001 From: Ben Hawks Date: Tue, 17 Mar 2020 17:03:31 -0500 Subject: [PATCH 09/13] Small changees/fixes, selects the correct subset of data from the nELinks dataset and drops rows that have an occupancy of 0 before training, along with updating model definitions --- scan_precision.py | 2 +- train.py | 30 ++++++++++++++++++++---------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/scan_precision.py b/scan_precision.py index cd040a9..f083306 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -8,7 +8,7 @@ import json from train import trainCNN -from util import plotGraphErr +from utils import plotGraphErr def plotScan(x,outs,name,odir,xtitle="n bits"): outs = pd.concat(outs) diff --git a/train.py b/train.py index 5da628c..e95ef5d 100644 --- a/train.py +++ b/train.py @@ -302,10 +302,10 @@ def trainCNN(options, args, pam_updates=None): print("Is GPU available? ", tf.test.is_gpu_available()) # default precisions for quantized training - nBits_input = {'total': 16, 'integer': 6} - nBits_accum = {'total': 16, 'integer': 6} - nBits_weight = {'total': 16, 'integer': 6} - nBits_encod = {'total': 16, 'integer': 6} + nBits_input = {'total': 32, 'integer': 4} + nBits_accum = {'total': 32, 'integer': 4} + nBits_weight = {'total': 32, 'integer': 4} + nBits_encod = {'total': 32, 'integer': 4} # model-dependent -- use common weights unless overridden conv_qbits = nBits_weight dense_qbits = nBits_weight @@ -317,8 +317,9 @@ def trainCNN(options, args, pam_updates=None): df_arr = [] for infile in os.listdir(options.inputFile): infile = os.path.join(options.inputFile,infile) - df_arr.append(pd.read_csv(infile, dtype=np.float64, header=0)) + df_arr.append(pd.read_csv(infile, dtype=np.float64, header=0, usecols=[*range(1, 49)])) data = pd.concat(df_arr) + data = data.loc[(data.sum(axis=1) != 0)] #drop rows where occupancy = 0 print(data.shape) data.describe() else: @@ -364,12 +365,21 @@ def trainCNN(options, args, pam_updates=None): 15,31, 47]) models = [ - {'name': '4x4_norm_d10', 'ws': '', - 'pams': {'shape': (4, 4, 3), + #{'name': '4x4_norm_d10', 'ws': '', + # 'pams': {'shape': (4, 4, 3), + # 'channels_first': False, + # 'arrange': arrange443, + # 'encoded_dim': 10, + # 'loss': 'weightedMSE'}}, + {'name': '4x4_norm_v7', 'ws': '', + 'pams': {'shape': (4, 4, 3), 'channels_first': False, - 'arrange': arrange443, - 'encoded_dim': 10, - 'loss': 'weightedMSE'}}, + 'arrange': arrange443, + 'loss': 'weightedMSE', + 'CNN_layer_nodes': [4, 4, 4], + 'CNN_kernel_size': [5, 5, 3], + 'CNN_pool': [False, False, False], }}, + ] #{'name':'denseCNN', 'ws':'denseCNN.hdf5', 'pams':{'shape':(1,8,8) } }, From a12ea8d9d87e949072a521b2a596b4bb713ef29a Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Wed, 18 Mar 2020 12:47:04 -0500 Subject: [PATCH 10/13] progress towards sinkhorn training --- qDenseCNN.py | 8 ++++++++ train.py | 3 ++- utils.py | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/qDenseCNN.py b/qDenseCNN.py index 26a7580..494230b 100644 --- a/qDenseCNN.py +++ b/qDenseCNN.py @@ -199,6 +199,14 @@ def init(self, printSummary=True): # keep_negitive = 0 on inputs, otherwise for self.autoencoder.compile(loss=self.weightedMSE, optimizer='adam') self.encoder.compile(loss=self.weightedMSE, optimizer='adam') + elif self.pams['loss'] == 'sink': + import ot_tf + x_tf = tf.compat.v1.placeholder(dtype=tf.float32, shape=[48, 2]) + y_tf = tf.compat.v1.placeholder(dtype=tf.float32, shape=[48, 2]) + M_tf = ot_tf.dmat(x_tf, y_tf) + tf_sinkhorn_loss = ot_tf.sink(M_tf, (48,48), 0.5) + self.autoencoder.compile(loss=tf_sinkhorn_loss, optimizer='adam') + self.encoder.compile(loss=tf_sinkhorn_loss, optimizer='adam') elif self.pams['loss'] != '': self.autoencoder.compile(loss=self.pams['loss'], optimizer='adam') self.encoder.compile(loss=self.pams['loss'], optimizer='adam') diff --git a/train.py b/train.py index e95ef5d..9a58db1 100644 --- a/train.py +++ b/train.py @@ -375,7 +375,8 @@ def trainCNN(options, args, pam_updates=None): 'pams': {'shape': (4, 4, 3), 'channels_first': False, 'arrange': arrange443, - 'loss': 'weightedMSE', + #'loss': 'weightedMSE', + 'loss': 'sink', 'CNN_layer_nodes': [4, 4, 4], 'CNN_kernel_size': [5, 5, 3], 'CNN_pool': [False, False, False], }}, diff --git a/utils.py b/utils.py index f7f7e85..7fa1517 100644 --- a/utils.py +++ b/utils.py @@ -38,3 +38,24 @@ def plotHist(vals,name,odir,xtitle="",ytitle="",nbins=40): plt.savefig(odir+"/"+name+".png") plt.close() return + +def decode_ECON(mantissa, exp, n_mantissa=3,n_exp=4): + if exp==0: return mantissa + mantissa += (1<>exp) - (1<<(n_mantissa-1)) + return (mantissa,exp) + +def test_econ(): + for m in range(1<<3): + for e in range(1<<4): + val = decode_ECON(m,e) + m1, e1 = encode_ECON(val) + print(m,e,'-->',val,'-->',m1,e1) From 54acd5e316118ce832f00fbd9ca25a88cd6ba0ab Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Thu, 19 Mar 2020 21:17:50 -0500 Subject: [PATCH 11/13] implemented sinkhorn func in a loop --- ot_tf.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ qDenseCNN.py | 53 +++++++++++++++++++++++++++++++++++++++++++++------- train.py | 2 ++ 3 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 ot_tf.py diff --git a/ot_tf.py b/ot_tf.py new file mode 100644 index 0000000..fa58ae1 --- /dev/null +++ b/ot_tf.py @@ -0,0 +1,51 @@ +import tensorflow as tf +#tf.compat.v1.disable_eager_execution() + +def sink(a, b, M, m_size, reg, numItermax=1000, stopThr=1e-9): + # we assume that no distances are null except those of the diagonal of distances + + # a = tf.expand_dims(tf.ones(shape=(m_size[0],)) / m_size[0], axis=1) # (na, 1) + # b = tf.expand_dims(tf.ones(shape=(m_size[1],)) / m_size[1], axis=1) # (nb, 1) + + # init data + Nini = m_size[0] + Nfin = m_size[1] + + u = tf.expand_dims(tf.ones(Nini) / Nini, axis=1) # (na, 1) + v = tf.expand_dims(tf.ones(Nfin) / Nfin, axis=1) # (nb, 1) + + K = tf.exp(-M / reg) # (na, nb) + + Kp = (1.0 / a) * K # (na, 1) * (na, nb) = (na, nb) + + cpt = tf.constant(0) + err = tf.constant(1.0) + + c = lambda cpt, u, v, err: tf.logical_and(cpt < numItermax, err > stopThr) + + def err_f1(): + # we can speed up the process by checking for the error only all the 10th iterations + transp = u * (K * tf.squeeze(v)) # (na, 1) * ((na, nb) * (nb,)) = (na, nb) + err_ = tf.pow(tf.norm(tensor=tf.reduce_sum(input_tensor=transp) - b, ord=1), 2) # (,) + return err_ + + def err_f2(): + return err + + def loop_func(cpt, u, v, err): + KtransposeU = tf.matmul(tf.transpose(a=K, perm=(1, 0)), u) # (nb, na) x (na, 1) = (nb, 1) + v = tf.compat.v1.div(b, KtransposeU) # (nb, 1) + u = 1.0 / tf.matmul(Kp, v) # (na, 1) + + err = tf.cond(pred=tf.equal(cpt % 10, 0), true_fn=err_f1, false_fn=err_f2) + + cpt = tf.add(cpt, 1) + return cpt, u, v, err + + _, u, v, _ = tf.while_loop(cond=c, body=loop_func, loop_vars=[cpt, u, v, err]) + + result = tf.reduce_sum(input_tensor=u * K * tf.reshape(v, (1, -1)) * M) + + return result + + diff --git a/qDenseCNN.py b/qDenseCNN.py index 494230b..374a2f2 100644 --- a/qDenseCNN.py +++ b/qDenseCNN.py @@ -10,6 +10,50 @@ import numpy as np import json +# for sinkhorn metric +import ot_tf +import ot + + +hexCoords = np.array([ + [0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044], + [2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794], + [4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198], + [6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603], + [-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945], + [-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393], + [-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992], + [-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023], + [4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895], + [2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299], + [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], + [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) +hexMetric = tf.constant( ot.dist(hexCoords, hexCoords, 'euclidean'), tf.float32) + +def myfunc(a): + reg=0.5 + y_true, y_pred = tf.split(a,num_or_size_splits=2,axis=1) + tf_sinkhorn_loss = ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg) + return tf_sinkhorn_loss + +def sinkhorn_loss(y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + y_pred = K.reshape(y_pred, (-1,48,1)) + y_true = K.reshape(y_true, (-1,48,1)) + cc = tf.concat([y_true, y_pred], axis=2) + return K.mean( tf.map_fn(myfunc, cc), axis=(-1) ) + + # return K.mean( tf.map_fn(myfunc, y_true), axis=(-1) ) + # return K.mean( tf.map_fn(myfunc, [y_true, y_pred]), axis=(-1) ) + # tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) ) + # return tf_sinkhorn_loss + # sy_true = tf.split(y_true,num_or_size_splits=K.shape(y_true)[0],axis=0) + # sy_pred = tf.split(y_pred,num_or_size_splits=K.shape(y_pred)[0],axis=0) + # losses = [ ot_tf.sink(sy_true[i], sy_pred[i], hexMetric, (48, 48), reg) for r in range(len(sy_true))] + # return losses[0] + #tf_sinkhorn_loss = K.mean( ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg), axis=(-1) ) + # tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) ) + # return tf_sinkhorn_loss class qDenseCNN: def __init__(self, name='', weights_f=''): @@ -200,13 +244,8 @@ def init(self, printSummary=True): # keep_negitive = 0 on inputs, otherwise for self.encoder.compile(loss=self.weightedMSE, optimizer='adam') elif self.pams['loss'] == 'sink': - import ot_tf - x_tf = tf.compat.v1.placeholder(dtype=tf.float32, shape=[48, 2]) - y_tf = tf.compat.v1.placeholder(dtype=tf.float32, shape=[48, 2]) - M_tf = ot_tf.dmat(x_tf, y_tf) - tf_sinkhorn_loss = ot_tf.sink(M_tf, (48,48), 0.5) - self.autoencoder.compile(loss=tf_sinkhorn_loss, optimizer='adam') - self.encoder.compile(loss=tf_sinkhorn_loss, optimizer='adam') + self.autoencoder.compile(loss=sinkhorn_loss, optimizer='adam') + self.encoder.compile(loss=sinkhorn_loss, optimizer='adam') elif self.pams['loss'] != '': self.autoencoder.compile(loss=self.pams['loss'], optimizer='adam') self.encoder.compile(loss=self.pams['loss'], optimizer='adam') diff --git a/train.py b/train.py index 9a58db1..1cf6139 100644 --- a/train.py +++ b/train.py @@ -540,6 +540,8 @@ def trainCNN(options, args, pam_updates=None): if model['ws']=='': if options.quickTrain: train_input = train_input[:5000] + # train_input = train_input.reshape(len(train_input),48) + # val_input = val_input.reshape(len(val_input),48) history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: save_models(m_autoCNN,model_name) From 51f404f951752ba98108a180a783214768161ee5 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Thu, 19 Mar 2020 21:36:38 -0500 Subject: [PATCH 12/13] add sinkhorn test --- tests/test_sinkhorn.py | 49 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/test_sinkhorn.py diff --git a/tests/test_sinkhorn.py b/tests/test_sinkhorn.py new file mode 100644 index 0000000..6e369bc --- /dev/null +++ b/tests/test_sinkhorn.py @@ -0,0 +1,49 @@ +import tensorflow as tf +import numpy as np +import sys +sys.path.append("/home/therwig/data/sandbox/hgcal/Ecoder/") +import ot_tf +import ot +#tf.compat.v1.disable_eager_execution() + +def main(): + + na=48 + nb=48 + reg=0.5 + a = tf.expand_dims(tf.ones(shape=(na,)) / na, axis=1) # (na, 1) + b = tf.expand_dims(tf.ones(shape=(nb,)) / nb, axis=1) # (nb, 1) + m = tf.constant( hexMetric(), tf.float32 ) + tf_sinkhorn_loss = ot_tf.sink(a, b, m, (na, nb), reg) + + # print('finish loads') + # print(a) + # print(b) + # print(m) + print(tf_sinkhorn_loss) + + x = tf.ones(shape=(100,48,1,)) + y = tf.split(x,num_or_size_splits=100,axis=0) + print(y[0]) + + return + + +hexCoords = np.array([ + [0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044], + [2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794], + [4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198], + [6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603], + [-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945], + [-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393], + [-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992], + [-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023], + [4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895], + [2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299], + [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], + [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) +def hexMetric(): + return ot.dist(hexCoords, hexCoords, 'euclidean') + +if __name__== "__main__": + main() From f80852768fa5d14ff9a66e3d85c59db0c0549626 Mon Sep 17 00:00:00 2001 From: Christian Herwig Date: Thu, 19 Mar 2020 22:50:31 -0500 Subject: [PATCH 13/13] working on pool --- qDenseCNN.py | 31 +++++++++++++++++++++++++++---- train.py | 2 -- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/qDenseCNN.py b/qDenseCNN.py index 374a2f2..2aaafa9 100644 --- a/qDenseCNN.py +++ b/qDenseCNN.py @@ -14,7 +14,6 @@ import ot_tf import ot - hexCoords = np.array([ [0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044], [2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794], @@ -51,10 +50,32 @@ def sinkhorn_loss(y_true, y_pred): # sy_pred = tf.split(y_pred,num_or_size_splits=K.shape(y_pred)[0],axis=0) # losses = [ ot_tf.sink(sy_true[i], sy_pred[i], hexMetric, (48, 48), reg) for r in range(len(sy_true))] # return losses[0] - #tf_sinkhorn_loss = K.mean( ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg), axis=(-1) ) + # tf_sinkhorn_loss = K.mean( ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg), axis=(-1) ) # tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) ) # return tf_sinkhorn_loss +def other_loss(y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + loss1 = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1)) + + # y_pred_rs = K.reshape(y_pred, (-1,48)) + # y_true_rs = K.reshape(y_true, (-1,48)) + # y_pred_x = + + y_pred_pool = tf.nn.pool(y_pred,(2,2),'AVG',strides=[1,1]) + y_true_pool = tf.nn.pool(y_true,(2,2),'AVG',strides=[1,1]) + loss2 = K.mean(K.square(y_true_pool - y_pred_pool) * K.maximum(y_true_pool, y_pred_pool), axis=(-1)) + #return loss1 + loss2 + return loss1 + + # return K.mean( tf.map_fn(myfunc, cc), axis=(-1) ) + +def weightedMSE(self, y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + loss = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1)) + return loss + + class qDenseCNN: def __init__(self, name='', weights_f=''): self.name = name @@ -244,8 +265,10 @@ def init(self, printSummary=True): # keep_negitive = 0 on inputs, otherwise for self.encoder.compile(loss=self.weightedMSE, optimizer='adam') elif self.pams['loss'] == 'sink': - self.autoencoder.compile(loss=sinkhorn_loss, optimizer='adam') - self.encoder.compile(loss=sinkhorn_loss, optimizer='adam') + self.autoencoder.compile(loss=other_loss, optimizer='adam') + self.encoder.compile(loss=other_loss, optimizer='adam') + # self.autoencoder.compile(loss=sinkhorn_loss, optimizer='adam') + # self.encoder.compile(loss=sinkhorn_loss, optimizer='adam') elif self.pams['loss'] != '': self.autoencoder.compile(loss=self.pams['loss'], optimizer='adam') self.encoder.compile(loss=self.pams['loss'], optimizer='adam') diff --git a/train.py b/train.py index 1cf6139..9a58db1 100644 --- a/train.py +++ b/train.py @@ -540,8 +540,6 @@ def trainCNN(options, args, pam_updates=None): if model['ws']=='': if options.quickTrain: train_input = train_input[:5000] - # train_input = train_input.reshape(len(train_input),48) - # val_input = val_input.reshape(len(val_input),48) history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: save_models(m_autoCNN,model_name)