diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2019-05-14 10:38:30 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2019-05-14 16:13:58 +0000 |
commit | 4b5c588ed5bbf635bfb4d20b662db417caa4558f (patch) | |
tree | 25d33b5020ebfa6b19b9a9870f682df51b17ebc3 /tests/validation/fixtures/L2NormalizeLayerFixture.h | |
parent | 2388de12aa5e71f4e295179b4ea344e3e306556a (diff) | |
download | ComputeLibrary-4b5c588ed5bbf635bfb4d20b662db417caa4558f.tar.gz |
COMPMID-2248
L2NormalizeLayer: negative axis
Change-Id: Ic164d7a9ddf1615a2e3b0e10430c34194a70f221
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1127
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/L2NormalizeLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/L2NormalizeLayerFixture.h | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/tests/validation/fixtures/L2NormalizeLayerFixture.h b/tests/validation/fixtures/L2NormalizeLayerFixture.h index 574722bd88..e3e1510ff0 100644 --- a/tests/validation/fixtures/L2NormalizeLayerFixture.h +++ b/tests/validation/fixtures/L2NormalizeLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,12 +40,16 @@ namespace test { namespace validation { +namespace +{ +constexpr int max_input_tensor_dim = 3; +} // namespace template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class L2NormalizeLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { _target = compute_target(shape, data_type, data_layout, axis, epsilon); _reference = compute_reference(shape, data_type, data_layout, axis, epsilon); @@ -59,7 +63,7 @@ protected: library->fill(tensor, distribution, 0); } - TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { if(data_layout == DataLayout::NHWC) { @@ -93,20 +97,21 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { + uint32_t actual_axis = wrap_around(axis, max_input_tensor_dim); if(data_layout == DataLayout::NHWC) { - switch(axis) + switch(actual_axis) { case 0: - axis = 2; + actual_axis = 2; break; case 1: - axis = 0; + actual_axis = 0; break; case 2: - axis = 1; + actual_axis = 1; break; default: break; @@ -118,7 +123,7 @@ protected: // Fill reference fill(src); - return reference::l2_normalize<T>(src, axis, epsilon); + return reference::l2_normalize<T>(src, actual_axis, epsilon); } TensorType _target{}; |