From c7b183ab741650653289f8ce3bdeb4926521fdbd Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 6 Mar 2020 18:12:09 +0000 Subject: COMPMID-3160: Add Bfloat16 support in NEGEMMConvolutionLayer Signed-off-by: Georgios Pinitas Change-Id: I0e449306c138a562ffc1455e76ec44b2fd059d85 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2860 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../validation/fixtures/ConvolutionLayerFixture.h | 39 ++++++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 3c4b625ac6..b4abebe18d 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -64,7 +64,9 @@ public: _data_type = data_type; _weights_data_type = weights_data_type; _is_quantized = is_data_type_quantized_asymmetric(data_type); - _bias_data_type = _is_quantized ? DataType::S32 : data_type; + _is_bfloat16 = data_type == DataType::BFLOAT16; + _bias_data_type = _is_quantized ? DataType::S32 : (_is_bfloat16 ? DataType::F32 : data_type); + _output_data_type = _is_bfloat16 ? DataType::F32 : data_type; _quantization_info = quantization_info; _weight_quantization_info = weight_quantization_info; _data_layout = data_layout; @@ -74,6 +76,15 @@ public: } protected: + void regularize_values(void *values, size_t size) + { + float *fvalues = static_cast(values); + for(size_t i = 0; i < size; ++i) + { + fvalues[i] = float(bfloat16(fvalues[i])); + } + } + template void fill(U &&tensor, int i) { @@ -119,6 +130,7 @@ protected: library->fill(tensor, distribution, i); break; } + case DataType::BFLOAT16: case DataType::F16: case DataType::F32: { @@ -155,7 +167,7 @@ protected: TensorType src = create_tensor(input_shape, _data_type, 1, _quantization_info, _data_layout); TensorType weights = create_tensor(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout); TensorType bias = create_tensor(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout); - TensorType dst = create_tensor(output_shape, _data_type, 1, _quantization_info, _data_layout); + TensorType dst = create_tensor(output_shape, _output_data_type, 1, _quantization_info, _data_layout); // Create and configure function FunctionType conv; @@ -195,16 +207,27 @@ protected: const unsigned int num_groups = input_shape[2] / weights_shape[2]; + // Setup reference data types + const DataType src_dt = _is_bfloat16 ? DataType::F32 : _data_type; + const DataType weights_dt = _is_bfloat16 ? DataType::F32 : _weights_data_type; + const DataType bias_dt = _is_bfloat16 ? DataType::F32 : _bias_data_type; + // Create reference - SimpleTensor src{ input_shape, _data_type, 1, _quantization_info }; - SimpleTensor weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info }; - SimpleTensor bias{ bias_shape, _bias_data_type, 1, _quantization_info }; + SimpleTensor src{ input_shape, src_dt, 1, _quantization_info }; + SimpleTensor weights{ weights_shape, weights_dt, 1, _weight_quantization_info }; + SimpleTensor bias{ bias_shape, bias_dt, 1, _quantization_info }; - // Fill reference fill(src, 0); fill(weights, 1); fill(bias, 2); + // Fill with bfloat16 to perform the conversion and reduce the mismatches in the output + if(_is_bfloat16) + { + regularize_values(static_cast(src.data()), src.num_elements()); + regularize_values(static_cast(weights.data()), weights.num_elements()); + } + return (act_info.enabled()) ? reference::activation_layer(reference::convolution_layer(src, weights, bias, output_shape, info, dilation, num_groups), act_info) : reference::convolution_layer(src, weights, bias, output_shape, info, dilation, num_groups); @@ -215,10 +238,12 @@ protected: DataType _data_type{}; DataType _weights_data_type{}; DataType _bias_data_type{}; + DataType _output_data_type{}; DataLayout _data_layout{}; QuantizationInfo _quantization_info{}; QuantizationInfo _weight_quantization_info{}; bool _is_quantized = false; + bool _is_bfloat16 = false; }; template -- cgit v1.2.1