From a7989bbdda4552caead67df674e28fddfac831aa Mon Sep 17 00:00:00 2001 From: huangjun12 <12272008@bjtu.edu.cn> Date: Tue, 3 Nov 2020 11:48:26 +0000 Subject: [PATCH] update bmn to 2.0rc --- bmn/README.md | 12 +- bmn/bmn.yaml | 2 +- bmn/eval.py | 10 +- bmn/modeling.py | 361 ++++++++++++++++++++++++------------------------ bmn/predict.py | 12 +- bmn/run.sh | 4 +- bmn/train.py | 52 ++++--- 7 files changed, 225 insertions(+), 228 deletions(-) diff --git a/bmn/README.md b/bmn/README.md index 4182409..886a638 100644 --- a/bmn/README.md +++ b/bmn/README.md @@ -32,7 +32,7 @@ BMN Overview git clone https://github.com/PaddlePaddle/hapi cd hapi export PYTHONPATH=`pwd`:$PYTHONPATH - cd examples/bmn + cd bmn ``` @@ -69,9 +69,9 @@ BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理 python train.py -默认使用静态图训练,若使用动态图训练只需要在运行脚本添加`-d`参数即可,如: +默认使用动态图训练,若使用静态图训练只需要在运行脚本添加`-s`参数即可,如: - python train.py -d + python train.py -s - 代码运行需要先安装pandas @@ -84,7 +84,7 @@ BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理 python eval.py --weights=$PATH_TO_WEIGHTS -- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,若未指定,脚本会下载已发布的模型[model](https://paddlemodels.bj.bcebos.com/hapi/bmn.pdparams)进行评估。 +- 进行评估时,可修改命令行中的`weights`参数指定需要评估的权重,如--weights='./checkpoint/final.pdparams'。若未指定,脚本会下载已发布的模型[model](https://paddlemodels.bj.bcebos.com/hapi/bmn.pdparams)进行评估。 - 上述程序会将运行结果保存在`--output_path`参数指定的文件夹下,默认为output/EVAL/BMN\_results;测试结果保存在`--result_path`参数指定的文件夹下,默认为evaluate\_results。 @@ -110,7 +110,7 @@ BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理 | AR@1 | AR@5 | AR@10 | AR@100 | AUC | | :---: | :---: | :---: | :---: | :---: | -| 33.10 | 49.18 | 56.54 | 75.12 | 67.16% | +| 33.23 | 49.16 | 56.59 | 75.27 | 67.18% | ## 模型推断 @@ -120,7 +120,7 @@ BMN的训练数据采用ActivityNet1.3提供的数据集,我们提供了处理 python predict.py --weights=$PATH_TO_WEIGHTS \ --filelist=$FILELIST -- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数为训练好的权重参数,若未指定,脚本会下载已发布的模型[model](https://paddlemodels.bj.bcebos.com/hapi/bmn.pdparams)进行预测。 +- 使用python命令行启动程序时,`--filelist`参数指定待推断的文件列表,如果不设置,默认为./infer.list。`--weights`参数指定训练好的权重参数,如--weights='./checkpoint/final.pdparams'。若未指定,脚本会下载已发布的模型[model](https://paddlemodels.bj.bcebos.com/hapi/bmn.pdparams)进行预测。 - 上述程序会将运行结果保存在`--output_path`参数指定的文件夹下,默认为output/INFER/BMN\_results;测试结果保存在`--result_path`参数指定的文件夹下,默认为predict\_results。 diff --git a/bmn/bmn.yaml b/bmn/bmn.yaml index 1cc8995..6599767 100644 --- a/bmn/bmn.yaml +++ b/bmn/bmn.yaml @@ -7,7 +7,7 @@ MODEL: num_sample: 32 num_sample_perbin: 3 anno_file: "./activitynet_1.3_annotations.json" - feat_path: './fix_feat_100' + feat_path: "./fix_feat_100" TRAIN: subset: "train" diff --git a/bmn/eval.py b/bmn/eval.py index 2992458..1f720a3 100644 --- a/bmn/eval.py +++ b/bmn/eval.py @@ -17,7 +17,6 @@ import os import sys import logging -import paddle.fluid as fluid from modeling import bmn, BmnLoss from bmn_metric import BmnMetric @@ -35,10 +34,7 @@ def parse_args(): parser = argparse.ArgumentParser("BMN test for performance evaluation.") parser.add_argument( - "-d", - "--dynamic", - action='store_true', - help="enable dygraph mode, only support dynamic mode at present time") + "-s", "--static", action='store_true', help="enable static mode") parser.add_argument( '--config_file', type=str, @@ -77,8 +73,8 @@ def parse_args(): # Performance Evaluation def test_bmn(args): + paddle.enable_static() if args.static else None device = paddle.set_device(args.device) - paddle.disable_static(device) if args.dynamic else None #config setting config = parse_config(args.config_file) @@ -110,7 +106,7 @@ def test_bmn(args): #load checkpoint if args.weights is not None: - assert os.path.exists(args.weights + '.pdparams'), \ + assert os.path.exists(args.weights), \ "Given weight dir {} not exist.".format(args.weights) logger.info('load test weights from {}'.format(args.weights)) model.load(args.weights) diff --git a/bmn/modeling.py b/bmn/modeling.py index 96883d9..401e39d 100644 --- a/bmn/modeling.py +++ b/bmn/modeling.py @@ -13,8 +13,8 @@ #limitations under the License. import paddle -import paddle.fluid as fluid -from paddle.fluid import ParamAttr +import paddle.nn.functional as F +from paddle import ParamAttr import numpy as np import math @@ -86,54 +86,24 @@ def get_interp1d_mask(tscale, dscale, prop_boundary_ratio, num_sample, return sample_mask -# Net -class Conv1D(fluid.dygraph.Layer): - def __init__(self, - prefix, - num_channels=256, - num_filters=256, - size_k=3, - padding=1, - groups=1, - act="relu"): - super(Conv1D, self).__init__() - fan_in = num_channels * size_k * 1 - k = 1. / math.sqrt(fan_in) - param_attr = ParamAttr( - name=prefix + "_w", - initializer=fluid.initializer.Uniform( - low=-k, high=k)) - bias_attr = ParamAttr( - name=prefix + "_b", - initializer=fluid.initializer.Uniform( - low=-k, high=k)) - - self._conv2d = fluid.dygraph.Conv2D( - num_channels=num_channels, - num_filters=num_filters, - filter_size=(1, size_k), - stride=1, - padding=(0, padding), - groups=groups, - act=act, - param_attr=param_attr, - bias_attr=bias_attr) - - def forward(self, x): - x = fluid.layers.unsqueeze(input=x, axes=[2]) - x = self._conv2d(x) - x = fluid.layers.squeeze(input=x, axes=[2]) - return x +def init_params(name, in_channels, kernel_size): + fan_in = in_channels * kernel_size * 1 + k = 1. / math.sqrt(fan_in) + param_attr = ParamAttr( + name=name, initializer=paddle.nn.initializer.Uniform( + low=-k, high=k)) + return param_attr -class BMN(fluid.dygraph.Layer): +# Net +class BMN(paddle.nn.Layer): """BMN model from `"BMN: Boundary-Matching Network for Temporal Action Proposal Generation" `_ Args: tscale (int): sequence length, default 100. dscale (int): max duration length, default 100. - prop_boundary_ratio (float): ratio of expanded temporal region in proposal boundary, default 0.5. + prop_boundary_ratio (float): ratio of expanded temporal region in proposal boundary, default 0.5. num_sample (int): number of samples betweent starting boundary and ending boundary of each propoasl, default 32. num_sample_perbin (int): number of selected points in each sample, default 3. """ @@ -154,141 +124,179 @@ def __init__(self, tscale, dscale, prop_boundary_ratio, num_sample, self.hidden_dim_3d = 512 # Base Module - self.b_conv1 = Conv1D( - prefix="Base_1", - num_channels=400, - num_filters=self.hidden_dim_1d, - size_k=3, + self.b_conv1 = paddle.nn.Conv1D( + in_channels=400, + out_channels=self.hidden_dim_1d, + kernel_size=3, padding=1, groups=4, - act="relu") - self.b_conv2 = Conv1D( - prefix="Base_2", - num_filters=self.hidden_dim_1d, - size_k=3, + weight_attr=init_params('Base_1_w', 400, 3), + bias_attr=init_params('Base_1_b', 400, 3)) + self.b_conv1_act = paddle.nn.ReLU() + self.b_conv2 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=self.hidden_dim_1d, + kernel_size=3, padding=1, groups=4, - act="relu") + weight_attr=init_params('Base_2_w', self.hidden_dim_1d, 3), + bias_attr=init_params('Base_2_b', self.hidden_dim_1d, 3)) + self.b_conv2_act = paddle.nn.ReLU() # Temporal Evaluation Module - self.ts_conv1 = Conv1D( - prefix="TEM_s1", - num_filters=self.hidden_dim_1d, - size_k=3, + self.ts_conv1 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=self.hidden_dim_1d, + kernel_size=3, padding=1, groups=4, - act="relu") - self.ts_conv2 = Conv1D( - prefix="TEM_s2", num_filters=1, size_k=1, padding=0, act="sigmoid") - self.te_conv1 = Conv1D( - prefix="TEM_e1", - num_filters=self.hidden_dim_1d, - size_k=3, + weight_attr=init_params('TEM_s1_w', self.hidden_dim_1d, 3), + bias_attr=init_params('TEM_s1_b', self.hidden_dim_1d, 3)) + self.ts_conv1_act = paddle.nn.ReLU() + self.ts_conv2 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=1, + kernel_size=1, + padding=0, + groups=1, + weight_attr=init_params('TEM_s2_w', self.hidden_dim_1d, 1), + bias_attr=init_params('TEM_s2_b', self.hidden_dim_1d, 1)) + self.ts_conv2_act = paddle.nn.Sigmoid() + + self.te_conv1 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=self.hidden_dim_1d, + kernel_size=3, padding=1, groups=4, - act="relu") - self.te_conv2 = Conv1D( - prefix="TEM_e2", num_filters=1, size_k=1, padding=0, act="sigmoid") + weight_attr=init_params('TEM_e1_w', self.hidden_dim_1d, 3), + bias_attr=init_params('TEM_e1_b', self.hidden_dim_1d, 3)) + self.te_conv1_act = paddle.nn.ReLU() + self.te_conv2 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=1, + kernel_size=1, + padding=0, + groups=1, + weight_attr=init_params('TEM_e2_w', self.hidden_dim_1d, 1), + bias_attr=init_params('TEM_e2_b', self.hidden_dim_1d, 1)) + self.te_conv2_act = paddle.nn.Sigmoid() #Proposal Evaluation Module - self.p_conv1 = Conv1D( - prefix="PEM_1d", - num_filters=self.hidden_dim_2d, - size_k=3, + self.p_conv1 = paddle.nn.Conv1D( + in_channels=self.hidden_dim_1d, + out_channels=self.hidden_dim_2d, + kernel_size=3, padding=1, - act="relu") + groups=1, + weight_attr=init_params('PEM_1d_w', self.hidden_dim_1d, 3), + bias_attr=init_params('PEM_1d_b', self.hidden_dim_1d, 3)) + self.p_conv1_act = paddle.nn.ReLU() - # get sample mask + # get sample mask sample_mask_array = get_interp1d_mask( self.tscale, self.dscale, self.prop_boundary_ratio, self.num_sample, self.num_sample_perbin) - self.sample_mask = fluid.layers.create_parameter( + self.sample_mask = paddle.static.create_parameter( shape=[self.tscale, self.num_sample * self.dscale * self.tscale], dtype=DATATYPE, - attr=fluid.ParamAttr( + attr=ParamAttr( name="sample_mask", trainable=False), - default_initializer=fluid.initializer.NumpyArrayInitializer( + default_initializer=paddle.nn.initializer.Assign( sample_mask_array)) - self.sample_mask.stop_gradient = True - self.p_conv3d1 = fluid.dygraph.Conv3D( - num_channels=128, - num_filters=self.hidden_dim_3d, - filter_size=(self.num_sample, 1, 1), + self.p_conv3d1 = paddle.nn.Conv3D( + in_channels=128, + out_channels=self.hidden_dim_3d, + kernel_size=(self.num_sample, 1, 1), stride=(self.num_sample, 1, 1), padding=0, - act="relu", - param_attr=ParamAttr(name="PEM_3d1_w"), + weight_attr=ParamAttr(name="PEM_3d1_w"), bias_attr=ParamAttr(name="PEM_3d1_b")) + self.p_conv3d1_act = paddle.nn.ReLU() - self.p_conv2d1 = fluid.dygraph.Conv2D( - num_channels=512, - num_filters=self.hidden_dim_2d, - filter_size=1, + self.p_conv2d1 = paddle.nn.Conv2D( + in_channels=512, + out_channels=self.hidden_dim_2d, + kernel_size=1, stride=1, padding=0, - act="relu", - param_attr=ParamAttr(name="PEM_2d1_w"), + weight_attr=ParamAttr(name="PEM_2d1_w"), bias_attr=ParamAttr(name="PEM_2d1_b")) - self.p_conv2d2 = fluid.dygraph.Conv2D( - num_channels=128, - num_filters=self.hidden_dim_2d, - filter_size=3, + self.p_conv2d1_act = paddle.nn.ReLU() + + self.p_conv2d2 = paddle.nn.Conv2D( + in_channels=128, + out_channels=self.hidden_dim_2d, + kernel_size=3, stride=1, padding=1, - act="relu", - param_attr=ParamAttr(name="PEM_2d2_w"), + weight_attr=ParamAttr(name="PEM_2d2_w"), bias_attr=ParamAttr(name="PEM_2d2_b")) - self.p_conv2d3 = fluid.dygraph.Conv2D( - num_channels=128, - num_filters=self.hidden_dim_2d, - filter_size=3, + self.p_conv2d2_act = paddle.nn.ReLU() + + self.p_conv2d3 = paddle.nn.Conv2D( + in_channels=128, + out_channels=self.hidden_dim_2d, + kernel_size=3, stride=1, padding=1, - act="relu", - param_attr=ParamAttr(name="PEM_2d3_w"), + weight_attr=ParamAttr(name="PEM_2d3_w"), bias_attr=ParamAttr(name="PEM_2d3_b")) - self.p_conv2d4 = fluid.dygraph.Conv2D( - num_channels=128, - num_filters=2, - filter_size=1, + self.p_conv2d3_act = paddle.nn.ReLU() + + self.p_conv2d4 = paddle.nn.Conv2D( + in_channels=128, + out_channels=2, + kernel_size=1, stride=1, padding=0, - act="sigmoid", - param_attr=ParamAttr(name="PEM_2d4_w"), + weight_attr=ParamAttr(name="PEM_2d4_w"), bias_attr=ParamAttr(name="PEM_2d4_b")) + self.p_conv2d4_act = paddle.nn.Sigmoid() def forward(self, x): #Base Module x = self.b_conv1(x) + x = self.b_conv1_act(x) x = self.b_conv2(x) + x = self.b_conv2_act(x) #TEM xs = self.ts_conv1(x) + xs = self.ts_conv1_act(xs) xs = self.ts_conv2(xs) - xs = fluid.layers.squeeze(xs, axes=[1]) + xs = self.ts_conv2_act(xs) + xs = paddle.squeeze(xs, axis=[1]) xe = self.te_conv1(x) + xe = self.te_conv1_act(xe) xe = self.te_conv2(xe) - xe = fluid.layers.squeeze(xe, axes=[1]) + xe = self.te_conv2_act(xe) + xe = paddle.squeeze(xe, axis=[1]) #PEM xp = self.p_conv1(x) + xp = self.p_conv1_act(xp) #BM layer - xp = fluid.layers.matmul(xp, self.sample_mask) - xp = fluid.layers.reshape( - xp, shape=[0, 0, -1, self.dscale, self.tscale]) + xp = paddle.matmul(xp, self.sample_mask) + xp = paddle.reshape(xp, shape=[0, 0, -1, self.dscale, self.tscale]) xp = self.p_conv3d1(xp) - xp = fluid.layers.squeeze(xp, axes=[2]) + xp = self.p_conv3d1_act(xp) + xp = paddle.squeeze(xp, axis=[2]) xp = self.p_conv2d1(xp) + xp = self.p_conv2d1_act(xp) xp = self.p_conv2d2(xp) + xp = self.p_conv2d2_act(xp) xp = self.p_conv2d3(xp) + xp = self.p_conv2d3_act(xp) xp = self.p_conv2d4(xp) + xp = self.p_conv2d4_act(xp) return xp, xs, xe -class BmnLoss(fluid.dygraph.Layer): +class BmnLoss(paddle.nn.Layer): """Loss for BMN model Args: @@ -302,44 +310,40 @@ def __init__(self, tscale, dscale): self.dscale = dscale def _get_mask(self): - bm_mask = [] + bm_mask_list = [] for idx in range(self.dscale): mask_vector = [1 for i in range(self.tscale - idx) ] + [0 for i in range(idx)] - bm_mask.append(mask_vector) - bm_mask = np.array(bm_mask, dtype=np.float32) - self_bm_mask = fluid.layers.create_global_var( + bm_mask_list.append(mask_vector) + bm_mask_array = np.array(bm_mask_list, dtype=np.float32) + self_bm_mask = paddle.static.create_global_var( shape=[self.dscale, self.tscale], value=0, dtype=DATATYPE, persistable=True) - fluid.layers.assign(bm_mask, self_bm_mask) + paddle.assign(bm_mask_array, self_bm_mask) + self_bm_mask.stop_gradient = True return self_bm_mask def tem_loss_func(self, pred_start, pred_end, gt_start, gt_end): def bi_loss(pred_score, gt_label): - pred_score = fluid.layers.reshape( - x=pred_score, shape=[-1], inplace=False) - gt_label = fluid.layers.reshape( - x=gt_label, shape=[-1], inplace=False) + pred_score = paddle.reshape(x=pred_score, shape=[-1]) + gt_label = paddle.reshape(x=gt_label, shape=[-1]) gt_label.stop_gradient = True - pmask = fluid.layers.cast(x=(gt_label > 0.5), dtype=DATATYPE) - num_entries = fluid.layers.cast( - fluid.layers.shape(pmask), dtype=DATATYPE) - num_positive = fluid.layers.cast( - fluid.layers.reduce_sum(pmask), dtype=DATATYPE) + pmask = paddle.cast(x=(gt_label > 0.5), dtype=DATATYPE) + num_entries = paddle.cast(paddle.shape(pmask), dtype=DATATYPE) + num_positive = paddle.cast(paddle.sum(pmask), dtype=DATATYPE) ratio = num_entries / num_positive coef_0 = 0.5 * ratio / (ratio - 1) coef_1 = 0.5 * ratio epsilon = 0.000001 - temp = fluid.layers.log(pred_score + epsilon) - loss_pos = fluid.layers.elementwise_mul( - fluid.layers.log(pred_score + epsilon), pmask) - loss_pos = coef_1 * fluid.layers.reduce_mean(loss_pos) - loss_neg = fluid.layers.elementwise_mul( - fluid.layers.log(1.0 - pred_score + epsilon), (1.0 - pmask)) - loss_neg = coef_0 * fluid.layers.reduce_mean(loss_neg) + temp = paddle.log(pred_score + epsilon) + loss_pos = paddle.multiply(paddle.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * paddle.mean(loss_pos) + loss_neg = paddle.multiply( + paddle.log(1.0 - pred_score + epsilon), (1.0 - pmask)) + loss_neg = coef_0 * paddle.mean(loss_neg) loss = -1 * (loss_pos + loss_neg) return loss @@ -349,69 +353,62 @@ def bi_loss(pred_score, gt_label): return loss def pem_reg_loss_func(self, pred_score, gt_iou_map, mask): + gt_iou_map = paddle.multiply(gt_iou_map, mask) - gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + u_hmask = paddle.cast(x=gt_iou_map > 0.7, dtype=DATATYPE) + u_mmask = paddle.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3) + u_mmask = paddle.cast(x=u_mmask, dtype=DATATYPE) + u_lmask = paddle.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.) + u_lmask = paddle.cast(x=u_lmask, dtype=DATATYPE) + u_lmask = paddle.multiply(u_lmask, mask) - u_hmask = fluid.layers.cast(x=gt_iou_map > 0.7, dtype=DATATYPE) - u_mmask = fluid.layers.logical_and(gt_iou_map <= 0.7, gt_iou_map > 0.3) - u_mmask = fluid.layers.cast(x=u_mmask, dtype=DATATYPE) - u_lmask = fluid.layers.logical_and(gt_iou_map <= 0.3, gt_iou_map >= 0.) - u_lmask = fluid.layers.cast(x=u_lmask, dtype=DATATYPE) - u_lmask = fluid.layers.elementwise_mul(u_lmask, mask) - - num_h = fluid.layers.cast( - fluid.layers.reduce_sum(u_hmask), dtype=DATATYPE) - num_m = fluid.layers.cast( - fluid.layers.reduce_sum(u_mmask), dtype=DATATYPE) - num_l = fluid.layers.cast( - fluid.layers.reduce_sum(u_lmask), dtype=DATATYPE) + num_h = paddle.cast(paddle.sum(u_hmask), dtype=DATATYPE) + num_m = paddle.cast(paddle.sum(u_mmask), dtype=DATATYPE) + num_l = paddle.cast(paddle.sum(u_lmask), dtype=DATATYPE) r_m = num_h / num_m - u_smmask = fluid.layers.uniform_random( + u_smmask = paddle.uniform( shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], dtype=DATATYPE, min=0.0, max=1.0) - u_smmask = fluid.layers.elementwise_mul(u_mmask, u_smmask) - u_smmask = fluid.layers.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE) + u_smmask = paddle.multiply(u_mmask, u_smmask) + u_smmask = paddle.cast(x=(u_smmask > (1. - r_m)), dtype=DATATYPE) r_l = num_h / num_l - u_slmask = fluid.layers.uniform_random( + u_slmask = paddle.uniform( shape=[gt_iou_map.shape[1], gt_iou_map.shape[2]], dtype=DATATYPE, min=0.0, max=1.0) - u_slmask = fluid.layers.elementwise_mul(u_lmask, u_slmask) - u_slmask = fluid.layers.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE) + u_slmask = paddle.multiply(u_lmask, u_slmask) + u_slmask = paddle.cast(x=(u_slmask > (1. - r_l)), dtype=DATATYPE) weights = u_hmask + u_smmask + u_slmask weights.stop_gradient = True - loss = fluid.layers.square_error_cost(pred_score, gt_iou_map) - loss = fluid.layers.elementwise_mul(loss, weights) - loss = 0.5 * fluid.layers.reduce_sum(loss) / fluid.layers.reduce_sum( - weights) - + loss = F.square_error_cost(pred_score, gt_iou_map) + loss = paddle.multiply(loss, weights) + loss = 0.5 * paddle.sum(loss) / paddle.sum(weights) return loss def pem_cls_loss_func(self, pred_score, gt_iou_map, mask): - gt_iou_map = fluid.layers.elementwise_mul(gt_iou_map, mask) + gt_iou_map = paddle.multiply(gt_iou_map, mask) gt_iou_map.stop_gradient = True - pmask = fluid.layers.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE) - nmask = fluid.layers.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE) - nmask = fluid.layers.elementwise_mul(nmask, mask) + pmask = paddle.cast(x=(gt_iou_map > 0.9), dtype=DATATYPE) + nmask = paddle.cast(x=(gt_iou_map <= 0.9), dtype=DATATYPE) + nmask = paddle.multiply(nmask, mask) - num_positive = fluid.layers.reduce_sum(pmask) - num_entries = num_positive + fluid.layers.reduce_sum(nmask) + num_positive = paddle.sum(pmask) + num_entries = num_positive + paddle.sum(nmask) ratio = num_entries / num_positive coef_0 = 0.5 * ratio / (ratio - 1) coef_1 = 0.5 * ratio epsilon = 0.000001 - loss_pos = fluid.layers.elementwise_mul( - fluid.layers.log(pred_score + epsilon), pmask) - loss_pos = coef_1 * fluid.layers.reduce_sum(loss_pos) - loss_neg = fluid.layers.elementwise_mul( - fluid.layers.log(1.0 - pred_score + epsilon), nmask) - loss_neg = coef_0 * fluid.layers.reduce_sum(loss_neg) + loss_pos = paddle.multiply(paddle.log(pred_score + epsilon), pmask) + loss_pos = coef_1 * paddle.sum(loss_pos) + loss_neg = paddle.multiply( + paddle.log(1.0 - pred_score + epsilon), nmask) + loss_neg = coef_0 * paddle.sum(loss_neg) loss = -1 * (loss_pos + loss_neg) / num_entries return loss @@ -423,14 +420,12 @@ def forward(self, gt_start, gt_end, video_index=None): - pred_bm_reg = fluid.layers.squeeze( - fluid.layers.slice( - pred_bm, axes=[1], starts=[0], ends=[1]), - axes=[1]) - pred_bm_cls = fluid.layers.squeeze( - fluid.layers.slice( - pred_bm, axes=[1], starts=[1], ends=[2]), - axes=[1]) + pred_bm_reg = paddle.squeeze( + paddle.slice( + pred_bm, axes=[1], starts=[0], ends=[1]), axis=[1]) + pred_bm_cls = paddle.squeeze( + paddle.slice( + pred_bm, axes=[1], starts=[1], ends=[2]), axis=[1]) bm_mask = self._get_mask() @@ -452,11 +447,11 @@ def bmn(tscale, mode, pretrained=True): """BMN model - + Args: tscale (int): sequence length, default 100. dscale (int): max duration length, default 100. - prop_boundary_ratio (float): ratio of expanded temporal region in proposal boundary, default 0.5. + prop_boundary_ratio (float): ratio of expanded temporal region in proposal boundary, default 0.5. num_sample (int): number of samples betweent starting boundary and ending boundary of each propoasl, default 32. num_sample_perbin (int): number of selected points in each sample, default 3. pretrained (bool): If True, returns a model with pre-trained model, default True. @@ -483,6 +478,6 @@ def bmn(tscale, if pretrained: weight_path = get_weights_path_from_url(*(pretrain_infos['bmn'])) assert weight_path.endswith('.pdparams'), \ - "suffix of weight must be .pdparams" + "suffix of weight must be .pdparams" model.load(weight_path) return model diff --git a/bmn/predict.py b/bmn/predict.py index d585050..9c106ef 100644 --- a/bmn/predict.py +++ b/bmn/predict.py @@ -17,7 +17,6 @@ import os import logging import paddle -import paddle.fluid as fluid from modeling import bmn, BmnLoss from bmn_metric import BmnMetric @@ -35,10 +34,7 @@ def parse_args(): parser = argparse.ArgumentParser("BMN inference.") parser.add_argument( - "-d", - "--dynamic", - action='store_true', - help="enable dygraph mode, only support dynamic mode at present time") + "-s", "--static", action='store_true', help="enable static mode") parser.add_argument( '--config_file', type=str, @@ -82,8 +78,8 @@ def parse_args(): # Prediction def infer_bmn(args): + paddle.enable_static() if args.static else None device = paddle.set_device(args.device) - paddle.disable_static(device) if args.dynamic else None #config setting config = parse_config(args.config_file) @@ -114,8 +110,8 @@ def infer_bmn(args): # load checkpoint if args.weights is not None: assert os.path.exists( - args.weights + - ".pdparams"), "Given weight dir {} not exist.".format(args.weights) + args.weights), "Given weight dir {} not exist.".format( + args.weights) logger.info('load test weights from {}'.format(args.weights)) model.load(args.weights) diff --git a/bmn/run.sh b/bmn/run.sh index 5c840a2..c8043b5 100644 --- a/bmn/run.sh +++ b/bmn/run.sh @@ -2,8 +2,8 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 start_time=$(date +%s) -python -m paddle.distributed.launch train.py -d +python3.7 train.py end_time=$(date +%s) cost_time=$[ $end_time-$start_time ] -echo "4 card static training time is $(($cost_time/60))min $(($cost_time%60))s" +echo "4 card dynamic training time is $(($cost_time/60))min $(($cost_time%60))s" diff --git a/bmn/train.py b/bmn/train.py index f15444c..8da5d8e 100644 --- a/bmn/train.py +++ b/bmn/train.py @@ -13,7 +13,7 @@ #limitations under the License. import paddle -import paddle.fluid as fluid +import paddle.distributed as dist import argparse import logging import sys @@ -34,7 +34,7 @@ def parse_args(): parser = argparse.ArgumentParser("Paddle high level api of BMN.") parser.add_argument( - "-d", "--dynamic", action='store_true', help="enable dygraph mode") + "-s", "--static", action='store_true', help="enable static mode") parser.add_argument( '--config_file', type=str, @@ -92,19 +92,18 @@ def optimizer(cfg, parameter_list): lr_decay = cfg.TRAIN.learning_rate_decay l2_weight_decay = cfg.TRAIN.l2_weight_decay lr = [base_lr, base_lr * lr_decay] - optimizer = fluid.optimizer.Adam( - fluid.layers.piecewise_decay( - boundaries=bd, values=lr), - parameter_list=parameter_list, - regularization=fluid.regularizer.L2DecayRegularizer( - regularization_coeff=l2_weight_decay)) + scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr) + optimizer = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=parameter_list, + weight_decay=l2_weight_decay) return optimizer # TRAIN def train_bmn(args): + paddle.enable_static() if args.static else None device = paddle.set_device(args.device) - paddle.disable_static(device) if args.dynamic else None if not os.path.isdir(args.save_dir): os.makedirs(args.save_dir) @@ -139,19 +138,30 @@ def train_bmn(args): # if resume weights is given, load resume weights directly if args.resume is not None: - model.load(args.resume) - model.fit(train_data=train_dataset, - eval_data=val_dataset, - batch_size=train_cfg.TRAIN.batch_size, - epochs=train_cfg.TRAIN.epoch, - eval_freq=args.valid_interval, - log_freq=args.log_interval, - save_dir=args.save_dir, - shuffle=train_cfg.TRAIN.use_shuffle, - num_workers=train_cfg.TRAIN.num_workers, - drop_last=True) + assert os.path.exists( + args.resume + '.pdparams' + ), "Given weight dir {}.pdparams not exist.".format(args.resume) + assert os.path.exists(args.resume + '.pdopt' + ), "Given weight dir {}.pdopt not exist.".format( + args.resume) + model.load(args.resume + '.pdparams') + optim.load(args.resume + '.pdopt') + + model.fit( + train_data=train_dataset, + eval_data=val_dataset, + batch_size=train_cfg.TRAIN.batch_size, #batch_size of one card + epochs=train_cfg.TRAIN.epoch, + eval_freq=args.valid_interval, + log_freq=args.log_interval, + save_dir=args.save_dir, + shuffle=train_cfg.TRAIN.use_shuffle, + num_workers=train_cfg.TRAIN.num_workers, + drop_last=True) if __name__ == "__main__": args = parse_args() - train_bmn(args) + dist.spawn( + train_bmn, args=(args, ), + nprocs=4) # if single-card training please set "nprocs=1"