diff --git a/train.py b/train.py index 31b853e..baf4f7f 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,6 @@ ##from utils import plotHist -import numba import json import pickle @@ -40,27 +39,26 @@ def double_data(data): i+=2 return np.array(doubled) -@numba.jit -def normalize(data,rescaleInputToMax=False): - maxes =[] - sums =[] - for i in range(len(data)): - maxes.append( data[i].max() ) - sums.append( data[i].sum() ) - if rescaleInputToMax: - data[i] = 1.*data[i]/(data[i].max() if data[i].max() else 1.) - else: - data[i] = 1.*data[i]/(data[i].sum() if data[i].sum() else 1.) - return data,np.array(maxes),np.array(sums) - -@numba.jit -def unnormalize(norm_data,maxvals,rescaleOutputToMax=False): - for i in range(len(norm_data)): - if rescaleOutputToMax: - norm_data[i] = norm_data[i] * maxvals[i] / (norm_data[i].max() if norm_data[i].max() else 1.) - else: - norm_data[i] = norm_data[i] * maxvals[i] / (norm_data[i].sum() if norm_data[i].sum() else 1.) - return norm_data +def normalize(data,rescaleInputToMax=False, shiftNormalization=False,bits=-1,integer=-1): + maxes = data.max(axis=1) + sums = data.sum(axis=1) + normalization = maxes.copy() if rescaleInputToMax else sums.copy() + if shiftNormalization: normalization = pow(2,np.log2(normalization).astype(int)) + normalization[normalization==0] = 1. + data = (data.transpose()/normalization).transpose() + if (bits != -1) and (integer != -1): + data = np.round(data*2**(bits-integer))/2**(bits-integer) + return data, maxes, sums + +def unnormalize(norm_data,maxvals,rescaleOutputToMax=False,shiftNormalization=False): + normalization = maxvals.copy() + if shiftNormalization: normalization = pow(2,np.log2(normalization).astype(int)) + dataT = (norm_data.transpose()*normalization) + #conserve either the total or max charge + conserveDenom = dataT.max(axis=0) if rescaleOutputToMax else dataT.sum(axis=0) + conserveDenom[conserveDenom==0] = 1. + conserve = maxvals/conserveDenom + return (dataT*conserve).transpose() def StringToTextFile(fname,s): with open(fname,'w') as f: @@ -1089,7 +1087,7 @@ def trainCNN(options, args, pam_updates=None): occupancy_all = np.count_nonzero(data_values,axis=1) occupancy_all_1MT = np.count_nonzero(data_values>35,axis=1) - normdata,maxdata,sumdata = normalize(data_values.copy(),rescaleInputToMax=options.rescaleInputToMax) + normdata,maxdata,sumdata = normalize(data_values.copy(),rescaleInputToMax=options.rescaleInputToMax,shiftNormalization=options.useShiftNormalization) maxdata = maxdata / 35. # normalize to units of transverse MIPs sumdata = sumdata / 35. # normalize to units of transverse MIPs @@ -1235,7 +1233,7 @@ def trainCNN(options, args, pam_updates=None): print("Restore normalization") input_Q_abs = np.array([input_Q[i]*(val_max[i] if options.rescaleInputToMax else val_sum[i]) for i in range(0,len(input_Q))]) input_calQ = np.array([input_calQ[i]*(val_max[i] if options.rescaleInputToMax else val_sum[i]) for i in range(0,len(input_calQ)) ]) # shape = (N,48) in CALQ order - output_calQ = unnormalize(output_calQ_fr.copy(), val_max if options.rescaleOutputToMax else val_sum, rescaleOutputToMax=options.rescaleOutputToMax) + output_calQ = unnormalize(output_calQ_fr.copy(), 35*val_max if options.rescaleOutputToMax else 35*val_sum, rescaleOutputToMax=options.rescaleOutputToMax,shiftNormalization=options.useShiftNormalization)/35. #shift normalization requires integer charge values, so val_max and val_sum get rescaled by 35 to convert back to ADC values, then output of unnormalize gets rescaled by 35. #occupancy_0MT = np.count_nonzero(input_Q_abs.reshape(len(input_Q),48),axis=1) #occupancy_1MT = np.count_nonzero(input_Q_abs.reshape(len(input_Q),48)>1.,axis=1) occupancy_0MT = np.count_nonzero(input_calQ.reshape(len(input_Q),48),axis=1) @@ -1300,6 +1298,7 @@ def trainCNN(options, args, pam_updates=None): parser.add_option("--AEonly", type='int', default=1, dest="AEonly", help="run only AE algo") parser.add_option("--rescaleInputToMax", type='int', default=0, dest="rescaleInputToMax", help="recale the input images so the maximum deposit is 1. Else normalize") parser.add_option("--rescaleOutputToMax", type='int', default=0, dest="rescaleOutputToMax", help="recale the output images to match the initial sum") + parser.add_option("--useShiftNormalization", action='store_true', default=False, dest="useShiftNormalization", help="use the bit-shift style normalization") parser.add_option("--nrowsPerFile", type='int', default=500000, dest="nrowsPerFile", help="load nrowsPerFile in a directory") parser.add_option("--occReweight", action='store_true', default = False,dest="occReweight", help="Train with per-event weight on TC occupancy") (options, args) = parser.parse_args()