aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h77
1 files changed, 61 insertions, 16 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 52fa8da60b..c5cddc28db 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -48,7 +48,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationGenericFixture : public framework::Fixture
{
public:
@@ -57,13 +57,15 @@ public:
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights,
- DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
+ DataType data_type, DataType weights_data_type, DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info, ActivationLayerInfo act_info)
{
- _data_type = data_type;
- _is_quantized = is_data_type_quantized_asymmetric(data_type);
- _bias_data_type = _is_quantized ? DataType::S32 : data_type;
- _quantization_info = quantization_info;
- _data_layout = data_layout;
+ _data_type = data_type;
+ _weights_data_type = weights_data_type;
+ _is_quantized = is_data_type_quantized_asymmetric(data_type);
+ _bias_data_type = _is_quantized ? DataType::S32 : data_type;
+ _quantization_info = quantization_info;
+ _weight_quantization_info = weight_quantization_info;
+ _data_layout = data_layout;
_target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info);
_reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info);
@@ -82,6 +84,26 @@ protected:
library->fill(tensor, distribution, i);
break;
}
+ case DataType::QSYMM8_PER_CHANNEL:
+ {
+ int min_bound = 128;
+ int max_bound = -127;
+ for(size_t i = 0; i < _weight_quantization_info.scale().size(); i++)
+ {
+ std::pair<int, int> bounds = get_symm_quantized_per_channel_bounds(tensor.quantization_info(), -1.0f, 1.0f, i);
+ if(bounds.first < min_bound)
+ {
+ min_bound = bounds.first;
+ }
+ if(bounds.second > max_bound)
+ {
+ max_bound = bounds.second;
+ }
+ }
+ std::uniform_int_distribution<int8_t> distribution(min_bound, max_bound);
+ library->fill(tensor, distribution, i);
+ break;
+ }
case DataType::S32:
{
std::uniform_int_distribution<int32_t> distribution(-100, 100);
@@ -122,7 +144,7 @@ protected:
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info, _data_layout);
- TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _data_type, 1, _quantization_info, _data_layout);
+ TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _weights_data_type, 1, _weight_quantization_info, _data_layout);
TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info, _data_layout);
TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info, _data_layout);
@@ -166,7 +188,7 @@ protected:
// Create reference
SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info };
- SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info };
+ SimpleTensor<TW> weights{ weights_shape, _weights_data_type, 1, _weight_quantization_info };
SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info };
// Fill reference
@@ -182,36 +204,59 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
DataType _data_type{};
+ DataType _weights_data_type{};
DataType _bias_data_type{};
DataLayout _data_layout{};
QuantizationInfo _quantization_info{};
+ QuantizationInfo _weight_quantization_info{};
bool _is_quantized = false;
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
DataLayout data_layout, ActivationLayerInfo act_info)
{
- ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
- data_type, data_layout,
- QuantizationInfo(), act_info);
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
+ data_type, data_type, data_layout,
+ QuantizationInfo(), QuantizationInfo(), act_info);
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
{
- ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
- data_type, data_layout, quantization_info, act_info);
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights,
+ data_type, data_type, data_layout, quantization_info, quantization_info, act_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
+class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
+ DataLayout data_layout, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataType weights_data_type)
+ {
+ std::vector<float> weights_scales{};
+ std::mt19937 gen(library->seed());
+ std::uniform_real_distribution<> dis(0.01f, 1);
+ for(size_t i = 0; i < output_shape[2]; ++i)
+ {
+ weights_scales.push_back(dis(gen));
+ }
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation,
+ reshape_weights, data_type, weights_data_type, data_layout,
+ quantization_info, QuantizationInfo(weights_scales), act_info);
}
};
} // namespace validation