diff options
Diffstat (limited to 'src/cpu/operators/CpuGemm.h')
-rw-r--r-- | src/cpu/operators/CpuGemm.h | 66 |
1 files changed, 40 insertions, 26 deletions
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h index 9b08e5d0f6..6b30d134fa 100644 --- a/src/cpu/operators/CpuGemm.h +++ b/src/cpu/operators/CpuGemm.h @@ -24,12 +24,12 @@ #ifndef ARM_COMPUTE_CPU_GEMM_H #define ARM_COMPUTE_CPU_GEMM_H -#include "src/cpu/ICpuOperator.h" - #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/function_info/GEMMInfo.h" + +#include "src/cpu/ICpuOperator.h" #include "src/cpu/kernels/CpuGemmInterleave4x4Kernel.h" #include "src/cpu/kernels/CpuGemmMatrixAdditionKernel.h" #include "src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h" @@ -93,16 +93,26 @@ public: * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should happen only for the first run */ - void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, - float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + void configure(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + ITensorInfo *d, + float alpha, + float beta, + const GEMMInfo &gemm_info = GEMMInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CpuGemm. * * Similar to @ref CpuGemm::configure() * * @return a status */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, - float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + static Status validate(const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *d, + float alpha, + float beta, + const GEMMInfo &gemm_info = GEMMInfo()); /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * @@ -111,12 +121,16 @@ public: * the value of arm_compute::WeightFormat need to be passed via the * parameter gemm_info. */ - static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, - const GEMMInfo &gemm_info = GEMMInfo()); + static Status has_opt_impl(arm_compute::WeightFormat &weight_format, + const ITensorInfo *a, + const ITensorInfo *b, + const ITensorInfo *c, + const ITensorInfo *d, + const GEMMInfo &gemm_info = GEMMInfo()); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; experimental::MemoryRequirements workspace() const override; /** Indicates if the convolution executes in variable weights mode. @@ -138,28 +152,28 @@ private: Count }; - std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{ nullptr }; - std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{ nullptr }; - std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{ nullptr }; - std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{ nullptr }; - std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{ nullptr }; - std::unique_ptr<CpuActivation> _alpha_scale_func{ nullptr }; - std::unique_ptr<CpuAdd> _add_bias{ nullptr }; - std::unique_ptr<CpuActivation> _activation_func{ nullptr }; + std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{nullptr}; + std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{nullptr}; + std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{nullptr}; + std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{nullptr}; + std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{nullptr}; + std::unique_ptr<CpuActivation> _alpha_scale_func{nullptr}; + std::unique_ptr<CpuAdd> _add_bias{nullptr}; + std::unique_ptr<CpuActivation> _activation_func{nullptr}; TensorInfo _tmp_a{}; TensorInfo _tmp_b{}; TensorInfo _tmp_d{}; - bool _run_vector_matrix_multiplication{ false }; - bool _run_alpha_scale{ false }; - bool _run_addition{ false }; - bool _run_bias_addition{ false }; - bool _run_activation{ false }; - bool _reshape_b_only_on_first_run{ false }; - bool _is_prepared{ false }; + bool _run_vector_matrix_multiplication{false}; + bool _run_alpha_scale{false}; + bool _run_addition{false}; + bool _run_bias_addition{false}; + bool _run_activation{false}; + bool _reshape_b_only_on_first_run{false}; + bool _is_prepared{false}; - experimental::MemoryRequirements _aux_mem{ Count }; + experimental::MemoryRequirements _aux_mem{Count}; }; } // namespace cpu } // namespace arm_compute |