From b3be45759bdd0749ae3a16fe470820f0d9830ea9 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Tue, 18 May 2021 10:46:00 +0100 Subject: Implement memory injection in CpuDirectGemmConv2d The following operators are now stateless by implementing memory injection. - CpuDirectGemmConv2d - CpuGemmAssemblyDispatch A test case is added to test if CpuDirectGemmConv2d can run on different group of tensors with a single configure. Resolves: COMPMID-4506 Change-Id: I48f44ed41236ca7e18da2de07bdbacc9007a3c5e Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5718 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello --- arm_compute/runtime/NEON/functions/NEGEMM.h | 3 +- .../NEON/functions/NEGEMMLowpMatrixMultiplyCore.h | 3 +- src/core/helpers/MemoryHelpers.h | 14 +- src/runtime/NEON/functions/NEGEMM.cpp | 19 +- src/runtime/NEON/functions/NEGEMMConv2d.cpp | 21 +- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 31 ++- src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp | 58 +++-- src/runtime/cpu/operators/CpuGemmDirectConv2d.h | 17 +- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 100 +++++---- .../operators/internal/CpuGemmAssemblyDispatch.h | 18 +- tests/validation/NEON/ConvolutionLayer.cpp | 239 ++++++++++++++------- 11 files changed, 360 insertions(+), 163 deletions(-) diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h index 6fa30bd545..a5d6bb6534 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMM.h +++ b/arm_compute/runtime/NEON/functions/NEGEMM.h @@ -146,7 +146,8 @@ private: bool _reshape_b_only_on_first_run; bool _is_prepared; - ITensorPack _asm_glue_tensors{}; + struct AsmGlueTensors; + std::unique_ptr _asm_glue_tensors; }; } // namespace arm_compute #endif /*ARM_COMPUTE_NEGEMM_H */ diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h index dc9783f9eb..ff50d6dbf7 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h @@ -171,7 +171,8 @@ private: bool _run_activation; bool _flip_signedness; - ITensorPack _asm_glue_tensors{}; + struct AsmGlueTensors; + std::unique_ptr _asm_glue_tensors; }; } // namespace arm_compute #endif /*ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H */ diff --git a/src/core/helpers/MemoryHelpers.h b/src/core/helpers/MemoryHelpers.h index 6756a90c25..dfa8e60758 100644 --- a/src/core/helpers/MemoryHelpers.h +++ b/src/core/helpers/MemoryHelpers.h @@ -56,12 +56,13 @@ WorkspaceData manage_workspace(const experimental::MemoryRequirement continue; } - const auto aux_info = TensorInfo{ TensorShape(req.size), 1, DataType::U8 }; + const auto alignment = req.alignment; + const auto aux_info = TensorInfo{ TensorShape(req.size + alignment), 1, DataType::U8 }; workspace_memory.emplace_back(req.slot, std::make_unique()); auto aux_tensor = workspace_memory.back().second.get(); ARM_COMPUTE_ERROR_ON_NULLPTR(aux_tensor); - aux_tensor->allocator()->init(aux_info); + aux_tensor->allocator()->init(aux_info, alignment); if(req.lifetime == experimental::MemoryLifetime::Temporary) { @@ -82,5 +83,14 @@ WorkspaceData manage_workspace(const experimental::MemoryRequirement return workspace_memory; } + +template +WorkspaceData manage_workspace(const experimental::MemoryRequirements &mem_reqs, + MemoryGroup &mgroup, + ITensorPack &run_pack) +{ + ITensorPack dummy_prep_pack{}; + return manage_workspace(mem_reqs, mgroup, run_pack, dummy_prep_pack); +} } // namespace arm_compute #endif /* SRC_COMMON_MEMORY_HELPERS_H */ diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 7318c3e492..b526874790 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -38,6 +38,7 @@ #include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h" #include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" #include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/MemoryHelpers.h" #include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" #include @@ -46,6 +47,14 @@ using namespace arm_compute::misc::shape_calculator; namespace arm_compute { +using WorkspaceDataType = WorkspaceData; + +struct NEGEMM::AsmGlueTensors +{ + ITensorPack tensors{}; + WorkspaceDataType ws{}; +}; + namespace { cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) @@ -63,7 +72,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) NEGEMM::NEGEMM(std::shared_ptr memory_manager, IWeightsManager *weights_manager) : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(std::make_unique()), _ma_kernel(), _alpha_scale_func(nullptr), _add_bias(), _activation_func(), _tmp_a(), _tmp_b(), _tmp_d(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_alpha_scale(false), - _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false) + _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _asm_glue_tensors(std::make_unique()) { } @@ -94,7 +103,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _asm_glue->configure(a->info(), b->info(), c_info_to_use, d->info(), asm_info); ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured()); - _asm_glue_tensors = + _asm_glue_tensors->tensors = { { ACL_SRC_0, a }, { ACL_SRC_1, b }, @@ -102,6 +111,8 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe { ACL_DST, d }, }; + _asm_glue_tensors->ws = manage_workspace(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors); + // Scale product by alpha if(_run_alpha_scale) { @@ -323,7 +334,7 @@ void NEGEMM::run() if(_asm_glue->is_configured()) { - _asm_glue->run(_asm_glue_tensors); + _asm_glue->run(_asm_glue_tensors->tensors); if(_run_alpha_scale) { _alpha_scale_func.run(); @@ -377,7 +388,7 @@ void NEGEMM::prepare() ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); } - _asm_glue->prepare(_asm_glue_tensors); + _asm_glue->prepare(_asm_glue_tensors->tensors); if(!original_b_managed_by_weights_manager) { _original_b->mark_as_unused(); diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp index 94ceb6d27c..790543a34a 100644 --- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp +++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp @@ -26,24 +26,37 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/core/helpers/MemoryHelpers.h" #include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h" #include namespace arm_compute { -using OperatorType = cpu::CpuGemmDirectConv2d; +using OperatorType = cpu::CpuGemmDirectConv2d; +using WorkspaceDataType = WorkspaceData; struct NEGEMMConv2d::Impl { ITensorPack tensors{}; + MemoryGroup mg{}; std::unique_ptr op{ nullptr }; + WorkspaceDataType ws{}; + + void allocate_and_add_workspace() + { + if(op) + { + ws = manage_workspace(op->workspace(), mg, tensors); + } + } }; NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr &memory_manager) : _impl(std::make_unique()) { - _impl->op = std::make_unique(memory_manager); + _impl->op = std::make_unique(); + _impl->mg = MemoryGroup(memory_manager); } NEGEMMConv2d::~NEGEMMConv2d() = default; @@ -55,7 +68,9 @@ void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITens _impl->tensors.add_const_tensor(TensorType::ACL_SRC_2, biases); _impl->tensors.add_tensor(TensorType::ACL_DST, output); - _impl->op->configure(input->info(), weights->info(), biases->info(), output->info(), info); + _impl->op->configure(input->info(), weights->info(), ((biases) ? biases->info() : nullptr), output->info(), info); + + _impl->allocate_and_add_workspace(); } Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info) diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index cc0f20e695..d42e656e0c 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -42,10 +42,17 @@ #include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" #include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h" #include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h" +#include "src/core/helpers/MemoryHelpers.h" #include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" namespace arm_compute { +using WorkspaceDataType = WorkspaceData; +struct NEGEMMLowpMatrixMultiplyCore::AsmGlueTensors +{ + ITensorPack tensors{}; + WorkspaceDataType ws{}; +}; namespace { cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) @@ -66,11 +73,11 @@ using namespace arm_compute::misc::shape_calculator; NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default; NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique(memory_manager, weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(), + : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique(weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(), _mtx_b_reshape_kernel(), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _convert_to_signed_asymm(), _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false), - _run_activation(false), _flip_signedness(false) + _run_activation(false), _flip_signedness(false), _asm_glue_tensors(std::make_unique()) { } @@ -149,18 +156,24 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, auto c_info_to_use = c == nullptr ? nullptr : c->info(); _asm_glue->configure(a_to_use->info(), b->info(), c_info_to_use, output->info(), asm_info); _fused_assembly_path = _asm_glue->is_configured(); - _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_2, c); - _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output); + _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_2, c); + _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output); } else { auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : output); _asm_glue->configure(a_to_use->info(), b->info(), nullptr, output_to_use->info(), asm_info); - _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output_to_use); + _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output_to_use); } _assembly_path = _asm_glue->is_configured(); - _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use); - _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_1, b); + _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use); + _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_1, b); + + if(_assembly_path) + { + _asm_glue_tensors->ws = manage_workspace(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors); + } + break; } default: @@ -520,7 +533,7 @@ void NEGEMMLowpMatrixMultiplyCore::run() // Run GEMM if(_asm_glue->is_configured()) { - _asm_glue->run(_asm_glue_tensors); + _asm_glue->run(_asm_glue_tensors->tensors); } else { @@ -590,7 +603,7 @@ void NEGEMMLowpMatrixMultiplyCore::prepare() ARM_COMPUTE_ERROR_ON(!_original_b->is_used()); } - _asm_glue->prepare(_asm_glue_tensors); + _asm_glue->prepare(_asm_glue_tensors->tensors); if(!original_b_managed_by_weights_manager) { _original_b->mark_as_unused(); diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp index b47a08a5e9..7b7b68a93b 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp @@ -53,8 +53,10 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; - PixelValue type_max{}; + + PixelValue type_min{}; + PixelValue type_max{}; + std::tie(type_min, type_max) = get_min_max(data_type); int32_t min_activation = type_min.get(); int32_t max_activation = type_max.get(); @@ -87,8 +89,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect } } // namespace -CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr &memory_manager) - : _gemm_asm_func(std::make_unique(memory_manager)), +CpuGemmDirectConv2d::CpuGemmDirectConv2d() + : _gemm_asm_func(std::make_unique()), _activation_func(std::make_unique()), _weights_permute_func(std::make_unique()), _permuted_weights_info(), @@ -163,6 +165,8 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo * } void CpuGemmDirectConv2d::run(ITensorPack &tensors) { + import_workspace_memory(tensors); + prepare(tensors); _gemm_asm_func->run(tensors); @@ -170,22 +174,14 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) { _activation_func->run(tensors); } -} -void CpuGemmDirectConv2d::allocate_permuted_weights() -{ - // TODO: This function will be removed when memory injection is implemeted. - ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr); - _permuted_weights->allocator()->free(); - _permuted_weights->allocator()->init(_permuted_weights_info); - _permuted_weights->allocator()->allocate(); + free_imported_workspace_memory(); } void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { - allocate_permuted_weights(); ITensorPack permute_tensors { { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) }, @@ -202,5 +198,41 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) } } +experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const +{ + experimental::MemoryRequirements req = _gemm_asm_func->workspace(); + + auto index = static_cast::type>(TensorType::ACL_INT_0); + + if(req.size() > 0) + { + index = req.back().slot + 1; + + constexpr auto max_index = static_cast::type>(TensorType::ACL_INT_4); + ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled. + ARM_COMPUTE_ERROR_ON(index > max_index); + } + + req.emplace_back(index, _permuted_weights_info.total_size(), 0); + + return req; +} + +void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors) +{ + auto imported_tensor = tensors.get_tensor(workspace().back().slot); + + ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor); + + auto imported_memory = imported_tensor->buffer(); + _permuted_weights->allocator()->init(_permuted_weights_info); + _permuted_weights->allocator()->import_memory(imported_memory); +} + +void CpuGemmDirectConv2d::free_imported_workspace_memory() +{ + _permuted_weights->allocator()->free(); +} + } // namespace cpu } // namespace arm_compute \ No newline at end of file diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h index 6aa17c2349..305a076908 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h @@ -48,7 +48,7 @@ class CpuGemmDirectConv2d : public ICpuOperator { public: /** Constructor */ - CpuGemmDirectConv2d(const std::shared_ptr &memory_manager = nullptr); + CpuGemmDirectConv2d(); ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d); /** Destructor */ ~CpuGemmDirectConv2d(); @@ -80,15 +80,16 @@ public: void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &info); /** Static function to check if given info will lead to a valid configuration of @ref CpuGemmDirectConv2d * - * Similar to CpuGemmDirectConv2d::configure() + * Similar to @ref CpuGemmDirectConv2d::configure() * * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; + experimental::MemoryRequirements workspace() const override; private: std::unique_ptr _gemm_asm_func; @@ -100,11 +101,13 @@ private: bool _is_prepared{ false }; bool _run_activation{ false }; - /** Function to allocated a tensor for permuted weights + /** Function to import workspace tensors * - * @note This function will be removed when memory injection is properly implemented. + * @param[in] tensors Tensor pack includes workspace tensors */ - void allocate_permuted_weights(); + void import_workspace_memory(ITensorPack &tensors); + /** Function free used workspace tensors */ + void free_imported_workspace_memory(); }; } // namespace cpu } // namespace arm_compute diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 0c511ff548..53d71a3b80 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -240,7 +240,7 @@ public: */ void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, - MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {}); + IWeightsManager *weights_manager, const OutputStage &os = {}); /** Set requantization shifts to be used * @@ -265,13 +265,42 @@ public: bool is_configured() const override; private: - /** Allocate a workspace tensor. + static constexpr size_t _workspace_alignment{ 4096 }; + /** Function to get the memory requirements */ + experimental::MemoryRequirements get_workspace() const override + { + experimental::MemoryRequirements req{}; + const auto size = _gemm_kernel_asm->get_working_size(); + if(size > 0) + { + req.emplace_back(TensorType::ACL_INT, size, _workspace_alignment); + } + return req; + } + + /** Function to import workspace tensors * - * @param[in] workspace_size Size to allocate. - * @param[in] memory_group Tensor memory group. - * @param[in] alignment Workspace memory alignment. + * @param[in] tensors Tensor pack includes workspace tensors */ - void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment); + void import_workspace(ITensorPack &tensors) + { + const auto size = _gemm_kernel_asm->get_working_size(); + + if(size > 0) + { + auto imported_tensor = tensors.get_tensor(TensorType::ACL_INT); + ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor); + const size_t workspace_size = _gemm_kernel_asm->get_working_size(); + _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + _workspace_alignment) }, 1, DataType::S8), _workspace_alignment); + _workspace.allocator()->import_memory(imported_tensor->buffer()); + } + } + /** Function free used workspace tensors */ + void free_imported_workspace() + { + _workspace.allocator()->free(); + } + /** Configure the indirect buffer * * @param[in] a Input tensor containing the Matrix A. @@ -333,6 +362,9 @@ private: } }; +template +constexpr size_t Fallback::_workspace_alignment; + template std::tuple Fallback::set_requantize_data(const std::vector &shifts, const std::vector &multipliers) @@ -470,7 +502,7 @@ void Fallback::configure_indirect(const ITen template void Fallback::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info, - MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os) + IWeightsManager *weights_manager, const OutputStage &os) { ARM_COMPUTE_UNUSED(c); arm_gemm::GemmConfig gemm_cfg; @@ -492,13 +524,6 @@ void Fallback::configure(const ITensorInfo * auto acl_gemm_wrapper = std::make_unique>(); ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr); acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter); - const size_t workspace_size = _gemm_kernel_asm->get_working_size(); - if(workspace_size > 0) - { - // Allocate workspace - const unsigned int alignment = 4096; - allocate_workspace(workspace_size, memory_group, alignment); - } //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001 @@ -586,15 +611,6 @@ void Fallback::prepare(ITensorPack &tensors) } } -template -void Fallback::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment) -{ - ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0"); - _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment); - memory_group.manage(&_workspace); - _workspace.allocator()->allocate(); -} - template bool Fallback::is_configured() const { @@ -609,6 +625,10 @@ void Fallback::run(ITensorPack &tensors) auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2); auto d = tensors.get_tensor(TensorType::ACL_DST); + ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + + import_workspace(tensors); + int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput); int ldb = 0; const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput); @@ -684,10 +704,11 @@ void Fallback::run(ITensorPack &tensors) bias, 0); // Schedule NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint); + free_imported_workspace(); } template -void create_arm_gemm(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, +void create_arm_gemm(std::unique_ptr &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { @@ -699,12 +720,12 @@ void create_arm_gemm(std::unique_ptr &arm_ge // Create arm_gemm fallback auto fallback = std::make_unique>(); - fallback->configure(a, b, c, d, args, info, memory_group, weights_manager); + fallback->configure(a, b, c, d, args, info, weights_manager); arm_gemm = std::move(fallback); } template -void create_arm_gemm_quant(std::unique_ptr &arm_gemm, MemoryGroup &memory_group, +void create_arm_gemm_quant(std::unique_ptr &arm_gemm, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info, IWeightsManager *weights_manager) { @@ -744,14 +765,14 @@ void create_arm_gemm_quant(std::unique_ptr & } // Configure fallback - fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info); + fallback->configure(a, b, c, d, args, info, weights_manager, gemm_requant_info); arm_gemm = std::move(fallback); } } //namespace -CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr memory_manager, IWeightsManager *weights_manager) - : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager) +CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(IWeightsManager *weights_manager) + : _arm_gemm(nullptr), _weights_manager(weights_manager) { } @@ -806,40 +827,40 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo switch(a->data_type()) { case DataType::F32: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info, _weights_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: if(d->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info, _weights_manager); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, a, b, c, d, act, info, _weights_manager); } break; case DataType::S8: case DataType::QASYMM8_SIGNED: if(d->data_type() == DataType::S32) { - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info, _weights_manager); } else { - create_arm_gemm_quant(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm_quant(_arm_gemm, a, b, c, d, act, info, _weights_manager); } break; #endif /* __aarch64__ */ #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) case DataType::BFLOAT16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info, _weights_manager); break; #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_arm_gemm(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager); + create_arm_gemm(_arm_gemm, a, b, c, d, act, info, _weights_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: @@ -860,10 +881,13 @@ bool CpuGemmAssemblyDispatch::is_configured() const void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) { - MemoryGroupResourceScope scope_mg(_memory_group); - ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr); _arm_gemm->run(tensors); } + +experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const +{ + return is_configured() ? _arm_gemm->get_workspace() : experimental::MemoryRequirements{}; +} } // namespace cpu } // namespace arm_compute diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h index ffc097c75c..154def6708 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -24,7 +24,6 @@ #ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H #define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H -#include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" #include "arm_compute/runtime/Tensor.h" @@ -62,7 +61,7 @@ class CpuGemmAssemblyDispatch : public ICpuOperator { public: /** Constructor */ - CpuGemmAssemblyDispatch(std::shared_ptr memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); + CpuGemmAssemblyDispatch(IWeightsManager *weights_manager = nullptr); /** Defautl destructor */ ~CpuGemmAssemblyDispatch() = default; @@ -71,10 +70,11 @@ public: class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual bool is_configured() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual bool is_configured() const = 0; + virtual ~IFallback() = default; + virtual experimental::MemoryRequirements get_workspace() const = 0; }; public: @@ -113,12 +113,12 @@ public: bool is_configured() const; // Inherited methods overridden: - void prepare(ITensorPack &tensors) override; - void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; private: std::unique_ptr _arm_gemm; /**< Interface for the arm_gemm fallback */ - MemoryGroup _memory_group; /**< Function memory group */ IWeightsManager *_weights_manager; /**< Pointer to the weights manager */ }; } // namespace cpu diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index b435744cdc..f38f9034a4 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -28,6 +28,8 @@ #include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" +#include "src/core/helpers/MemoryHelpers.h" +#include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h" #include "tests/NEON/Accessor.h" #include "tests/PaddingCalculator.h" #include "tests/datasets/LargeConvolutionLayerDataset.h" @@ -169,13 +171,13 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradConvolutionLayerFixture, frame validate(Accessor(_target), _reference, abs_tolerance_f32); } FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEWinogradConvolutionLayerMixedDataLayoutFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(combine(combine(combine( - framework::dataset::make("Input", TensorShape(8U, 8U, 32U)), - framework::dataset::make("Weight", TensorShape(1U, 3U, 32U, 1U))), - framework::dataset::make("Bias", TensorShape(1U))), - framework::dataset::make("Output", TensorShape(8U, 6U, 1U))), - framework::dataset::make("PadStrideInfo", PadStrideInfo(1, 1, 0, 0))), - framework::dataset::make("Dilation", Size2D(1U, 1U))), + combine(combine(combine(combine(combine(combine(combine(combine( + framework::dataset::make("Input", TensorShape(8U, 8U, 32U)), + framework::dataset::make("Weight", TensorShape(1U, 3U, 32U, 1U))), + framework::dataset::make("Bias", TensorShape(1U))), + framework::dataset::make("Output", TensorShape(8U, 6U, 1U))), + framework::dataset::make("PadStrideInfo", PadStrideInfo(1, 1, 0, 0))), + framework::dataset::make("Dilation", Size2D(1U, 1U))), framework::dataset::make("DataType", { DataType::F32 })), ActivationFunctionsDataset), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }))) @@ -408,9 +410,7 @@ TEST_SUITE(Float) #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) TEST_SUITE(BFLOAT16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::BFLOAT16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::BFLOAT16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -422,10 +422,7 @@ TEST_SUITE_END() // BFLOAT16 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), - ActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NCHW })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16); @@ -435,26 +432,24 @@ TEST_SUITE_END() // FP16 TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); } FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerMixedDataLayoutFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine( - framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), - framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), - framework::dataset::make("Bias", TensorShape(2U))), - framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), - framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), - framework::dataset::make("Dilation", Size2D(1, 1))), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - ActivationFunctionsDataset)) + combine(combine(combine(combine(combine(combine(combine(combine(combine( + framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), + framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), + framework::dataset::make("Bias", TensorShape(2U))), + framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), + framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), + framework::dataset::make("Dilation", Size2D(1, 1))), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); @@ -479,28 +474,25 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( - framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), - framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), - framework::dataset::make("Bias", TensorShape(2U))), - framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), - framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), - framework::dataset::make("Dilation", Size2D(1, 1))), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), + framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), + framework::dataset::make("Bias", TensorShape(2U))), + framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), + framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), + framework::dataset::make("Dilation", Size2D(1, 1))), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), + QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -509,28 +501,25 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); } FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( - framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), - framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), - framework::dataset::make("Bias", TensorShape(2U))), - framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), - framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), - framework::dataset::make("Dilation", Size2D(1, 1))), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + framework::dataset::make("Input", TensorShape(23U, 27U, 5U)), + framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))), + framework::dataset::make("Bias", TensorShape(2U))), + framework::dataset::make("Output", TensorShape(11U, 25U, 2U))), + framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))), + framework::dataset::make("Dilation", Size2D(1, 1))), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), + QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -571,13 +560,117 @@ TEST_SUITE(DirectGEMMConv2d) template using NEDirectGEMMConv2dLayerFixture = ConvolutionValidationFixture; +/** Test case to test if an operator configured once can be run for different group of tensors */ +TEST_CASE(MemoryInjection, framework::DatasetMode::ALL) +{ + auto conv = std::make_unique(); + const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, DataType::F32, DataLayout::NHWC); + const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, DataType::F32, DataLayout::NHWC); + const auto bias_info = TensorInfo(TensorShape(3U), 1, DataType::F32, DataLayout::NHWC); + auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, DataType::F32, DataLayout::NHWC); + const auto conv_info = Conv2dInfo{}; + + conv->configure(&src_info, &weight_info, &bias_info, &dst_info, conv_info); + + auto run_conv = [&]() -> Tensor + { + auto pack = ITensorPack{}; + auto mg = MemoryGroup{}; + auto ws = manage_workspace(conv->workspace(), mg, pack); + + // tensors are newly created every call of this lambda function + auto src = create_tensor(src_info); + auto weight = create_tensor(weight_info); + auto bias = create_tensor(bias_info); + auto dst = create_tensor(dst_info); + + src.allocator()->allocate(); + weight.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + pack.add_const_tensor(TensorType::ACL_SRC_0, &src); + pack.add_const_tensor(TensorType::ACL_SRC_1, &weight); + pack.add_const_tensor(TensorType::ACL_SRC_2, &bias); + pack.add_tensor(TensorType::ACL_DST, &dst); + + // value aren't too important if the same values are used in each execution. + std::vector static_number_to_be_filled{}; + for(size_t i = 0; i < weight_info.tensor_shape().total_size(); i++) + { + static_number_to_be_filled.emplace_back(i); + } + + library->fill_static_values(Accessor(src), static_number_to_be_filled); + library->fill_static_values(Accessor(weight), static_number_to_be_filled); + library->fill_static_values(Accessor(bias), static_number_to_be_filled); + + // This operator is configured once and captured by this lambda. + conv->run(pack); + return dst; + }; + + auto result_0 = run_conv(); + auto result_1 = run_conv(); + + for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); i++) + { + ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + } +} + +TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) +{ + auto conv = std::make_unique(); + const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, DataType::F32, DataLayout::NHWC); + const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, DataType::F32, DataLayout::NHWC); + const auto bias_info = TensorInfo(TensorShape(3U), 1, DataType::F32, DataLayout::NHWC); + auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, DataType::F32, DataLayout::NHWC); + const auto conv_info = Conv2dInfo{}; + + auto run_conv = [&]() + { + auto src = create_tensor(src_info); + auto weight = create_tensor(weight_info); + auto bias = create_tensor(bias_info); + auto dst = create_tensor(dst_info); + + conv->configure(&src, &weight, &bias, &dst, conv_info); + + src.allocator()->allocate(); + weight.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + // value aren't too important if the same values are used in each execution. + std::vector static_number_to_be_filled{}; + for(size_t i = 0; i < weight_info.tensor_shape().total_size(); i++) + { + static_number_to_be_filled.emplace_back(i); + } + + library->fill_static_values(Accessor(src), static_number_to_be_filled); + library->fill_static_values(Accessor(weight), static_number_to_be_filled); + library->fill_static_values(Accessor(bias), static_number_to_be_filled); + + conv->run(); + + return dst; + }; + + auto result_0 = run_conv(); + auto result_1 = run_conv(); + + for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); i++) + { + ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS); + } +} + TEST_SUITE(Float) TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - ActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); @@ -601,11 +694,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -614,11 +704,8 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); -- cgit v1.2.1