aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h20
1 files changed, 8 insertions, 12 deletions
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index ffc097c75c..355273adeb 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -24,10 +24,6 @@
#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
-#include "arm_compute/runtime/IMemoryManager.h"
-#include "arm_compute/runtime/IWeightsManager.h"
-#include "arm_compute/runtime/MemoryGroup.h"
-#include "arm_compute/runtime/Tensor.h"
#include "src/core/common/Macros.h"
#include "src/runtime/cpu/ICpuOperator.h"
@@ -62,7 +58,7 @@ class CpuGemmAssemblyDispatch : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+ CpuGemmAssemblyDispatch();
/** Defautl destructor */
~CpuGemmAssemblyDispatch() = default;
@@ -71,10 +67,11 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual bool is_configured() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -115,11 +112,10 @@ public:
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;
void run(ITensorPack &tensors) override;
+ experimental::MemoryRequirements workspace() const override;
private:
- std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
- MemoryGroup _memory_group; /**< Function memory group */
- IWeightsManager *_weights_manager; /**< Pointer to the weights manager */
+ std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
};
} // namespace cpu
} // namespace arm_compute