aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-20 13:23:44 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commite2220551b7a64b929650ba9a60529c31e70c13c5 (patch)
tree5d609887f15b4392cdade7bb388710ceafc62260 /arm_compute/core/utils
parenteff8d95991205e874091576e2d225f63246dd0bb (diff)
downloadComputeLibrary-e2220551b7a64b929650ba9a60529c31e70c13c5.tar.gz
COMPMID-1367: Enable NHWC in graph examples
Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h30
1 files changed, 14 insertions, 16 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index e5516ba154..dbf26a423d 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -201,15 +201,8 @@ inline TensorShape compute_im2col_fc_shape(const ITensorInfo *input, const int n
inline TensorShape compute_im2col_flatten_shape(const ITensorInfo *input)
{
// The output shape will be the flatten version of the input (i.e. [ width * height * channels, 1, 1, ... ] ). Used for FlattenLayer.
-
- ARM_COMPUTE_ERROR_ON(input->num_dimensions() < 3);
-
TensorShape output_shape{ input->tensor_shape() };
-
- const size_t flatten_shape = input->dimension(0) * input->dimension(1) * input->dimension(2);
- output_shape.set(0, flatten_shape);
- output_shape.remove_dimension(1);
- output_shape.remove_dimension(1);
+ output_shape.collapse(3, 0);
return output_shape;
}
@@ -403,20 +396,25 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo
}
template <typename T>
-inline TensorShape get_shape_from_info(T *info)
+inline TensorShape extract_shape(T *data)
+{
+ return data->info()->tensor_shape();
+}
+
+inline TensorShape extract_shape(ITensorInfo *data)
{
- return info->info()->tensor_shape();
+ return data->tensor_shape();
}
-inline TensorShape get_shape_from_info(ITensorInfo *info)
+inline TensorShape extract_shape(const TensorShape *data)
{
- return info->tensor_shape();
+ return *data;
}
template <typename T>
inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inputs_vector)
{
- TensorShape out_shape = get_shape_from_info(inputs_vector[0]);
+ TensorShape out_shape = extract_shape(inputs_vector[0]);
size_t max_x = 0;
size_t max_y = 0;
@@ -425,7 +423,7 @@ inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inp
for(const auto &tensor : inputs_vector)
{
ARM_COMPUTE_ERROR_ON(tensor == nullptr);
- const TensorShape shape = get_shape_from_info(tensor);
+ const TensorShape shape = extract_shape(tensor);
max_x = std::max(shape.x(), max_x);
max_y = std::max(shape.y(), max_y);
depth += shape.z();
@@ -441,13 +439,13 @@ inline TensorShape calculate_depth_concatenate_shape(const std::vector<T *> &inp
template <typename T>
inline TensorShape calculate_width_concatenate_shape(const std::vector<T *> &inputs_vector)
{
- TensorShape out_shape = get_shape_from_info(inputs_vector[0]);
+ TensorShape out_shape = extract_shape(inputs_vector[0]);
size_t width = 0;
for(const auto &tensor : inputs_vector)
{
ARM_COMPUTE_ERROR_ON(tensor == nullptr);
- const TensorShape shape = get_shape_from_info(tensor);
+ const TensorShape shape = extract_shape(tensor);
width += shape.x();
}