diff options
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 39 |
1 files changed, 32 insertions, 7 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 3c4b625ac6..b4abebe18d 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -64,7 +64,9 @@ public: _data_type = data_type; _weights_data_type = weights_data_type; _is_quantized = is_data_type_quantized_asymmetric(data_type); - _bias_data_type = _is_quantized ? DataType::S32 : data_type; + _is_bfloat16 = data_type == DataType::BFLOAT16; + _bias_data_type = _is_quantized ? DataType::S32 : (_is_bfloat16 ? DataType::F32 : data_type); + _output_data_type = _is_bfloat16 ? DataType::F32 : data_type; _quantization_info = quantization_info; _weight_quantization_info = weight_quantization_info; _data_layout = data_layout; @@ -74,6 +76,15 @@ public: } protected: + void regularize_values(void *values, size_t size) + { + float *fvalues = static_cast<float *>(values); + for(size_t i = 0; i < size; ++i) + { + fvalues[i] = float(bfloat16(fvalues[i])); + } + } + template <typename U> void fill(U &&tensor, int i) { @@ -119,6 +130,7 @@ protected: library->fill(tensor, distribution, i); break; } + case DataType::BFLOAT16: case DataType::F16: case DataType::F32: { @@ -155,7 +167,7 @@ protected: TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info, _data_layout); TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout); TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout); - TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info, _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, _output_data_type, 1, _quantization_info, _data_layout); // Create and configure function FunctionType conv; @@ -195,16 +207,27 @@ protected: const unsigned int num_groups = input_shape[2] / weights_shape[2]; + // Setup reference data types + const DataType src_dt = _is_bfloat16 ? DataType::F32 : _data_type; + const DataType weights_dt = _is_bfloat16 ? DataType::F32 : _weights_data_type; + const DataType bias_dt = _is_bfloat16 ? DataType::F32 : _bias_data_type; + // Create reference - SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info }; - SimpleTensor<TW> weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info }; - SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info }; + SimpleTensor<T> src{ input_shape, src_dt, 1, _quantization_info }; + SimpleTensor<TW> weights{ weights_shape, weights_dt, 1, _weight_quantization_info }; + SimpleTensor<TBias> bias{ bias_shape, bias_dt, 1, _quantization_info }; - // Fill reference fill(src, 0); fill(weights, 1); fill(bias, 2); + // Fill with bfloat16 to perform the conversion and reduce the mismatches in the output + if(_is_bfloat16) + { + regularize_values(static_cast<void *>(src.data()), src.num_elements()); + regularize_values(static_cast<void *>(weights.data()), weights.num_elements()); + } + return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups), act_info) : reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups); @@ -215,10 +238,12 @@ protected: DataType _data_type{}; DataType _weights_data_type{}; DataType _bias_data_type{}; + DataType _output_data_type{}; DataLayout _data_layout{}; QuantizationInfo _quantization_info{}; QuantizationInfo _weight_quantization_info{}; bool _is_quantized = false; + bool _is_bfloat16 = false; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> |