aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.h')
-rw-r--r--src/cpu/operators/CpuGemm.h66
1 files changed, 40 insertions, 26 deletions
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index 9b08e5d0f6..6b30d134fa 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -24,12 +24,12 @@
#ifndef ARM_COMPUTE_CPU_GEMM_H
#define ARM_COMPUTE_CPU_GEMM_H
-#include "src/cpu/ICpuOperator.h"
-
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/function_info/GEMMInfo.h"
+
+#include "src/cpu/ICpuOperator.h"
#include "src/cpu/kernels/CpuGemmInterleave4x4Kernel.h"
#include "src/cpu/kernels/CpuGemmMatrixAdditionKernel.h"
#include "src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
@@ -93,16 +93,26 @@ public:
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
* if the reshape of matrix B should happen only for the first run
*/
- void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
- float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ void configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info = GEMMInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CpuGemm.
*
* Similar to @ref CpuGemm::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
- float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ static Status validate(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *d,
+ float alpha,
+ float beta,
+ const GEMMInfo &gemm_info = GEMMInfo());
/** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters.
*
@@ -111,12 +121,16 @@ public:
* the value of arm_compute::WeightFormat need to be passed via the
* parameter gemm_info.
*/
- static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
- const GEMMInfo &gemm_info = GEMMInfo());
+ static Status has_opt_impl(arm_compute::WeightFormat &weight_format,
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *d,
+ const GEMMInfo &gemm_info = GEMMInfo());
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &constants) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &constants) override;
experimental::MemoryRequirements workspace() const override;
/** Indicates if the convolution executes in variable weights mode.
@@ -138,28 +152,28 @@ private:
Count
};
- std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{ nullptr };
- std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{ nullptr };
- std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{ nullptr };
- std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{ nullptr };
- std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{ nullptr };
- std::unique_ptr<CpuActivation> _alpha_scale_func{ nullptr };
- std::unique_ptr<CpuAdd> _add_bias{ nullptr };
- std::unique_ptr<CpuActivation> _activation_func{ nullptr };
+ std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{nullptr};
+ std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{nullptr};
+ std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{nullptr};
+ std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{nullptr};
+ std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{nullptr};
+ std::unique_ptr<CpuActivation> _alpha_scale_func{nullptr};
+ std::unique_ptr<CpuAdd> _add_bias{nullptr};
+ std::unique_ptr<CpuActivation> _activation_func{nullptr};
TensorInfo _tmp_a{};
TensorInfo _tmp_b{};
TensorInfo _tmp_d{};
- bool _run_vector_matrix_multiplication{ false };
- bool _run_alpha_scale{ false };
- bool _run_addition{ false };
- bool _run_bias_addition{ false };
- bool _run_activation{ false };
- bool _reshape_b_only_on_first_run{ false };
- bool _is_prepared{ false };
+ bool _run_vector_matrix_multiplication{false};
+ bool _run_alpha_scale{false};
+ bool _run_addition{false};
+ bool _run_bias_addition{false};
+ bool _run_activation{false};
+ bool _reshape_b_only_on_first_run{false};
+ bool _is_prepared{false};
- experimental::MemoryRequirements _aux_mem{ Count };
+ experimental::MemoryRequirements _aux_mem{Count};
};
} // namespace cpu
} // namespace arm_compute