aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-03-13 16:20:04 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-21 10:33:53 +0000
commita3e57c20a0b7a174f0c357676a4da40a248d04db (patch)
treed92b2316a00db6ce07dd2af476791281fcc98de6
parent8918b23073851417e8be6e5e53c6380dbdedf201 (diff)
downloadComputeLibrary-a3e57c20a0b7a174f0c357676a4da40a248d04db.tar.gz
Add dynamic weights for CPU fully connected layer
Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/operators/CpuFullyConnected.cpp50
-rw-r--r--src/cpu/operators/CpuFullyConnected.h6
-rw-r--r--src/cpu/operators/CpuGemm.cpp15
-rw-r--r--src/runtime/NEON/functions/NEFullyConnectedLayer.cpp12
-rw-r--r--tests/validation/CL/FullyConnectedLayer.cpp7
-rw-r--r--tests/validation/NEON/FullyConnectedLayer.cpp25
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h68
7 files changed, 133 insertions, 50 deletions
diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp
index 1e1598a8ee..af630154cf 100644
--- a/src/cpu/operators/CpuFullyConnected.cpp
+++ b/src/cpu/operators/CpuFullyConnected.cpp
@@ -166,7 +166,8 @@ CpuFullyConnected::CpuFullyConnected()
_is_prepared(false),
_enable_fast_math(false),
_fixed_format(false),
- _weight_format(arm_compute::WeightFormat::UNSPECIFIED)
+ _weight_format(arm_compute::WeightFormat::UNSPECIFIED),
+ _dynamic_weights(false)
{
}
@@ -189,7 +190,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo *
const Status status = get_gemmlowp_output_stage_info(&src_info, &weights_info, dst, act, gemmlowp_output_stage_info);
ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK);
- GEMMInfo gemm_info;
+ GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info);
gemm_info.set_activation_info(act);
gemm_info.set_fast_math(_enable_fast_math);
@@ -199,7 +200,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo *
else
{
// Configure matrix multiply kernel
- GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */);
+ GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */);
gemm_info.set_activation_info(act);
gemm_info.set_fast_math(_enable_fast_math);
gemm_info.set_fixed_format(_fixed_format);
@@ -256,6 +257,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
_enable_fast_math = fc_info.enable_fast_math;
_fixed_format = weights_info.weight_format() != WeightFormat::UNSPECIFIED;
_weight_format = weights_info.weight_format();
+ _dynamic_weights = !weights->are_values_constant() && _needs_weights_reshape;
// With the Fully Connected layer we can have 4 different cases:
// 1) Convolution layer -> Fully Connected layer without batches
@@ -329,15 +331,32 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei
{
// Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch
// Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation
- _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases
- && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare,
- _reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size());
+ // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time.
+ _aux_mem[TransposedWeights] = MemoryInfo(
+ offset_int_vec(TransposedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary :
+ (_is_quantized_asymmetric && biases && !(biases->are_values_constant())) ? MemoryLifetime::Persistent :
+ MemoryLifetime::Prepare,
+ _reshaped_weights.total_size());
+
+ _aux_mem[ConvertedWeights] = MemoryInfo(
+ offset_int_vec(ConvertedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+ _converted_weights.total_size());
}
else
{
- _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), _needs_weights_conversion ? MemoryLifetime::Prepare : MemoryLifetime::Persistent, _reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Persistent, _converted_weights.total_size());
+ _aux_mem[TransposedWeights] = MemoryInfo(
+ offset_int_vec(TransposedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary :
+ _needs_weights_conversion ? MemoryLifetime::Prepare :
+ MemoryLifetime::Persistent,
+ _reshaped_weights.total_size());
+
+ _aux_mem[ConvertedWeights] = MemoryInfo(
+ offset_int_vec(ConvertedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Persistent,
+ _converted_weights.total_size());
}
_aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size());
}
@@ -375,7 +394,6 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
&& fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
- ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights));
bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
bool is_fc_after_conv = true;
@@ -459,6 +477,11 @@ void CpuFullyConnected::run(ITensorPack &tensors)
{
prepare(tensors);
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+ ++_asrt_run_count;
+ ARM_COMPUTE_ERROR_ON(_dynamic_weights && _asrt_prepare_count != _asrt_run_count);
+#endif // ARM_COMPUTE_ASSERTS_ENABLED
+
auto src = tensors.get_const_tensor(ACL_SRC_0);
CpuAuxTensorHandler flattened_src(offset_int_vec(FlattenedSrc), _flattened_src, tensors, false);
@@ -491,8 +514,13 @@ void CpuFullyConnected::run(ITensorPack &tensors)
void CpuFullyConnected::prepare(ITensorPack &tensors)
{
- if(!_is_prepared)
+ if(!_is_prepared || _dynamic_weights)
{
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+ ++_asrt_prepare_count;
+ ARM_COMPUTE_ERROR_ON(!_dynamic_weights && _asrt_prepare_count > 1);
+#endif // ARM_COMPUTE_ASSERTS_ENABLED
+
auto weights = tensors.get_const_tensor(ACL_SRC_1);
CpuAuxTensorHandler reshaped_weights(offset_int_vec(TransposedWeights), _reshaped_weights, tensors, false);
diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h
index 9cd67f2ca6..a5a464f67a 100644
--- a/src/cpu/operators/CpuFullyConnected.h
+++ b/src/cpu/operators/CpuFullyConnected.h
@@ -155,6 +155,12 @@ private:
bool _enable_fast_math;
bool _fixed_format;
arm_compute::WeightFormat _weight_format;
+ bool _dynamic_weights;
+
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+ int _asrt_run_count{};
+ int _asrt_prepare_count{};
+#endif // ARM_COMPUTE_ASSERTS_ENABLED
};
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 545d59f410..f914bceec3 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info);
const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
// Check if we need to reshape the matrix B only on the first run
@@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
_reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
_run_vector_matrix_multiplication = a->dimension(1) < 2;
_run_alpha_scale = alpha != 1.f;
- _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run();
- _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
+ _run_bias_addition = is_c_bias;
+ _run_addition = beta != 0 && beta != 1 && c != nullptr;
_run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
if(run_optimised)
@@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso
Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
- const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
+ const bool is_c_bias = beta == 1 && c != nullptr;
+ const bool run_addition = c != nullptr && beta != 0 && beta != 1;
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
-
+
if (is_fixed_format_fast_math(gemm_info.weight_format()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
@@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
}
- if(c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
@@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens
}
// Validate matrix addition kernel
- if(beta != 0 && c != nullptr && !is_c_bias)
+ if(run_addition)
{
ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
}
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 919e5ed84f..891487efd3 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -49,6 +49,7 @@ struct NEFullyConnectedLayer::Impl
experimental::MemoryRequirements aux_mem_req{};
bool is_prepared{ false };
+ bool dynamic_weights{ false };
};
NEFullyConnectedLayer::~NEFullyConnectedLayer() = default;
@@ -87,6 +88,12 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
_impl->aux_mem_req = _impl->op->workspace();
_impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } };
_impl->workspace = manage_workspace<Tensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack);
+
+ _impl->dynamic_weights =
+ !weights->info()->are_values_constant() &&
+ fc_info.transpose_weights &&
+ !fc_info.are_weights_reshaped &&
+ !fc_info.retain_internal_weights;
}
Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights,
@@ -104,7 +111,10 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
void NEFullyConnectedLayer::run()
{
- prepare();
+ if(!_impl->dynamic_weights)
+ {
+ prepare();
+ }
MemoryGroupResourceScope scope_mg(_impl->memory_group);
_impl->op->run(_impl->run_pack);
diff --git a/tests/validation/CL/FullyConnectedLayer.cpp b/tests/validation/CL/FullyConnectedLayer.cpp
index 09da519c51..fcfae4e156 100644
--- a/tests/validation/CL/FullyConnectedLayer.cpp
+++ b/tests/validation/CL/FullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -172,9 +172,10 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerMixedDataLayoutF
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
+ framework::dataset::make("WeightsReshaped", { true })))
{
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters),
diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp
index 5639fb47da..882ad04cd2 100644
--- a/tests/validation/NEON/FullyConnectedLayer.cpp
+++ b/tests/validation/NEON/FullyConnectedLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -323,6 +323,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture<half>, framework::
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16);
}
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
+ framework::dataset::make("WeightsReshaped", { false, true })))
+{
+}
TEST_SUITE_END()
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -362,9 +368,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture<float>, framework:
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32);
}
-FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
+ framework::dataset::make("WeightsReshaped", { false, true })))
{
}
TEST_SUITE_END()
@@ -428,6 +435,12 @@ FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture<u
framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))))
{
}
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
+ framework::dataset::make("WeightsReshaped", { false, true })))
+{
+}
TEST_SUITE_END()
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(
@@ -464,6 +477,12 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture<
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
}
+FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))),
+ framework::dataset::make("WeightsReshaped", { false, true })))
+{
+}
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE_END() // Quantized
TEST_SUITE_END() // FullyConnectedLayer
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index b5efccdf70..7d1aa494ba 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -326,23 +326,34 @@ private:
}
}
- void validate_with_tolerance(TensorType &target, SimpleTensor<T> &ref)
+ void validate_with_tolerance(TensorType &target, SimpleTensor<float> &ref)
{
- if(_data_type == DataType::F32)
- {
- constexpr RelativeTolerance<float> rel_tolerance_f32(0.05f);
- constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
- validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
- }
- else if(_data_type == DataType::QASYMM8)
- {
- constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1);
- validate(AccessorType(target), ref, tolerance_qasymm8);
- }
- else
- {
- validate(AccessorType(target), ref);
- }
+ constexpr RelativeTolerance<float> rel_tolerance_f32(0.01f);
+ constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.001f);
+ validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
+ }
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ void validate_with_tolerance(TensorType &target, SimpleTensor<half_float::half> &ref)
+ {
+ constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
+ const RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.2f));
+ constexpr float tolerance_num_f16 = 0.07f;
+
+ validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16);
+ }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+ void validate_with_tolerance(TensorType &target, SimpleTensor<uint8_t> &ref)
+ {
+ constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1);
+ validate(AccessorType(target), ref, tolerance_qasymm8);
+ }
+
+ void validate_with_tolerance(TensorType &target, SimpleTensor<int8_t> &ref)
+ {
+ constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8_signed(1);
+ validate(AccessorType(target), ref, tolerance_qasymm8_signed);
}
public:
@@ -351,7 +362,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
+ DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped)
{
_data_type = data_type;
@@ -368,7 +379,7 @@ public:
_src.allocator()->init(src_info);
TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
- if(!constant_weights)
+ if(!constant_weights && weights_reshaped)
{
const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
wei_info.set_tensor_shape(tr_weights_shape);
@@ -388,8 +399,8 @@ public:
fc_info.activation_info = activation_info;
if(!constant_weights)
{
- fc_info.are_weights_reshaped = true;
- fc_info.transpose_weights = false;
+ fc_info.are_weights_reshaped = weights_reshaped;
+ fc_info.transpose_weights = !weights_reshaped;
}
FunctionType fc;
fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
@@ -428,7 +439,14 @@ public:
fill(AccessorType(_src), randomizer_offset);
if(!constant_weights)
{
- fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ if(weights_reshaped)
+ {
+ fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ }
+ else
+ {
+ fill(AccessorType(_weights), randomizer_offset + 1);
+ }
}
if(!constant_bias)
{
@@ -472,10 +490,10 @@ class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamic
public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info)
+ DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, false, true);
+ dst_shape, data_type, activation_info, false, true, weights_reshaped);
}
};
@@ -488,7 +506,7 @@ public:
DataType data_type, ActivationLayerInfo activation_info)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, true, false);
+ dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */);
}
};
} // namespace validation