From b6c8f7914ad409e2e3879a80f684408283ef4595 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Tue, 21 Jul 2020 13:58:56 -0700 Subject: [PATCH 1/2] Remove "static inputs" for reduction ops --- ngraph_bridge/ngraph_builder.cc | 73 ++++++--------------- ngraph_bridge/ngraph_mark_for_clustering.cc | 7 -- 2 files changed, 21 insertions(+), 59 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 24529aefd..9f219d0a9 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2079,45 +2079,8 @@ static Status TranslateNonMaxSuppressionV4Op( return Status::OK(); } -static Status TranslateReduceOp( - const Node* op, const std::vector& static_input_map, - Builder::OpMap& ng_op_map, - std::function(ng::Output, - ng::Output, const bool)> - create_ng_node) { - ng::Output ng_input; - TF_RETURN_IF_ERROR(GetInputNode(ng_op_map, op, 0, ng_input)); - bool tf_keep_dims; - if (GetNodeAttr(op->attrs(), "keep_dims", &tf_keep_dims) != Status::OK()) { - tf_keep_dims = false; - } - - std::vector axes; - TF_RETURN_IF_ERROR(GetStaticInputVector(op, 1, static_input_map, &axes)); - - ng::Shape input_shape = ng_input.get_shape(); - size_t input_rank = input_shape.size(); - - TF_RETURN_IF_ERROR(CheckAxisDimInRange(axes, input_rank)); - - std::vector ng_reduction_axes_vect(axes.size()); - std::transform( - axes.begin(), axes.end(), ng_reduction_axes_vect.begin(), - [input_rank](int idx) { return idx + (idx < 0 ? (int)input_rank : 0); }); - auto ng_reduction_axes = ConstructNgNode( - op->name(), ng::element::i64, ng::Shape{ng_reduction_axes_vect.size()}, - ng_reduction_axes_vect); - - ng::Output ng_node = - create_ng_node(ng_input, ng_reduction_axes, tf_keep_dims); - Builder::SetTracingInfo(op->name(), ng_node); - - SaveNgOp(ng_op_map, op->name(), ng_node); - return Status::OK(); -} - template -static Status TranslateDirectReduceOp( +static Status TranslateReduceOp( const Node* op, const std::vector& static_input_map, Builder::OpMap& ng_op_map) { // ensure its either an arithmetic or a logical reduction @@ -2127,13 +2090,19 @@ static Status TranslateDirectReduceOp( "Expected node to be either a valid logical or arithmetic reduction " "type"); } - return TranslateReduceOp( - op, static_input_map, ng_op_map, - [&op](ng::Output ng_input, - ng::Output ng_reduction_axes, const bool keep_dims) { - return ConstructNgNode(op->name(), ng_input, ng_reduction_axes, - keep_dims); - }); + + shared_ptr ng_input, ng_reduction_indices; + TF_RETURN_IF_ERROR( + GetInputNodes(ng_op_map, op, &ng_input, &ng_reduction_indices)); + bool keep_dims; + if (GetNodeAttr(op->attrs(), "keep_dims", &keep_dims) != Status::OK()) { + keep_dims = false; + } + + auto ng_node = + ConstructNgNode(op->name(), ng_input, ng_reduction_indices, keep_dims); + SaveNgOp(ng_op_map, op->name(), ng_node); + return Status::OK(); } static Status TranslateOneHotOp( @@ -3002,8 +2971,8 @@ const static std::map< {"Add", TranslateBinaryOp}, {"AddN", TranslateAddNOp}, {"AddV2", TranslateBinaryOp}, - {"Any", TranslateDirectReduceOp}, - {"All", TranslateDirectReduceOp}, + {"Any", TranslateReduceOp}, + {"All", TranslateReduceOp}, {"ArgMax", TranslateArgMaxOp}, {"ArgMin", TranslateArgMinOp}, {"Asin", TranslateUnaryOp}, @@ -3053,13 +3022,13 @@ const static std::map< {"LogicalNot", TranslateUnaryOp}, {"LogicalOr", TranslateBinaryOp}, {"MatMul", TranslateMatMulOp}, - {"Max", TranslateDirectReduceOp}, + {"Max", TranslateReduceOp}, {"Maximum", TranslateBinaryOp}, {"MaxPool", TranslateMaxPoolOp}, {"MaxPool3D", TranslateMaxPool3DOp}, {"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op}, - {"Mean", TranslateDirectReduceOp}, - {"Min", TranslateDirectReduceOp}, + {"Mean", TranslateReduceOp}, + {"Min", TranslateReduceOp}, {"Minimum", TranslateBinaryOp}, {"MirrorPad", TranslatePadOp}, {"Mul", TranslateBinaryOp}, @@ -3077,7 +3046,7 @@ const static std::map< {"Pow", TranslateBinaryOp}, // PreventGradient is just Identity in dataflow terms, so reuse that. {"PreventGradient", TranslateIdentityOp}, - {"Prod", TranslateDirectReduceOp}, + {"Prod", TranslateReduceOp}, {"Rank", TranslateRankOp}, {"RealDiv", TranslateBinaryOp}, {"Reciprocal", TranslateReciprocalOp}, @@ -3106,7 +3075,7 @@ const static std::map< {"Squeeze", TranslateSqueezeOp}, {"StridedSlice", TranslateStridedSliceOp}, {"Sub", TranslateBinaryOp}, - {"Sum", TranslateDirectReduceOp}, + {"Sum", TranslateReduceOp}, {"Tan", TranslateUnaryOp}, {"Tanh", TranslateUnaryOp}, {"Tile", TranslateTileOp}, diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index 7b3133c6f..c733dad7c 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -195,8 +195,6 @@ const std::map& GetAttributeSetters() { if (!initialized) { // Set Additional Attributes (if any) - set_attributes_map["Any"] = SetStaticInputs({1}); - set_attributes_map["All"] = SetStaticInputs({1}); set_attributes_map["ArgMax"] = SetStaticInputs({1}); set_attributes_map["ArgMin"] = SetStaticInputs({1}); set_attributes_map["ConcatV2"] = SetStaticInputs({-1}); @@ -204,15 +202,11 @@ const std::map& GetAttributeSetters() { set_attributes_map["ExpandDims"] = SetStaticInputs({1}); set_attributes_map["Fill"] = SetStaticInputs({0}); set_attributes_map["GatherV2"] = SetStaticInputs({2}); - set_attributes_map["Max"] = SetStaticInputs({1}); - set_attributes_map["Mean"] = SetStaticInputs({1}); - set_attributes_map["Min"] = SetStaticInputs({1}); set_attributes_map["MirrorPad"] = SetStaticInputs({1}); set_attributes_map["NonMaxSuppressionV4"] = SetStaticInputs({2, 3, 4}); set_attributes_map["OneHot"] = SetStaticInputs({1}); set_attributes_map["Pad"] = SetStaticInputs({1}); set_attributes_map["PadV2"] = SetStaticInputs({1, 2}); - set_attributes_map["Prod"] = SetStaticInputs({1}); set_attributes_map["Reshape"] = SetStaticInputs({1}); set_attributes_map["Shape"] = SetStaticInputs({0}); set_attributes_map["ScatterNd"] = SetStaticInputs({2}); @@ -220,7 +214,6 @@ const std::map& GetAttributeSetters() { set_attributes_map["Split"] = SetStaticInputs({0}); set_attributes_map["SplitV"] = SetStaticInputs({1, 2}); set_attributes_map["StridedSlice"] = SetStaticInputs({1, 2, 3}); - set_attributes_map["Sum"] = SetStaticInputs({1}); set_attributes_map["TopKV2"] = SetStaticInputs({1}); set_attributes_map["Tile"] = SetStaticInputs({1}); set_attributes_map["Transpose"] = SetStaticInputs({1}); From c741fa829271808d7f753e859e64c11e826d76a4 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni Date: Mon, 21 Sep 2020 17:38:40 -0700 Subject: [PATCH 2/2] Compilation fix --- ngraph_bridge/ngraph_builder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ngraph_bridge/ngraph_builder.cc b/ngraph_bridge/ngraph_builder.cc index 9f219d0a9..f0c072ef8 100644 --- a/ngraph_bridge/ngraph_builder.cc +++ b/ngraph_bridge/ngraph_builder.cc @@ -2091,9 +2091,9 @@ static Status TranslateReduceOp( "type"); } - shared_ptr ng_input, ng_reduction_indices; + ng::Output ng_input, ng_reduction_indices; TF_RETURN_IF_ERROR( - GetInputNodes(ng_op_map, op, &ng_input, &ng_reduction_indices)); + GetInputNodes(ng_op_map, op, ng_input, ng_reduction_indices)); bool keep_dims; if (GetNodeAttr(op->attrs(), "keep_dims", &keep_dims) != Status::OK()) { keep_dims = false;