diff options
Diffstat (limited to 'tests/validation/fixtures/RemapFixture.h')
-rw-r--r-- | tests/validation/fixtures/RemapFixture.h | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/tests/validation/fixtures/RemapFixture.h b/tests/validation/fixtures/RemapFixture.h index 2cb8e67f62..03cb6aef42 100644 --- a/tests/validation/fixtures/RemapFixture.h +++ b/tests/validation/fixtures/RemapFixture.h @@ -50,7 +50,7 @@ public: { std::mt19937 gen(library->seed()); std::uniform_int_distribution<uint8_t> distribution(0, 255); - const T constant_border_value = static_cast<T>(distribution(gen)); + PixelValue constant_border_value{ static_cast<T>(distribution(gen)) }; _data_layout = data_layout; _target = compute_target(shape, policy, data_type, border_mode, constant_border_value); @@ -59,13 +59,35 @@ public: protected: template <typename U> - void fill(U &&tensor, int i, float min, float max) + void fill(U &&tensor, int i, int min, int max) { - std::uniform_int_distribution<> distribution((int)min, (int)max); - library->fill(tensor, distribution, i); + switch(tensor.data_type()) + { + case DataType::F32: + { + // map_x,y as integer values + std::uniform_int_distribution<int> distribution(min, max); + library->fill(tensor, distribution, i); + break; + } + case DataType::F16: + { + arm_compute::utils::uniform_real_distribution_16bit<half> distribution(static_cast<float>(min), static_cast<float>(max)); + library->fill(tensor, distribution, i); + break; + } + case DataType::U8: + { + std::uniform_int_distribution<uint8_t> distribution(min, max); + library->fill(tensor, distribution, i); + break; + } + default: + ARM_COMPUTE_ERROR("DataType for Remap not supported"); + } } - TensorType compute_target(TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, T constant_border_value) + TensorType compute_target(TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, PixelValue constant_border_value) { if(_data_layout == DataLayout::NHWC) { @@ -111,14 +133,16 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, T constant_border_value) + SimpleTensor<T> compute_reference(const TensorShape shape, InterpolationPolicy policy, DataType data_type, BorderMode border_mode, PixelValue constant_border_value) { - ARM_COMPUTE_ERROR_ON(data_type != DataType::U8); + ARM_COMPUTE_ERROR_ON(data_type != DataType::U8 && data_type != DataType::F16); // Create reference SimpleTensor<T> src{ shape, data_type }; SimpleTensor<float> map_x{ shape, DataType::F32 }; SimpleTensor<float> map_y{ shape, DataType::F32 }; + T border_value{}; + constant_border_value.get(border_value); // Create the valid mask Tensor _valid_mask = SimpleTensor<T> { shape, data_type }; @@ -131,7 +155,7 @@ protected: fill(map_y, 2, -5, max_val); // Compute reference - return reference::remap<T>(src, map_x, map_y, _valid_mask, policy, border_mode, constant_border_value); + return reference::remap<T>(src, map_x, map_y, _valid_mask, policy, border_mode, border_value); } TensorType _target{}; |