aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2024-02-14 16:33:31 -0800
committerEric Kunze <eric.kunze@arm.com>2024-03-13 16:31:19 -0700
commit0afe61f88ce3d2f445c5f01ae5567cb1b0b7f303 (patch)
treebfa03c381634090e9074d3e1167c3256a31f3c84
parent6dd341093507157aabbea00b90ca8902509cfd4f (diff)
downloadspecification-0afe61f88ce3d2f445c5f01ae5567cb1b0b7f303.tar.gz
Modify convolution operators to improve bias handling
Accumulator size moves to an enumerated attribute, out_t for floating-point changes to be the size of the input. Bias for floating-point also becomes the bit width of the input type. Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: I7369417adbb1106ce34a1978e7f511a30272c318
-rw-r--r--chapters/introduction.adoc4
-rw-r--r--pseudocode/operators/CONV2D.tosac17
-rw-r--r--pseudocode/operators/CONV3D.tosac5
-rw-r--r--pseudocode/operators/DEPTHWISE_CONV2D.tosac5
-rw-r--r--pseudocode/operators/TRANSPOSE_CONV2D.tosac34
-rw-r--r--tosa.xml96
-rw-r--r--tosa.xsd2
7 files changed, 94 insertions, 69 deletions
diff --git a/chapters/introduction.adoc b/chapters/introduction.adoc
index c34bf7b..c6764d3 100644
--- a/chapters/introduction.adoc
+++ b/chapters/introduction.adoc
@@ -397,7 +397,7 @@ if (!local_bound) {
output_bnd = operation_fp64(input_abs, weight_abs, bias_abs);
size_t T = tensor_size(output_shape) // number dot product results
-size_t ksb = (max_value(bias_abs) > 0) ? (KS + 1) : KS; // kernel size and bias
+size ksb = ceil(KS / exp2(normal_frac<acc_t>() - normal_frac<out_t>())) + ((max_value(bias_abs) > 0) ? 1 : 0);
fp64_t out_err_sum = 0.0;
fp64_t out_err_sumsq = 0.0;
for_each(index in output_shape) {
@@ -412,7 +412,7 @@ for_each(index in output_shape) {
REQUIRE(out_ref == 0.0 && out_imp == 0.0);
out_err = 0.0;
} else { // 0.0 < out_bnd < infinity
- fp64_t out_err_bnd = max(out_bnd * exp2(-1-normal_frac<acc_t>()), normal_min<acc_t>());
+ fp64_t out_err_bnd = max(out_bnd * exp2(-1-normal_frac<out_t>()), normal_min<out_t>());
out_err = (static_cast<fp64_t>(out_imp) - out_ref) / out_err_bnd;
REQUIRE(abs(out_err) <= ksb);
}
diff --git a/pseudocode/operators/CONV2D.tosac b/pseudocode/operators/CONV2D.tosac
index fe61747..0ae0e81 100644
--- a/pseudocode/operators/CONV2D.tosac
+++ b/pseudocode/operators/CONV2D.tosac
@@ -17,24 +17,25 @@ ERROR_IF(OW != idiv_check(IW - 1 + pad_left + pad_right - (KW - 1) * dilation_x,
ERROR_IF(BC != OC && BC != 1);
for_each(0 <= n < N, 0 <= oy < OH, 0 <= ox < OW, 0 <= oc < OC) {
- out_t acc = 0;
+ acc_t acc = 0;
index_t iy = oy * stride_y - pad_top;
index_t ix = ox * stride_x - pad_left;
for_each(0 <= ky < KH, 0 <= kx < KW, 0 <= ic < IC) {
index_t y = iy + ky * dilation_y;
index_t x = ix + kx * dilation_x;
if (0 <= y < IH && 0 <= x < IW) {
- out_t value = static_cast<out_t>(tensor_read<in_t>(input,
+ acc_t value = static_cast<out_t>(tensor_read<in_t>(input,
[N,IH,IW,IC],
[n,y,x,ic]));
- out_t weight = static_cast<out_t>(tensor_read<weight_t>(weight,
+ acc_t weight = static_cast<out_t>(tensor_read<weight_t>(weight,
[OC,KH,KW,IC],
[oc,ky,kx,ic]));
- value = apply_sub_s<out_t>(value, static_cast<out_t>(input_zp));
- weight = apply_sub_s<out_t>(weight, static_cast<out_t>(weight_zp));
- acc = apply_add_s<out_t>(acc, apply_mul_s<out_t>(value, weight));
+ value = apply_sub_s<acc_t>(value, static_cast<out_t>(input_zp));
+ weight = apply_sub_s<acc_t>(weight, static_cast<out_t>(weight_zp));
+ acc = apply_add_s<acc_t>(acc, apply_mul_s<acc_t>(value, weight));
}
}
- acc = apply_add_s<out_t>(acc, bias[(BC == 1) ? 0 : oc]);
- tensor_write<out_t>(output, [N,OH,OW,OC], [n,oy,ox,oc], acc);
+ out_t out = static_cast<out_t>(acc);
+ out = apply_add_s<out_t>(out, bias[(BC == 1) ? 0 : oc]);
+ tensor_write<out_t>(output, [N,OH,OW,OC], [n,oy,ox,oc], out);
}
diff --git a/pseudocode/operators/CONV3D.tosac b/pseudocode/operators/CONV3D.tosac
index 7568564..e53b7eb 100644
--- a/pseudocode/operators/CONV3D.tosac
+++ b/pseudocode/operators/CONV3D.tosac
@@ -38,6 +38,7 @@ for_each(0 <= n < N, 0 <= od < OD, 0 <= oy < OH, 0 <= ox < OW, 0 <= oc < OC) {
acc = apply_add_s<out_t>(acc, apply_mul_s<out_t>(value, weight));
}
}
- acc = apply_add_s<out_t>(acc, bias[(BC == 1) ? 0 : oc]);
- tensor_write<out_t>(output, [N,OD,OH,OW,OC], [n,od,oy,ox,oc], acc);
+ out_t out = static_cast<out_t>(acc);
+ out = apply_add_s<out_t>(out, bias[(BC == 1) ? 0 : oc]);
+ tensor_write<out_t>(output, [N,OD,OH,OW,OC], [n,od,oy,ox,oc], out);
}
diff --git a/pseudocode/operators/DEPTHWISE_CONV2D.tosac b/pseudocode/operators/DEPTHWISE_CONV2D.tosac
index a473375..419d2eb 100644
--- a/pseudocode/operators/DEPTHWISE_CONV2D.tosac
+++ b/pseudocode/operators/DEPTHWISE_CONV2D.tosac
@@ -35,6 +35,7 @@ for_each(0 <= n < N, 0 <= oy < OH, 0 <= ox < OW, 0 <= c < C, 0 <= m < M) {
acc = apply_add_s<out_t>(acc, apply_mul_s<out_t>(value, weight));
}
}
- acc = apply_add_s<out_t>(acc, bias[(BC == 1) ? 0 : (c * M) + m]);
- tensor_write<out_t>(output, [N,OH,OW,C * M], [n,oy,ox,c * M + m], acc);
+ out_t out = static_cast<out_t>(acc);
+ out = apply_add_s<out_t>(out, bias[(BC == 1) ? 0 : (c * M) + m]);
+ tensor_write<out_t>(output, [N,OH,OW,C * M], [n,oy,ox,c * M + m], out);
}
diff --git a/pseudocode/operators/TRANSPOSE_CONV2D.tosac b/pseudocode/operators/TRANSPOSE_CONV2D.tosac
index ab61348..6713b30 100644
--- a/pseudocode/operators/TRANSPOSE_CONV2D.tosac
+++ b/pseudocode/operators/TRANSPOSE_CONV2D.tosac
@@ -16,20 +16,24 @@ ERROR_IF(OH != (IH - 1) * stride_y + out_pad_top + out_pad_bottom + KH);
ERROR_IF(OW != (IW - 1) * stride_x + out_pad_left + out_pad_right + KW);
ERROR_IF(BC != OC && BC != 1);
-for_each(index in [N, OH, OW, OC]) {
- tensor_write<out_t>(output, [N,OH,OW,OC], index, bias[(BC == 1) ? 0 : index[3]]);
-}
-for_each(0 <= n < N, 0 <= iy < IH, 0 <= ix < IW, 0 <= oc < OC,
- 0 <= ic < IC, 0 <= ky < KH, 0 <= kx < KW) {
- index_t oy = iy * stride_y + out_pad_top + ky;
- index_t ox = ix * stride_x + out_pad_left + kx;
- if (oy >= 0 && oy < OH && ox >= 0 && ox < OW) {
- out_t acc = static_cast<out_t>(tensor_read<out_t>(output, [N,OH,OW,OC], [n,oy,ox,oc]));
- out_t value = static_cast<out_t>(tensor_read<in_t>(input, [N,IH,IW,IC], [n,iy,ix,ic]));
- out_t weight = static_cast<out_t>(tensor_read<weight_t>(weight, [OC,KH,KW,IC], [oc,ky,kx,ic]));
- value = apply_sub_s<out_t>(value, static_cast<out_t>(input_zp));
- weight = apply_sub_s<out_t>(weight, static_cast<out_t>(weight_zp));
- acc = apply_add_s<out_t>(acc, apply_mul_s<out_t>(value, weight));
- tensor_write<out_t>(output, [N,OH,OW,OC], [n,oy,ox,oc], acc);
+for_each(0 <= n < N, 0 <= iy < IH, 0 <= ix < IW, 0 <= dy < stride_y, 0 <= dx < stride_x, 0 <= oc < OC) {
+ acc_t acc = 0;
+ index_t oy = iy * stride_y + dy + out_pad_top;
+ index_t ox = ix * stride_x + dx + out_pad_left;
+
+ for_each(0 <= sy * stride_y < KY - dy, 0 <= sx * stride_x < KX - dx, 0 <= ic < IC) {
+ index_t y = iy - sy;
+ index_t x = ix - sx;
+ index_t ky = dy + sy * stride_y;
+ index_t kx = dx + sx * stride_x;
+ acc_t value = static_cast<acc_t>(tensor_read<in_t>(input, [N,IH,IW,IC], [n,y,x,ic]));
+ acc_t weight_value = static_cast<acc_t>(tensor_read<weight_t>(weight, [OH,KH,KW,IC], [oc,ky,kx,ic]));
+ value = apply_sub_s<acc_t>(value, static_cast<acc_t>(input_zp));
+ weight_value = apply_sub_s<acc_t>(weight_value, static_cast<acc_t>(weight_zp));
+ acc = apply_add_s<acc_t>(acc, apply_mul_s<acc_t>(value, weight_value));
}
+
+ out_t out = static_cast<out_t>(acc);
+ out = apply_add_s<out_t>(out, bias[(BC == 1) ? 0 : oc]);
+ tensor_write<out_t>(output, [N,OH,OW,OC], [n,oy,ox,oc], out);
}
diff --git a/tosa.xml b/tosa.xml
index 17c82f9..675eeb6 100644
--- a/tosa.xml
+++ b/tosa.xml
@@ -92,8 +92,8 @@
<levellimit value="pad_right" limit="MAX_KERNEL"/>
<rank min="1" max="1"/>
</argument>
- <argument category="attribute" name="acc_size" type="tensor_t" shape="-" tensor-element-type="acc_size_t">
- <description>Enumerated type, must be one of INT32, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
+ <argument category="attribute" name="acc_type" type="tensor_t" shape="-" tensor-element-type="acc_type_t">
+ <description>Enumerated type, must be one of INT32, INT48, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
<rank min="0" max="0"/>
</argument>
<argument category="attribute" name="input_zp" type="tensor_t" shape="-" tensor-element-type="in_out_t">
@@ -174,6 +174,10 @@
<description>[dilation_y, dilation_x]</description>
<rank min="1" max="1"/>
</argument>
+ <argument category="attribute" name="acc_type" type="tensor_t" shape="-" tensor-element-type="acc_type_t">
+ <description>Enumerated type, must be one of INT32, INT48, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
+ <rank min="0" max="0"/>
+ </argument>
<argument category="attribute" name="input_zp" type="tensor_t" shape="-" tensor-element-type="in_t">
<description>Input tensor zero point. Must be zero for non-int8 types.</description>
<rank min="0" max="0"/>
@@ -199,32 +203,33 @@
<type name='in_t' />
<type name='weight_t' />
<type name='out_t' />
+ <type name='acc_t' />
</types>
- <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" >
+ <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="BI"/>
</typesupport>
- <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t">
+ <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="EXT-INT4"/>
</typesupport>
- <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t">
+ <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t" acc_t="i48_t">
<op_profile name="EXT-INT16"/>
</typesupport>
- <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t">
+ <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E4M3"/>
</typesupport>
- <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t">
+ <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E5M2"/>
</typesupport>
- <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t">
+ <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp32_t">
+ <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="fp32_t">
+ <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="bf16_t" acc_t="fp32_t">
<op_profile name="EXT-BF16"/>
</typesupport>
- <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t">
+ <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
</operator>
@@ -268,6 +273,10 @@
<description>[dilation_d, dilation_y, dilation_x]</description>
<rank min="1" max="1"/>
</argument>
+ <argument category="attribute" name="acc_type" type="tensor_t" shape="-" tensor-element-type="acc_type_t">
+ <description>Enumerated type, must be one of INT32, INT48, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
+ <rank min="0" max="0"/>
+ </argument>
<argument category="attribute" name="input_zp" type="tensor_t" shape="-" tensor-element-type="in_t">
<description>Input tensor zero point. Must be zero for non-int8 types.</description>
<rank min="0" max="0"/>
@@ -294,31 +303,31 @@
<type name='weight_t' />
<type name='out_t' />
</types>
- <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t">
+ <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="BI"/>
</typesupport>
- <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t">
+ <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="EXT-INT4"/>
</typesupport>
- <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t">
+ <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t" acc_t="i48_t">
<op_profile name="EXT-INT16"/>
</typesupport>
- <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t">
+ <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E4M3"/>
</typesupport>
- <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t">
+ <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E5M2"/>
</typesupport>
- <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t">
+ <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp32_t">
+ <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="fp32_t">
+ <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="bf16_t" acc_t="fp32_t">
<op_profile name="EXT-BF16"/>
</typesupport>
- <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t">
+ <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
</operator>
@@ -358,6 +367,10 @@
<description>[dilation_y, dilation_x]</description>
<rank min="1" max="1"/>
</argument>
+ <argument category="attribute" name="acc_type" type="tensor_t" shape="-" tensor-element-type="acc_type_t">
+ <description>Enumerated type, must be one of INT32, INT48, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
+ <rank min="0" max="0"/>
+ </argument>
<argument category="attribute" name="input_zp" type="tensor_t" shape="-" tensor-element-type="in_t">
<description>Input tensor zero point. Must be zero for non-int8 types.</description>
<rank min="0" max="0"/>
@@ -384,31 +397,31 @@
<type name='weight_t' />
<type name='out_t' />
</types>
- <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" >
+ <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="BI"/>
</typesupport>
- <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t" >
+ <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="EXT-INT4"/>
</typesupport>
- <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t">
+ <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t" acc_t="i48_t">
<op_profile name="EXT-INT16"/>
</typesupport>
- <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t">
+ <typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" weight_t="fp8e4m3_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E4M3"/>
</typesupport>
- <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t">
+ <typesupport mode="fp8e5m2 with fp16 accumulate" in_t="fp8e5m2_t" weight_t="fp8e5m2_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E5M2"/>
</typesupport>
- <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t">
+ <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp32_t">
+ <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="fp32_t">
+ <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="bf16_t" acc_t="fp32_t">
<op_profile name="EXT-BF16"/>
</typesupport>
- <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t">
+ <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
</operator>
@@ -684,6 +697,10 @@
<levellimit value="stride_x" limit="MAX_STRIDE"/>
<rank min="1" max="1"/>
</argument>
+ <argument category="attribute" name="acc_type" type="tensor_t" shape="-" tensor-element-type="acc_type_t">
+ <description>Enumerated type, must be one of INT32, INT48, FP16, FP32, as defined in the Supported Data Types table for this operation</description>
+ <rank min="0" max="0"/>
+ </argument>
<argument category="attribute" name="input_zp" type="tensor_t" shape="-" tensor-element-type="in_t">
<description>Input tensor zero point. Must be zero for non-int8 types.</description>
<rank min="0" max="0"/>
@@ -704,37 +721,37 @@
<description>Output tensor</description>
<rank min="4" max="4"/>
</argument>
- </arguments>
+ </arguments>
<types>
<type name='in_t' />
<type name='weight_t' />
<type name='out_t' />
</types>
- <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t">
+ <typesupport mode="signed 8x8 with int32 accumulate" in_t="i8_t" weight_t="i8_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="BI"/>
</typesupport>
- <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t">
+ <typesupport mode="signed 8x4 with int32 accumulate" in_t="i8_t" weight_t="i4_t" out_t="i32_t" acc_t="i32_t">
<op_profile name="EXT-INT4"/>
</typesupport>
- <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t">
+ <typesupport mode="signed 16x8 with int48 accumulate" in_t="i16_t" weight_t="i8_t" out_t="i48_t" acc_t="i48_t">
<op_profile name="EXT-INT16"/>
</typesupport>
<typesupport mode="fp8e4m3 with fp16 accumulate" in_t="fp8e4m3_t" out_t="fp16_t">
<op_profile name="EXT-FP8E4M3"/>
</typesupport>
- <typesupport mode="fp8e5m2" in_t="fp8e5m2_t" out_t="fp16_t">
+ <typesupport mode="fp8e5m2" in_t="fp8e5m2_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="EXT-FP8E5M2"/>
</typesupport>
- <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t">
+ <typesupport mode="fp16 with fp16 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp16_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp32_t">
+ <typesupport mode="fp16 with fp32 accumulate" in_t="fp16_t" weight_t="fp16_t" out_t="fp16_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
- <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="fp32_t">
+ <typesupport mode="bf16 with fp32 accumulate" in_t="bf16_t" weight_t="bf16_t" out_t="bf16_t" acc_t="fp32_t">
<op_profile name="EXT-BF16"/>
</typesupport>
- <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t">
+ <typesupport mode="fp32 with fp32 accumulate" in_t="fp32_t" weight_t="fp32_t" out_t="fp32_t" acc_t="fp32_t">
<op_profile name="MI"/>
</typesupport>
</operator>
@@ -3299,10 +3316,11 @@ used.</description>
<enumval value="1" name="BILINEAR" description="Bilinear resize"/>
</enum>
- <enum name="acc_size_t" description="Allowed accumulator sizes">
+ <enum name="acc_type_t" description="Allowed accumulator types">
<enumval value="0" name="INT32" description="32-bit integer"/>
<enumval value="1" name="FP16" description="16-bit floating-point"/>
<enumval value="2" name="FP32" description="32-bit floating-point"/>
+ <enumval value="3" name="INT48" description="48-bit integer"/>
</enum>
<enum name="var_t" description="Variable tensor data type">
diff --git a/tosa.xsd b/tosa.xsd
index b39a2f4..54610d5 100644
--- a/tosa.xsd
+++ b/tosa.xsd
@@ -57,7 +57,7 @@
<xs:simpleType name="enumtypename">
<xs:restriction base="xs:string">
<xs:enumeration value="resize_mode_t"/>
- <xs:enumeration value="acc_size_t"/>
+ <xs:enumeration value="acc_type_t"/>
</xs:restriction>
</xs:simpleType>