diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp | 27 |
1 files changed, 26 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index eede1a4f76..d9035c8917 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -573,7 +573,7 @@ public: } // Estimate cycles for given problem given provided parameters - static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms, const OutputStage &os = {} ) { // Note: Current hybrid kernels don't actually round up height (they // have paths for each possible height). Might need to make this // configurable in future. @@ -591,6 +591,31 @@ public: uint64_t total_cycles = mac_cycles; + // Quantizing kernels with separate quantize need to add in the extra stages. + if (std::is_same<OutputStage, Requantize32>::value && SeparateQuantize) { + const Requantize32 *qp = reinterpret_cast<const Requantize32 *>(&os); + + // Row sums: need to consider each value in A (batch * multi * M * K)... + uint64_t rowsum_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * args._Msize * roundup(args._Ksize, strategy::k_unroll()); + + // ... but row sums are skipped if B offset==0. + if (qp->b_offset == 0) { + rowsum_bytes = 0; + } + + // Use "prepare bytes per cycle" to store "row sum values per cycle". + float rowsum_cycles = static_cast<float>(rowsum_bytes) / params.prepare_bytes_cycle; + + // Requantize: need to consider each value in C (batch * multi * M * N) + uint64_t requantize_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * args._Msize * args._Nsize; + + // Use "merge bytes per cycle" to store "requantize values per cycle". + float requantize_cycles = static_cast<float>(requantize_bytes) / params.merge_bytes_cycle; + + // Recalculate total_cycles with the extra components. + total_cycles = mac_cycles + rowsum_cycles + requantize_cycles; + } + return total_cycles; } |