diff options
Diffstat (limited to 'tests/validation/Reference.cpp')
-rw-r--r-- | tests/validation/Reference.cpp | 36 |
1 files changed, 24 insertions, 12 deletions
diff --git a/tests/validation/Reference.cpp b/tests/validation/Reference.cpp index 62dfcba37e..04362f0dc1 100644 --- a/tests/validation/Reference.cpp +++ b/tests/validation/Reference.cpp @@ -506,18 +506,30 @@ RawTensor Reference::compute_reference_convolution_layer(const TensorShape &inpu RawTensor ref_dst = library->get(output_shape, dt, 1, fixed_point_position); // Fill reference - if(dt == DataType::F16 || dt == DataType::F32) - { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); - library->fill(ref_src, distribution, 0); - library->fill(ref_weights, distribution, 1); - library->fill(ref_bias, distribution, 2); - } - else + switch(dt) { - library->fill_tensor_uniform(ref_src, 0); - library->fill_tensor_uniform(ref_weights, 1); - library->fill_tensor_uniform(ref_bias, 2); + case DataType::F32: + case DataType::F16: + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(ref_src, distribution, 0); + library->fill(ref_weights, distribution, 1); + library->fill(ref_bias, distribution, 2); + break; + } + case DataType::QS16: + case DataType::QS8: + { + library->fill_tensor_uniform(ref_src, 0); + library->fill_tensor_uniform(ref_weights, 1); + library->fill_tensor_uniform(ref_bias, 2); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + break; + } } // Compute reference @@ -546,7 +558,7 @@ RawTensor Reference::compute_reference_fully_connected_layer(const TensorShape & RawTensor ref_weights = library->get(ws, dt, 1, fixed_point_position); // Fill reference - if(dt == DataType::F32) + if(dt == DataType::F16 || dt == DataType::F32) { std::uniform_real_distribution<> distribution(-1.0f, 1.0f); library->fill(ref_src, distribution, 0); |