From a3e57c20a0b7a174f0c357676a4da40a248d04db Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Mon, 13 Mar 2023 16:20:04 +0000 Subject: Add dynamic weights for CPU fully connected layer Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../fixtures/FullyConnectedLayerFixture.h | 68 ++++++++++++++-------- 1 file changed, 43 insertions(+), 25 deletions(-) (limited to 'tests/validation/fixtures') diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index b5efccdf70..7d1aa494ba 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -326,23 +326,34 @@ private: } } - void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) { - if(_data_type == DataType::F32) - { - constexpr RelativeTolerance rel_tolerance_f32(0.05f); - constexpr AbsoluteTolerance abs_tolerance_f32(0.0001f); - validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); - } - else if(_data_type == DataType::QASYMM8) - { - constexpr AbsoluteTolerance tolerance_qasymm8(1); - validate(AccessorType(target), ref, tolerance_qasymm8); - } - else - { - validate(AccessorType(target), ref); - } + constexpr RelativeTolerance rel_tolerance_f32(0.01f); + constexpr AbsoluteTolerance abs_tolerance_f32(0.001f); + validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); + } + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); + const RelativeTolerance rel_tolerance_f16(half_float::half(0.2f)); + constexpr float tolerance_num_f16 = 0.07f; + + validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance tolerance_qasymm8(1); + validate(AccessorType(target), ref, tolerance_qasymm8); + } + + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance tolerance_qasymm8_signed(1); + validate(AccessorType(target), ref, tolerance_qasymm8_signed); } public: @@ -351,7 +362,7 @@ public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias) + DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped) { _data_type = data_type; @@ -368,7 +379,7 @@ public: _src.allocator()->init(src_info); TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); - if(!constant_weights) + if(!constant_weights && weights_reshaped) { const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; wei_info.set_tensor_shape(tr_weights_shape); @@ -388,8 +399,8 @@ public: fc_info.activation_info = activation_info; if(!constant_weights) { - fc_info.are_weights_reshaped = true; - fc_info.transpose_weights = false; + fc_info.are_weights_reshaped = weights_reshaped; + fc_info.transpose_weights = !weights_reshaped; } FunctionType fc; fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); @@ -428,7 +439,14 @@ public: fill(AccessorType(_src), randomizer_offset); if(!constant_weights) { - fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); + if(weights_reshaped) + { + fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); + } + else + { + fill(AccessorType(_weights), randomizer_offset + 1); + } } if(!constant_bias) { @@ -472,10 +490,10 @@ class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamic public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info) + DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, false, true); + dst_shape, data_type, activation_info, false, true, weights_reshaped); } }; @@ -488,7 +506,7 @@ public: DataType data_type, ActivationLayerInfo activation_info) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, true, false); + dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */); } }; } // namespace validation -- cgit v1.2.1