From 9e4824c909b14dbaf7106e9527b0ffa22ef09bdc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 12 Apr 2019 13:15:58 +0100 Subject: COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum Alters the concatenate layer to be layout agnostic and accept an index as thec concatenation axis instead of an typed layout dependent enumeration. Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/994 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- arm_compute/graph/LayerDescriptors.h | 2 +- arm_compute/graph/Utils.h | 4 ++-- arm_compute/graph/backends/FunctionHelpers.h | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) (limited to 'arm_compute/graph') diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h index 79099326ec..f52beab523 100644 --- a/arm_compute/graph/LayerDescriptors.h +++ b/arm_compute/graph/LayerDescriptors.h @@ -32,7 +32,7 @@ namespace graph { namespace descriptors { -/** Common node parameters */ +/** Concatenate layer descriptor */ struct ConcatLayerDescriptor { /** Default constructor */ diff --git a/arm_compute/graph/Utils.h b/arm_compute/graph/Utils.h index 4ffccec9be..2fa2f3b627 100644 --- a/arm_compute/graph/Utils.h +++ b/arm_compute/graph/Utils.h @@ -110,12 +110,12 @@ void release_default_graph_context(GraphContext &ctx); size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension); /** Get index of a tensor's given dimension depending on its layout * - * @param[in] descriptor Descriptor + * @param[in] data_layout Data layout of the tensor * @param[in] data_layout_dimension Tensor data layout dimension * * @return Idx of given dimension */ -size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension); +size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension); /** Get the list of driving nodes of a given node * * @param[in] node Node to find the driving node of diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h index e05f4bc8cf..f6e6286a19 100644 --- a/arm_compute/graph/backends/FunctionHelpers.h +++ b/arm_compute/graph/backends/FunctionHelpers.h @@ -28,6 +28,7 @@ #include "arm_compute/graph/Tensor.h" #include "arm_compute/graph/TypePrinter.h" #include "arm_compute/graph/Types.h" +#include "arm_compute/graph/Utils.h" #include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h" #include "arm_compute/graph/backends/Utils.h" #include "arm_compute/graph/nodes/Nodes.h" @@ -321,7 +322,8 @@ std::unique_ptr create_concatenate_layer(ConcatenateLaye inputs.push_back(get_backing_tensor(node.input(i))); } typename TargetInfo::TensorType *output = get_backing_tensor(node.output(0)); - const DataLayoutDimension concat_axis = node.concatenation_axis(); + const DataLayout data_layout = node.output(0) != nullptr ? node.output(0)->desc().layout : DataLayout::UNKNOWN; + const size_t concat_axis = get_dimension_idx(data_layout, node.concatenation_axis()); // Create and configure function auto func = support::cpp14::make_unique(); -- cgit v1.2.1