From a3e57c20a0b7a174f0c357676a4da40a248d04db Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Mon, 13 Mar 2023 16:20:04 +0000 Subject: Add dynamic weights for CPU fully connected layer Resolves: COMPMID-5917 Signed-off-by: Viet-Hoa Do Change-Id: I073067b490f2a1b96b81a037ea431c9a2e5c7503 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9322 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- src/cpu/operators/CpuFullyConnected.cpp | 50 ++++++++++++---- src/cpu/operators/CpuFullyConnected.h | 6 ++ src/cpu/operators/CpuGemm.cpp | 15 ++--- .../NEON/functions/NEFullyConnectedLayer.cpp | 12 +++- tests/validation/CL/FullyConnectedLayer.cpp | 7 ++- tests/validation/NEON/FullyConnectedLayer.cpp | 25 +++++++- .../fixtures/FullyConnectedLayerFixture.h | 68 ++++++++++++++-------- 7 files changed, 133 insertions(+), 50 deletions(-) diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index 1e1598a8ee..af630154cf 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -166,7 +166,8 @@ CpuFullyConnected::CpuFullyConnected() _is_prepared(false), _enable_fast_math(false), _fixed_format(false), - _weight_format(arm_compute::WeightFormat::UNSPECIFIED) + _weight_format(arm_compute::WeightFormat::UNSPECIFIED), + _dynamic_weights(false) { } @@ -189,7 +190,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo * const Status status = get_gemmlowp_output_stage_info(&src_info, &weights_info, dst, act, gemmlowp_output_stage_info); ARM_COMPUTE_ERROR_ON(status.error_code() != ErrorCode::OK); - GEMMInfo gemm_info; + GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */); gemm_info.set_gemmlowp_output_stage(gemmlowp_output_stage_info); gemm_info.set_activation_info(act); gemm_info.set_fast_math(_enable_fast_math); @@ -199,7 +200,7 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo * else { // Configure matrix multiply kernel - GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + GEMMInfo gemm_info(false, false, !_dynamic_weights /* Reshape weights only for the first run */); gemm_info.set_activation_info(act); gemm_info.set_fast_math(_enable_fast_math); gemm_info.set_fixed_format(_fixed_format); @@ -256,6 +257,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei _enable_fast_math = fc_info.enable_fast_math; _fixed_format = weights_info.weight_format() != WeightFormat::UNSPECIFIED; _weight_format = weights_info.weight_format(); + _dynamic_weights = !weights->are_values_constant() && _needs_weights_reshape; // With the Fully Connected layer we can have 4 different cases: // 1) Convolution layer -> Fully Connected layer without batches @@ -329,15 +331,32 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei { // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation - _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases - && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare, - _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); + // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time. + _aux_mem[TransposedWeights] = MemoryInfo( + offset_int_vec(TransposedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : + (_is_quantized_asymmetric && biases && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : + MemoryLifetime::Prepare, + _reshaped_weights.total_size()); + + _aux_mem[ConvertedWeights] = MemoryInfo( + offset_int_vec(ConvertedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare, + _converted_weights.total_size()); } else { - _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), _needs_weights_conversion ? MemoryLifetime::Prepare : MemoryLifetime::Persistent, _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Persistent, _converted_weights.total_size()); + _aux_mem[TransposedWeights] = MemoryInfo( + offset_int_vec(TransposedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : + _needs_weights_conversion ? MemoryLifetime::Prepare : + MemoryLifetime::Persistent, + _reshaped_weights.total_size()); + + _aux_mem[ConvertedWeights] = MemoryInfo( + offset_int_vec(ConvertedWeights), + _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Persistent, + _converted_weights.total_size()); } _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size()); } @@ -375,7 +394,6 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); - ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; bool is_fc_after_conv = true; @@ -459,6 +477,11 @@ void CpuFullyConnected::run(ITensorPack &tensors) { prepare(tensors); +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + ++_asrt_run_count; + ARM_COMPUTE_ERROR_ON(_dynamic_weights && _asrt_prepare_count != _asrt_run_count); +#endif // ARM_COMPUTE_ASSERTS_ENABLED + auto src = tensors.get_const_tensor(ACL_SRC_0); CpuAuxTensorHandler flattened_src(offset_int_vec(FlattenedSrc), _flattened_src, tensors, false); @@ -491,8 +514,13 @@ void CpuFullyConnected::run(ITensorPack &tensors) void CpuFullyConnected::prepare(ITensorPack &tensors) { - if(!_is_prepared) + if(!_is_prepared || _dynamic_weights) { +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + ++_asrt_prepare_count; + ARM_COMPUTE_ERROR_ON(!_dynamic_weights && _asrt_prepare_count > 1); +#endif // ARM_COMPUTE_ASSERTS_ENABLED + auto weights = tensors.get_const_tensor(ACL_SRC_1); CpuAuxTensorHandler reshaped_weights(offset_int_vec(TransposedWeights), _reshaped_weights, tensors, false); diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h index 9cd67f2ca6..a5a464f67a 100644 --- a/src/cpu/operators/CpuFullyConnected.h +++ b/src/cpu/operators/CpuFullyConnected.h @@ -155,6 +155,12 @@ private: bool _enable_fast_math; bool _fixed_format; arm_compute::WeightFormat _weight_format; + bool _dynamic_weights; + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + int _asrt_run_count{}; + int _asrt_prepare_count{}; +#endif // ARM_COMPUTE_ASSERTS_ENABLED }; } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 545d59f410..f914bceec3 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -64,7 +64,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info); const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; bool run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run(); // Check if we need to reshape the matrix B only on the first run @@ -72,8 +72,8 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); _run_vector_matrix_multiplication = a->dimension(1) < 2; _run_alpha_scale = alpha != 1.f; - _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run(); - _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run(); + _run_bias_addition = is_c_bias; + _run_addition = beta != 0 && beta != 1 && c != nullptr; _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info()))); if(run_optimised) @@ -153,12 +153,13 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); - const bool is_c_bias = gemm_info.reshape_b_only_on_first_run(); + const bool is_c_bias = beta == 1 && c != nullptr; + const bool run_addition = c != nullptr && beta != 0 && beta != 1; ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32); - + if (is_fixed_format_fast_math(gemm_info.weight_format())) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32); @@ -177,7 +178,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d); } - if(c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0); ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d()); @@ -265,7 +266,7 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens } // Validate matrix addition kernel - if(beta != 0 && c != nullptr && !is_c_bias) + if(run_addition) { ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta)); } diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 919e5ed84f..891487efd3 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -49,6 +49,7 @@ struct NEFullyConnectedLayer::Impl experimental::MemoryRequirements aux_mem_req{}; bool is_prepared{ false }; + bool dynamic_weights{ false }; }; NEFullyConnectedLayer::~NEFullyConnectedLayer() = default; @@ -87,6 +88,12 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->aux_mem_req = _impl->op->workspace(); _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } }; _impl->workspace = manage_workspace(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack); + + _impl->dynamic_weights = + !weights->info()->are_values_constant() && + fc_info.transpose_weights && + !fc_info.are_weights_reshaped && + !fc_info.retain_internal_weights; } Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights, @@ -104,7 +111,10 @@ Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn void NEFullyConnectedLayer::run() { - prepare(); + if(!_impl->dynamic_weights) + { + prepare(); + } MemoryGroupResourceScope scope_mg(_impl->memory_group); _impl->op->run(_impl->run_pack); diff --git a/tests/validation/CL/FullyConnectedLayer.cpp b/tests/validation/CL/FullyConnectedLayer.cpp index 09da519c51..fcfae4e156 100644 --- a/tests/validation/CL/FullyConnectedLayer.cpp +++ b/tests/validation/CL/FullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -172,9 +172,10 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerMixedDataLayoutF // Validate output validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))), + framework::dataset::make("WeightsReshaped", { true }))) { } FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters), diff --git a/tests/validation/NEON/FullyConnectedLayer.cpp b/tests/validation/NEON/FullyConnectedLayer.cpp index 5639fb47da..882ad04cd2 100644 --- a/tests/validation/NEON/FullyConnectedLayer.cpp +++ b/tests/validation/NEON/FullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -323,6 +323,12 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture, framework:: // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); } +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::F16)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))), + framework::dataset::make("WeightsReshaped", { false, true }))) +{ +} TEST_SUITE_END() #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -362,9 +368,10 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEFullyConnectedLayerFixture, framework: // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32); } -FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallFullyConnectedLayerDataset(), +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)))) + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))), + framework::dataset::make("WeightsReshaped", { false, true }))) { } TEST_SUITE_END() @@ -428,6 +435,12 @@ FIXTURE_DATA_TEST_CASE(RunDynamicBias, NEFullyConnectedLayerDynamicBiasFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))), + framework::dataset::make("WeightsReshaped", { false, true }))) +{ +} TEST_SUITE_END() TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEFullyConnectedLayerQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine( @@ -464,6 +477,12 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, NEFullyConnectedLayerQuantizedFixture< // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8_signed); } +FIXTURE_DATA_TEST_CASE(RunDynamicWeights, NEFullyConnectedLayerDynamicWeightsFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("ActivationInfo", ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU))), + framework::dataset::make("WeightsReshaped", { false, true }))) +{ +} TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE_END() // Quantized TEST_SUITE_END() // FullyConnectedLayer diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index b5efccdf70..7d1aa494ba 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -326,23 +326,34 @@ private: } } - void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) { - if(_data_type == DataType::F32) - { - constexpr RelativeTolerance rel_tolerance_f32(0.05f); - constexpr AbsoluteTolerance abs_tolerance_f32(0.0001f); - validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); - } - else if(_data_type == DataType::QASYMM8) - { - constexpr AbsoluteTolerance tolerance_qasymm8(1); - validate(AccessorType(target), ref, tolerance_qasymm8); - } - else - { - validate(AccessorType(target), ref); - } + constexpr RelativeTolerance rel_tolerance_f32(0.01f); + constexpr AbsoluteTolerance abs_tolerance_f32(0.001f); + validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); + } + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance abs_tolerance_f16(0.3f); + const RelativeTolerance rel_tolerance_f16(half_float::half(0.2f)); + constexpr float tolerance_num_f16 = 0.07f; + + validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance tolerance_qasymm8(1); + validate(AccessorType(target), ref, tolerance_qasymm8); + } + + void validate_with_tolerance(TensorType &target, SimpleTensor &ref) + { + constexpr AbsoluteTolerance tolerance_qasymm8_signed(1); + validate(AccessorType(target), ref, tolerance_qasymm8_signed); } public: @@ -351,7 +362,7 @@ public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias) + DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped) { _data_type = data_type; @@ -368,7 +379,7 @@ public: _src.allocator()->init(src_info); TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); - if(!constant_weights) + if(!constant_weights && weights_reshaped) { const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; wei_info.set_tensor_shape(tr_weights_shape); @@ -388,8 +399,8 @@ public: fc_info.activation_info = activation_info; if(!constant_weights) { - fc_info.are_weights_reshaped = true; - fc_info.transpose_weights = false; + fc_info.are_weights_reshaped = weights_reshaped; + fc_info.transpose_weights = !weights_reshaped; } FunctionType fc; fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); @@ -428,7 +439,14 @@ public: fill(AccessorType(_src), randomizer_offset); if(!constant_weights) { - fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); + if(weights_reshaped) + { + fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); + } + else + { + fill(AccessorType(_weights), randomizer_offset + 1); + } } if(!constant_bias) { @@ -472,10 +490,10 @@ class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamic public: template void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, - DataType data_type, ActivationLayerInfo activation_info) + DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, false, true); + dst_shape, data_type, activation_info, false, true, weights_reshaped); } }; @@ -488,7 +506,7 @@ public: DataType data_type, ActivationLayerInfo activation_info) { FullyConnectedWithDynamicTensorsFixture::setup(src_shape, weights_shape, bias_shape, - dst_shape, data_type, activation_info, true, false); + dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */); } }; } // namespace validation -- cgit v1.2.1