aboutsummaryrefslogtreecommitdiff
path: root/pseudocode
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2023-10-20 15:58:55 -0700
committerEric Kunze <eric.kunze@arm.com>2024-02-14 16:36:04 -0800
commit74e2ceba954ed6111b3e3ce40c5ff88fe79ff043 (patch)
tree7e1967b073313d7df4885693eda931230d401eb0 /pseudocode
parent9fe5e964e2193f0e345670f7f4098beecd7fd6eb (diff)
downloadspecification-74e2ceba954ed6111b3e3ce40c5ff88fe79ff043.tar.gz
Initial FP8 support
Adds support for Open Compute Project (OCP) 8-bit floating point operations to the TOSA specification. Both E4M3 and E5M2 types are supported for profiles as indicated in the Supported Data Types table for each operator. FP8 operator list ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD DIM RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: I3dd83f48afcc3c880c5c88039337ff4f1fd95b1b
Diffstat (limited to 'pseudocode')
-rw-r--r--pseudocode/library/generic_helpers.tosac11
-rw-r--r--pseudocode/library/numeric_accuracy_helpers.tosac6
-rw-r--r--pseudocode/library/numeric_conversion_helpers.tosac8
-rw-r--r--pseudocode/library/type_conversion_helpers.tosac9
-rw-r--r--pseudocode/operators/CAST.tosac31
5 files changed, 47 insertions, 18 deletions
diff --git a/pseudocode/library/generic_helpers.tosac b/pseudocode/library/generic_helpers.tosac
index a9d71ec..a2fdbe0 100644
--- a/pseudocode/library/generic_helpers.tosac
+++ b/pseudocode/library/generic_helpers.tosac
@@ -8,11 +8,20 @@
// by a licensing agreement from ARM Limited.
bool_t is_floating_point(type) {
- if (type == fp16_t || type == fp32_t || type == bf16_t)
+ if (type == fp16_t || type == fp32_t || type == bf16_t || type == fp8e4m3_t || type == fp8e5m2_t)
return true;
return false;
}
+bool_t is_saturating_float_type(type) {
+ // Saturate for the fp8 formats, all other floats do not saturate
+ if (type == fp8e4m3_t || type == fp8e5m2_t) {
+ return true;
+ }
+ return false;
+}
+
+
int32_t idiv(int32_t input1, int32_t input2) {
return input1 / input2; // Integer divide that truncates towards zero
}
diff --git a/pseudocode/library/numeric_accuracy_helpers.tosac b/pseudocode/library/numeric_accuracy_helpers.tosac
index 4a2b111..b89d898 100644
--- a/pseudocode/library/numeric_accuracy_helpers.tosac
+++ b/pseudocode/library/numeric_accuracy_helpers.tosac
@@ -31,6 +31,8 @@ fp64_t normal_min<in_t>() {
case fp32_t: return exp2(-126);
case bf16_t: return exp2(-126);
case fp16_t: return exp2( -14);
+ case fp8e4m3_t: return exp2(-6);
+ case fp8e5m2_t: return exp2(-14);
}
}
@@ -39,6 +41,8 @@ fp64_t normal_max<in_t>() {
case fp32_t: return exp2(128) - exp2(127-23);
case bf16_t: return exp2(128) - exp2(127- 7);
case fp16_t: return exp2( 16) - exp2( 15-10);
+ case fp8e4m3_t: return exp2( 9) - exp2( 8-2);
+ case fp8e5m2_t: return exp2( 16) - exp2( 15-2);
}
}
@@ -48,5 +52,7 @@ int normal_frac<in_t> () {
case fp32_t: return 23;
case fp16_t: return 10;
case bf16_t: return 7;
+ case fp8e4m3_t: return 3;
+ case fp8e5m2_t: return 2;
}
}
diff --git a/pseudocode/library/numeric_conversion_helpers.tosac b/pseudocode/library/numeric_conversion_helpers.tosac
index fac7078..576351f 100644
--- a/pseudocode/library/numeric_conversion_helpers.tosac
+++ b/pseudocode/library/numeric_conversion_helpers.tosac
@@ -11,8 +11,14 @@ int round_to_nearest_int(float_t f)
Converts the floating-point value to f, with rounding to the nearest integer value.
For the required precision see the section: Main inference precision requirements.
-float_t round_to_nearest_float(in_t f)
+float_t round_to_nearest_float_nonsaturating(in_t f)
Converts the input value into floating-point, rounding to the nearest representable value.
+ Values that are not NaN outside of the representable range of the destination type must be set to infinity of the correct sign.
+ For the required precision see the section: Main inference precision requirements.
+
+float_t round_to_nearest_float_saturating(in_t f)
+ Converts the input value into floating-point, rounding to the nearest representable normal value.
+ Values that are not NaN outside of the representable range must return the maximum representable normal value of the correct sign.
For the required precision see the section: Main inference precision requirements.
out_t sign_extend<out_t>(in_t input)
diff --git a/pseudocode/library/type_conversion_helpers.tosac b/pseudocode/library/type_conversion_helpers.tosac
index f26c589..f2b42a6 100644
--- a/pseudocode/library/type_conversion_helpers.tosac
+++ b/pseudocode/library/type_conversion_helpers.tosac
@@ -11,6 +11,9 @@
// A no-op for floating-point types
Type make_signed(Type in_t)
{
+ if (is_floating_point<in_t>()) {
+ return in_t;
+ }
switch(in_t) {
case bool_t:
return bool_t;
@@ -22,12 +25,6 @@ Type make_signed(Type in_t)
return int32_t;
case i48_t:
return int48_t;
- case fp16_t:
- return fp16_t;
- case bf16_t:
- return bf16_t;
- case fp32_t:
- return fp32_t;
}
}
diff --git a/pseudocode/operators/CAST.tosac b/pseudocode/operators/CAST.tosac
index fac73e3..fd3ce72 100644
--- a/pseudocode/operators/CAST.tosac
+++ b/pseudocode/operators/CAST.tosac
@@ -12,16 +12,27 @@ for_each(index in shape) {
out_t out;
if (out_t == bool_t) {
out = (in != 0) ? true : false;
- } else if (in_t == bool_t) {
- out = (in) ? 1 : 0;
- } else if (out_t == fp16_t || out_t == bf16_t || out_t == fp32_t) {
- out = round_to_nearest_float(in);
- } else if (in_t == fp16_t || in_t == bf16_t || in_t == fp32_t) {
- out = truncate<out_t>(apply_clip_s<i32_t>(round_to_nearest_int(in), minimum<out_t>, maximum<out_t>));
- } else if (sizeof(out_t) >= sizeof(in_t)) {
- out = sign_extend<out_t>(in);
+ } else if (is_floating_point_type<out_t>()) {
+ // Conversion to float cases
+ if (in_t == bool_t) {
+ out = (in) ? 1.0 : 0.0;
+ }
+ if (is_saturating_float_type<out_t>()) {
+ out = round_to_nearest_float_saturating(in);
+ } else {
+ out = round_to_nearest_float_nonsaturating(in);
+ }
} else {
- out = truncate<out_t>(in);
+ // Conversion to integer cases
+ if (in_t == bool_t) {
+ out = (in) ? 1 : 0;
+ } else if (is_floating_point_type<in_t>()) {
+ out = truncate<out_t>(apply_clip_s<i32_t>(round_to_nearest_int(in), minimum<out_t>, maximum<out_t>));
+ } else if (sizeof(out_t) >= sizeof(in_t)) {
+ out = sign_extend<out_t>(in);
+ } else {
+ out = truncate<out_t>(in);
+ }
}
- tensor_write<out_t>(output, shape, index, out);
+ tensor_write<out_t>(output, shape, index, out)
}