diff options
Diffstat (limited to 'src/cpu/operators/CpuGemm.h')
-rw-r--r-- | src/cpu/operators/CpuGemm.h | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h index 6b30d134fa..a05258d206 100644 --- a/src/cpu/operators/CpuGemm.h +++ b/src/cpu/operators/CpuGemm.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_CPU_GEMM_H -#define ARM_COMPUTE_CPU_GEMM_H +#ifndef ACL_SRC_CPU_OPERATORS_CPUGEMM_H +#define ACL_SRC_CPU_OPERATORS_CPUGEMM_H #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/TensorInfo.h" @@ -36,6 +36,7 @@ #include "src/cpu/kernels/CpuGemmTranspose1xWKernel.h" #include "src/cpu/operators/CpuActivation.h" #include "src/cpu/operators/CpuAdd.h" +#include "src/cpu/operators/CpuTranspose.h" #include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" #include <memory> @@ -144,16 +145,17 @@ public: private: enum AuxTensorIdx { - AsmGemmWorkspace = 0, - Pretraspose, - InterleavedLHS, - TransposedRHS, + /* Slots 0 - 2 reserved for CpuGemmAssemblyDispatch */ + InterleavedLHS = 3, + PreTransposedRHS, + Transposed1xWRHS, TempResult, Count }; std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{nullptr}; - std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{nullptr}; + std::unique_ptr<CpuTranspose> _pretranspose_b_func{nullptr}; + std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose1xW_b_kernel{nullptr}; std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{nullptr}; std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{nullptr}; std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{nullptr}; @@ -162,10 +164,13 @@ private: std::unique_ptr<CpuActivation> _activation_func{nullptr}; TensorInfo _tmp_a{}; + TensorInfo _pretransposed_b{}; TensorInfo _tmp_b{}; TensorInfo _tmp_d{}; bool _run_vector_matrix_multiplication{false}; + bool _run_interleave_transpose{ + true}; /**< If we run CpuGemmInterleave4x4Kernel on lhs and CpuGemmTranspose1xWKernel on rhs */ bool _run_alpha_scale{false}; bool _run_addition{false}; bool _run_bias_addition{false}; @@ -177,4 +182,4 @@ private: }; } // namespace cpu } // namespace arm_compute -#endif /*ARM_COMPUTE_CPU_GEMM_H */ +#endif // ACL_SRC_CPU_OPERATORS_CPUGEMM_H |