From b3be45759bdd0749ae3a16fe470820f0d9830ea9 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Tue, 18 May 2021 10:46:00 +0100 Subject: Implement memory injection in CpuDirectGemmConv2d The following operators are now stateless by implementing memory injection. - CpuDirectGemmConv2d - CpuGemmAssemblyDispatch A test case is added to test if CpuDirectGemmConv2d can run on different group of tensors with a single configure. Resolves: COMPMID-4506 Change-Id: I48f44ed41236ca7e18da2de07bdbacc9007a3c5e Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5718 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Pablo Marquez Tello --- src/runtime/cpu/operators/CpuGemmDirectConv2d.h | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.h') diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h index 6aa17c2349..305a076908 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h @@ -48,7 +48,7 @@ class CpuGemmDirectConv2d : public ICpuOperator { public: /** Constructor */ - CpuGemmDirectConv2d(const std::shared_ptr &memory_manager = nullptr); + CpuGemmDirectConv2d(); ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d); /** Destructor */ ~CpuGemmDirectConv2d(); @@ -80,15 +80,16 @@ public: void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &info); /** Static function to check if given info will lead to a valid configuration of @ref CpuGemmDirectConv2d * - * Similar to CpuGemmDirectConv2d::configure() + * Similar to @ref CpuGemmDirectConv2d::configure() * * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; + experimental::MemoryRequirements workspace() const override; private: std::unique_ptr _gemm_asm_func; @@ -100,11 +101,13 @@ private: bool _is_prepared{ false }; bool _run_activation{ false }; - /** Function to allocated a tensor for permuted weights + /** Function to import workspace tensors * - * @note This function will be removed when memory injection is properly implemented. + * @param[in] tensors Tensor pack includes workspace tensors */ - void allocate_permuted_weights(); + void import_workspace_memory(ITensorPack &tensors); + /** Function free used workspace tensors */ + void free_imported_workspace_memory(); }; } // namespace cpu } // namespace arm_compute -- cgit v1.2.1