aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h31
1 files changed, 31 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index dfccec8b37..bc85c6986f 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -809,6 +809,37 @@ inline TensorShape compute_pool_shape(const ITensorInfo &input, PoolingLayerInfo
return output_shape;
}
+/** Calculate the output unpool shape of a tensor
+ *
+ * @param[in] input Input tensor info
+ * @param[in] pool_info Pooling layer info
+ *
+ * @return the calculated shape
+ */
+inline TensorShape compute_unpool_shape(const ITensorInfo &input, PoolingLayerInfo pool_info)
+{
+ const unsigned int idx_width = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::WIDTH);
+ const unsigned int idx_height = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::HEIGHT);
+ const TensorShape input_shape = input.tensor_shape();
+ ARM_COMPUTE_ERROR_ON(input_shape[idx_height] <= 1 || input_shape[idx_width] <= 1);
+ const PadStrideInfo pad_stride_info = pool_info.pad_stride_info;
+ const unsigned int stride_x = pad_stride_info.stride().first;
+ const unsigned int stride_y = pad_stride_info.stride().second;
+
+ const int pad_left = pad_stride_info.pad_left();
+ const int pad_top = pad_stride_info.pad_top();
+ const int pad_right = pad_stride_info.pad_right();
+ const int pad_bottom = pad_stride_info.pad_bottom();
+
+ TensorShape output_shape = input_shape;
+ const unsigned int out_width = (input_shape[idx_width] - 1) * stride_x - pad_left - pad_right + pool_info.pool_size.width;
+ const unsigned int out_height = (input_shape[idx_height] - 1) * stride_y - pad_top - pad_bottom + pool_info.pool_size.height;
+
+ output_shape.set(idx_width, out_width);
+ output_shape.set(idx_height, out_height);
+ return output_shape;
+}
+
/** Calculate the output roi align shape of a tensor
*
* @param[in] input Input tensor info