aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/assembly/arm_gemm.hpp')
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp116
1 files changed, 115 insertions, 1 deletions
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 247cb1d470..48fd7c6b43 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,6 +47,57 @@ enum class GemmMethod
GEMM_HYBRID_QUANTIZED
};
+/** Memory layouts for the weights tensor.
+ *
+ * * UNSPECIFIED is used to select kernels that do not run in
+ * variable weights mode.
+ *
+ * * ANY is used to query the kernel database to retrieve any of the
+ * kernels that runs in variable weights mode. Once a kernel is
+ * found, the specific format expected by the kernel can be
+ * retrieved by the user for reordering the weights tensor
+ * accordingly.
+ *
+ * The other values OHWIo{interleave_by}i{block_by} describe the
+ * memory layout of a 4D tensor with layout OHWI that has been
+ * transformed into a 4D tensor with dimensions O'HWI' where:
+ *
+ * O' = first multiple of {interleave_by} s.t. O<=O'
+ * I' = first multiple of {block_by} s.t. I<=I'
+ *
+ * The total size of the dst tensor is O' x H x W x I'
+ *
+ * The access function of the tensor with layout
+ * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+ * access function, where the 6 parameters are computed as follows:
+ *
+ * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+ *
+ * x4 = h RANGE [0, H-1] SIZE: H
+ * x3 = w RANGE [0, W-1] SIZE: W
+ * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by}
+ * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by}
+ * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by}
+ * TOTAL SIZE: O' * H * W * I'
+ *
+ * 4D 6D
+ * ----------------- -----------------------------------
+ * value(o, h, w, i) = x5 * H * W * I' * {interleave_by}
+ * + x4 * W * I' * {interleave_by}
+ * + x3 * I' * {interleave_by}
+ * + x2 * {interleave_by} * {block_by}
+ * + x1 * {block_by}
+ * + x0
+ *
+ * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+ * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+ * as a 2D tensor, where the number of rows is O'/{interleave_by}
+ * and the number of columns is {interleave_by} * H * W * I'.
+ *
+ * The postfix *_bf16 is for the memory layout needed for the
+ * fast-mode kernels, in which the weights are passed in bfloat16
+ * format.
+ */
enum class WeightFormat
{
UNSPECIFIED = 0x1,
@@ -87,6 +138,69 @@ enum class WeightFormat
OHWIo64i8 = 0x804000
};
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+ return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+ return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+ return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
+inline std::string to_string(WeightFormat wf)
+{
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf: \
+ return #wf;
+ switch(wf)
+ {
+ __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+ __CASE_WEIGHT_FORMAT(ANY)
+ __CASE_WEIGHT_FORMAT(OHWI)
+ __CASE_WEIGHT_FORMAT(OHWIo2)
+ __CASE_WEIGHT_FORMAT(OHWIo4)
+ __CASE_WEIGHT_FORMAT(OHWIo8)
+ __CASE_WEIGHT_FORMAT(OHWIo16)
+ __CASE_WEIGHT_FORMAT(OHWIo32)
+ __CASE_WEIGHT_FORMAT(OHWIo64)
+ __CASE_WEIGHT_FORMAT(OHWIo128)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2)
+ __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2)
+ __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2)
+ __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2)
+ __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2)
+ __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4)
+ __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4)
+ __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4)
+ __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4)
+ __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4)
+ __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+ __CASE_WEIGHT_FORMAT(OHWIo2i8)
+ __CASE_WEIGHT_FORMAT(OHWIo4i8)
+ __CASE_WEIGHT_FORMAT(OHWIo8i8)
+ __CASE_WEIGHT_FORMAT(OHWIo16i8)
+ __CASE_WEIGHT_FORMAT(OHWIo32i8)
+ __CASE_WEIGHT_FORMAT(OHWIo64i8)
+ default:
+ return "invalid value";
+ }
+#undef __CASE_WEIGHT_FORMAT
+}
+
struct KernelDescription
{
GemmMethod method = GemmMethod::DEFAULT;
@@ -230,6 +344,6 @@ template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
template <typename Top, typename Tret, class OutputStage = Nothing>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm