aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-20 13:23:44 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite2220551b7a64b929650ba9a60529c31e70c13c5 (patch)
tree5d609887f15b4392cdade7bb388710ceafc62260 /src/graph/nodes
parenteff8d95991205e874091576e2d225f63246dd0bb (diff)
downloadComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/graph/nodes')
-rw-r--r--src/graph/nodes/ConcatenateLayerNode.cpp (renamed from src/graph/nodes/DepthConcatenateLayerNode.cpp)62
1 files changed, 39 insertions, 23 deletions
diff --git a/src/graph/nodes/DepthConcatenateLayerNode.cpp b/src/graph/nodes/ConcatenateLayerNode.cpp
index 08cccc1ff1..ade3f6e1a9 100644
--- a/src/graph/nodes/DepthConcatenateLayerNode.cpp
+++ b/src/graph/nodes/ConcatenateLayerNode.cpp
@@ -21,58 +21,74 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#include "arm_compute/graph/nodes/DepthConcatenateLayerNode.h"
+#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
+
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
namespace arm_compute
{
namespace graph
{
-DepthConcatenateLayerNode::DepthConcatenateLayerNode(unsigned int total_nodes)
- : _total_nodes(total_nodes), _is_enabled(true)
+ConcatenateLayerNode::ConcatenateLayerNode(unsigned int total_nodes, DataLayoutDimension axis)
+ : _total_nodes(total_nodes), _axis(axis), _is_enabled(true)
{
_input_edges.resize(_total_nodes, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
-void DepthConcatenateLayerNode::set_enabled(bool is_enabled)
+void ConcatenateLayerNode::set_enabled(bool is_enabled)
{
_is_enabled = is_enabled;
}
-bool DepthConcatenateLayerNode::is_enabled() const
+bool ConcatenateLayerNode::is_enabled() const
{
return _is_enabled;
}
-TensorDescriptor DepthConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors)
+DataLayoutDimension ConcatenateLayerNode::concatenation_axis() const
+{
+ return _axis;
+}
+
+TensorDescriptor ConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors,
+ DataLayoutDimension axis)
{
ARM_COMPUTE_ERROR_ON(input_descriptors.size() == 0);
TensorDescriptor output_descriptor = input_descriptors[0];
+ const int axis_idx = get_dimension_idx(output_descriptor, axis);
- size_t max_x = 0;
- size_t max_y = 0;
- size_t depth = 0;
-
- for(const auto &input_descriptor : input_descriptors)
+ // Extract shapes
+ std::vector<const TensorShape *> shapes;
+ for(auto &input_descriptor : input_descriptors)
{
- max_x = std::max(input_descriptor.shape.x(), max_x);
- max_y = std::max(input_descriptor.shape.y(), max_y);
- depth += input_descriptor.shape.z();
+ shapes.emplace_back(&input_descriptor.shape);
}
- output_descriptor.shape.set(0, max_x);
- output_descriptor.shape.set(1, max_y);
- output_descriptor.shape.set(2, depth);
+ // Calculate output shape
+ if(axis_idx == 0)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_width_concatenate_shape(shapes);
+ }
+ else if(axis_idx == 2)
+ {
+ output_descriptor.shape = arm_compute::misc::shape_calculator::calculate_depth_concatenate_shape(shapes);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Unsupported concatenation axis!");
+ }
return output_descriptor;
}
-bool DepthConcatenateLayerNode::forward_descriptors()
+bool ConcatenateLayerNode::forward_descriptors()
{
if(_outputs[0] != NullTensorID)
{
@@ -84,7 +100,7 @@ bool DepthConcatenateLayerNode::forward_descriptors()
return false;
}
-TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
+TensorDescriptor ConcatenateLayerNode::configure_output(size_t idx) const
{
ARM_COMPUTE_UNUSED(idx);
ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
@@ -106,18 +122,18 @@ TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
ARM_COMPUTE_ERROR_ON(t == nullptr);
inputs_descriptors.push_back(t->desc());
}
- output_info = compute_output_descriptor(inputs_descriptors);
+ output_info = compute_output_descriptor(inputs_descriptors, _axis);
}
return output_info;
}
-NodeType DepthConcatenateLayerNode::type() const
+NodeType ConcatenateLayerNode::type() const
{
- return NodeType::DepthConcatenateLayer;
+ return NodeType::ConcatenateLayer;
}
-void DepthConcatenateLayerNode::accept(INodeVisitor &v)
+void ConcatenateLayerNode::accept(INodeVisitor &v)
{
v.visit(*this);
}