aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends/FunctionHelpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/backends/FunctionHelpers.h')
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
1 files changed, 3 insertions, 1 deletions
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index e05f4bc8cf..f6e6286a19 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -28,6 +28,7 @@
#include "arm_compute/graph/Tensor.h"
#include "arm_compute/graph/TypePrinter.h"
#include "arm_compute/graph/Types.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/backends/FusedConvolutionBatchNormalizationFunction.h"
#include "arm_compute/graph/backends/Utils.h"
#include "arm_compute/graph/nodes/Nodes.h"
@@ -321,7 +322,8 @@ std::unique_ptr<arm_compute::IFunction> create_concatenate_layer(ConcatenateLaye
inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i)));
}
typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0));
- const DataLayoutDimension concat_axis = node.concatenation_axis();
+ const DataLayout data_layout = node.output(0) != nullptr ? node.output(0)->desc().layout : DataLayout::UNKNOWN;
+ const size_t concat_axis = get_dimension_idx(data_layout, node.concatenation_axis());
// Create and configure function
auto func = support::cpp14::make_unique<ConcatenateLayerFunction>();