aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h39
1 files changed, 32 insertions, 7 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 3c4b625ac6..b4abebe18d 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -64,7 +64,9 @@ public:
_data_type = data_type;
_weights_data_type = weights_data_type;
_is_quantized = is_data_type_quantized_asymmetric(data_type);
- _bias_data_type = _is_quantized ? DataType::S32 : data_type;
+ _is_bfloat16 = data_type == DataType::BFLOAT16;
+ _bias_data_type = _is_quantized ? DataType::S32 : (_is_bfloat16 ? DataType::F32 : data_type);
+ _output_data_type = _is_bfloat16 ? DataType::F32 : data_type;
_quantization_info = quantization_info;
_weight_quantization_info = weight_quantization_info;
_data_layout = data_layout;
@@ -74,6 +76,15 @@ public:
}
protected:
+ void regularize_values(void *values, size_t size)
+ {
+ float *fvalues = static_cast<float *>(values);
+ for(size_t i = 0; i < size; ++i)
+ {
+ fvalues[i] = float(bfloat16(fvalues[i]));
+ }
+ }
+
template <typename U>
void fill(U &&tensor, int i)
{
@@ -119,6 +130,7 @@ protected:
library->fill(tensor, distribution, i);
break;
}
+ case DataType::BFLOAT16:
case DataType::F16:
case DataType::F32:
{
@@ -155,7 +167,7 @@ protected:
TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info, _data_layout);
TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout);
TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout);
- TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info, _data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, _output_data_type, 1, _quantization_info, _data_layout);
// Create and configure function
FunctionType conv;
@@ -195,16 +207,27 @@ protected:
const unsigned int num_groups = input_shape[2] / weights_shape[2];
+ // Setup reference data types
+ const DataType src_dt = _is_bfloat16 ? DataType::F32 : _data_type;
+ const DataType weights_dt = _is_bfloat16 ? DataType::F32 : _weights_data_type;
+ const DataType bias_dt = _is_bfloat16 ? DataType::F32 : _bias_data_type;
+
// Create reference
- SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info };
- SimpleTensor<TW> weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info };
- SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info };
+ SimpleTensor<T> src{ input_shape, src_dt, 1, _quantization_info };
+ SimpleTensor<TW> weights{ weights_shape, weights_dt, 1, _weight_quantization_info };
+ SimpleTensor<TBias> bias{ bias_shape, bias_dt, 1, _quantization_info };
- // Fill reference
fill(src, 0);
fill(weights, 1);
fill(bias, 2);
+ // Fill with bfloat16 to perform the conversion and reduce the mismatches in the output
+ if(_is_bfloat16)
+ {
+ regularize_values(static_cast<void *>(src.data()), src.num_elements());
+ regularize_values(static_cast<void *>(weights.data()), weights.num_elements());
+ }
+
return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups),
act_info) :
reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation, num_groups);
@@ -215,10 +238,12 @@ protected:
DataType _data_type{};
DataType _weights_data_type{};
DataType _bias_data_type{};
+ DataType _output_data_type{};
DataLayout _data_layout{};
QuantizationInfo _quantization_info{};
QuantizationInfo _weight_quantization_info{};
bool _is_quantized = false;
+ bool _is_bfloat16 = false;
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>