From e2220551b7a64b929650ba9a60529c31e70c13c5 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 20 Jul 2018 13:23:44 +0100 Subject: COMPMID-1367: Enable NHWC in graph examples Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena Reviewed-by: Anthony Barbier Tested-by: Jenkins --- src/graph/mutators/DepthConcatSubTensorMutator.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'src/graph/mutators') diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp index c56f4c5106..241c07b367 100644 --- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp +++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp @@ -25,8 +25,9 @@ #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/Utils.h" #include "arm_compute/graph/backends/BackendRegistry.h" -#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h" +#include "arm_compute/graph/nodes/ConcatenateLayerNode.h" #include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/core/utils/misc/Iterable.h" @@ -45,11 +46,18 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) // Should be in reverse order of execution for(auto &node : arm_compute::utils::iterable::reverse_iterate(g.nodes())) { - if(node && node->type() == NodeType::DepthConcatenateLayer && node->output(0) != nullptr) + if(node && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr) { // Get output tensor auto output_tensor = node->output(0); + // Check concatenation axis (Sub-tensor optimization is support for concatenation axis >=2) + auto *concat_node = arm_compute::utils::cast::polymorphic_downcast(node.get()); + if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc(), concat_node->concatenation_axis()) < 2) + { + continue; + } + // Check that all tensor have the same target and valid inputs bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(), [&](const EdgeID & eid) @@ -76,7 +84,7 @@ void DepthConcatSubTensorMutator::mutate(Graph &g) depth += input_shape.z(); } - auto *dc_node = arm_compute::utils::cast::polymorphic_downcast(node.get()); + auto *dc_node = arm_compute::utils::cast::polymorphic_downcast(node.get()); dc_node->set_enabled(false); } } -- cgit v1.2.1