diff options
Diffstat (limited to 'tests/validation/fixtures/DepthwiseConvolutionFixture.h')
-rw-r--r-- | tests/validation/fixtures/DepthwiseConvolutionFixture.h | 89 |
1 files changed, 71 insertions, 18 deletions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionFixture.h b/tests/validation/fixtures/DepthwiseConvolutionFixture.h index f49e76c70c..b1d31d657a 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionFixture.h @@ -43,14 +43,22 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T> -class DepthwiseConvolutionValidationFixture : public framework::Fixture +class DepthwiseConvolutionValidationGenericFixture : public framework::Fixture { public: + using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type; + +public: template <typename...> - void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info) + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) { - _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info); - _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info); + _quantization_info = quantization_info; + _data_type = data_type; + + const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; + + _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); + _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info); } protected: @@ -59,28 +67,46 @@ protected: { switch(tensor.data_type()) { + case DataType::QASYMM8: + { + std::uniform_int_distribution<uint8_t> distribution(0, 10); + library->fill(tensor, distribution, i); + break; + } case DataType::F32: { std::uniform_real_distribution<> distribution(-1.0f, 1.0f); library->fill(tensor, distribution, i); break; } + case DataType::S32: + { + std::uniform_int_distribution<int32_t> distribution(-1000, 1000); + library->fill(tensor, distribution, i); + break; + } default: library->fill_tensor_uniform(tensor, i); } } - TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info) + TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info, + const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor<TensorType>(input_shape, DataType::F32); - TensorType weights = create_tensor<TensorType>(weights_shape, DataType::F32); - TensorType biases = create_tensor<TensorType>(biases_shape, DataType::F32); - TensorType dst = create_tensor<TensorType>(output_shape, DataType::F32); + TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, quantization_info); + TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, 0, quantization_info); + TensorType biases = create_tensor<TensorType>(biases_shape, bias_data_type, 1, 0, quantization_info); + TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, quantization_info); // Create Depthwise Convolution configure function - FunctionType depthwise_convolution; - depthwise_convolution.configure(&src, &weights, &biases, &dst, pad_stride_info); + FunctionType dwc; + dwc.configure(&src, &weights, &biases, &dst, pad_stride_info); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(biases.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); // Allocate tensors src.allocator()->allocate(); @@ -99,16 +125,17 @@ protected: fill(AccessorType(biases), 2); // Compute function - depthwise_convolution.run(); + dwc.run(); return dst; } - SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info) + SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info, + const DataType data_type, const DataType bias_data_type, QuantizationInfo quantization_info) { - SimpleTensor<T> src(in_shape, DataType::F32); - SimpleTensor<T> weights(weights_shape, DataType::F32); - SimpleTensor<T> biases(biases_shape, DataType::F32); + SimpleTensor<T> src{ in_shape, data_type, 1, 0, quantization_info }; + SimpleTensor<T> weights{ weights_shape, data_type, 1, 0, quantization_info }; + SimpleTensor<TBias> biases{ biases_shape, data_type, 1, 0, quantization_info }; fill(src, 0); fill(weights, 1); @@ -117,8 +144,34 @@ protected: return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info); } - TensorType _target{}; - SimpleTensor<T> _reference{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; + DataType _data_type{}; + QuantizationInfo _quantization_info{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class DepthwiseConvolutionValidationFixture : public DepthwiseConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type) + { + DepthwiseConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, + data_type, QuantizationInfo()); + } +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class DepthwiseConvolutionValidationQuantizedFixture : public DepthwiseConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info) + { + DepthwiseConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, + data_type, quantization_info); + } }; } // namespace validation } // namespace test |