diff options
Diffstat (limited to 'tests/validation/fixtures/ElementWiseUnaryFixture.h')
-rw-r--r-- | tests/validation/fixtures/ElementWiseUnaryFixture.h | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/tests/validation/fixtures/ElementWiseUnaryFixture.h b/tests/validation/fixtures/ElementWiseUnaryFixture.h index f508bc1d34..ba131630a3 100644 --- a/tests/validation/fixtures/ElementWiseUnaryFixture.h +++ b/tests/validation/fixtures/ElementWiseUnaryFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,7 +53,7 @@ public: protected: template <typename U> - void fill(U &&tensor, int i) + void fill(U &&tensor, int i, DataType data_type) { switch(_op) { @@ -69,6 +69,28 @@ protected: library->fill(tensor, distribution, i); break; } + case ElementWiseUnary::NEG: + { + switch(data_type) + { + case DataType::F32: + case DataType::F16: + { + std::uniform_real_distribution<> distribution(-2.0f, 2.0f); + library->fill(tensor, distribution, i); + break; + } + case DataType::S32: + { + std::uniform_int_distribution<int32_t> distribution(-100, 100); + library->fill(tensor, distribution, i); + break; + } + default: + ARM_COMPUTE_ERROR("DataType for Elementwise Negation Not implemented"); + } + break; + } default: ARM_COMPUTE_ERROR("Not implemented"); } @@ -95,7 +117,7 @@ protected: ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); // Fill tensors - fill(AccessorType(src), 0); + fill(AccessorType(src), 0, data_type); // Compute function elwiseunary_layer.run(); @@ -109,7 +131,7 @@ protected: SimpleTensor<T> src{ shape, data_type }; // Fill reference - fill(src, 0); + fill(src, 0, data_type); return reference::elementwise_unary<T>(src, _op); } @@ -140,6 +162,17 @@ public: ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::EXP); } }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class NegValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType data_type) + { + ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, ElementWiseUnary::NEG); + } +}; } // namespace validation } // namespace test } // namespace arm_compute |