From f450caa7d2ac9a2a90407fb81203228dc82ef4a1 Mon Sep 17 00:00:00 2001 From: Chunosov Date: Wed, 8 Nov 2017 16:09:35 +0700 Subject: COMPMID-661: softmax-uint8 implementation (#16) Change-Id: Iad11ce70a8a0878a48e445a092035c49c926cece Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94855 Tested-by: Kaizen Reviewed-by: Anthony Barbier --- tests/validation/fixtures/SoftmaxLayerFixture.h | 61 ++++++++++++++++++------- 1 file changed, 45 insertions(+), 16 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h index 9c8f044e81..9836502cd2 100644 --- a/tests/validation/fixtures/SoftmaxLayerFixture.h +++ b/tests/validation/fixtures/SoftmaxLayerFixture.h @@ -43,27 +43,33 @@ namespace test namespace validation { template -class SoftmaxValidationFixedPointFixture : public framework::Fixture +class SoftmaxValidationGenericFixture : public framework::Fixture { public: template - void setup(TensorShape shape, DataType data_type, int fractional_bits) + void setup(TensorShape shape, DataType data_type, int fractional_bits, QuantizationInfo quantization_info) { - _fractional_bits = fractional_bits; + _fractional_bits = fractional_bits; + _quantization_info = quantization_info; - _target = compute_target(shape, data_type, fractional_bits); - _reference = compute_reference(shape, data_type, fractional_bits); + _target = compute_target(shape, data_type, fractional_bits, quantization_info); + _reference = compute_reference(shape, data_type, fractional_bits, quantization_info); } protected: template void fill(U &&tensor) { - if(_fractional_bits == 0) + if(!is_data_type_quantized(tensor.data_type())) { std::uniform_real_distribution<> distribution(-1000.f, 1000.f); library->fill(tensor, distribution, 0); } + else if(is_data_type_quantized_asymmetric(tensor.data_type())) + { + std::uniform_int_distribution<> distribution(0, 100); + library->fill(tensor, distribution, 0); + } else { const int one_fixed = 1 << _fractional_bits; @@ -72,11 +78,11 @@ protected: } } - TensorType compute_target(const TensorShape &shape, DataType data_type, int fixed_point_position = 0) + TensorType compute_target(const TensorShape &shape, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create tensors - TensorType src = create_tensor(shape, data_type, 1, fixed_point_position); - TensorType dst = create_tensor(shape, data_type, 1, fixed_point_position); + TensorType src = create_tensor(shape, data_type, 1, fixed_point_position, quantization_info); + TensorType dst = create_tensor(shape, data_type, 1, fixed_point_position, QuantizationInfo(1.f / 256, 0)); // Create and configure function FunctionType smx_layer; @@ -101,10 +107,10 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType data_type, int fixed_point_position = 0) + SimpleTensor compute_reference(const TensorShape &shape, DataType data_type, int fixed_point_position, QuantizationInfo quantization_info) { // Create reference - SimpleTensor src{ shape, data_type, 1, fixed_point_position }; + SimpleTensor src{ shape, data_type, 1, fixed_point_position, quantization_info }; // Fill reference fill(src); @@ -112,19 +118,42 @@ protected: return reference::softmax_layer(src); } - TensorType _target{}; - SimpleTensor _reference{}; - int _fractional_bits{}; + TensorType _target{}; + SimpleTensor _reference{}; + int _fractional_bits{}; + QuantizationInfo _quantization_info{}; }; template -class SoftmaxValidationFixture : public SoftmaxValidationFixedPointFixture +class SoftmaxValidationFixture : public SoftmaxValidationGenericFixture { public: template void setup(TensorShape shape, DataType data_type) { - SoftmaxValidationFixedPointFixture::setup(shape, data_type, 0); + SoftmaxValidationGenericFixture::setup(shape, data_type, 0, QuantizationInfo()); + } +}; + +template +class SoftmaxValidationFixedPointFixture : public SoftmaxValidationGenericFixture +{ +public: + template + void setup(TensorShape shape, DataType data_type, int fixed_point_position) + { + SoftmaxValidationGenericFixture::setup(shape, data_type, fixed_point_position, QuantizationInfo()); + } +}; + +template +class SoftmaxValidationQuantizedFixture : public SoftmaxValidationGenericFixture +{ +public: + template + void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info) + { + SoftmaxValidationGenericFixture::setup(shape, data_type, 0, quantization_info); } }; } // namespace validation -- cgit v1.2.1