diff --git a/src/ngraph/pass/reshape_sinking.cpp b/src/ngraph/pass/reshape_sinking.cpp index 78ecc8b9eca..66a4077daf6 100644 --- a/src/ngraph/pass/reshape_sinking.cpp +++ b/src/ngraph/pass/reshape_sinking.cpp @@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr node) return ss.str(); } +static shared_ptr + make_reshape(shared_ptr arg, const AxisVector& input_order, const Shape& output_shape) +{ + auto reshape = make_shared(arg, input_order, output_shape); + NGRAPH_DEBUG << "Make Reshape " << describe_reshape(reshape); + return reshape; +} + +static void + write_reshapemap(ReshapeMap& reorders, shared_ptr target, shared_ptr reshape) +{ + NGRAPH_DEBUG << "Write ReshapeMap[" << target->get_name() + << "] = " << describe_reshape(reshape); + reorders[target] = reshape; +} + +static shared_ptr read_reshapemap(ReshapeMap& reorders, shared_ptr target) +{ + auto reorder = reorders.at(target); + NGRAPH_DEBUG << "Read ReshapeMap[" << target->get_name() << "] -> " + << describe_reshape(reorder); + return reorder; +} + static shared_ptr combine_reshapes(shared_ptr r1, shared_ptr r2) { auto default_order = ngraph::get_default_order(r1->get_shape()); auto perm_r1 = apply_permutation(default_order, r1->get_input_order()); auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order()); - auto rreshape = make_shared(r2->get_argument(0), perm_r2, r2->get_shape()); + auto rreshape = make_reshape(r2->get_argument(0), perm_r2, r2->get_shape()); + NGRAPH_DEBUG << "Combining " << describe_reshape(r1) << " and " << describe_reshape(r2) + << " into " << describe_reshape(rreshape); return rreshape; } static void insert_reshape(shared_ptr target, shared_ptr reshape, size_t input_index) { + NGRAPH_DEBUG << "Inserting reshape at input " << target->get_name() << " input index " + << input_index; auto arg = target->input(input_index).get_source_output(); + NGRAPH_DEBUG << "Arg shape: " << arg.get_shape(); auto new_reshape = reshape->copy_with_new_inputs({arg}); + NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input " + << target->get_name() << " input index " << input_index; target->input(input_index).replace_source_output(new_reshape->output(0)); } @@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr reshape, static shared_ptr create_default_reshape(shared_ptr n) { auto default_order = ngraph::get_default_order(n->get_shape()); - auto default_reshape = make_shared(n, default_order, n->get_shape()); + auto default_reshape = make_reshape(n, default_order, n->get_shape()); + NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape); return default_reshape; } @@ -187,7 +219,7 @@ void swim(Input input, shared_ptr reshape) auto new_arg_shape = ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order); broadcast_input = - make_shared(broadcast_input, new_source_axis_order, new_arg_shape); + make_reshape(broadcast_input, new_source_axis_order, new_arg_shape); } auto new_broadcast = make_shared( @@ -209,12 +241,11 @@ void swim(Input input, shared_ptr reshape) //of a binary op isn't in the default format (i.e. nhwc instead of nchw) //We have to normalize this other argument to nchw by swimming nchw towards parameters //as far as we can -static void convert_binary_to_default_order( - shared_ptr binary, - const Input& input, - shared_ptr right, - unordered_map, shared_ptr>& reorders, - set>& reshapes_to_delete) +static void convert_binary_to_default_order(shared_ptr binary, + const Input& input, + shared_ptr right, + ReshapeMap& reorders, + set>& reshapes_to_delete) { auto left = input.get_source_output().get_node_shared_ptr(); auto perm_to_def = @@ -222,13 +253,13 @@ static void convert_binary_to_default_order( auto new_shape = apply_permutation(left->get_shape(), perm_to_def); NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", " << right->get_name(); - auto new_reshape = make_shared(left, perm_to_def, new_shape); + auto new_reshape = make_reshape(left, perm_to_def, new_shape); NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to " << left->get_name(); //this should now insert and swim reshape on right swim(input, new_reshape); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); - reorders[binary] = reorders.at(right); + write_reshapemap(reorders, binary, read_reshapemap(reorders, right)); } static void materialize_shapes(shared_ptr n, @@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr n, auto arg = n->get_argument(i); if (reorders.count(arg) != 0) { - NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for " + auto arg_reshape = reorders.at(arg); + NGRAPH_DEBUG << "Materializing " << describe_reshape(arg_reshape) << " for " << arg->get_name(); - mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete); - if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape())) + mark_reshape_for_deletion(arg_reshape, reshapes_to_delete); + auto arg_shape = arg->get_shape(); + if (arg_reshape->get_input_order() != get_default_order(arg->get_shape())) { // Insert if arg needs to be transposed. - insert_reshape(n, reorders.at(arg), i); + insert_reshape(n, arg_reshape, i); } //no swimming up } } - reorders[n] = create_default_reshape(n); + write_reshapemap(reorders, n, create_default_reshape(n)); } static void sink_reshape(shared_ptr reshape, ReshapeMap& reorders, set>& reshapes_to_delete) { + NGRAPH_DEBUG << "Sinking Reshape :" << describe_reshape(reshape); auto orig_reshape = reorders.at(reshape->get_argument(0)); - if (!reshape->get_is_transpose()) + // 1) Not a Transpose or 2) Rank changing operation. + if ((reshape->get_output_shape().size() != reshape->get_input_order().size()) || + (!reshape->get_is_transpose())) { NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape " - << reshape->get_name(); + << describe_reshape(reshape); insert_reshape(reshape, orig_reshape, 0); mark_reshape_for_deletion(orig_reshape, reshapes_to_delete); - reorders[reshape] = create_default_reshape(reshape); + write_reshapemap(reorders, reshape, create_default_reshape(reshape)); } else { @@ -284,9 +320,7 @@ static void sink_reshape(shared_ptr reshape, //replace reshape with combined one ngraph::replace_node(reshape, new_reshape); mark_reshape_for_deletion(new_reshape, reshapes_to_delete); - reorders[new_reshape] = new_reshape; - NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and" - << describe_reshape(reshape) << " into " << describe_reshape(new_reshape); + write_reshapemap(reorders, new_reshape, new_reshape); } } @@ -294,9 +328,9 @@ static void sink_unary(shared_ptr n, ReshapeMap& reorders, set>& reshapes_to_delete) { - auto arg_reshape = reorders.at(n->get_argument(0)); + auto arg_reshape = read_reshapemap(reorders, n->get_argument(0)); NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name(); - reorders[n] = reorders[n->get_argument(0)]; + write_reshapemap(reorders, n, arg_reshape); } static void sink_binary(shared_ptr binary, @@ -310,7 +344,7 @@ static void sink_binary(shared_ptr binary { NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for " << binary->get_name(); - reorders[binary] = reorders.at(left); + write_reshapemap(reorders, binary, read_reshapemap(reorders, left)); //at this point, both reshapes will be eventually removed mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete); mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete); @@ -360,9 +394,9 @@ static void sink_slice(shared_ptr n, NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name(); ngraph::replace_node(n, new_slice); - auto new_reshape = make_shared(new_slice, order, n->get_shape()); + auto new_reshape = make_reshape(new_slice, order, n->get_shape()); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); - reorders[new_slice] = new_reshape; + write_reshapemap(reorders, new_slice, new_reshape); } static void @@ -385,9 +419,9 @@ static void ngraph::replace_node(dummy_correct_shape, n->get_argument(0)); NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name(); ngraph::replace_node(n, new_pad); - auto new_reshape = make_shared(new_pad, order, n->get_shape()); + auto new_reshape = make_reshape(new_pad, order, n->get_shape()); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); - reorders[new_pad] = new_reshape; + write_reshapemap(reorders, new_pad, new_reshape); } static void sink_quantize(shared_ptr quantize, ReshapeMap& reorders, @@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr quantize, quantize->get_round_mode()); ngraph::replace_node(quantize, new_quantize); - reorders[new_quantize] = arg_reshape; + write_reshapemap(reorders, new_quantize, arg_reshape); } static void sink_concat(shared_ptr n, @@ -451,9 +485,9 @@ static void sink_concat(shared_ptr n, NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name(); ngraph::replace_node(n, new_concat); - auto new_reshape = make_shared(new_concat, order, n->get_shape()); + auto new_reshape = make_reshape(new_concat, order, n->get_shape()); NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name(); - reorders[new_concat] = new_reshape; + write_reshapemap(reorders, new_concat, new_reshape); } static void sink_dequantize(shared_ptr dequantize, @@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr dequantize, axes_in_def_order); ngraph::replace_node(dequantize, new_dequantize); - reorders[new_dequantize] = arg_reshape; + write_reshapemap(reorders, new_dequantize, arg_reshape); } //The goal of ReshapeSinking is to remove @@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr //STEP 1 : Sink or Swim reshapes away for op clusters for (auto n : f->get_ordered_ops()) { - NGRAPH_DEBUG << "Processing node " << n->get_name(); + NGRAPH_DEBUG << "Start: Processing node " << n->get_name(); //collect all Result nodes for a sanity check if (n->is_output()) { @@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr } else if (auto goe = dynamic_pointer_cast(n)) { - reorders[goe] = create_default_reshape(goe); + write_reshapemap(reorders, goe, create_default_reshape(goe)); } else if (auto quantize = dynamic_pointer_cast(n)) { @@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr { materialize_shapes(n, reorders, reshapes_to_delete); } + NGRAPH_DEBUG << "End: Processing node " << n->get_name(); } //STEP 2: purge all the reshapes we either sunk or swam.