From a2bb80ea7111509c24caad8629533089decef430 Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Mon, 19 Jun 2023 14:57:57 +0100 Subject: Use MatMul in fully connected layer with dynamic weights when supported - Use MatMul kernels in FC layer when using dynamic weights without broadcasting or bias. - Fix minor typo in IClMatMulNativeKernelConfig.h Partially Resolves : [COMPMID-6193] Signed-off-by: Mohammed Suhail Munshi Change-Id: Id494062b5b4f4e75ff9714c202dde941955afa52 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9797 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Gunes Bayir Benchmark: Arm Jenkins --- .../fixtures/FullyConnectedLayerFixture.h | 63 +++++++++++++++------- 1 file changed, 43 insertions(+), 20 deletions(-) (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h') diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 75bef144ad..e13c01d1e2 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -335,9 +335,9 @@ private: void validate_with_tolerance(TensorType &target, SimpleTensor &ref) { - constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); + constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); const RelativeTolerance rel_tolerance_f16(half_float::half(0.2f)); - constexpr float tolerance_num_f16 = 0.07f; + constexpr float tolerance_num_f16 = 0.07f; validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); } @@ -360,36 +360,36 @@ public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped) + DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped, bool remove_bias = false) { _data_type = data_type; - const bool is_quantized = is_data_type_quantized(data_type); - + const bool is_quantized = is_data_type_quantized(data_type); const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type; const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo(); const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo(); const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo(); - // Setup tensor meta-data + // Configure TensorInfo Objects const TensorInfo src_info(src_shape, 1, data_type, src_qinfo); - _src.allocator()->init(src_info); + const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); + TensorInfo bias_info(bias_shape, 1, bias_data_type); + TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); - TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); if(!constant_weights && weights_reshaped) { const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; wei_info.set_tensor_shape(tr_weights_shape); } wei_info.set_are_values_constant(constant_weights); - _weights.allocator()->init(wei_info); - - TensorInfo bias_info(bias_shape, 1, bias_data_type); bias_info.set_are_values_constant(constant_bias); - _bias.allocator()->init(bias_info); - const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); + // Initialise Tensors + _src.allocator()->init(src_info); + _weights.allocator()->init(wei_info); + if(!remove_bias) + _bias.allocator()->init(bias_info); _dst.allocator()->init(dst_info); // Configure FC layer and mark the weights as non constant @@ -401,12 +401,13 @@ public: fc_info.transpose_weights = !weights_reshaped; } FunctionType fc; - fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); + fc.configure(&_src, &_weights, (remove_bias) ? nullptr : &_bias, &_dst, fc_info); // Allocate all the tensors _src.allocator()->allocate(); _weights.allocator()->allocate(); - _bias.allocator()->allocate(); + if(!remove_bias) + _bias.allocator()->allocate(); _dst.allocator()->allocate(); // Run multiple iterations with different inputs @@ -424,11 +425,20 @@ public: fill(AccessorType(_weights), 1); fill(weights, 1); } - if(constant_bias) + if(constant_bias && !remove_bias) { fill(AccessorType(_bias), 2); fill(bias, 2); } + // To remove bias, fill with 0 + if(remove_bias && is_quantized) + { + library->fill_tensor_value(bias, 0); + } + else if(remove_bias) + { + library->fill_tensor_value(bias, (float)0.0); + } for(int i = 0; i < num_iterations; ++i) { @@ -446,7 +456,7 @@ public: fill(AccessorType(_weights), randomizer_offset + 1); } } - if(!constant_bias) + if(!constant_bias && !remove_bias) { fill(AccessorType(_bias), randomizer_offset + 2); } @@ -462,7 +472,7 @@ public: { fill(weights, randomizer_offset + 1); } - if(!constant_bias) + if(!constant_bias && !remove_bias) { fill(bias, randomizer_offset + 2); } @@ -491,7 +501,20 @@ public: DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, false, true, weights_reshaped); + dst_shape, data_type, activation_info, false, true, weights_reshaped, false); + } +}; + +template +class FullyConnectedDynamicNoBiasFixture : public FullyConnectedWithDynamicTensorsFixture +{ +public: + template + void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, + DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) + { + FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, + dst_shape, data_type, activation_info, false, true, weights_reshaped, true); } }; @@ -504,7 +527,7 @@ public: DataType data_type, ActivationLayerInfo activation_info) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */); + dst_shape, data_type, activation_info, true, false, false, false /* weights_reshaped (not used) */); } }; } // namespace validation -- cgit v1.2.1