aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FlattenLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/FlattenLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FlattenLayerFixture.h23
1 files changed, 14 insertions, 9 deletions
diff --git a/tests/validation/fixtures/FlattenLayerFixture.h b/tests/validation/fixtures/FlattenLayerFixture.h
index 3de0ba45ae..ef94ea83b0 100644
--- a/tests/validation/fixtures/FlattenLayerFixture.h
+++ b/tests/validation/fixtures/FlattenLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -43,6 +44,8 @@ namespace test
{
namespace validation
{
+using namespace arm_compute::misc::shape_calculator;
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class FlattenLayerValidationFixture : public framework::Fixture
{
@@ -51,8 +54,13 @@ public:
void setup(TensorShape shape, DataType data_type)
{
_fractional_bits = is_data_type_fixed_point(data_type) ? 4 : 0;
- _target = compute_target(shape, data_type);
- _reference = compute_reference(shape, data_type);
+
+ TensorShape shape_flatten;
+ TensorInfo input_info(shape, 1, data_type, _fractional_bits);
+ shape_flatten = compute_im2col_flatten_shape(&input_info);
+
+ _target = compute_target(shape, shape_flatten, data_type);
+ _reference = compute_reference(shape, shape_flatten, data_type);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(_target.info()->tensor_shape(), _reference.shape());
}
@@ -73,11 +81,8 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, DataType data_type)
+ TensorType compute_target(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type)
{
- TensorShape shape_flatten(shape);
- shape_flatten.collapse(3);
-
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type, 1, _fractional_bits);
TensorType dst = create_tensor<TensorType>(shape_flatten, data_type, 1, _fractional_bits);
@@ -105,7 +110,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type)
{
// Create reference
SimpleTensor<T> src{ shape, data_type, 1, _fractional_bits };
@@ -113,7 +118,7 @@ protected:
// Fill reference
fill(src);
- return reference::flatten_layer<T>(src);
+ return reference::flatten_layer<T>(src, shape_flatten);
}
TensorType _target{};