aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/CpuGemmDirectConv2d.h')
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.h40
1 files changed, 18 insertions, 22 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
index 6aa17c2349..b572f36a3a 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
@@ -24,14 +24,12 @@
#ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
#define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
-#include "arm_compute/core/ITensorInfo.h"
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/core/TensorInfo.h"
#include "src/core/common/Macros.h"
-#include "src/core/cpu/ICpuKernel.h"
#include "src/runtime/cpu/ICpuOperator.h"
-
-#include <memory>
+#include "src/runtime/cpu/operators/CpuActivation.h"
+#include "src/runtime/cpu/operators/CpuPermute.h"
+#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
namespace arm_compute
{
@@ -40,15 +38,11 @@ class ITensor;
struct Conv2dInfo;
namespace cpu
{
-class CpuGemmAssemblyDispatch;
-class CpuActivation;
-class CpuPermute;
-
class CpuGemmDirectConv2d : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
+ CpuGemmDirectConv2d();
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d);
/** Destructor */
~CpuGemmDirectConv2d();
@@ -89,22 +83,24 @@ public:
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &constants) override;
+ experimental::MemoryRequirements workspace() const override;
private:
+ enum AuxTensorIdx
+ {
+ AsmGemmWorkspace = 0,
+ Pretranspose,
+ PermutedWeights,
+ Count
+ };
+
std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func;
std::unique_ptr<CpuActivation> _activation_func;
std::unique_ptr<CpuPermute> _weights_permute_func;
- const ITensorInfo *_original_weights_info{};
- TensorInfo _permuted_weights_info;
- std::unique_ptr<Tensor> _permuted_weights{ nullptr };
- bool _is_prepared{ false };
- bool _run_activation{ false };
-
- /** Function to allocated a tensor for permuted weights
- *
- * @note This function will be removed when memory injection is properly implemented.
- */
- void allocate_permuted_weights();
+ experimental::MemoryRequirements _aux_mem;
+ TensorInfo _perm_weights;
+ bool _run_activation;
+ bool _is_prepared;
};
} // namespace cpu
} // namespace arm_compute