diff options
Diffstat (limited to 'src/graph/nodes/PoolingLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/PoolingLayerNode.cpp | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/src/graph/nodes/PoolingLayerNode.cpp b/src/graph/nodes/PoolingLayerNode.cpp index a7b6b3679a..26c145ae31 100644 --- a/src/graph/nodes/PoolingLayerNode.cpp +++ b/src/graph/nodes/PoolingLayerNode.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/Utils.h" #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/INodeVisitor.h" +#include "arm_compute/graph/Utils.h" namespace arm_compute { @@ -43,20 +44,24 @@ PoolingLayerInfo PoolingLayerNode::pooling_info() const return _info; } -TensorShape PoolingLayerNode::compute_output_shape(TensorShape input_shape, PoolingLayerInfo info) +TensorDescriptor PoolingLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + PoolingLayerInfo info) { - const int pool_size_x = info.is_global_pooling() ? input_shape.x() : info.pool_size().width; - const int pool_size_y = info.is_global_pooling() ? input_shape.y() : info.pool_size().height; - unsigned int pooled_width = 0; unsigned int pooled_height = 0; - std::tie(pooled_width, pooled_height) = scaled_dimensions(input_shape.x(), input_shape.y(), pool_size_x, pool_size_y, info.pad_stride_info()); - TensorShape output_shape{ input_shape }; - output_shape.set(0, pooled_width); - output_shape.set(1, pooled_height); + const unsigned int input_width = get_dimension_size(input_descriptor, DataLayoutDimension::WIDTH); + const unsigned int input_height = get_dimension_size(input_descriptor, DataLayoutDimension::HEIGHT); + const unsigned int pool_size_x = info.is_global_pooling() ? input_width : info.pool_size().width; + const unsigned int pool_size_y = info.is_global_pooling() ? input_height : info.pool_size().height; + + std::tie(pooled_width, pooled_height) = scaled_dimensions(input_width, input_height, pool_size_x, pool_size_y, info.pad_stride_info()); + + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::WIDTH), pooled_width); + output_descriptor.shape.set(get_dimension_idx(output_descriptor, DataLayoutDimension::HEIGHT), pooled_height); - return output_shape; + return output_descriptor; } bool PoolingLayerNode::forward_descriptors() @@ -79,15 +84,7 @@ TensorDescriptor PoolingLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); - TensorDescriptor output_info = src->desc(); - TensorShape output_shape = compute_output_shape(src->desc().shape, _info); - output_info.shape = output_shape; - return output_info; -} - -Status PoolingLayerNode::validate() -{ - return Status{}; + return compute_output_descriptor(src->desc(), _info); } NodeType PoolingLayerNode::type() const |