From 70b140f0219464a239bb40e53023d7de02eb71cc Mon Sep 17 00:00:00 2001 From: Dominic Symes Date: Tue, 9 Feb 2021 15:23:05 +0000 Subject: Add batch dimension to MatMul Add batch dimension for consistency with other operators. Change-Id: I9b1734a1a60304f46a14a6cda1bd6be6678f1037 Signed-off-by: Dominic Symes --- chapters/tensor_ops.adoc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc index 76f39ca..571b9aa 100644 --- a/chapters/tensor_ops.adoc +++ b/chapters/tensor_ops.adoc @@ -362,16 +362,16 @@ for_each (0 <= n < N, 0 <= oc < OC) { |=== ==== MATMUL -Performs a two dimensional matrix multiplication. This allows both inputs to be activations, rather than reserving weights as an attribute in the FULLY_CONNECTED operator. +Performs two dimensional matrix multiplications. This allows both inputs to be activations, rather than reserving weights as an attribute in the FULLY_CONNECTED operator. *Arguments:* |=== |Argument|Type|Name|Shape|Description -|Input|in_t*|A|[M,K]|Input tensor A -|Input|in_t*|B|[K,N]|Input tensor B -|Output|acc_t*|C|[M,N]|Output tensor C +|Input|in_t*|A|[N,H,C]|Input tensor A, N matrices of size HxC +|Input|in_t*|B|[N,C,W]|Input tensor B, N matrices of size CxW +|Output|acc_t*|output|[N,H,W]|Output tensor, N matrices of size HxW |=== *Quantization Parameters:* @@ -388,14 +388,14 @@ Performs a two dimensional matrix multiplication. This allows both inputs to be [source,c] ---- assert(in_t == int8_t || (A_zp == 0 && B_zp == 0)) // Zero point only for int8 -for_each (0 <= m < M, 0 <= n < N) { +for_each (0 <= n < N, 0 <= h < H, 0 <= w < W) { acc_t acc = 0 - for_each (0 <= k < K) { - in_t value1 = tensor_read(A, [M,K], [m,k], A_zp) - in_t value2 = tensor_read(B, [K,N], [k,n], B_zp) + for_each (0 <= c < C) { + in_t value1 = tensor_read(A, [N,H,C], [n,h,c], A_zp) + in_t value2 = tensor_read(B, [N,C,W], [n,c,w], B_zp) acc = apply_add(acc, value1 * value2) } - tensor_write(C, [M,N], [m,n], acc) + tensor_write(output, [N,H,W], [n,h,w], acc) } ---- -- cgit v1.2.1