aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemm.h')
-rw-r--r--src/cpu/operators/CpuGemm.h21
1 files changed, 13 insertions, 8 deletions
diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h
index 6b30d134fa..a05258d206 100644
--- a/src/cpu/operators/CpuGemm.h
+++ b/src/cpu/operators/CpuGemm.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CPU_GEMM_H
-#define ARM_COMPUTE_CPU_GEMM_H
+#ifndef ACL_SRC_CPU_OPERATORS_CPUGEMM_H
+#define ACL_SRC_CPU_OPERATORS_CPUGEMM_H
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/TensorInfo.h"
@@ -36,6 +36,7 @@
#include "src/cpu/kernels/CpuGemmTranspose1xWKernel.h"
#include "src/cpu/operators/CpuActivation.h"
#include "src/cpu/operators/CpuAdd.h"
+#include "src/cpu/operators/CpuTranspose.h"
#include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
#include <memory>
@@ -144,16 +145,17 @@ public:
private:
enum AuxTensorIdx
{
- AsmGemmWorkspace = 0,
- Pretraspose,
- InterleavedLHS,
- TransposedRHS,
+ /* Slots 0 - 2 reserved for CpuGemmAssemblyDispatch */
+ InterleavedLHS = 3,
+ PreTransposedRHS,
+ Transposed1xWRHS,
TempResult,
Count
};
std::unique_ptr<kernels::CpuGemmInterleave4x4Kernel> _interleave_kernel{nullptr};
- std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose_kernel{nullptr};
+ std::unique_ptr<CpuTranspose> _pretranspose_b_func{nullptr};
+ std::unique_ptr<kernels::CpuGemmTranspose1xWKernel> _transpose1xW_b_kernel{nullptr};
std::unique_ptr<kernels::CpuGemmMatrixMultiplyKernel> _mm_kernel{nullptr};
std::unique_ptr<CpuGemmAssemblyDispatch> _asm_glue{nullptr};
std::unique_ptr<kernels::CpuGemmMatrixAdditionKernel> _ma_kernel{nullptr};
@@ -162,10 +164,13 @@ private:
std::unique_ptr<CpuActivation> _activation_func{nullptr};
TensorInfo _tmp_a{};
+ TensorInfo _pretransposed_b{};
TensorInfo _tmp_b{};
TensorInfo _tmp_d{};
bool _run_vector_matrix_multiplication{false};
+ bool _run_interleave_transpose{
+ true}; /**< If we run CpuGemmInterleave4x4Kernel on lhs and CpuGemmTranspose1xWKernel on rhs */
bool _run_alpha_scale{false};
bool _run_addition{false};
bool _run_bias_addition{false};
@@ -177,4 +182,4 @@ private:
};
} // namespace cpu
} // namespace arm_compute
-#endif /*ARM_COMPUTE_CPU_GEMM_H */
+#endif // ACL_SRC_CPU_OPERATORS_CPUGEMM_H