aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FullyConnectedLayerFixture.h
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-03-13 16:20:04 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-21 10:33:53 +0000
commita3e57c20a0b7a174f0c357676a4da40a248d04db (patch)
treed92b2316a00db6ce07dd2af476791281fcc98de6 /tests/validation/fixtures/FullyConnectedLayerFixture.h
parent8918b23073851417e8be6e5e53c6380dbdedf201 (diff)
downloadComputeLibrary-a3e57c20a0b7a174f0c357676a4da40a248d04db.tar.gz
Add dynamic weights for CPU fully connected layer
Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h68
1 files changed, 43 insertions, 25 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index b5efccdf70..7d1aa494ba 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -326,23 +326,34 @@ private:
}
}
- void validate_with_tolerance(TensorType &target, SimpleTensor<T> &ref)
+ void validate_with_tolerance(TensorType &target, SimpleTensor<float> &ref)
{
- if(_data_type == DataType::F32)
- {
- constexpr RelativeTolerance<float> rel_tolerance_f32(0.05f);
- constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
- validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
- }
- else if(_data_type == DataType::QASYMM8)
- {
- constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1);
- validate(AccessorType(target), ref, tolerance_qasymm8);
- }
- else
- {
- validate(AccessorType(target), ref);
- }
+ constexpr RelativeTolerance<float> rel_tolerance_f32(0.01f);
+ constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.001f);
+ validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
+ }
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ void validate_with_tolerance(TensorType &target, SimpleTensor<half_float::half> &ref)
+ {
+ constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
+ const RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.2f));
+ constexpr float tolerance_num_f16 = 0.07f;
+
+ validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16);
+ }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+ void validate_with_tolerance(TensorType &target, SimpleTensor<uint8_t> &ref)
+ {
+ constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1);
+ validate(AccessorType(target), ref, tolerance_qasymm8);
+ }
+
+ void validate_with_tolerance(TensorType &target, SimpleTensor<int8_t> &ref)
+ {
+ constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8_signed(1);
+ validate(AccessorType(target), ref, tolerance_qasymm8_signed);
}
public:
@@ -351,7 +362,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
+ DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped)
{
_data_type = data_type;
@@ -368,7 +379,7 @@ public:
_src.allocator()->init(src_info);
TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
- if(!constant_weights)
+ if(!constant_weights && weights_reshaped)
{
const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
wei_info.set_tensor_shape(tr_weights_shape);
@@ -388,8 +399,8 @@ public:
fc_info.activation_info = activation_info;
if(!constant_weights)
{
- fc_info.are_weights_reshaped = true;
- fc_info.transpose_weights = false;
+ fc_info.are_weights_reshaped = weights_reshaped;
+ fc_info.transpose_weights = !weights_reshaped;
}
FunctionType fc;
fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
@@ -428,7 +439,14 @@ public:
fill(AccessorType(_src), randomizer_offset);
if(!constant_weights)
{
- fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ if(weights_reshaped)
+ {
+ fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ }
+ else
+ {
+ fill(AccessorType(_weights), randomizer_offset + 1);
+ }
}
if(!constant_bias)
{
@@ -472,10 +490,10 @@ class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamic
public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info)
+ DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, false, true);
+ dst_shape, data_type, activation_info, false, true, weights_reshaped);
}
};
@@ -488,7 +506,7 @@ public:
DataType data_type, ActivationLayerInfo activation_info)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, true, false);
+ dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */);
}
};
} // namespace validation