aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FlattenLayerFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-03-09 15:30:43 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commit156fcf3f36f6168e47d65db167bba3af5037e3d9 (patch)
tree89240783068a72b918791cf18a613eb43b93035d /tests/validation/fixtures/FlattenLayerFixture.h
parent8de92619e223225aabdca873c02f231d8e941fd1 (diff)
downloadComputeLibrary-156fcf3f36f6168e47d65db167bba3af5037e3d9.tar.gz
COMPMID-802 Add NHWC data format support for NEON im2col.
Change-Id: I86e678179106a2b83d1c6a7cfe562df91b0f9eb2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/124000 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'tests/validation/fixtures/FlattenLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FlattenLayerFixture.h23
1 files changed, 14 insertions, 9 deletions
diff --git a/tests/validation/fixtures/FlattenLayerFixture.h b/tests/validation/fixtures/FlattenLayerFixture.h
index 3de0ba45ae..ef94ea83b0 100644
--- a/tests/validation/fixtures/FlattenLayerFixture.h
+++ b/tests/validation/fixtures/FlattenLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/Tensor.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -43,6 +44,8 @@ namespace test
{
namespace validation
{
+using namespace arm_compute::misc::shape_calculator;
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class FlattenLayerValidationFixture : public framework::Fixture
{
@@ -51,8 +54,13 @@ public:
void setup(TensorShape shape, DataType data_type)
{
_fractional_bits = is_data_type_fixed_point(data_type) ? 4 : 0;
- _target = compute_target(shape, data_type);
- _reference = compute_reference(shape, data_type);
+
+ TensorShape shape_flatten;
+ TensorInfo input_info(shape, 1, data_type, _fractional_bits);
+ shape_flatten = compute_im2col_flatten_shape(&input_info);
+
+ _target = compute_target(shape, shape_flatten, data_type);
+ _reference = compute_reference(shape, shape_flatten, data_type);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(_target.info()->tensor_shape(), _reference.shape());
}
@@ -73,11 +81,8 @@ protected:
}
}
- TensorType compute_target(const TensorShape &shape, DataType data_type)
+ TensorType compute_target(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type)
{
- TensorShape shape_flatten(shape);
- shape_flatten.collapse(3);
-
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type, 1, _fractional_bits);
TensorType dst = create_tensor<TensorType>(shape_flatten, data_type, 1, _fractional_bits);
@@ -105,7 +110,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, const TensorShape &shape_flatten, DataType data_type)
{
// Create reference
SimpleTensor<T> src{ shape, data_type, 1, _fractional_bits };
@@ -113,7 +118,7 @@ protected:
// Fill reference
fill(src);
- return reference::flatten_layer<T>(src);
+ return reference::flatten_layer<T>(src, shape_flatten);
}
TensorType _target{};