From 3170439f3938d007e58998d61eed98560c3f026c Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Mon, 25 Oct 2021 16:04:20 -0700 Subject: Remove zp subtraction from tensor_read pseudocode Operators which use the zero-point functionalty for 8-bit integer processing are updated to do the zero-point subtract in their pseudocode. Note that the PAD operator no longer takes a zero point argument, and instead requires callers to account for the zero point in the pad_const argument. Change-Id: I3bca1cae85aa2093000c420f0433633c347a29de --- chapters/tensor_ops.adoc | 45 +++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 16 deletions(-) (limited to 'chapters/tensor_ops.adoc') diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc index ad4d75d..d7ced25 100644 --- a/chapters/tensor_ops.adoc +++ b/chapters/tensor_ops.adoc @@ -42,14 +42,14 @@ ERROR_IF(flatten(left_shape, right_shape) != shape); for_each(left_index in left_shape) { for_each(right_index in right_shape) { in_t max_value = minimum_value; - int32_t max_index = 0; + out_t max_index = 0; for (i = 0; i < shape[axis]; i++) { index = flatten(left_index, [i], right_index); in_t value = tensor_read(input, shape1, index); if (value > max_value) { max_value = value; max_index = i; } } index = flatten(left_index, right_index); - tensor_write(output, shape, index, max_index); + tensor_write(output, shape, index, max_index); } } ---- @@ -114,11 +114,12 @@ for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W, 0 <= c < C ) { // average, padding does not count if (0 <= y < IH and 0 <= x < IW) { count++; - acc_t value = tensor_read(input, [N,IH,IW,C], [n,y,x,c], input_zp); + acc_t value = tensor_read(input, [N,IH,IW,C], [n,y,x,c]); + value = value - input_zp; acc = apply_add(acc, value); } } - if (is_float(out_t)) { + if (is_float(in_t)) { output_val = acc / (float)count; } else { scale_t scale = reciprocal_scale(count); @@ -176,8 +177,10 @@ for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) { y = iy + ky * dilation_y; x = ix + kx * dilation_x; if (0 <= y < IH && 0 <= x < IW) { - acc_t value = tensor_read(input, [N,IH,IW,IC], [n,y,x,ic], input_zp); - acc_t weight = tensor_read(weight, [OC,KH,KW,IC], [oc,ky,kx,ic], weight_zp); + acc_t value = tensor_read(input, [N,IH,IW,IC], [n,y,x,ic]); + acc_t weight = tensor_read(weight, [OC,KH,KW,IC], [oc,ky,kx,ic]); + value = value - input_zp; + weight = weight - weight_zp; acc = apply_add(acc, value * weight); } } @@ -237,8 +240,10 @@ for_each(0 <= n < N, 0 <= od < D, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) { y = iy + ky * dilation_y; x = ix + kx * dilation_x; if (0 <= x < IW && 0 <= y < IH && 0 <= d <= ID) { - acc_t value = tensor_read(input, [N,ID,IH,IW,IC], [n,d,y,x,ic], input_zp); - acc_t weight = tensor_read(weight,[OC,KD,KH,KW,IC],[oc,kd,ky,kx,ic], weight_zp); + acc_t value = tensor_read(input, [N,ID,IH,IW,IC], [n,d,y,x,ic]); + acc_t weight = tensor_read(weight,[OC,KD,KH,KW,IC],[oc,kd,ky,kx,ic]); + value = value - input_zp; + weight = weight - weight_zp; acc = apply_add(acc, value * weight); } } @@ -297,8 +302,10 @@ for_each(0 <= n(input, [N,H,W,C], [n,y,x,c], input_zp); - acc_t weight = tensor_read(weight, [KH,KW,C,M], [ky,kx,c,m], weight_zp); + acc_t value = tensor_read(input, [N,H,W,C], [n,y,x,c]); + acc_t weight = tensor_read(weight, [KH,KW,C,M], [ky,kx,c,m]); + value = value - input_zp; + weight = weight - weight_zp; acc = apply_add(acc, value * weight); } } @@ -344,8 +351,10 @@ ERROR_IF(weight_t != int8_t && weight_zp != 0); for_each(0 <= n < N, 0 <= oc < OC) { acc_t acc = 0; for_each(0 <= ic < IC) { - acc_t value = tensor_read(input, [N,IC], [n,ic], input_zp); - acc_t weight = tensor_read(weight, [OC,IC], [oc,ic], weight_zp); + acc_t value = tensor_read(input, [N,IC], [n,ic]); + acc_t weight = tensor_read(weight, [OC,IC], [oc,ic]); + value = value - input_zp; + weight = weight - weight_zp; acc = apply_add(acc, value * weight); } acc = apply_add(acc, bias[oc]); @@ -387,8 +396,10 @@ ERROR_IF(in_t != int8_t && (A_zp != 0 || B_zp != 0)); // Zero point only for int for_each(0 <= n < N, 0 <= h < H, 0 <= w < W) { acc_t acc = 0; for_each(0 <= c < C) { - acc_t value1 = tensor_read(A, [N,H,C], [n,h,c], A_zp); - acc_t value2 = tensor_read(B, [N,C,W], [n,c,w], B_zp); + acc_t value1 = tensor_read(A, [N,H,C], [n,h,c]); + acc_t value2 = tensor_read(B, [N,C,W], [n,c,w]); + value1 = value1 - A_zp; + value2 = value2 - B_zp; acc = apply_add(acc, value1 * value2); } tensor_write(output, [N,H,W], [n,h,w], acc); @@ -499,8 +510,10 @@ for_each(0 <= n < N, 0 <= iy < IH, 0 <= ix < IW, 0 <= oc < OC, ox = ix * stride_x - out_pad_left + kx; if (oy >= 0 && oy < OH && ox >= 0 && ox < OW) { acc_t acc = tensor_read(output, [N,OH,OW,OC], [n,oy,ox,oc]); - acc_t value = tensor_read(input, [N,IH,IW,IC], [n,iy,ix,ic], input_zp); - acc_t weight = tensor_read(weight, [OC,KH,KW,IC], [oc,ky,kx,ic], weight_zp); + acc_t value = tensor_read(input, [N,IH,IW,IC], [n,iy,ix,ic]); + acc_t weight = tensor_read(weight, [OC,KH,KW,IC], [oc,ky,kx,ic]); + value = value - input_zp; + weight = weight - weight_zp; acc = apply_add(acc, value * weight); tensor_write(output, [N,OH,OW,OC], [n,oy,ox,oc], acc); } -- cgit v1.2.1