aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.h')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.h34
1 files changed, 33 insertions, 1 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 0c51c92359..588c45294a 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -82,6 +82,38 @@ public:
public:
/** If supported create a Compute Library function else fallback to the arm_gemm function.
*
+ * @note Configuring "batches"
+ * The shapes of @p a @p b and @p d are arranged as follows:
+ * Lowest dimension <-> Highest dimension
+ * a: [K, M, Batch, Multi]
+ * b: [N, K, Multi]
+ * d: [N, M, Batch, Multi]
+ *
+ * The "Batch" refers to where "Batch" number of MxK slices of tensor a multiplies with a single KxN slice of b
+ * The "Multi" refers to where "Multi" number of individual multiplication of a with b
+ *
+ * E.g. the following are some example input shape configurations
+ *
+ * (1) Normal 2D gemm
+ * a: [K=3, M=4]
+ * b: [N=5, K=3]
+ * d: [N=5, M=4]
+ *
+ * (2) Batches of a sharing b (e.g. gemm-based batched convolution where b is the shared )
+ * a: [K=3, M=4, Batch=9]
+ * b: [N=5, K=3]
+ * d: [N=5, M=4, Batch=9]
+ *
+ * (3) "Batches" of independent gemm (e.g. batched matmul)
+ * a: [K=3, M=4, Batch=1, Multi=7]
+ * b: [N=5, K=3, Multi=7]
+ * d: [N=5, M=4, Batch=1, Multi=7]
+ *
+ * (4) "Batches" of independent gemm where b is also shared
+ * a: [K=3, M=4, Batch=4, Multi=7]
+ * b: [N=5, K=3, Multi=7]
+ * d: [N=5, M=4, Batch=4, Multi=7]
+ *
* @param[in] a Input tensor (Matrix A)
* @param[in] b Input tensor (Matrix B)
* @param[in] c Input tensor (Matrix C) used to pass the bias for quantized calculations