aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/ConcatenateLayerNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/ConcatenateLayerNode.cpp')
-rw-r--r--src/graph/nodes/ConcatenateLayerNode.cpp9
1 files changed, 4 insertions, 5 deletions
diff --git a/src/graph/nodes/ConcatenateLayerNode.cpp b/src/graph/nodes/ConcatenateLayerNode.cpp
index ade3f6e1a9..3ce09d0073 100644
--- a/src/graph/nodes/ConcatenateLayerNode.cpp
+++ b/src/graph/nodes/ConcatenateLayerNode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -71,10 +71,9 @@ TensorDescriptor ConcatenateLayerNode::compute_output_descriptor(const std::vect
shapes.emplace_back(&input_descriptor.shape);
}
- // Calculate output shape
- if(axis_idx == 0)
+ if(axis_idx < 2)
{
- output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(shapes);
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_concatenate_shape(shapes, axis_idx);
}
else if(axis_idx == 2)
{
@@ -138,4 +137,4 @@ void ConcatenateLayerNode::accept(INodeVisitor &v)
v.visit(*this);
}
} // namespace graph
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute