diff options
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.h')
-rw-r--r-- | src/runtime/cpu/operators/CpuGemmDirectConv2d.h | 40 |
1 files changed, 18 insertions, 22 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h index 6aa17c2349..b572f36a3a 100644 --- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h +++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h @@ -24,14 +24,12 @@ #ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H #define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H -#include "arm_compute/core/ITensorInfo.h" -#include "arm_compute/core/experimental/Types.h" -#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/core/TensorInfo.h" #include "src/core/common/Macros.h" -#include "src/core/cpu/ICpuKernel.h" #include "src/runtime/cpu/ICpuOperator.h" - -#include <memory> +#include "src/runtime/cpu/operators/CpuActivation.h" +#include "src/runtime/cpu/operators/CpuPermute.h" +#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h" namespace arm_compute { @@ -40,15 +38,11 @@ class ITensor; struct Conv2dInfo; namespace cpu { -class CpuGemmAssemblyDispatch; -class CpuActivation; -class CpuPermute; - class CpuGemmDirectConv2d : public ICpuOperator { public: /** Constructor */ - CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr); + CpuGemmDirectConv2d(); ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d); /** Destructor */ ~CpuGemmDirectConv2d(); @@ -89,22 +83,24 @@ public: // Inherited methods overridden: void run(ITensorPack &tensors) override; void prepare(ITensorPack &constants) override; + experimental::MemoryRequirements workspace() const override; private: + enum AuxTensorIdx + { + AsmGemmWorkspace = 0, + Pretranspose, + PermutedWeights, + Count + }; + std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func; std::unique_ptr<CpuActivation> _activation_func; std::unique_ptr<CpuPermute> _weights_permute_func; - const ITensorInfo *_original_weights_info{}; - TensorInfo _permuted_weights_info; - std::unique_ptr<Tensor> _permuted_weights{ nullptr }; - bool _is_prepared{ false }; - bool _run_activation{ false }; - - /** Function to allocated a tensor for permuted weights - * - * @note This function will be removed when memory injection is properly implemented. - */ - void allocate_permuted_weights(); + experimental::MemoryRequirements _aux_mem; + TensorInfo _perm_weights; + bool _run_activation; + bool _is_prepared; }; } // namespace cpu } // namespace arm_compute |