diff options
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/graph/LayerDescriptors.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/Utils.h | 4 | ||||
-rw-r--r-- | arm_compute/graph/backends/FunctionHelpers.h | 4 | ||||
-rw-r--r-- | arm_compute/runtime/CL/functions/CLConcatenateLayer.h | 6 | ||||
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEConcatenateLayer.h | 4 | ||||
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h | 2 |
6 files changed, 12 insertions, 10 deletions
diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h index 79099326ec..f52beab523 100644 --- a/arm_compute/graph/LayerDescriptors.h +++ b/arm_compute/graph/LayerDescriptors.h @@ -32,7 +32,7 @@ namespace graph { namespace descriptors { -/** Common node parameters */ +/** Concatenate layer descriptor */ struct ConcatLayerDescriptor { /** Default constructor */ diff --git a/arm_compute/graph/Utils.h b/arm_compute/graph/Utils.h index 4ffccec9be..2fa2f3b627 100644 --- a/arm_compute/graph/Utils.h +++ b/arm_compute/graph/Utils.h @@ -110,12 +110,12 @@ void release_default_graph_context(GraphContext &ctx); size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension); /** Get index of a tensor's given dimension depending on its layout * - * @param[in] descriptor Descriptor + * @param[in] data_layout Data layout of the tensor * @param[in] data_layout_dimension Tensor data layout dimension * * @return Idx of given dimension */ -size_t get_dimension_idx(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension); +size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension); /** Get the list of driving nodes of a given node * * @param[in] node Node to find the driving node of diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h index e05f4bc8cf..f6e6286a19 100644 --- a/arm_compute/graph/backends/FunctionHelpers.h +++ b/arm_compute/graph/backends/FunctionHelpers.h @@ -28,6 +28,7 @@ #include "arm_compute/graph/Tensor.h" #include "arm_compute/graph/TypePrinter.h" #include "arm_compute/graph/Types.h" +#include "arm_compute/graph/Utils.h" #include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h" #include "arm_compute/graph/backends/Utils.h" #include "arm_compute/graph/nodes/Nodes.h" @@ -321,7 +322,8 @@ std::unique_ptr<arm_compute::IFunction> create_concatenate_layer(ConcatenateLaye inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i))); } typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0)); - const DataLayoutDimension concat_axis = node.concatenation_axis(); + const DataLayout data_layout = node.output(0) != nullptr ? node.output(0)->desc().layout : DataLayout::UNKNOWN; + const size_t concat_axis = get_dimension_idx(data_layout, node.concatenation_axis()); // Create and configure function auto func = support::cpp14::make_unique<ConcatenateLayerFunction>(); diff --git a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h index 5cf09c8ee0..d85a4453d8 100644 --- a/arm_compute/runtime/CL/functions/CLConcatenateLayer.h +++ b/arm_compute/runtime/CL/functions/CLConcatenateLayer.h @@ -59,7 +59,7 @@ public: * @param[out] output Output tensor. Data types supported: Same as @p input. * @param[in] axis Concatenation axis. Supported underlying concatenation axis are 0, 1 and 2. */ - void configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, DataLayoutDimension axis); + void configure(const std::vector<ICLTensor *> &inputs_vector, ICLTensor *output, size_t axis); /** Static function to check if given info will lead to a valid configuration of @ref CLConcatenateLayer * * @note Input and output tensor dimensions preconditions defer depending on the concatenation axis. @@ -71,7 +71,7 @@ public: * * @return a status */ - static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis); + static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis); // Inherited methods overridden: void run() override; @@ -81,5 +81,5 @@ private: unsigned int _num_inputs; unsigned int _axis; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLCONCATENATELAYER_H__ */ diff --git a/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h b/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h index 7dfbcf9199..f8cda326d2 100644 --- a/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h +++ b/arm_compute/runtime/NEON/functions/NEConcatenateLayer.h @@ -59,7 +59,7 @@ public: * @param[out] output Output tensor. Data types supported: Same as @p input. * @param[in] axis Concatenation axis. Supported underlying concatenation axis are 0, 1 and 2. */ - void configure(const std::vector<ITensor *> &inputs_vector, ITensor *output, DataLayoutDimension axis); + void configure(const std::vector<ITensor *> &inputs_vector, ITensor *output, size_t axis); /** Static function to check if given info will lead to a valid configuration of @ref NEConcatenateLayer * * @note Input and output tensor dimensions preconditions defer depending on the concatenation axis. @@ -71,7 +71,7 @@ public: * * @return a status */ - static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, DataLayoutDimension axis); + static Status validate(const std::vector<ITensorInfo *> &inputs_vector, const ITensorInfo *output, size_t axis); // Inherited methods overridden: void run() override; diff --git a/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h b/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h index da38151e73..e2f2c4c44c 100644 --- a/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h +++ b/arm_compute/runtime/NEON/functions/NEDepthConcatenateLayer.h @@ -89,5 +89,5 @@ private: std::unique_ptr<NEFillBorderKernel[]> _border_handlers_vector; unsigned int _num_inputs; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_NEDEPTHCONCATENATE_H__ */ |