aboutsummaryrefslogtreecommitdiff
path: root/chapters/pseudocode.adoc
diff options
context:
space:
mode:
Diffstat (limited to 'chapters/pseudocode.adoc')
-rw-r--r--chapters/pseudocode.adoc165
1 files changed, 144 insertions, 21 deletions
diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc
index c026089..55c35d4 100644
--- a/chapters/pseudocode.adoc
+++ b/chapters/pseudocode.adoc
@@ -1,7 +1,7 @@
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
-// (C) COPYRIGHT 2021-2022 ARM Limited
+// (C) COPYRIGHT 2021-2023 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
@@ -221,21 +221,44 @@ The following functions provide arithmetic while defining requirements such that
[source,c++]
----
-in_t apply_add<in_t>(in_t a, in_t b) {
+in_t apply_add_s<in_t>(in_t a, in_t b) {
if (is_floating_point(in_t)) return a + b;
- int64_t c = (int64_t)a + (int64_t)b;
- REQUIRE(c >= minimum<in_t> && c <= maximum<in_t>);
- return (in_t)c;
+ int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b);
+ REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
+ return static_cast<in_t>(c);
+}
+
+in_t apply_add_u<in_t>(in_t a, in_t b) {
+ if (is_floating_point(in_t)) return a + b;
+ uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b);
+ REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
+ return truncate<in_t>(c);
+}
+
+in_t apply_arith_rshift<in_t>(in_t a, in_t b) {
+ int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b);
+ return static_cast<in_t>(c);
+}
+
+in_t apply_intdiv_s<in_t>(in_t a, in_t b) {
+ int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b);
+ REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
+ return static_cast<in_t>(c);
}
in_t apply_ceil<in_t>(in_t input) {
return input value rounded up to nearest integer
}
-in_t apply_clip<in_t>(in_t value, in_t min_val, in_t max_val) {
- REQUIRE(min_val <= max_val);
- value = apply_max(value, min_val);
- value = apply_min(value, max_val);
+in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) {
+ if (is_floating_point(in_t>) {
+ REQUIRE(min_val <= max_val);
+ }
+ else {
+ REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val));
+ }
+ value = apply_max_s<in_t>(value, min_val);
+ value = apply_min_s<in_t>(value, max_val);
return value;
}
@@ -257,22 +280,37 @@ in_t apply_log<in_t>(in_t input) {
return the natural logarithm of input
}
-in_t apply_max<in_t>(in_t a, in_t b) {
+in_t apply_logical_rshift<in_t>(in_t a, in_t b) {
+ uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b);
+ return static_cast<in_t>(c);
+}
+
+in_t apply_max_s<in_t>(in_t a, in_t b) {
if (is_floating_point(in_t)) {
if (isNaN(a) || isNaN(b)) {
return NaN;
}
+ if (a >= b) return a; else return b;
}
- if (a >= b) return a; else return b;
+ // Integer version
+ if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b;
}
-in_t apply_min<in_t>(in_t a, in_t b) {
+in_t apply_min_s<in_t>(in_t a, in_t b) {
if (is_floating_point(in_t)) {
if (isNaN(a) || isNaN(b)) {
return NaN;
}
+ if (a < b) return a; else return b;
}
- if (a < b) return a; else return b;
+ // Integer version
+ if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b;
+}
+
+in_t apply_mul_s<in_t>(in_t a, in_t b) {
+ if (is_floating_point(in_t)) return a * b;
+ int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b);
+ return static_cast<in_t>(c);
}
in_t apply_pow<in_t>(in_t a, in_t b) {
@@ -283,11 +321,17 @@ in_t apply_sqrt<in_t>(in_t input) {
return the square root of input
}
-in_t apply_sub<in_t>(in_t a, in_t b) {
+in_t apply_sub_s<in_t>(in_t a, in_t b) {
if (is_floating_point(in_t)) return a - b;
- int64_t c = (int64_t)a - (int64_t)b;
- REQUIRE(c >= minimum<in_t> && c <= maximum<in_t>);
- return (in_t)c;
+ int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b);
+ REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
+ return static_cast<in_t>(c);
+}
+
+in_t apply_sub_u<in_t>(in_t a, in_t b) {
+ uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b);
+ REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
+ return truncate<in_t>(c);
}
int32_t count_leading_zeros(int32_t a) {
@@ -305,6 +349,69 @@ int32_t count_leading_zeros(int32_t a) {
}
----
+==== Type Conversion Helpers
+
+The following definitions indicate the type to be used when the given parameters are provided.
+
+[source,c++]
+----
+
+// Returns a signed version of the given type
+// A no-op for floating-point types
+Type make_signed(Type in_t)
+{
+ switch(in_t) {
+ case bool_t:
+ return bool_t;
+ case i8_t:
+ return int8_t;
+ case i16_t:
+ return int16_t;
+ case i32_t:
+ return int32_t;
+ case i48_t:
+ return int48_t;
+ case fp16_t:
+ return fp16_t;
+ case bf16_t:
+ return bf16_t;
+ case fp32_t:
+ return fp32_t;
+ }
+}
+
+// Returns the usigned type of the given type
+// Error to call this with anything but i8_t or i16_t
+
+Type make_unsigned(Type in_t)
+{
+ ERROR_IF(in_t != i8_t && in_t != i16_t);
+ switch(in_t) {
+ case i8_t:
+ return uint8_t;
+ case i16_t:
+ return uint16_t;
+ }
+}
+
+out_t static_cast<out_t>(in_t value)
+{
+ // Operates similar to the c++ standard static_cast
+ // Limited to simple numeric conversion for TOSA.
+ // Sign extends signed integer input types if needed
+ // Zero extends unsigned integer input types if needed
+ // Truncates when converting to a smaller width data type
+ // Conversion from integer to floating-point is exact if possible
+ // If converting between signless integer types, treated as signed integer
+}
+
+out_t bitcast<out_t>(in_t value)
+{
+ // Treats the bits of value as if they were of type out_t
+ // Only supported for integer types of the same bit width
+}
+----
+
==== Numeric Conversion Helpers
The following definitions are used in pseudocode to do numeric conversions.
@@ -321,13 +428,17 @@ float_t round_to_nearest_float(in_t f)
Converts the input value into floating-point, rounding to the nearest representable value.
For the required precision see the section: Main inference precision requirements.
-out_t sign_extend(in_t input)
- Only valid for two's complement integer values where out_t has more bits than in_t.
- Output = input
- Replicate the top bit of input for all bits between the top bit of input and the top bit of output.
+out_t sign_extend<out_t>(in_t input)
+ Floating point values are unchanged.
+ For two's complement integer values where out_t has more bits than in_t, replicate the top bit of input for all bits between the top bit of input and the top bit of output.
+
+out_t zero_extend<out_t>(in_t input)
+ Floating point values are unchanged.
+ For two's complement integer values where out_t has more bits than in_t, insert zero values for all bits between the top bit of input and the top bit of output.
out_t truncate(in_t input)
output is the sizeof(out_t) least significant bits in input.
+ Nop for floating-point types
----
The following definition is used to flatten a list of lists into a single list.
@@ -389,4 +500,16 @@ float_t cos(angle)
bool power_of_two(int32_t value)
return true if value is a power of two, false otherwise
+
+in_out_t maximum_s<Type T>
+ return the maximum value when interpreting type T as a signed value as returned by the make_signed helper.
+
+in_out_t minimum_s<Type T>
+ return the minimum value when interpreting type T as a signed value as returned by the make_signed helper.
+
+in_out_t maximum_u<Type T>
+ return the maximum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.
+
+in_out_t minimum_u<Type T>
+ return the minimum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.
----