aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2021-05-18 10:46:00 +0100
committerPablo Marquez Tello <pablo.tello@arm.com>2021-05-27 16:33:44 +0000
commitb3be45759bdd0749ae3a16fe470820f0d9830ea9 (patch)
tree10bb8c1c0a049a23c00781c64e993f1b197c0d05
parentbc91297c865808ed2c321febc405179f63195ff8 (diff)
downloadComputeLibrary-b3be45759bdd0749ae3a16fe470820f0d9830ea9.tar.gz
Implement memory injection in CpuDirectGemmConv2d
The following operators are now stateless by implementing memory injection. - CpuDirectGemmConv2d - CpuGemmAssemblyDispatch A test case is added to test if CpuDirectGemmConv2d can run on different group of tensors with a single configure. Resolves: COMPMID-4506 Change-Id: I48f44ed41236ca7e18da2de07bdbacc9007a3c5e Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5718 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMM.h3
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h3
-rw-r--r--src/core/helpers/MemoryHelpers.h14
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp19
-rw-r--r--src/runtime/NEON/functions/NEGEMMConv2d.cpp21
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp31
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp58
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.h17
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp100
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h18
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp239
11 files changed, 360 insertions, 163 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index 6fa30bd545..a5d6bb6534 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -146,7 +146,8 @@ private:
bool _reshape_b_only_on_first_run;
bool _is_prepared;
- ITensorPack _asm_glue_tensors{};
+ struct AsmGlueTensors;
+ std::unique_ptr<AsmGlueTensors> _asm_glue_tensors;
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEGEMM_H */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
index dc9783f9eb..ff50d6dbf7 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h
@@ -171,7 +171,8 @@ private:
bool _run_activation;
bool _flip_signedness;
- ITensorPack _asm_glue_tensors{};
+ struct AsmGlueTensors;
+ std::unique_ptr<AsmGlueTensors> _asm_glue_tensors;
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H */
diff --git a/src/core/helpers/MemoryHelpers.h b/src/core/helpers/MemoryHelpers.h
index 6756a90c25..dfa8e60758 100644
--- a/src/core/helpers/MemoryHelpers.h
+++ b/src/core/helpers/MemoryHelpers.h
@@ -56,12 +56,13 @@ WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirement
continue;
}
- const auto aux_info = TensorInfo{ TensorShape(req.size), 1, DataType::U8 };
+ const auto alignment = req.alignment;
+ const auto aux_info = TensorInfo{ TensorShape(req.size + alignment), 1, DataType::U8 };
workspace_memory.emplace_back(req.slot, std::make_unique<TensorType>());
auto aux_tensor = workspace_memory.back().second.get();
ARM_COMPUTE_ERROR_ON_NULLPTR(aux_tensor);
- aux_tensor->allocator()->init(aux_info);
+ aux_tensor->allocator()->init(aux_info, alignment);
if(req.lifetime == experimental::MemoryLifetime::Temporary)
{
@@ -82,5 +83,14 @@ WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirement
return workspace_memory;
}
+
+template <typename TensorType>
+WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
+ MemoryGroup &mgroup,
+ ITensorPack &run_pack)
+{
+ ITensorPack dummy_prep_pack{};
+ return manage_workspace<TensorType>(mem_reqs, mgroup, run_pack, dummy_prep_pack);
+}
} // namespace arm_compute
#endif /* SRC_COMMON_MEMORY_HELPERS_H */
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 7318c3e492..b526874790 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -38,6 +38,7 @@
#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
#include <cmath>
@@ -46,6 +47,14 @@ using namespace arm_compute::misc::shape_calculator;
namespace arm_compute
{
+using WorkspaceDataType = WorkspaceData<Tensor>;
+
+struct NEGEMM::AsmGlueTensors
+{
+ ITensorPack tensors{};
+ WorkspaceDataType ws{};
+};
+
namespace
{
cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
@@ -63,7 +72,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>()), _ma_kernel(),
_alpha_scale_func(nullptr), _add_bias(), _activation_func(), _tmp_a(), _tmp_b(), _tmp_d(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_alpha_scale(false),
- _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+ _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _asm_glue_tensors(std::make_unique<AsmGlueTensors>())
{
}
@@ -94,7 +103,7 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_asm_glue->configure(a->info(), b->info(), c_info_to_use, d->info(), asm_info);
ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
- _asm_glue_tensors =
+ _asm_glue_tensors->tensors =
{
{ ACL_SRC_0, a },
{ ACL_SRC_1, b },
@@ -102,6 +111,8 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
{ ACL_DST, d },
};
+ _asm_glue_tensors->ws = manage_workspace<Tensor>(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors);
+
// Scale product by alpha
if(_run_alpha_scale)
{
@@ -323,7 +334,7 @@ void NEGEMM::run()
if(_asm_glue->is_configured())
{
- _asm_glue->run(_asm_glue_tensors);
+ _asm_glue->run(_asm_glue_tensors->tensors);
if(_run_alpha_scale)
{
_alpha_scale_func.run();
@@ -377,7 +388,7 @@ void NEGEMM::prepare()
ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
}
- _asm_glue->prepare(_asm_glue_tensors);
+ _asm_glue->prepare(_asm_glue_tensors->tensors);
if(!original_b_managed_by_weights_manager)
{
_original_b->mark_as_unused();
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
index 94ceb6d27c..790543a34a 100644
--- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
@@ -26,24 +26,37 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h"
#include <set>
namespace arm_compute
{
-using OperatorType = cpu::CpuGemmDirectConv2d;
+using OperatorType = cpu::CpuGemmDirectConv2d;
+using WorkspaceDataType = WorkspaceData<Tensor>;
struct NEGEMMConv2d::Impl
{
ITensorPack tensors{};
+ MemoryGroup mg{};
std::unique_ptr<OperatorType> op{ nullptr };
+ WorkspaceDataType ws{};
+
+ void allocate_and_add_workspace()
+ {
+ if(op)
+ {
+ ws = manage_workspace<Tensor>(op->workspace(), mg, tensors);
+ }
+ }
};
NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
: _impl(std::make_unique<Impl>())
{
- _impl->op = std::make_unique<OperatorType>(memory_manager);
+ _impl->op = std::make_unique<OperatorType>();
+ _impl->mg = MemoryGroup(memory_manager);
}
NEGEMMConv2d::~NEGEMMConv2d() = default;
@@ -55,7 +68,9 @@ void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITens
_impl->tensors.add_const_tensor(TensorType::ACL_SRC_2, biases);
_impl->tensors.add_tensor(TensorType::ACL_DST, output);
- _impl->op->configure(input->info(), weights->info(), biases->info(), output->info(), info);
+ _impl->op->configure(input->info(), weights->info(), ((biases) ? biases->info() : nullptr), output->info(), info);
+
+ _impl->allocate_and_add_workspace();
}
Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info)
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index cc0f20e695..d42e656e0c 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -42,10 +42,17 @@
#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h"
#include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
namespace arm_compute
{
+using WorkspaceDataType = WorkspaceData<Tensor>;
+struct NEGEMMLowpMatrixMultiplyCore::AsmGlueTensors
+{
+ ITensorPack tensors{};
+ WorkspaceDataType ws{};
+};
namespace
{
cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
@@ -66,11 +73,11 @@ using namespace arm_compute::misc::shape_calculator;
NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default;
NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>(memory_manager, weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(),
+ : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>(weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(),
_mtx_b_reshape_kernel(), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(),
_convert_to_signed_asymm(), _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0),
_b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false),
- _run_activation(false), _flip_signedness(false)
+ _run_activation(false), _flip_signedness(false), _asm_glue_tensors(std::make_unique<AsmGlueTensors>())
{
}
@@ -149,18 +156,24 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
auto c_info_to_use = c == nullptr ? nullptr : c->info();
_asm_glue->configure(a_to_use->info(), b->info(), c_info_to_use, output->info(), asm_info);
_fused_assembly_path = _asm_glue->is_configured();
- _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_2, c);
- _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output);
+ _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_2, c);
+ _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output);
}
else
{
auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : output);
_asm_glue->configure(a_to_use->info(), b->info(), nullptr, output_to_use->info(), asm_info);
- _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output_to_use);
+ _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output_to_use);
}
_assembly_path = _asm_glue->is_configured();
- _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use);
- _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_1, b);
+ _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use);
+ _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_1, b);
+
+ if(_assembly_path)
+ {
+ _asm_glue_tensors->ws = manage_workspace<Tensor>(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors);
+ }
+
break;
}
default:
@@ -520,7 +533,7 @@ void NEGEMMLowpMatrixMultiplyCore::run()
// Run GEMM
if(_asm_glue->is_configured())
{
- _asm_glue->run(_asm_glue_tensors);
+ _asm_glue->run(_asm_glue_tensors->tensors);
}
else
{
@@ -590,7 +603,7 @@ void NEGEMMLowpMatrixMultiplyCore::prepare()
ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
}
- _asm_glue->prepare(_asm_glue_tensors);
+ _asm_glue->prepare(_asm_glue_tensors->tensors);
if(!original_b_managed_by_weights_manager)
{
_original_b->mark_as_unused();
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index b47a08a5e9..7b7b68a93b 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -53,8 +53,10 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src,
ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
};
- PixelValue type_min{};
- PixelValue type_max{};
+
+ PixelValue type_min{};
+ PixelValue type_max{};
+
std::tie(type_min, type_max) = get_min_max(data_type);
int32_t min_activation = type_min.get<int32_t>();
int32_t max_activation = type_max.get<int32_t>();
@@ -87,8 +89,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
}
} // namespace
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+ : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
_activation_func(std::make_unique<CpuActivation>()),
_weights_permute_func(std::make_unique<CpuPermute>()),
_permuted_weights_info(),
@@ -163,6 +165,8 @@ Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *
}
void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
+ import_workspace_memory(tensors);
+
prepare(tensors);
_gemm_asm_func->run(tensors);
@@ -170,22 +174,14 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors)
{
_activation_func->run(tensors);
}
-}
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
- // TODO: This function will be removed when memory injection is implemeted.
- ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
- _permuted_weights->allocator()->free();
- _permuted_weights->allocator()->init(_permuted_weights_info);
- _permuted_weights->allocator()->allocate();
+ free_imported_workspace_memory();
}
void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- allocate_permuted_weights();
ITensorPack permute_tensors
{
{ TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
@@ -202,5 +198,41 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
}
}
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+ experimental::MemoryRequirements req = _gemm_asm_func->workspace();
+
+ auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0);
+
+ if(req.size() > 0)
+ {
+ index = req.back().slot + 1;
+
+ constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4);
+ ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled.
+ ARM_COMPUTE_ERROR_ON(index > max_index);
+ }
+
+ req.emplace_back(index, _permuted_weights_info.total_size(), 0);
+
+ return req;
+}
+
+void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors)
+{
+ auto imported_tensor = tensors.get_tensor(workspace().back().slot);
+
+ ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
+
+ auto imported_memory = imported_tensor->buffer();
+ _permuted_weights->allocator()->init(_permuted_weights_info);
+ _permuted_weights->allocator()->import_memory(imported_memory);
+}
+
+void CpuGemmDirectConv2d::free_imported_workspace_memory()
+{
+ _permuted_weights->allocator()->free();
+}
+
} // namespace cpu
} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
index 6aa17c2349..305a076908 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
@@ -48,7 +48,7 @@ class CpuGemmDirectConv2d : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
+ CpuGemmDirectConv2d();
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d);
/** Destructor */
~CpuGemmDirectConv2d();
@@ -80,15 +80,16 @@ public:
void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &info);
/** Static function to check if given info will lead to a valid configuration of @ref CpuGemmDirectConv2d
*
- * Similar to CpuGemmDirectConv2d::configure()
+ * Similar to @ref CpuGemmDirectConv2d::configure()
*
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
+ experimental::MemoryRequirements workspace() const override;
private:
std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func;
@@ -100,11 +101,13 @@ private:
bool _is_prepared{ false };
bool _run_activation{ false };
- /** Function to allocated a tensor for permuted weights
+ /** Function to import workspace tensors
*
- * @note This function will be removed when memory injection is properly implemented.
+ * @param[in] tensors Tensor pack includes workspace tensors
*/
- void allocate_permuted_weights();
+ void import_workspace_memory(ITensorPack &tensors);
+ /** Function free used workspace tensors */
+ void free_imported_workspace_memory();
};
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 0c511ff548..53d71a3b80 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -240,7 +240,7 @@ public:
*/
void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
- MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
+ IWeightsManager *weights_manager, const OutputStage &os = {});
/** Set requantization shifts to be used
*
@@ -265,13 +265,42 @@ public:
bool is_configured() const override;
private:
- /** Allocate a workspace tensor.
+ static constexpr size_t _workspace_alignment{ 4096 };
+ /** Function to get the memory requirements */
+ experimental::MemoryRequirements get_workspace() const override
+ {
+ experimental::MemoryRequirements req{};
+ const auto size = _gemm_kernel_asm->get_working_size();
+ if(size > 0)
+ {
+ req.emplace_back(TensorType::ACL_INT, size, _workspace_alignment);
+ }
+ return req;
+ }
+
+ /** Function to import workspace tensors
*
- * @param[in] workspace_size Size to allocate.
- * @param[in] memory_group Tensor memory group.
- * @param[in] alignment Workspace memory alignment.
+ * @param[in] tensors Tensor pack includes workspace tensors
*/
- void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+ void import_workspace(ITensorPack &tensors)
+ {
+ const auto size = _gemm_kernel_asm->get_working_size();
+
+ if(size > 0)
+ {
+ auto imported_tensor = tensors.get_tensor(TensorType::ACL_INT);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
+ const size_t workspace_size = _gemm_kernel_asm->get_working_size();
+ _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + _workspace_alignment) }, 1, DataType::S8), _workspace_alignment);
+ _workspace.allocator()->import_memory(imported_tensor->buffer());
+ }
+ }
+ /** Function free used workspace tensors */
+ void free_imported_workspace()
+ {
+ _workspace.allocator()->free();
+ }
+
/** Configure the indirect buffer
*
* @param[in] a Input tensor containing the Matrix A.
@@ -334,6 +363,9 @@ private:
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
+constexpr size_t Fallback<TypeInput, TypeOutput, OutputStage>::_workspace_alignment;
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
{
@@ -470,7 +502,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
- MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
+ IWeightsManager *weights_manager, const OutputStage &os)
{
ARM_COMPUTE_UNUSED(c);
arm_gemm::GemmConfig gemm_cfg;
@@ -492,13 +524,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
- const size_t workspace_size = _gemm_kernel_asm->get_working_size();
- if(workspace_size > 0)
- {
- // Allocate workspace
- const unsigned int alignment = 4096;
- allocate_workspace(workspace_size, memory_group, alignment);
- }
//if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
//the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
@@ -587,15 +612,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
-{
- ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
- _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment);
- memory_group.manage(&_workspace);
- _workspace.allocator()->allocate();
-}
-
-template <typename TypeInput, typename TypeOutput, class OutputStage>
bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
{
return _optimised_kernel != nullptr;
@@ -609,6 +625,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
auto d = tensors.get_tensor(TensorType::ACL_DST);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+
+ import_workspace(tensors);
+
int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
int ldb = 0;
const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
@@ -684,10 +704,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
bias, 0);
// Schedule
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
+ free_imported_workspace();
}
template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
@@ -699,12 +720,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
+ fallback->configure(a, b, c, d, args, info, weights_manager);
arm_gemm = std::move(fallback);
}
template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
+void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
@@ -744,14 +765,14 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
}
// Configure fallback
- fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
+ fallback->configure(a, b, c, d, args, info, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
} //namespace
-CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
+CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(IWeightsManager *weights_manager)
+ : _arm_gemm(nullptr), _weights_manager(weights_manager)
{
}
@@ -806,40 +827,40 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo
switch(a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
}
break;
#endif /* __aarch64__ */
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
break;
#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
@@ -860,10 +881,13 @@ bool CpuGemmAssemblyDispatch::is_configured() const
void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
{
- MemoryGroupResourceScope scope_mg(_memory_group);
-
ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
_arm_gemm->run(tensors);
}
+
+experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
+{
+ return is_configured() ? _arm_gemm->get_workspace() : experimental::MemoryRequirements{};
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index ffc097c75c..154def6708 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -24,7 +24,6 @@
#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
-#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
@@ -62,7 +61,7 @@ class CpuGemmAssemblyDispatch : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+ CpuGemmAssemblyDispatch(IWeightsManager *weights_manager = nullptr);
/** Defautl destructor */
~CpuGemmAssemblyDispatch() = default;
@@ -71,10 +70,11 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual bool is_configured() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
+ virtual experimental::MemoryRequirements get_workspace() const = 0;
};
public:
@@ -113,12 +113,12 @@ public:
bool is_configured() const;
// Inherited methods overridden:
- void prepare(ITensorPack &tensors) override;
- void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ experimental::MemoryRequirements workspace() const override;
private:
std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
- MemoryGroup _memory_group; /**< Function memory group */
IWeightsManager *_weights_manager; /**< Pointer to the weights manager */
};
} // namespace cpu
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index b435744cdc..f38f9034a4 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -28,6 +28,8 @@
#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h"
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h"
#include "tests/NEON/Accessor.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/LargeConvolutionLayerDataset.h"
@@ -169,13 +171,13 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradConvolutionLayerFixture<float>, frame
validate(Accessor(_target), _reference, abs_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEWinogradConvolutionLayerMixedDataLayoutFixture<float>, framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(combine(combine(combine(
- framework::dataset::make("Input", TensorShape(8U, 8U, 32U)),
- framework::dataset::make("Weight", TensorShape(1U, 3U, 32U, 1U))),
- framework::dataset::make("Bias", TensorShape(1U))),
- framework::dataset::make("Output", TensorShape(8U, 6U, 1U))),
- framework::dataset::make("PadStrideInfo", PadStrideInfo(1, 1, 0, 0))),
- framework::dataset::make("Dilation", Size2D(1U, 1U))),
+ combine(combine(combine(combine(combine(combine(combine(combine(
+ framework::dataset::make("Input", TensorShape(8U, 8U, 32U)),
+ framework::dataset::make("Weight", TensorShape(1U, 3U, 32U, 1U))),
+ framework::dataset::make("Bias", TensorShape(1U))),
+ framework::dataset::make("Output", TensorShape(8U, 6U, 1U))),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(1, 1, 0, 0))),
+ framework::dataset::make("Dilation", Size2D(1U, 1U))),
framework::dataset::make("DataType", { DataType::F32 })),
ActivationFunctionsDataset),
framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
@@ -408,9 +410,7 @@ TEST_SUITE(Float)
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
TEST_SUITE(BFLOAT16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::BFLOAT16)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::BFLOAT16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
ActivationFunctionsDataset))
{
// Validate output
@@ -422,10 +422,7 @@ TEST_SUITE_END() // BFLOAT16
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW })),
- ActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NCHW })), ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
@@ -435,26 +432,24 @@ TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
}
FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerMixedDataLayoutFixture<float>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(
- framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
- framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
- framework::dataset::make("Bias", TensorShape(2U))),
- framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
- framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
- framework::dataset::make("Dilation", Size2D(1, 1))),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
+ framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
+ framework::dataset::make("Bias", TensorShape(2U))),
+ framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
+ framework::dataset::make("Dilation", Size2D(1, 1))),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
@@ -479,28 +474,25 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
- framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
- framework::dataset::make("Bias", TensorShape(2U))),
- framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
- framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
- framework::dataset::make("Dilation", Size2D(1, 1))),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
+ framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
+ framework::dataset::make("Bias", TensorShape(2U))),
+ framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
+ framework::dataset::make("Dilation", Size2D(1, 1))),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -509,28 +501,25 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
- framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
- framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
- framework::dataset::make("Bias", TensorShape(2U))),
- framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
- framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
- framework::dataset::make("Dilation", Size2D(1, 1))),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
+ framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U))),
+ framework::dataset::make("Bias", TensorShape(2U))),
+ framework::dataset::make("Output", TensorShape(11U, 25U, 2U))),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0))),
+ framework::dataset::make("Dilation", Size2D(1, 1))),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -571,13 +560,117 @@ TEST_SUITE(DirectGEMMConv2d)
template <typename T>
using NEDirectGEMMConv2dLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEGEMMConv2d, T>;
+/** Test case to test if an operator configured once can be run for different group of tensors */
+TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
+{
+ auto conv = std::make_unique<cpu::CpuGemmDirectConv2d>();
+ const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, DataType::F32, DataLayout::NHWC);
+ const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, DataType::F32, DataLayout::NHWC);
+ const auto bias_info = TensorInfo(TensorShape(3U), 1, DataType::F32, DataLayout::NHWC);
+ auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, DataType::F32, DataLayout::NHWC);
+ const auto conv_info = Conv2dInfo{};
+
+ conv->configure(&src_info, &weight_info, &bias_info, &dst_info, conv_info);
+
+ auto run_conv = [&]() -> Tensor
+ {
+ auto pack = ITensorPack{};
+ auto mg = MemoryGroup{};
+ auto ws = manage_workspace<Tensor>(conv->workspace(), mg, pack);
+
+ // tensors are newly created every call of this lambda function
+ auto src = create_tensor<Tensor>(src_info);
+ auto weight = create_tensor<Tensor>(weight_info);
+ auto bias = create_tensor<Tensor>(bias_info);
+ auto dst = create_tensor<Tensor>(dst_info);
+
+ src.allocator()->allocate();
+ weight.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ pack.add_const_tensor(TensorType::ACL_SRC_0, &src);
+ pack.add_const_tensor(TensorType::ACL_SRC_1, &weight);
+ pack.add_const_tensor(TensorType::ACL_SRC_2, &bias);
+ pack.add_tensor(TensorType::ACL_DST, &dst);
+
+ // value aren't too important if the same values are used in each execution.
+ std::vector<float> static_number_to_be_filled{};
+ for(size_t i = 0; i < weight_info.tensor_shape().total_size(); i++)
+ {
+ static_number_to_be_filled.emplace_back(i);
+ }
+
+ library->fill_static_values(Accessor(src), static_number_to_be_filled);
+ library->fill_static_values(Accessor(weight), static_number_to_be_filled);
+ library->fill_static_values(Accessor(bias), static_number_to_be_filled);
+
+ // This operator is configured once and captured by this lambda.
+ conv->run(pack);
+ return dst;
+ };
+
+ auto result_0 = run_conv();
+ auto result_1 = run_conv();
+
+ for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); i++)
+ {
+ ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ }
+}
+
+TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
+{
+ auto conv = std::make_unique<NEGEMMConv2d>();
+ const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, DataType::F32, DataLayout::NHWC);
+ const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, DataType::F32, DataLayout::NHWC);
+ const auto bias_info = TensorInfo(TensorShape(3U), 1, DataType::F32, DataLayout::NHWC);
+ auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, DataType::F32, DataLayout::NHWC);
+ const auto conv_info = Conv2dInfo{};
+
+ auto run_conv = [&]()
+ {
+ auto src = create_tensor<Tensor>(src_info);
+ auto weight = create_tensor<Tensor>(weight_info);
+ auto bias = create_tensor<Tensor>(bias_info);
+ auto dst = create_tensor<Tensor>(dst_info);
+
+ conv->configure(&src, &weight, &bias, &dst, conv_info);
+
+ src.allocator()->allocate();
+ weight.allocator()->allocate();
+ bias.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ // value aren't too important if the same values are used in each execution.
+ std::vector<float> static_number_to_be_filled{};
+ for(size_t i = 0; i < weight_info.tensor_shape().total_size(); i++)
+ {
+ static_number_to_be_filled.emplace_back(i);
+ }
+
+ library->fill_static_values(Accessor(src), static_number_to_be_filled);
+ library->fill_static_values(Accessor(weight), static_number_to_be_filled);
+ library->fill_static_values(Accessor(bias), static_number_to_be_filled);
+
+ conv->run();
+
+ return dst;
+ };
+
+ auto result_0 = run_conv();
+ auto result_1 = run_conv();
+
+ for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); i++)
+ {
+ ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ }
+}
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- ActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
@@ -601,11 +694,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -614,11 +704,8 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
- framework::dataset::make("ReshapeWeights", { true })),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataLayout", { DataLayout::NHWC })),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
- QuantizedActivationFunctionsDataset))
+ framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_qasymm8);