diff options
Diffstat (limited to 'src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r-- | src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h | 35 |
1 files changed, 15 insertions, 20 deletions
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 0bbae49a7e..ffc097c75c 100644 --- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -21,14 +21,15 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef SRC_NEGEMMASSEMBLYDISPATCH_H -#define SRC_NEGEMMASSEMBLYDISPATCH_H +#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H +#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H -#include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/IWeightsManager.h" #include "arm_compute/runtime/MemoryGroup.h" #include "arm_compute/runtime/Tensor.h" +#include "src/core/common/Macros.h" +#include "src/runtime/cpu/ICpuOperator.h" namespace arm_compute { @@ -57,29 +58,23 @@ struct AsmGemmInfo }; /** Assembly kernel glue */ -class CpuGemmAssemblyDispatch : public IFunction +class CpuGemmAssemblyDispatch : public ICpuOperator { public: /** Constructor */ CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr); - /** Prevent instances of this class from being copy constructed */ - CpuGemmAssemblyDispatch(const CpuGemmAssemblyDispatch &) = delete; - /** Prevent instances of this class from being copied */ - CpuGemmAssemblyDispatch &operator=(const CpuGemmAssemblyDispatch &) = delete; - /** Default move constructor */ - CpuGemmAssemblyDispatch(CpuGemmAssemblyDispatch &&) = default; - /** Default move assignment operator */ - CpuGemmAssemblyDispatch &operator=(CpuGemmAssemblyDispatch &&) = default; /** Defautl destructor */ ~CpuGemmAssemblyDispatch() = default; + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmAssemblyDispatch); + class IFallback { public: - virtual void run() = 0; - virtual void prepare() = 0; - virtual bool is_configured() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual bool is_configured() const = 0; + virtual ~IFallback() = default; }; public: @@ -91,7 +86,7 @@ public: * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. * @param[in] info GEMM meta-data */ - void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info); + void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info); /** Indicates whether or not this function can be used to process the given parameters. * @@ -118,8 +113,8 @@ public: bool is_configured() const; // Inherited methods overridden: - void prepare() override; - void run() override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; private: std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */ @@ -128,4 +123,4 @@ private: }; } // namespace cpu } // namespace arm_compute -#endif /* SRC_NEGEMMASSEMBLYDISPATCH_H */ +#endif /* ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H */ |