aboutsummaryrefslogtreecommitdiff
path: root/arm_compute
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute')
-rw-r--r--arm_compute/graph/LayerDescriptors.h2
-rw-r--r--arm_compute/graph/Utils.h4
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
-rw-r--r--arm_compute/runtime/CL/functions/CLConcatenateLayer.h6
-rw-r--r--arm_compute/runtime/NEON/functions/NEConcatenateLayer.h4
-rw-r--r--arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h2
6 files changed, 12 insertions, 10 deletions
diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h
index 79099326ec..f52beab523 100644
--- a/arm_compute/graph/LayerDescriptors.h
+++ b/arm_compute/graph/LayerDescriptors.h
@@ -32,7 +32,7 @@ namespace graph
{
namespace descriptors
{
-/** Common node parameters */
+/** Concatenate layer descriptor */
struct ConcatLayerDescriptor
{
/** Default constructor */
diff --git a/arm_compute/graph/Utils.h b/arm_compute/graph/Utils.h
index 4ffccec9be..2fa2f3b627 100644
--- a/arm_compute/graph/Utils.h
+++ b/arm_compute/graph/Utils.h
@@ -110,12 +110,12 @@ void release_default_graph_context(GraphContext &ctx);
size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension);
/** Get index of a tensor's given dimension depending on its layout
*
- * @param[in] descriptor Descriptor
+ * @param[in] data_layout Data layout of the tensor
* @param[in] data_layout_dimension Tensor data layout dimension
*
* @return Idx of given dimension
*/
-size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension);
+size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension);
/** Get the list of driving nodes of a given node
*
* @param[in] node Node to find the driving node of
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index e05f4bc8cf..f6e6286a19 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -28,6 +28,7 @@
#include "arm_compute/graph/Tensor.h"
#include "arm_compute/graph/TypePrinter.h"
#include "arm_compute/graph/Types.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h"
#include "arm_compute/graph/backends/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
@@ -321,7 +322,8 @@ std::unique_ptr<arm_compute::IFunction> create_concatenate_layer(ConcatenateLaye
inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i)));
}
typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0));
- const DataLayoutDimension concat_axis = node.concatenation_axis();
+ const DataLayout data_layout = node.output(0) != nullptr ? node.output(0)->desc().layout : DataLayout::UNKNOWN;
+ const size_t concat_axis = get_dimension_idx(data_layout, node.concatenation_axis());
// Create and configure function
auto func = support::cpp14::make_unique<ConcatenateLayerFunction>();
diff --git a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h
index 5cf09c8ee0..d85a4453d8 100644
--- a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h
+++ b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h
@@ -59,7 +59,7 @@ public:
* @param[out] output Output tensor. Data types supported: Same as @p input.
* @param[in] axis Concatenation axis. Supported underlying concatenation axis are 0, 1 and 2.
*/
- void configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, DataLayoutDimension axis);
+ void configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis);
/** Static function to check if given info will lead to a valid configuration of @ref CLConcatenateLayer
*
* @note Input and output tensor dimensions preconditions defer depending on the concatenation axis.
@@ -71,7 +71,7 @@ public:
*
* @return a status
*/
- static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis);
+ static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis);
// Inherited methods overridden:
void run() override;
@@ -81,5 +81,5 @@ private:
unsigned int _num_inputs;
unsigned int _axis;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLCONCATENATELAYER_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h b/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h
index 7dfbcf9199..f8cda326d2 100644
--- a/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h
@@ -59,7 +59,7 @@ public:
* @param[out] output Output tensor. Data types supported: Same as @p input.
* @param[in] axis Concatenation axis. Supported underlying concatenation axis are 0, 1 and 2.
*/
- void configure(const std::vector<ITensor *> &inputs_vector, ITensor *output, DataLayoutDimension axis);
+ void configure(const std::vector<ITensor *> &inputs_vector, ITensor *output, size_t axis);
/** Static function to check if given info will lead to a valid configuration of @ref NEConcatenateLayer
*
* @note Input and output tensor dimensions preconditions defer depending on the concatenation axis.
@@ -71,7 +71,7 @@ public:
*
* @return a status
*/
- static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis);
+ static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis);
// Inherited methods overridden:
void run() override;
diff --git a/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h b/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h
index da38151e73..e2f2c4c44c 100644
--- a/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h
@@ -89,5 +89,5 @@ private:
std::unique_ptr<NEFillBorderKernel[]> _border_handlers_vector;
unsigned int _num_inputs;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEDEPTHCONCATENATE_H__ */