diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2020-06-03 13:20:41 +0100 |
---|---|---|
committer | Manuel Bottini <manuel.bottini@arm.com> | 2020-06-22 16:23:58 +0000 |
commit | 80feed5193de6b10d8ab65b42fb988c241c5d09d (patch) | |
tree | d7b695d0c3d099e7bbdbd10a82fd355d77f0fdfc /tests/validation/fixtures | |
parent | 0028d7c7230d3fda419db5c6d5d8141830bd13f9 (diff) | |
download | ComputeLibrary-80feed5193de6b10d8ab65b42fb988c241c5d09d.tar.gz |
COMPMID-3479: Perform in-place computations in NEElementwiseUnaryKernel
Change-Id: I2102bfe95c2c2335bb587842f9d860cf939a9026
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3315
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ElementWiseUnaryFixture.h | 58 |
1 files changed, 39 insertions, 19 deletions
diff --git a/tests/validation/fixtures/ElementWiseUnaryFixture.h b/tests/validation/fixtures/ElementWiseUnaryFixture.h index 3f6d5b3cb3..b11b802d11 100644 --- a/tests/validation/fixtures/ElementWiseUnaryFixture.h +++ b/tests/validation/fixtures/ElementWiseUnaryFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -44,10 +44,10 @@ class ElementWiseUnaryValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, DataType input_data_type, ElementWiseUnary op) + void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op) { _op = op; - _target = compute_target(input_shape, input_data_type); + _target = compute_target(input_shape, input_data_type, in_place); _reference = compute_reference(input_shape, input_data_type); } @@ -115,25 +115,27 @@ protected: } } - TensorType compute_target(const TensorShape &shape, DataType data_type) + TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place) { // Create tensors TensorType src = create_tensor<TensorType>(shape, data_type); TensorType dst = create_tensor<TensorType>(shape, data_type); + TensorType *actual_dst = in_place ? &src : &dst; + // Create and configure function FunctionType elwiseunary_layer; - - elwiseunary_layer.configure(&src, &dst); + elwiseunary_layer.configure(&src, actual_dst); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); - - // Allocate tensors src.allocator()->allocate(); - dst.allocator()->allocate(); ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + if(!in_place) + { + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + dst.allocator()->allocate(); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + } // Fill tensors fill(AccessorType(src), 0, data_type); @@ -141,7 +143,14 @@ protected: // Compute function elwiseunary_layer.run(); - return dst; + if(in_place) + { + return src; + } + else + { + return dst; + } } SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type) @@ -167,7 +176,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::RSQRT); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT); } }; @@ -178,7 +187,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::EXP); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::EXP); } }; @@ -189,7 +198,18 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::NEG); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::NEG); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class NegValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType data_type, bool in_place) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, in_place, ElementWiseUnary::NEG); } }; @@ -200,7 +220,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::LOG); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::LOG); } }; @@ -211,7 +231,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::ABS); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ABS); } }; @@ -222,7 +242,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::SIN); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::SIN); } }; @@ -233,7 +253,7 @@ public: template <typename...> void setup(const TensorShape &shape, DataType data_type) { - ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::ROUND); + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND); } }; } // namespace validation |