aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/PoolingLayerNode.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-04-27 19:07:19 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commitcac13b1cfd593889271f8e2191be2039b8d88f36 (patch)
treed1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /src/graph/nodes/PoolingLayerNode.cpp
parentad0c7388f6261989a268ffb2d042f2bd80736e3f (diff)
downloadComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/nodes/PoolingLayerNode.cpp')
-rw-r--r--src/graph/nodes/PoolingLayerNode.cpp33
1 files changed, 15 insertions, 18 deletions
diff --git a/src/graph/nodes/PoolingLayerNode.cpp b/src/graph/nodes/PoolingLayerNode.cpp
index a7b6b3679a..26c145ae31 100644
--- a/src/graph/nodes/PoolingLayerNode.cpp
+++ b/src/graph/nodes/PoolingLayerNode.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
+#include "arm_compute/graph/Utils.h"
namespace arm_compute
{
@@ -43,20 +44,24 @@ PoolingLayerInfo PoolingLayerNode::pooling_info() const
return _info;
}
-TensorShape PoolingLayerNode::compute_output_shape(TensorShape input_shape, PoolingLayerInfo info)
+TensorDescriptor PoolingLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
+ PoolingLayerInfo info)
{
- const int pool_size_x = info.is_global_pooling() ? input_shape.x() : info.pool_size().width;
- const int pool_size_y = info.is_global_pooling() ? input_shape.y() : info.pool_size().height;
-
unsigned int pooled_width = 0;
unsigned int pooled_height = 0;
- std::tie(pooled_width, pooled_height) = scaled_dimensions(input_shape.x(), input_shape.y(), pool_size_x, pool_size_y, info.pad_stride_info());
- TensorShape output_shape{ input_shape };
- output_shape.set(0, pooled_width);
- output_shape.set(1, pooled_height);
+ const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH);
+ const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT);
+ const unsigned int pool_size_x = info.is_global_pooling() ? input_width : info.pool_size().width;
+ const unsigned int pool_size_y = info.is_global_pooling() ? input_height : info.pool_size().height;
+
+ std::tie(pooled_width, pooled_height) = scaled_dimensions(input_width, input_height, pool_size_x, pool_size_y, info.pad_stride_info());
+
+ TensorDescriptor output_descriptor = input_descriptor;
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), pooled_width);
+ output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), pooled_height);
- return output_shape;
+ return output_descriptor;
}
bool PoolingLayerNode::forward_descriptors()
@@ -79,15 +84,7 @@ TensorDescriptor PoolingLayerNode::configure_output(size_t idx) const
const Tensor *src = input(0);
ARM_COMPUTE_ERROR_ON(src == nullptr);
- TensorDescriptor output_info = src->desc();
- TensorShape output_shape = compute_output_shape(src->desc().shape, _info);
- output_info.shape = output_shape;
- return output_info;
-}
-
-Status PoolingLayerNode::validate()
-{
- return Status{};
+ return compute_output_descriptor(src->desc(), _info);
}
NodeType PoolingLayerNode::type() const