aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h35
1 files changed, 15 insertions, 20 deletions
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 0bbae49a7e..ffc097c75c 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -21,14 +21,15 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef SRC_NEGEMMASSEMBLYDISPATCH_H
-#define SRC_NEGEMMASSEMBLYDISPATCH_H
+#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
+#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
-#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
+#include "src/core/common/Macros.h"
+#include "src/runtime/cpu/ICpuOperator.h"
namespace arm_compute
{
@@ -57,29 +58,23 @@ struct AsmGemmInfo
};
/** Assembly kernel glue */
-class CpuGemmAssemblyDispatch : public IFunction
+class CpuGemmAssemblyDispatch : public ICpuOperator
{
public:
/** Constructor */
CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
- /** Prevent instances of this class from being copy constructed */
- CpuGemmAssemblyDispatch(const CpuGemmAssemblyDispatch &) = delete;
- /** Prevent instances of this class from being copied */
- CpuGemmAssemblyDispatch &operator=(const CpuGemmAssemblyDispatch &) = delete;
- /** Default move constructor */
- CpuGemmAssemblyDispatch(CpuGemmAssemblyDispatch &&) = default;
- /** Default move assignment operator */
- CpuGemmAssemblyDispatch &operator=(CpuGemmAssemblyDispatch &&) = default;
/** Defautl destructor */
~CpuGemmAssemblyDispatch() = default;
+ ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmAssemblyDispatch);
+
class IFallback
{
public:
- virtual void run() = 0;
- virtual void prepare() = 0;
- virtual bool is_configured() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -91,7 +86,7 @@ public:
* @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
* @param[in] info GEMM meta-data
*/
- void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info);
+ void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info);
/** Indicates whether or not this function can be used to process the given parameters.
*
@@ -118,8 +113,8 @@ public:
bool is_configured() const;
// Inherited methods overridden:
- void prepare() override;
- void run() override;
+ void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
private:
std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
@@ -128,4 +123,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /* SRC_NEGEMMASSEMBLYDISPATCH_H */
+#endif /* ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H */