From 74e2ceba954ed6111b3e3ce40c5ff88fe79ff043 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Fri, 20 Oct 2023 15:58:55 -0700 Subject: Initial FP8 support Adds support for Open Compute Project (OCP) 8-bit floating point operations to the TOSA specification. Both E4M3 and E5M2 types are supported for profiles as indicated in the Supported Data Types table for each operator. FP8 operator list ARGMAX AVGPOOL CONV2D CONV3D DEPTHWISE_CONV2D MATMUL MAX_POOL2D TRANSPOSE_CONV2D CONST CAST CONCAT PAD DIM RESHAPE REVERSE SLICE TILE TRANSPOSE GATHER SCATTER Signed-off-by: Eric Kunze Change-Id: I3dd83f48afcc3c880c5c88039337ff4f1fd95b1b --- chapters/introduction.adoc | 33 +++- chapters/pseudocode.adoc | 3 +- pseudocode/library/generic_helpers.tosac | 11 +- pseudocode/library/numeric_accuracy_helpers.tosac | 6 + .../library/numeric_conversion_helpers.tosac | 8 +- pseudocode/library/type_conversion_helpers.tosac | 9 +- pseudocode/operators/CAST.tosac | 31 ++-- tools/dictionary.dic | 2 + tosa.xml | 182 +++++++++++++++++++++ tosa.xsd | 2 + 10 files changed, 262 insertions(+), 25 deletions(-) diff --git a/chapters/introduction.adoc b/chapters/introduction.adoc index 9d53510..17c16a8 100644 --- a/chapters/introduction.adoc +++ b/chapters/introduction.adoc @@ -245,7 +245,8 @@ Multiplication of an infinity by a zero must produce a NaN. + Otherwise the result must be within 0.5 ulp of the mathematical result. | <> -| Floating-point result overflows must be set to infinity of the correct sign. + +| Result overflows when converting between fp32_t, bf16_t and fp16_t must be set to infinity of the correct sign. + +fp8e4m3_t and fp8e5m2_t must use the saturation mode rules defined in <> when converting from the wider floating-point types. + Floating-point result underflows must be set to zero of the correct sign. + Cast from floating-point to integer result overflows must be saturated. + Cast from floating-point to integer must be rounded using round to nearest, ties to even, rounding mode. + @@ -339,7 +340,7 @@ This may be, for example, a convolution. This section defines the accuracy required for these operations. In this section: -* "fp64 arithmetic" refers to double-precision floating-point arithmetic defined by IEEE 754 (<>[1]) +* "fp64 arithmetic" refers to double-precision floating-point arithmetic defined by <> * `operation_fp64()` is an fp64 reference implementation of the operation * `operation_imp()` is the implementation under test * `local_bound` is defined as follows: @@ -537,10 +538,29 @@ The number formats supported by a given operator are listed in its table of supp | (1<<47)-1 |Signed 48-bit two's-complement value. +|fp8e4m3_t +| -448 +| 448 +| 8-bit floating-point defined by <> with four bits of exponent and three bits of mantissa. + +Normal values must be supported. + +Denormal values must be supported. + +The NaN encoding must be supported. + +Signed zero must be supported. + +|fp8e5m2_t +| -infinity +| +infinity +| 8-bit floating-point defined by <> with five bits of exponent and two bits of mantissa. + +Normal values must be supported. + +Denormal values must be supported. + +Positive and negative infinity must be supported. + +NaN encodings must be supported. + +Signed zero must be supported. + |fp16_t | -infinity | +infinity -| 16-bit half-precision floating-point defined by <>[1]. + +| 16-bit half-precision floating-point defined by <> . + Normal values must be supported. + Denormal values must either be supported or flushed to zero. + Positive and negative infinity must be supported. + @@ -560,7 +580,7 @@ Signed zero must be supported. |fp32_t | -infinity | +infinity -| 32-bit single-precision floating-point defined by <>[1]. + +| 32-bit single-precision floating-point defined by <> . + Normal values must be supported. + Denormal values must either be supported or flushed to zero. + Positive and negative infinity must be supported. + @@ -570,7 +590,7 @@ Signed zero must be supported. |fp64_t | -infinity | + infinity -| 64-bit double-precision floating-point defined by <>[1]. + +| 64-bit double-precision floating-point defined by <>. + Normal values must be supported. + Denormal values must either be supported or flushed to zero. + Positive and negative infinity must be supported. + @@ -744,4 +764,5 @@ void generate_lookup_table(int16_t *table, int32_t (*reference)(int32_t)) The following publications are referred to in this specification, or provide more information: -. IEEE Std 754-2008, _IEEE Standard for Floating-point Arithmetic_, August 2008. +. [[IEEE-754]]IEEE Std 754-2008, _IEEE Standard for Floating-point Arithmetic_, August 2008. +. [[OCP-OFP8]]Open Compute Project OCP 8-bit Floating Point Specification (OFP8) Revision 1.0 diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc index acce9c9..53b1142 100644 --- a/chapters/pseudocode.adoc +++ b/chapters/pseudocode.adoc @@ -1,7 +1,7 @@ // // This confidential and proprietary software may be used only as // authorised by a licensing agreement from ARM Limited -// (C) COPYRIGHT 2021-2023 ARM Limited +// (C) COPYRIGHT 2021-2024 ARM Limited // ALL RIGHTS RESERVED // The entire notice above must be reproduced on all authorised // copies and copies may only be made to the extent permitted @@ -142,6 +142,7 @@ include::{pseudocode}/library/arithmetic_helpers.tosac[lines=10..-1] The following definitions indicate the type to be used when the given parameters are provided. + [source,c++] ---- include::{pseudocode}/library/type_conversion_helpers.tosac[lines=10..-1] diff --git a/pseudocode/library/generic_helpers.tosac b/pseudocode/library/generic_helpers.tosac index a9d71ec..a2fdbe0 100644 --- a/pseudocode/library/generic_helpers.tosac +++ b/pseudocode/library/generic_helpers.tosac @@ -8,11 +8,20 @@ // by a licensing agreement from ARM Limited. bool_t is_floating_point(type) { - if (type == fp16_t || type == fp32_t || type == bf16_t) + if (type == fp16_t || type == fp32_t || type == bf16_t || type == fp8e4m3_t || type == fp8e5m2_t) return true; return false; } +bool_t is_saturating_float_type(type) { + // Saturate for the fp8 formats, all other floats do not saturate + if (type == fp8e4m3_t || type == fp8e5m2_t) { + return true; + } + return false; +} + + int32_t idiv(int32_t input1, int32_t input2) { return input1 / input2; // Integer divide that truncates towards zero } diff --git a/pseudocode/library/numeric_accuracy_helpers.tosac b/pseudocode/library/numeric_accuracy_helpers.tosac index 4a2b111..b89d898 100644 --- a/pseudocode/library/numeric_accuracy_helpers.tosac +++ b/pseudocode/library/numeric_accuracy_helpers.tosac @@ -31,6 +31,8 @@ fp64_t normal_min() { case fp32_t: return exp2(-126); case bf16_t: return exp2(-126); case fp16_t: return exp2( -14); + case fp8e4m3_t: return exp2(-6); + case fp8e5m2_t: return exp2(-14); } } @@ -39,6 +41,8 @@ fp64_t normal_max() { case fp32_t: return exp2(128) - exp2(127-23); case bf16_t: return exp2(128) - exp2(127- 7); case fp16_t: return exp2( 16) - exp2( 15-10); + case fp8e4m3_t: return exp2( 9) - exp2( 8-2); + case fp8e5m2_t: return exp2( 16) - exp2( 15-2); } } @@ -48,5 +52,7 @@ int normal_frac () { case fp32_t: return 23; case fp16_t: return 10; case bf16_t: return 7; + case fp8e4m3_t: return 3; + case fp8e5m2_t: return 2; } } diff --git a/pseudocode/library/numeric_conversion_helpers.tosac b/pseudocode/library/numeric_conversion_helpers.tosac index fac7078..576351f 100644 --- a/pseudocode/library/numeric_conversion_helpers.tosac +++ b/pseudocode/library/numeric_conversion_helpers.tosac @@ -11,8 +11,14 @@ int round_to_nearest_int(float_t f) Converts the floating-point value to f, with rounding to the nearest integer value. For the required precision see the section: Main inference precision requirements. -float_t round_to_nearest_float(in_t f) +float_t round_to_nearest_float_nonsaturating(in_t f) Converts the input value into floating-point, rounding to the nearest representable value. + Values that are not NaN outside of the representable range of the destination type must be set to infinity of the correct sign. + For the required precision see the section: Main inference precision requirements. + +float_t round_to_nearest_float_saturating(in_t f) + Converts the input value into floating-point, rounding to the nearest representable normal value. + Values that are not NaN outside of the representable range must return the maximum representable normal value of the correct sign. For the required precision see the section: Main inference precision requirements. out_t sign_extend(in_t input) diff --git a/pseudocode/library/type_conversion_helpers.tosac b/pseudocode/library/type_conversion_helpers.tosac index f26c589..f2b42a6 100644 --- a/pseudocode/library/type_conversion_helpers.tosac +++ b/pseudocode/library/type_conversion_helpers.tosac @@ -11,6 +11,9 @@ // A no-op for floating-point types Type make_signed(Type in_t) { + if (is_floating_point()) { + return in_t; + } switch(in_t) { case bool_t: return bool_t; @@ -22,12 +25,6 @@ Type make_signed(Type in_t) return int32_t; case i48_t: return int48_t; - case fp16_t: - return fp16_t; - case bf16_t: - return bf16_t; - case fp32_t: - return fp32_t; } } diff --git a/pseudocode/operators/CAST.tosac b/pseudocode/operators/CAST.tosac index fac73e3..fd3ce72 100644 --- a/pseudocode/operators/CAST.tosac +++ b/pseudocode/operators/CAST.tosac @@ -12,16 +12,27 @@ for_each(index in shape) { out_t out; if (out_t == bool_t) { out = (in != 0) ? true : false; - } else if (in_t == bool_t) { - out = (in) ? 1 : 0; - } else if (out_t == fp16_t || out_t == bf16_t || out_t == fp32_t) { - out = round_to_nearest_float(in); - } else if (in_t == fp16_t || in_t == bf16_t || in_t == fp32_t) { - out = truncate(apply_clip_s(round_to_nearest_int(in), minimum, maximum)); - } else if (sizeof(out_t) >= sizeof(in_t)) { - out = sign_extend(in); + } else if (is_floating_point_type()) { + // Conversion to float cases + if (in_t == bool_t) { + out = (in) ? 1.0 : 0.0; + } + if (is_saturating_float_type()) { + out = round_to_nearest_float_saturating(in); + } else { + out = round_to_nearest_float_nonsaturating(in); + } } else { - out = truncate(in); + // Conversion to integer cases + if (in_t == bool_t) { + out = (in) ? 1 : 0; + } else if (is_floating_point_type()) { + out = truncate(apply_clip_s(round_to_nearest_int(in), minimum, maximum)); + } else if (sizeof(out_t) >= sizeof(in_t)) { + out = sign_extend(in); + } else { + out = truncate(in); + } } - tensor_write(output, shape, index, out); + tensor_write(output, shape, index, out) } diff --git a/tools/dictionary.dic b/tools/dictionary.dic index 6b83c53..53377a0 100644 --- a/tools/dictionary.dic +++ b/tools/dictionary.dic @@ -49,6 +49,8 @@ multipler NaN NPUs OC +OCP +OFP pre precisions pseudocode diff --git a/tosa.xml b/tosa.xml index 691c35c..19822f6 100644 --- a/tosa.xml +++ b/tosa.xml @@ -36,6 +36,13 @@ + + + + + + + @@ -99,6 +106,13 @@ + + + + + + + @@ -181,6 +195,13 @@ + + + + + + + @@ -267,6 +288,13 @@ + + + + + + + @@ -349,6 +377,13 @@ + + + + + + + @@ -492,6 +527,13 @@ + + + + + + + @@ -546,6 +588,13 @@ + + + + + + + @@ -654,6 +703,13 @@ + + + + + + + @@ -1932,6 +1988,13 @@ + + + + + + + @@ -1974,6 +2037,13 @@ + + + + + + + @@ -2010,6 +2080,13 @@ + + + + + + + @@ -2047,6 +2124,13 @@ + + + + + + + @@ -2084,6 +2168,13 @@ + + + + + + + @@ -2126,6 +2217,13 @@ used. + + + + + + + @@ -2163,6 +2261,13 @@ used. + + + + + + + @@ -2200,6 +2305,13 @@ used. + + + + + + + @@ -2237,6 +2349,13 @@ used. + + + + + + + @@ -2276,6 +2395,13 @@ used. + + + + + + + @@ -2422,10 +2548,38 @@ used. + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -2438,6 +2592,13 @@ used. + + + + + + + @@ -2454,6 +2615,13 @@ used. + + + + + + + @@ -2552,6 +2720,13 @@ used. + + + + + + + @@ -2586,6 +2761,13 @@ used. + + + + + + + diff --git a/tosa.xsd b/tosa.xsd index a52c1a7..e0afbe2 100644 --- a/tosa.xsd +++ b/tosa.xsd @@ -38,6 +38,8 @@ + + -- cgit v1.2.1