aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-04-12 13:15:58 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-04-15 16:52:22 +0000
commit9e4824c909b14dbaf7106e9527b0ffa22ef09bdc (patch)
treeb1cc8f6a8b275a7e227e305f1b02870d5e0f30ec /arm_compute/graph
parentd66094e37ecd747e85f30130e1a678bdbaf30788 (diff)
downloadComputeLibrary-9e4824c909b14dbaf7106e9527b0ffa22ef09bdc.tar.gz
COMPMID-2111: ConcatenateLayer API should accept an index instead of an enum
Alters the concatenate layer to be layout agnostic and accept an index as thec concatenation axis instead of an typed layout dependent enumeration. Change-Id: I0eaaf919f66a1ba1b09bbfb47c171fc1d4045530 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/994 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph')
-rw-r--r--arm_compute/graph/LayerDescriptors.h2
-rw-r--r--arm_compute/graph/Utils.h4
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
3 files changed, 6 insertions, 4 deletions
diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h
index 79099326ec..f52beab523 100644
--- a/arm_compute/graph/LayerDescriptors.h
+++ b/arm_compute/graph/LayerDescriptors.h
@@ -32,7 +32,7 @@ namespace graph
{
namespace descriptors
{
-/** Common node parameters */
+/** Concatenate layer descriptor */
struct ConcatLayerDescriptor
{
/** Default constructor */
diff --git a/arm_compute/graph/Utils.h b/arm_compute/graph/Utils.h
index 4ffccec9be..2fa2f3b627 100644
--- a/arm_compute/graph/Utils.h
+++ b/arm_compute/graph/Utils.h
@@ -110,12 +110,12 @@ void release_default_graph_context(GraphContext &ctx);
size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension);
/** Get index of a tensor's given dimension depending on its layout
*
- * @param[in] descriptor Descriptor
+ * @param[in] data_layout Data layout of the tensor
* @param[in] data_layout_dimension Tensor data layout dimension
*
* @return Idx of given dimension
*/
-size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension);
+size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension);
/** Get the list of driving nodes of a given node
*
* @param[in] node Node to find the driving node of
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index e05f4bc8cf..f6e6286a19 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -28,6 +28,7 @@
#include "arm_compute/graph/Tensor.h"
#include "arm_compute/graph/TypePrinter.h"
#include "arm_compute/graph/Types.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h"
#include "arm_compute/graph/backends/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
@@ -321,7 +322,8 @@ std::unique_ptr<arm_compute::IFunction> create_concatenate_layer(ConcatenateLaye
inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i)));
}
typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0));
- const DataLayoutDimension concat_axis = node.concatenation_axis();
+ const DataLayout data_layout = node.output(0) != nullptr ? node.output(0)->desc().layout : DataLayout::UNKNOWN;
+ const size_t concat_axis = get_dimension_idx(data_layout, node.concatenation_axis());
// Create and configure function
auto func = support::cpp14::make_unique<ConcatenateLayerFunction>();