From 4b5c588ed5bbf635bfb4d20b662db417caa4558f Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 14 May 2019 10:38:30 +0100 Subject: COMPMID-2248 L2NormalizeLayer: negative axis Change-Id: Ic164d7a9ddf1615a2e3b0e10430c34194a70f221 Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/1127 Tested-by: Arm Jenkins Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins --- .../validation/fixtures/L2NormalizeLayerFixture.h | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/L2NormalizeLayerFixture.h b/tests/validation/fixtures/L2NormalizeLayerFixture.h index 574722bd88..e3e1510ff0 100644 --- a/tests/validation/fixtures/L2NormalizeLayerFixture.h +++ b/tests/validation/fixtures/L2NormalizeLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,12 +40,16 @@ namespace test { namespace validation { +namespace +{ +constexpr int max_input_tensor_dim = 3; +} // namespace template class L2NormalizeLayerValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { _target = compute_target(shape, data_type, data_layout, axis, epsilon); _reference = compute_reference(shape, data_type, data_layout, axis, epsilon); @@ -59,7 +63,7 @@ protected: library->fill(tensor, distribution, 0); } - TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { if(data_layout == DataLayout::NHWC) { @@ -93,20 +97,21 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + SimpleTensor compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { + uint32_t actual_axis = wrap_around(axis, max_input_tensor_dim); if(data_layout == DataLayout::NHWC) { - switch(axis) + switch(actual_axis) { case 0: - axis = 2; + actual_axis = 2; break; case 1: - axis = 0; + actual_axis = 0; break; case 2: - axis = 1; + actual_axis = 1; break; default: break; @@ -118,7 +123,7 @@ protected: // Fill reference fill(src); - return reference::l2_normalize(src, axis, epsilon); + return reference::l2_normalize(src, actual_axis, epsilon); } TensorType _target{}; -- cgit v1.2.1