diff options
Diffstat (limited to 'chapters/pseudocode.adoc')
-rw-r--r-- | chapters/pseudocode.adoc | 165 |
1 files changed, 144 insertions, 21 deletions
diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc index c026089..55c35d4 100644 --- a/chapters/pseudocode.adoc +++ b/chapters/pseudocode.adoc @@ -1,7 +1,7 @@ // // This confidential and proprietary software may be used only as // authorised by a licensing agreement from ARM Limited -// (C) COPYRIGHT 2021-2022 ARM Limited +// (C) COPYRIGHT 2021-2023 ARM Limited // ALL RIGHTS RESERVED // The entire notice above must be reproduced on all authorised // copies and copies may only be made to the extent permitted @@ -221,21 +221,44 @@ The following functions provide arithmetic while defining requirements such that [source,c++] ---- -in_t apply_add<in_t>(in_t a, in_t b) { +in_t apply_add_s<in_t>(in_t a, in_t b) { if (is_floating_point(in_t)) return a + b; - int64_t c = (int64_t)a + (int64_t)b; - REQUIRE(c >= minimum<in_t> && c <= maximum<in_t>); - return (in_t)c; + int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b); + REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>); + return static_cast<in_t>(c); +} + +in_t apply_add_u<in_t>(in_t a, in_t b) { + if (is_floating_point(in_t)) return a + b; + uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b); + REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>); + return truncate<in_t>(c); +} + +in_t apply_arith_rshift<in_t>(in_t a, in_t b) { + int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b); + return static_cast<in_t>(c); +} + +in_t apply_intdiv_s<in_t>(in_t a, in_t b) { + int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b); + REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>); + return static_cast<in_t>(c); } in_t apply_ceil<in_t>(in_t input) { return input value rounded up to nearest integer } -in_t apply_clip<in_t>(in_t value, in_t min_val, in_t max_val) { - REQUIRE(min_val <= max_val); - value = apply_max(value, min_val); - value = apply_min(value, max_val); +in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) { + if (is_floating_point(in_t>) { + REQUIRE(min_val <= max_val); + } + else { + REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val)); + } + value = apply_max_s<in_t>(value, min_val); + value = apply_min_s<in_t>(value, max_val); return value; } @@ -257,22 +280,37 @@ in_t apply_log<in_t>(in_t input) { return the natural logarithm of input } -in_t apply_max<in_t>(in_t a, in_t b) { +in_t apply_logical_rshift<in_t>(in_t a, in_t b) { + uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b); + return static_cast<in_t>(c); +} + +in_t apply_max_s<in_t>(in_t a, in_t b) { if (is_floating_point(in_t)) { if (isNaN(a) || isNaN(b)) { return NaN; } + if (a >= b) return a; else return b; } - if (a >= b) return a; else return b; + // Integer version + if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b; } -in_t apply_min<in_t>(in_t a, in_t b) { +in_t apply_min_s<in_t>(in_t a, in_t b) { if (is_floating_point(in_t)) { if (isNaN(a) || isNaN(b)) { return NaN; } + if (a < b) return a; else return b; } - if (a < b) return a; else return b; + // Integer version + if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b; +} + +in_t apply_mul_s<in_t>(in_t a, in_t b) { + if (is_floating_point(in_t)) return a * b; + int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b); + return static_cast<in_t>(c); } in_t apply_pow<in_t>(in_t a, in_t b) { @@ -283,11 +321,17 @@ in_t apply_sqrt<in_t>(in_t input) { return the square root of input } -in_t apply_sub<in_t>(in_t a, in_t b) { +in_t apply_sub_s<in_t>(in_t a, in_t b) { if (is_floating_point(in_t)) return a - b; - int64_t c = (int64_t)a - (int64_t)b; - REQUIRE(c >= minimum<in_t> && c <= maximum<in_t>); - return (in_t)c; + int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b); + REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>); + return static_cast<in_t>(c); +} + +in_t apply_sub_u<in_t>(in_t a, in_t b) { + uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b); + REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>); + return truncate<in_t>(c); } int32_t count_leading_zeros(int32_t a) { @@ -305,6 +349,69 @@ int32_t count_leading_zeros(int32_t a) { } ---- +==== Type Conversion Helpers + +The following definitions indicate the type to be used when the given parameters are provided. + +[source,c++] +---- + +// Returns a signed version of the given type +// A no-op for floating-point types +Type make_signed(Type in_t) +{ + switch(in_t) { + case bool_t: + return bool_t; + case i8_t: + return int8_t; + case i16_t: + return int16_t; + case i32_t: + return int32_t; + case i48_t: + return int48_t; + case fp16_t: + return fp16_t; + case bf16_t: + return bf16_t; + case fp32_t: + return fp32_t; + } +} + +// Returns the usigned type of the given type +// Error to call this with anything but i8_t or i16_t + +Type make_unsigned(Type in_t) +{ + ERROR_IF(in_t != i8_t && in_t != i16_t); + switch(in_t) { + case i8_t: + return uint8_t; + case i16_t: + return uint16_t; + } +} + +out_t static_cast<out_t>(in_t value) +{ + // Operates similar to the c++ standard static_cast + // Limited to simple numeric conversion for TOSA. + // Sign extends signed integer input types if needed + // Zero extends unsigned integer input types if needed + // Truncates when converting to a smaller width data type + // Conversion from integer to floating-point is exact if possible + // If converting between signless integer types, treated as signed integer +} + +out_t bitcast<out_t>(in_t value) +{ + // Treats the bits of value as if they were of type out_t + // Only supported for integer types of the same bit width +} +---- + ==== Numeric Conversion Helpers The following definitions are used in pseudocode to do numeric conversions. @@ -321,13 +428,17 @@ float_t round_to_nearest_float(in_t f) Converts the input value into floating-point, rounding to the nearest representable value. For the required precision see the section: Main inference precision requirements. -out_t sign_extend(in_t input) - Only valid for two's complement integer values where out_t has more bits than in_t. - Output = input - Replicate the top bit of input for all bits between the top bit of input and the top bit of output. +out_t sign_extend<out_t>(in_t input) + Floating point values are unchanged. + For two's complement integer values where out_t has more bits than in_t, replicate the top bit of input for all bits between the top bit of input and the top bit of output. + +out_t zero_extend<out_t>(in_t input) + Floating point values are unchanged. + For two's complement integer values where out_t has more bits than in_t, insert zero values for all bits between the top bit of input and the top bit of output. out_t truncate(in_t input) output is the sizeof(out_t) least significant bits in input. + Nop for floating-point types ---- The following definition is used to flatten a list of lists into a single list. @@ -389,4 +500,16 @@ float_t cos(angle) bool power_of_two(int32_t value) return true if value is a power of two, false otherwise + +in_out_t maximum_s<Type T> + return the maximum value when interpreting type T as a signed value as returned by the make_signed helper. + +in_out_t minimum_s<Type T> + return the minimum value when interpreting type T as a signed value as returned by the make_signed helper. + +in_out_t maximum_u<Type T> + return the maximum value when interpreting type T as an unsigned value as returned by the make_unsigned helper. + +in_out_t minimum_u<Type T> + return the minimum value when interpreting type T as an unsigned value as returned by the make_unsigned helper. ---- |