diff options
Diffstat (limited to 'chapters/tensor_ops.adoc')
-rw-r--r-- | chapters/tensor_ops.adoc | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc index ff5f25a..16b0341 100644 --- a/chapters/tensor_ops.adoc +++ b/chapters/tensor_ops.adoc @@ -102,7 +102,7 @@ ERROR_IF(in_t != int8_t && input_zp != 0); // Zero point only for int8_t ERROR_IF(in_t != int8_t && output_zp != 0); // Zero point only for int8_t ERROR_IF(kernel_y < 1 || kernel_x < 1); // kernel size must be >= 1 ERROR_IF(stride_y < 1 || stride_x < 1); -ERROR_IF(pad_top < 0 || pad_buttom < 0 || pad_left < 0 || pad_right < 0); +ERROR_IF(pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0); // Padding must be less than kernel size to avoid // a divide-by-zero. ERROR_IF(pad_right >= kernel_x || pad_left >= kernel_x); @@ -118,12 +118,12 @@ for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W, 0 <= c < C ) { for_each(0 <= ky < kernel_y, 0 <= kx < kernel_x) { y = iy + ky; x = ix + kx; - acc_t value = tensor_read<in_t>(input, [N,IH,IW,C], [n,y,x,c], input_zp, pad); - acc = apply_add<acc_t>(acc, value); // Only values from the input tensor are used to calculate the // average, padding does not count if (0 <= y < IH and 0 <= x < IW) { count++; + acc_t value = tensor_read<in_t>(input, [N,IH,IW,C], [n,y,x,c], input_zp); + acc = apply_add<acc_t>(acc, value); } } if (is_float(out_t)) { @@ -190,9 +190,11 @@ for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) { for_each(0 <= ky < KH, 0 <= kx < KW, 0 <= ic < IC) { y = iy + ky * dilation_y; x = ix + kx * dilation_x; - acc_t value = tensor_read<in_t>(input, [N,IH,IW,IC], [n,y,x,ic], input_zp, pad); - acc_t weight = tensor_read<weight_t>(weight, [OC,KH,KW,IC], [oc,ky,kx,ic], weight_zp); - acc = apply_add<acc_t>(acc, value * weight); + if (0 <= y < IH && 0 <= x < IW) { + acc_t value = tensor_read<in_t>(input, [N,IH,IW,IC], [n,y,x,ic], input_zp); + acc_t weight = tensor_read<weight_t>(weight, [OC,KH,KW,IC], [oc,ky,kx,ic], weight_zp); + acc = apply_add<acc_t>(acc, value * weight); + } } acc = apply_add<acc_t>(acc, bias[oc]); tensor_write<acc_t>(output, [N,H,W,OC], [n,oy,ox,oc], acc); @@ -256,9 +258,11 @@ for_each(0 <= n < N, 0 <= od < D, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) { d = id + kd * dilation_d; y = iy + ky * dilation_y; x = ix + kx * dilation_x; - acc_t value = tensor_read<in_t>(input, [N,ID,IH,IW,IC], [n,d,y,x,ic], input_zp, pad); - acc_t weight = tensor_read<weight_t>(weight,[OC,KD,KH,KW,IC],[oc,kd,ky,kx,ic], weight_zp); - acc = apply_add<acc_t>(acc, value * weight); + if (0 <= x < IW && 0 <= y < IH && 0 <= d <= ID) { + acc_t value = tensor_read<in_t>(input, [N,ID,IH,IW,IC], [n,d,y,x,ic], input_zp); + acc_t weight = tensor_read<weight_t>(weight,[OC,KD,KH,KW,IC],[oc,kd,ky,kx,ic], weight_zp); + acc = apply_add<acc_t>(acc, value * weight); + } } acc = apply_add<acc_t>(acc, bias[oc]); tensor_write<acc_t>(output, [N,D,H,W,OC], [n,od,oy,ox,oc], acc); @@ -321,9 +325,11 @@ for_each(0 <= n<N, 0 <= oy < H, 0 <= ox < W; 0 <= c < (C * M), 0 <= m < M) { for_each(0 <= ky < KH, 0 <= kx < KW) { y = iy + ky * dilation_y; x = ix + kx * dilation_x; - acc_t value = tensor_read<in_t>(input, [N,H,W,C], [n,y,x,c], input_zp, pad); - acc_t weight = tensor_read<weight_t>(weight, [KH,KW,C,M], [ky,kx,c,m], weight_zp); - acc = apply_add<acc_t>(acc, value * weight); + if (0 <= y < IH && 0 <= x < IW) { + acc_t value = tensor_read<in_t>(input, [N,H,W,C], [n,y,x,c], input_zp); + acc_t weight = tensor_read<weight_t>(weight, [KH,KW,C,M], [ky,kx,c,m], weight_zp); + acc = apply_add<acc_t>(acc, value * weight); + } } acc = apply_add<acc_t>(acc, bias[(c * M) + m]); tensor_write<acc_t>(output, [N,H,W,C * M], [n,oy,ox,c * M + m], acc); @@ -467,7 +473,7 @@ None ---- ERROR_IF(kernel_y < 1 || kernel_x < 1); // kernel size must be >= 1 ERROR_IF(stride_y < 1 || stride_x < 1); -ERROR_IF(pad_top < 0 || pad_buttom < 0 || pad_left < 0 || pad_right < 0); +ERROR_IF(pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0); // Padding must be less than kernel size, otherwise no // input values will be used. ERROR_IF(pad_right >= kernel_x || pad_left >= kernel_x); |