aboutsummaryrefslogtreecommitdiff
path: root/chapters/pseudocode.adoc
blob: 89545036b9f307e42999238aadc27f95ca4d01a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
// (C) COPYRIGHT 2021-2023 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
// by a licensing agreement from ARM Limited.

== TOSA Pseudocode

The TOSA pseudocode provides precise descriptions of TOSA operations.
Each operator contains pseudocode describing the operator's functionality.
This section contains pseudocode functions shared across multiple operators in the specification.

=== Operator Validation Helpers

The following functions are used to define the valid conditions for TOSA operators.

The REQUIRE function defines the conditions required by the TOSA operator.
If the conditions are not met then the result of the TOSA graph is marked as unpredictable.
Once the tosa_graph_result is set to tosa_unpredictable, the whole graph is considered unpredictable.

The ERROR_IF function defines a condition that must set an error if the condition holds and the graph is not unpredictable.
Note that if a graph contains both unpredictable and error statements then result of tosa_execute_graph() is tosa_unpredictable.
This condition is captured in the ERROR_IF function.

*Implementation Notes*

* An implementation is not required to detect unpredictable behavior. If tosa_execute_graph() returns tosa_unpredictable then the tosa_test_compliance() function does not require any specific output from an implementation.
* An implementation is required to detect errors in a graph that does not have unpredictable behavior (see tosa_test_compliance).
* An acceptable implementation is to stop and report an error on the first ERROR_IF condition that occurs. This satifies tosa_test_compliance() even if the tosa_execute_graph() was tosa_unpredictable.
* If the tosa_execute_graphs() result is tosa_unpredictable or tosa_error, then there is no requirement on the implementation to execute any portion of the TOSA graph.

[source,c++]
----
void REQUIRE(condition) {
    // Unpredictable overrides any previous result
    if (!(condition)) {
        tosa_graph_result = tosa_unpredictable;
    }
}

void ERROR_IF(condition) {
    // Error encodes a predictable error state and so is not registered
    // if the graph is marked as unpredictable.
    if (tosa_graph_result != tosa_unpredictable && condition) {
        tosa_graph_result = tosa_error;
    }
}

void LEVEL_CHECK(condition) {
    // If a level is specified and the level condition fails then
    // the result is unpredictable.
    REQUIRE(condition);
}
----

=== Tensor Access Helpers

==== Tensor Utilities

[source,c++]
----
// Convert tensor index coordinates to an element offset
size_t tensor_index_to_offset(shape_t shape, shape_t index) {
    size_t size = tensor_size(shape);  // check tensor shape is valid
    size_t offset = 0;
    for (int32_t i = 0; i < rank(shape); i++) {
        REQUIRE(index[i] >= 0 && index[i] < shape[i]);
        offset = offset * shape[i] + index[i];
    }
    return offset;
}

// Convert an element offset to tensor index coordinates
shape_t tensor_offset_to_index(shape_t shape, size_t offset) {
    size_t size = tensor_size(shape);  // check tensor shape is valid
    REQUIRE(offset < size);
    shape_t index(rank(shape));    // index has rank(shape) indicies
    for(int32_t i = rank(shape) - 1; i >= 0; i--) {
        index[i] = offset % shape[i];
        offset /= shape[i];
    }
    return index;
}

// Check the tensor shape is valid and return the tensor size in elements
size_t tensor_size(shape_t shape) {
    size_t size = 1;
    for (int32_t i = 0; i < rank(shape); i++) {
        REQUIRE(1 <= shape[i] && shape[i] <= maximum<size_t> / size);
        size *= shape[i];
    }
    return size;
}

// Return the size of the tensor in the given axis
// For a rank=0 tensor, returns 1 for all axes
size_t shape_dim(shape_t shape, int axis) {
    return (axis >= rank(shape)) ? 1 : shape[axis];
}
----

==== Tensor Read

tensor_read reads a single data value out of the given tensor.
The shape argument contains the shape of the tensor.
Index is the coordinates within the tensor of the value to be read.

[source,c++]
----
in_t tensor_read<in_t>(in_t *address, shape_t shape, shape_t index) {
    size_t offset = tensor_index_to_offset(shape, index);
    return address[offset];
}
----

==== Tensor Write

tensor_write writes a single data value into the given tensor.
The shape argument contains the shape of the tensor.
Index is the coordinates within the tensor of the value to be written.
value is the value to be written to the given coordinate.

[source,c++]
----
void tensor_write<type>(<type> *address, shape_t shape, shape_t index, <type> value) {
    size_t offset = tensor_index_to_offset(shape, index);
    address[offset] = value;
}
----

==== Variable Tensor Allocate

variable_tensor_allocate allocates the mutable persistent memory block for storing variable tensors.
The shape argument contains the shape of the allocated memory block for the variable_tensor.
The uid argument is a globally unique identifier for variable tensors.

[source,c++]
----
tensor_t* variable_tensor_allocate<in_t>(shape_t shape, int32_t uid) {
    size_t size = tensor_size(shape);
    tensor_t *allocated_tensor = new tensor_t;
    allocated_tensor->data = new in_t[size];
    allocated_tensor->uid = uid;
    allocated_tensor->is_written = false;
    allocated_tensor->shape = shape;
    allocated_tensor->type = in_t;
    return allocated_tensor;
}
----

==== Variable Tensor Lookup

variable_tensor_lookup checks whether a variable tensor has been allocated or not.
The uid argument is a globally unique identifier for variable tensors.

[source,c++]
----
tensor_t variable_tensor_lookup(int32_t uid) {
    // The global all_allocated_variable_tensors was instantiated at the first
    // time of executing the tosa graph
    for_each(tensor_t allocated_tensor in all_allocated_variable_tensors) {
        if (allocated_tensor.uid == uid) {
            return allocated_tensor;
        }
    }
    return NULL;
}
----

==== Broadcast Helpers

The following function derives the broadcast output shape from the input shapes.

[source,c++]
----
shape_t broadcast_shape(shape_t shape1, shape_t shape2) {
    ERROR_IF(rank(shape1) != rank(shape2));
    shape_t shape = shape1;
    for (int32_t i = 0; i < rank(shape); i++) {
        if (shape[i] == 1) {
            shape[i] = shape2[i];
        } else {
            ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]);
        }
    }
    return shape;
}
----

The following function maps an index in the output tensor to an index in the input tensor.

[source,c++]
----
// The index argument should be a valid location within out_shape.
// The function returns the location within in_shape that contributes
// to the output based on broadcasting rules.

shape_t apply_broadcast(shape_t out_shape, shape_t in_shape, shape_t index) {
    ERROR_IF(rank(out_shape) != rank(in_shape));
    ERROR_IF(rank(out_shape) != rank(index));
    for (int32_t i = 0; i < rank(out_shape); i++) {
        if (out_shape[i] != in_shape[i]) {
            ERROR_IF(in_shape[i] != 1);
            index[i] = 0;
        }
    }
    return index;
}
----

=== General Pseudocode Helpers

This section contains general pseudocode utility functions used throughout the specification.

==== Arithmetic Helpers

The following functions provide arithmetic while defining requirements such that values stay in the valid range.

[source,c++]
----
in_t apply_add_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a + b;
    int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_add_u<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a + b;
    uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
    return truncate<in_t>(c);
}

in_t apply_arith_rshift<in_t>(in_t a, in_t b) {
    int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_intdiv_s<in_t>(in_t a, in_t b) {
    int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_ceil<in_t>(in_t input) {
    return input value rounded up to nearest integer
}

in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) {
    if (is_floating_point(in_t>) {
        REQUIRE(min_val <= max_val);
    }
    else {
        REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val));
    }
    value = apply_max_s<in_t>(value, min_val);
    value = apply_min_s<in_t>(value, max_val);
    return value;
}

in_t apply_clip_u<in_t>(in_t value, in_t min_val, in_t max_val) {
    REQUIRE(zero_extend<int64_t>(min_val) <= zero_extend<int64_t>(max_val));
    value = apply_max_u<in_t>(value, min_val);
    value = apply_min_u<in_t>(value, max_val);
    return value;
}

in_t apply_exp<in_t>(in_t input) {
    return e to the power input
}

in_t apply_floor<in_t>(in_t input) {
    return input value rounded down to nearest integer
}

in_t apply_log<in_t>(in_t input) {
    if (input == 0) {
        return -INFINITY
    }
    else if (input < 0) {
        return NaN;
    }
    return the natural logarithm of input
}

in_t apply_logical_rshift<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_max_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a >= b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_max_u<in_t>(in_t a, in_t b) {
    if (zero_extend<uint64_t>(a) >= zero_extend<int64_t>(b)) return a; else return b;
}

in_t apply_min_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a < b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_min_u<in_t>(in_t a, in_t b) {
    if (zero_extend<int64_t>(a) < zero_extend<int64_t>(b)) return a; else return b;
}

in_t apply_mul_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a * b;
    int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_pow<in_t>(in_t a, in_t b) {
    return a ** b; // a raised to the power b
}

in_t apply_sqrt<in_t>(in_t input) {
    return the square root of input
}

in_t apply_sub_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a - b;
    int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_sub_u<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
    return truncate<in_t>(c);
}

int32_t count_leading_zeros(int32_t a) {
    int32_t acc = 32;
    if (a != 0) {
        uint32_t mask;
        mask = 1 << (32 - 1); // width of int32_t - 1
        acc = 0;
        while ((mask & a) == 0) {
            mask = mask >> 1;
            acc = acc + 1;
        }
    }
    return acc;
}
----

==== Type Conversion Helpers

The following definitions indicate the type to be used when the given parameters are provided.

[source,c++]
----

// Returns a signed version of the given type
// A no-op for floating-point types
Type make_signed(Type in_t)
{
    switch(in_t) {
        case bool_t:
            return bool_t;
        case i8_t:
            return int8_t;
        case i16_t:
            return int16_t;
        case i32_t:
            return int32_t;
        case i48_t:
            return int48_t;
        case fp16_t:
            return fp16_t;
        case bf16_t:
            return bf16_t;
        case fp32_t:
            return fp32_t;
    }
}

// Returns the usigned type of the given type
// Error to call this with anything but i8_t or i16_t

Type make_unsigned(Type in_t)
{
    ERROR_IF(in_t != i8_t && in_t != i16_t);
    switch(in_t) {
        case i8_t:
            return uint8_t;
        case i16_t:
            return uint16_t;
    }
}

out_t static_cast<out_t>(in_t value)
{
    // Operates similar to the c++ standard static_cast
    // Limited to simple numeric conversion for TOSA.
    // Sign extends signed integer input types if needed
    // Zero extends unsigned integer input types if needed
    // Truncates when converting to a smaller width data type
    // Conversion from integer to floating-point is exact if possible
    // If converting between signless integer types, treated as signed integer
}

out_t bitcast<out_t>(in_t value)
{
    // Treats the bits of value as if they were of type out_t
    // Only supported for integer types of the same bit width
}
----

==== Numeric Accuracy Helpers

For a floating point number of type in_t a normal value is of the form (1.x * 2^e).
The fractional part 'x' has a number of fractional or mantissa bits depending on the type.
The exponent 'e' has a normal range depending on the type.
The functions below return the ranges according to type.

[source,c++]
----
fp64_t exp2(int n) {
    if (n < -1075) {
        return 0.0; // smaller than smallest denormal
    }
    REQUIRE(n <= 1023);
    fp64_t v = 1.0;
    while (n > 0) { v = v*2.0; n--; }
    while (n < 0) { v = v/2.0; n++; }
    return v;
}

int ilog2(fp64_t v) {
    REQURE(0 < v && v < infinity);
    int n = 0;
    while (v >= 2.0) { v = v/2.0; n++; }
    while (v <  1.0) { v = v*2.0; n--; }
    return n;
}

fp64_t normal_min<in_t>() {
  switch (in_t) {
    case fp32_t: return exp2(-126);
    case bf16_t: return exp2(-126);
    case fp16_t: return exp2( -14);
  }
}

fp64_t normal_max<in_t>() {
  switch (in_t) {
    case fp32_t: return exp2(128) - exp2(127-23);
    case bf16_t: return exp2(128) - exp2(127- 7);
    case fp16_t: return exp2( 16) - exp2( 15-10);
  }
}

// Number of fractional (mantissa bits)
int normal_frac<in_t> () {
  switch (in_t) {
    case fp32_t: return 23;
    case fp16_t: return 10;
    case bf16_t: return  7;
  }
}
----

The following functions check if a test value in floating-point format in_t is within an error range compared to a reference value.
The functions assume that denormal values may be flushed to zero.
For the first function, the permitted error range is specified as num_ulp which is converted to an error bound as specified by the code.
For the second function, the permitted error range is specified as an absolute error bound.

[source,c++]
----
bool tosa_reference_check_fp<in_t>(in_t test_value, fp64_t ref_value, fp64_t num_ulp) {
  fp64_t err_bnd = 0.0;
  if (is_normal_fp64(ref_value) && abs(ref_value) != 0) {
    int ref_exp = ilog2(abs(ref_value));
    fp64_t ref_pow2 = max(exp2(ref_exp), normal_min<in_t>);
    fp64_t val_ulp  = ref_pow2 * exp2(-normal_frac<in_t>);
    err_bnd = val_ulp * num_ulp;
  }
  return tosa_reference_check_fp_bnd<in_t>(test_value, ref_value, err_bnd);
}

bool tosa_reference_check_fp_bnd<in_t>(in_t test_value, fp64_t ref_value, fp64_t err_bnd) {
  if (is_a_NaN(ref_value)) {
    return is_a_NaN(test_value);
  }
  REQUIRE(err_bnd >= 0.0);
  if (ref_value < 0) {
    ref_value  = -ref_value;
    test_value = -test_value;
  }
  fp64_t ref_max = ref_value + err_bnd;
  fp64_t ref_min = ref_value - err_bnd;
  if (ref_max > normal_max<in_t>) ref_max = infinity;
  if (ref_min > normal_max<in_t>) ref_min = infinity;
  if (ref_max < normal_min<in_t>) ref_max = normal_min<in_t>;
  if (ref_min < normal_min<in_t>) ref_min = 0;
  return (static_cast<fp64_t>(test_value) >= ref_min &&
          static_cast<fp64_t>(test_value) <= ref_max);
}

----

==== Numeric Conversion Helpers

The following definitions are used in pseudocode to do numeric conversions.
Where the *float_t* type is used, it represents all of the floating-point data types supported by the given profile.
See <<Number formats>> for details on the floating-point formats.

[source,c++]
----
int round_to_nearest_int(float_t f)
  Converts the floating-point value to f, with rounding to the nearest integer value.
  For the required precision see the section: Main inference precision requirements.

float_t round_to_nearest_float(in_t f)
  Converts the input value into floating-point, rounding to the nearest representable value.
  For the required precision see the section: Main inference precision requirements.

out_t sign_extend<out_t>(in_t input)
  Floating point values are unchanged.
  For two's complement integer values where out_t has more bits than in_t, replicate the top bit of input for all bits between the top bit of input and the top bit of output.

out_t zero_extend<out_t>(in_t input)
  Floating point values are unchanged.
  For two's complement integer values where out_t has more bits than in_t, insert zero values for all bits between the top bit of input and the top bit of output.

out_t truncate(in_t input)
  output is the sizeof(out_t) least significant bits in input.
  Nop for floating-point types
----

The following definition is used to flatten a list of lists into a single list.

[source,c++]
----
in_t* flatten(in_t lists[]) {
    in_t output = [];
    for_each(list in lists) {
        for_each(element in list) {
            output.append(element);
        }
    }
}
----

Generic helper functions used to keep the pseudocode concise.

[source,c++]
----

bool_t is_floating_point(type) {
    if (type == fp16_t || type == fp32_t || type == bf16_t)
        return true;
    return false;
}

int32_t idiv(int32_t input1, int32_t input2) {
    return input1 / input2; // Integer divide that truncates towards zero
}

// Integer division that checks input1 is a multiple of input2

int32_t idiv_check(int32_t input1, int32_t input2) {
    ERROR_IF(input1 % input2 != 0); // input1 must be a multiple of input2
    return input1 / input2;         // exact quotient without rounding
}

// perform an integer division with rounding towards minus infinity

int32_t idiv_floor(int32_t input1, int32_t input2) {
    int32_t rval = input1 / input2;
    if (rval * input2 > input1) {
        rval--;
    }
    return rval;
}

int32_t length(in_t input)
    return number of elements in input list

int32_t rank(in_t input)
    return rank of an input tensor

int32_t sum(in_t input[])
    return the sum of values of an input list

bool isNaN(float input)
    return True if floating-point input value is NaN

float_t pi()
    returns value of pi

float_t sin(angle)
    return sine of angle given in radians

float_t cos(angle)
    return cosine of angle given in radians

bool power_of_two(int32_t value)
    return true if value is a power of two, false otherwise

in_out_t maximum_s<Type T>
    return the maximum value when interpreting type T as a signed value as returned by the make_signed helper.

in_out_t minimum_s<Type T>
    return the minimum value when interpreting type T as a signed value as returned by the make_signed helper.

in_out_t maximum_u<Type T>
    return the maximum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.

in_out_t minimum_u<Type T>
    return the minimum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.
----