From 33e03074c36d85de87e9032a2583b04ce8ddcd6b Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 14 Jan 2021 13:43:40 +0000 Subject: Cycle estimate-based kernel selection for dot product quantized s8/u8 kernels Resolves: COMPMID-3990 Signed-off-by: Georgios Pinitas Change-Id: If840c79209940535450f4ea1cbf6b0ec646a168e Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4866 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 27 +++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp') diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index eede1a4f76..d9035c8917 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -573,7 +573,7 @@ public: } // Estimate cycles for given problem given provided parameters - static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms, const OutputStage &os = {} ) { // Note: Current hybrid kernels don't actually round up height (they // have paths for each possible height). Might need to make this // configurable in future. @@ -591,6 +591,31 @@ public: uint64_t total_cycles = mac_cycles; + // Quantizing kernels with separate quantize need to add in the extra stages. + if (std::is_same::value && SeparateQuantize) { + const Requantize32 *qp = reinterpret_cast(&os); + + // Row sums: need to consider each value in A (batch * multi * M * K)... + uint64_t rowsum_bytes = static_cast(args._nbatches) * args._nmulti * args._Msize * roundup(args._Ksize, strategy::k_unroll()); + + // ... but row sums are skipped if B offset==0. + if (qp->b_offset == 0) { + rowsum_bytes = 0; + } + + // Use "prepare bytes per cycle" to store "row sum values per cycle". + float rowsum_cycles = static_cast(rowsum_bytes) / params.prepare_bytes_cycle; + + // Requantize: need to consider each value in C (batch * multi * M * N) + uint64_t requantize_bytes = static_cast(args._nbatches) * args._nmulti * args._Msize * args._Nsize; + + // Use "merge bytes per cycle" to store "requantize values per cycle". + float requantize_cycles = static_cast(requantize_bytes) / params.merge_bytes_cycle; + + // Recalculate total_cycles with the extra components. + total_cycles = mac_cycles + rowsum_cycles + requantize_cycles; + } + return total_cycles; } -- cgit v1.2.1