aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp27
1 files changed, 26 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index eede1a4f76..d9035c8917 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -573,7 +573,7 @@ public:
}
// Estimate cycles for given problem given provided parameters
- static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters &params) {
+ static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters &params, const OutputStage &os = {} ) {
// Note: Current hybrid kernels don't actually round up height (they
// have paths for each possible height). Might need to make this
// configurable in future.
@@ -591,6 +591,31 @@ public:
uint64_t total_cycles = mac_cycles;
+ // Quantizing kernels with separate quantize need to add in the extra stages.
+ if (std::is_same<OutputStage, Requantize32>::value && SeparateQuantize) {
+ const Requantize32 *qp = reinterpret_cast<const Requantize32 *>(&os);
+
+ // Row sums: need to consider each value in A (batch * multi * M * K)...
+ uint64_t rowsum_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * args._Msize * roundup(args._Ksize, strategy::k_unroll());
+
+ // ... but row sums are skipped if B offset==0.
+ if (qp->b_offset == 0) {
+ rowsum_bytes = 0;
+ }
+
+ // Use "prepare bytes per cycle" to store "row sum values per cycle".
+ float rowsum_cycles = static_cast<float>(rowsum_bytes) / params.prepare_bytes_cycle;
+
+ // Requantize: need to consider each value in C (batch * multi * M * N)
+ uint64_t requantize_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * args._Msize * args._Nsize;
+
+ // Use "merge bytes per cycle" to store "requantize values per cycle".
+ float requantize_cycles = static_cast<float>(requantize_bytes) / params.merge_bytes_cycle;
+
+ // Recalculate total_cycles with the extra components.
+ total_cycles = mac_cycles + rowsum_cycles + requantize_cycles;
+ }
+
return total_cycles;
}