diff options
Diffstat (limited to 'tests/validation/fixtures/L2NormalizeLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/L2NormalizeLayerFixture.h | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/tests/validation/fixtures/L2NormalizeLayerFixture.h b/tests/validation/fixtures/L2NormalizeLayerFixture.h index 574722bd88..e3e1510ff0 100644 --- a/tests/validation/fixtures/L2NormalizeLayerFixture.h +++ b/tests/validation/fixtures/L2NormalizeLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,12 +40,16 @@ namespace test { namespace validation { +namespace +{ +constexpr int max_input_tensor_dim = 3; +} // namespace template <typename TensorType, typename AccessorType, typename FunctionType, typename T> class L2NormalizeLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { _target = compute_target(shape, data_type, data_layout, axis, epsilon); _reference = compute_reference(shape, data_type, data_layout, axis, epsilon); @@ -59,7 +63,7 @@ protected: library->fill(tensor, distribution, 0); } - TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { if(data_layout == DataLayout::NHWC) { @@ -93,20 +97,21 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, int axis, float epsilon) { + uint32_t actual_axis = wrap_around(axis, max_input_tensor_dim); if(data_layout == DataLayout::NHWC) { - switch(axis) + switch(actual_axis) { case 0: - axis = 2; + actual_axis = 2; break; case 1: - axis = 0; + actual_axis = 0; break; case 2: - axis = 1; + actual_axis = 1; break; default: break; @@ -118,7 +123,7 @@ protected: // Fill reference fill(src); - return reference::l2_normalize<T>(src, axis, epsilon); + return reference::l2_normalize<T>(src, actual_axis, epsilon); } TensorType _target{}; |