diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2018-10-11 17:33:32 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:55:45 +0000 |
commit | 8aaf93e8c12ce93d3d0082d4f4b70376f15536da (patch) | |
tree | 0922f3dde6fafae181e101df315ef36007801850 /tests/validation/fixtures | |
parent | c93691717a6e7ca67e32b4dedd233b8c63b6daf2 (diff) | |
download | ComputeLibrary-8aaf93e8c12ce93d3d0082d4f4b70376f15536da.tar.gz |
COMPMID-1632 Add CLL2NormalizationLayer for NHWC and FP32
Change-Id: Iae22554d5fe893fd22a000eab5bfd8275ea06eb3
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/154102
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/L2NormalizeLayerFixture.h | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/tests/validation/fixtures/L2NormalizeLayerFixture.h b/tests/validation/fixtures/L2NormalizeLayerFixture.h index 6f11dcb658..097d1c4ec2 100644 --- a/tests/validation/fixtures/L2NormalizeLayerFixture.h +++ b/tests/validation/fixtures/L2NormalizeLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,10 +45,10 @@ class L2NormalizeLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type, unsigned int axis, float epsilon) + void setup(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) { - _target = compute_target(shape, data_type, axis, epsilon); - _reference = compute_reference(shape, data_type, axis, epsilon); + _target = compute_target(shape, data_type, data_layout, axis, epsilon); + _reference = compute_reference(shape, data_type, data_layout, axis, epsilon); } protected: @@ -58,11 +58,16 @@ protected: library->fill_tensor_uniform(tensor, 0); } - TensorType compute_target(const TensorShape &shape, DataType data_type, unsigned int axis, float epsilon) + TensorType compute_target(TensorShape shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) { + if(data_layout == DataLayout::NHWC) + { + permute(shape, PermutationVector(2U, 0U, 1U)); + } + // Create tensors - TensorType src = create_tensor<TensorType>(shape, data_type); - TensorType dst = create_tensor<TensorType>(shape, data_type); + TensorType src = create_tensor<TensorType>(shape, data_type, 1, QuantizationInfo(), data_layout); + TensorType dst = create_tensor<TensorType>(shape, data_type, 1, QuantizationInfo(), data_layout); // Create and configure function FunctionType l2_norm_func; @@ -87,8 +92,25 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, unsigned int axis, float epsilon) + SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, DataLayout data_layout, unsigned int axis, float epsilon) { + if(data_layout == DataLayout::NHWC) + { + switch(axis) + { + case 0: + axis = 2; + break; + case 1: + axis = 0; + break; + case 2: + axis = 1; + break; + default: + break; + } + } // Create reference SimpleTensor<T> src{ shape, data_type }; |