aboutsummaryrefslogtreecommitdiff
path: root/chapters/pseudocode.adoc
blob: d674c9cf4fd89543603a2acce661672a45576b22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
// (C) COPYRIGHT 2021-2023 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
// by a licensing agreement from ARM Limited.

== TOSA Pseudocode

The TOSA pseudocode provides precise descriptions of TOSA operations.
Each operator contains pseudocode describing the operator's functionality.
This section contains pseudocode functions shared across multiple operators in the specification.

=== Operator Validation Helpers

The following functions are used to define the valid conditions for TOSA operators.

The REQUIRE function defines the conditions required by the TOSA operator.
If the conditions are not met then the result of the TOSA graph is marked as unpredictable.
Once the tosa_graph_result is set to tosa_unpredictable, the whole graph is considered unpredictable.

The ERROR_IF function defines a condition that must set an error if the condition holds and the graph is not unpredictable.
Note that if a graph contains both unpredictable and error statements then result of tosa_execute_graph() is tosa_unpredictable.
This condition is captured in the ERROR_IF function.

*Implementation Notes*

* An implementation is not required to detect unpredictable behavior. If tosa_execute_graph() returns tosa_unpredictable then the tosa_test_compliance() function does not require any specific output from an implementation.
* An implementation is required to detect errors in a graph that does not have unpredictable behavior (see tosa_test_compliance).
* An acceptable implementation is to stop and report an error on the first ERROR_IF condition that occurs. This satifies tosa_test_compliance() even if the tosa_execute_graph() was tosa_unpredictable.
* If the tosa_execute_graphs() result is tosa_unpredictable or tosa_error, then there is no requirement on the implementation to execute any portion of the TOSA graph.

[source,c++]
----
void REQUIRE(condition) {
    // Unpredictable overrides any previous result
    if (!(condition)) {
        tosa_graph_result = tosa_unpredictable;
    }
}

void ERROR_IF(condition) {
    // Error encodes a predictable error state and so is not registered
    // if the graph is marked as unpredictable.
    if (tosa_graph_result != tosa_unpredictable && condition) {
        tosa_graph_result = tosa_error;
    }
}

void LEVEL_CHECK(condition) {
    // If a level is specified and the level condition fails then
    // the result is unpredictable.
    REQUIRE(condition);
}
----

=== Tensor Access Helpers

==== Tensor Utilities

[source,c++]
----
// Convert tensor index coordinates to an element offset
size_t tensor_index_to_offset(dim_t shape, dim_t index) {
    size_t size = tensor_size(shape);  // check tensor shape is valid
    size_t offset = 0;
    for (int32_t i = 0; i < rank(shape); i++) {
        REQUIRE(index[i] >= 0 && index[i] < shape[i]);
        offset = offset * shape[i] + index[i];
    }
    return offset;
}

// Convert an element offset to tensor index coordinates
dim_t tensor_offset_to_index(dim_t shape, size_t offset) {
    size_t size = tensor_size(shape);  // check tensor shape is valid
    REQUIRE(offset < size);
    dim_t index(rank(shape));    // index has rank(shape) indicies
    for(int32_t i = rank(shape) - 1; i >= 0; i--) {
        index[i] = offset % shape[i];
        offset /= shape[i];
    }
    return index;
}

// Check the tensor shape is valid and return the tensor size in elements
size_t tensor_size(dim_t shape) {
    size_t size = 1;
    for (int32_t i = 0; i < rank(shape); i++) {
        REQUIRE(1 <= shape[i] && shape[i] <= maximum<size_t> / size);
        size *= shape[i];
    }
    return size;
}

// Return the size of the tensor in the given axis
// For a rank=0 tensor, returns 1 for all axes
size_t shape_dim(dim_t shape, int axis) {
    return (axis >= rank(shape)) ? 1 : shape[axis];
}
----

==== Tensor Read

tensor_read reads a single data value out of the given tensor.
The shape argument contains the shape of the tensor.
Index is the coordinates within the tensor of the value to be read.

[source,c++]
----
in_t tensor_read<in_t>(in_t *address, dim_t shape, dim_t index) {
    size_t offset = tensor_index_to_offset(shape, index);
    return address[offset];
}
----

==== Tensor Write

tensor_write writes a single data value into the given tensor.
The shape argument contains the shape of the tensor.
Index is the coordinates within the tensor of the value to be written.
value is the value to be written to the given coordinate.

[source,c++]
----
void tensor_write<type>(<type> *address, dim_t shape, dim_t index, <type> value) {
    size_t offset = tensor_index_to_offset(shape, index);
    address[offset] = value;
}
----

==== Variable Tensor Allocate

variable_tensor_allocate allocates the mutable persistent memory block for storing variable tensors.
The shape argument contains the shape of the allocated memory block for the variable_tensor.
The uid argument is a globally unique identifier for variable tensors.

[source,c++]
----
tensor_t* variable_tensor_allocate<in_t>(dim_t shape, int32_t uid) {
    size_t size = tensor_size(shape);
    tensor_t *allocated_tensor = new tensor_t;
    allocated_tensor->data = new in_t[size];
    allocated_tensor->uid = uid;
    allocated_tensor->is_written = false;
    allocated_tensor->shape = shape;
    allocated_tensor->type = in_t;
    return allocated_tensor;
}
----

==== Variable Tensor Lookup

variable_tensor_lookup checks whether a variable tensor has been allocated or not.
The uid argument is a globally unique identifier for variable tensors.

[source,c++]
----
tensor_t variable_tensor_lookup(int32_t uid) {
    // The global all_allocated_variable_tensors was instantiated at the first
    // time of executing the tosa graph
    for_each(tensor_t allocated_tensor in all_allocated_variable_tensors) {
        if (allocated_tensor.uid == uid) {
            return allocated_tensor;
        }
    }
    return NULL;
}
----

==== Broadcast Helpers

The following function derives the broadcast output shape from the input shapes.

[source,c++]
----
dim_t broadcast_shape(dim_t shape1, dim_t shape2) {
    ERROR_IF(rank(shape1) != rank(shape2));
    dim_t shape = shape1;
    for (int32_t i = 0; i < rank(shape); i++) {
        if (shape[i] == 1) {
            shape[i] = shape2[i];
        } else {
            ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]);
        }
    }
    return shape;
}
----

The following function maps an index in the output tensor to an index in the input tensor.

[source,c++]
----
// The index argument should be a valid location within out_shape.
// The function returns the location within in_shape that contributes
// to the output based on broadcasting rules.

dim_t apply_broadcast(dim_t out_shape, dim_t in_shape, dim_t index) {
    ERROR_IF(rank(out_shape) != rank(in_shape));
    ERROR_IF(rank(out_shape) != rank(index));
    for (int32_t i = 0; i < rank(out_shape); i++) {
        if (out_shape[i] != in_shape[i]) {
            ERROR_IF(in_shape[i] != 1);
            index[i] = 0;
        }
    }
    return index;
}
----

=== General Pseudocode Helpers

This section contains general pseudocode utility functions used throughout the specification.

==== Arithmetic Helpers

The following functions provide arithmetic while defining requirements such that values stay in the valid range.

[source,c++]
----
in_t apply_add_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a + b;
    int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_add_u<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a + b;
    uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
    return truncate<in_t>(c);
}

in_t apply_arith_rshift<in_t>(in_t a, in_t b) {
    int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_intdiv_s<in_t>(in_t a, in_t b) {
    int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_ceil<in_t>(in_t input) {
    return input value rounded up to nearest integer
}

in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) {
    if (is_floating_point(in_t>) {
        REQUIRE(min_val <= max_val);
    }
    else {
        REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val));
    }
    value = apply_max_s<in_t>(value, min_val);
    value = apply_min_s<in_t>(value, max_val);
    return value;
}

in_t apply_exp<in_t>(in_t input) {
    return e to the power input
}

in_t apply_floor<in_t>(in_t input) {
    return input value rounded down to nearest integer
}

in_t apply_log<in_t>(in_t input) {
    if (input == 0) {
        return -INFINITY
    }
    else if (input < 0) {
        return NaN;
    }
    return the natural logarithm of input
}

in_t apply_logical_rshift<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_max_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a >= b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_min_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a < b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_mul_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a * b;
    int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_pow<in_t>(in_t a, in_t b) {
    return a ** b; // a raised to the power b
}

in_t apply_sqrt<in_t>(in_t input) {
    return the square root of input
}

in_t apply_sub_s<in_t>(in_t a, in_t b) {
    if (is_floating_point(in_t)) return a - b;
    int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t> && c <= maximum_s<in_t>);
    return static_cast<in_t>(c);
}

in_t apply_sub_u<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_u_t> && c <= maximum_u<in_u_t>);
    return truncate<in_t>(c);
}

int32_t count_leading_zeros(int32_t a) {
    int32_t acc = 32;
    if (a != 0) {
        uint32_t mask;
        mask = 1 << (32 - 1); // width of int32_t - 1
        acc = 0;
        while ((mask & a) == 0) {
            mask = mask >> 1;
            acc = acc + 1;
        }
    }
    return acc;
}
----

==== Type Conversion Helpers

The following definitions indicate the type to be used when the given parameters are provided.

[source,c++]
----

// Returns a signed version of the given type
// A no-op for floating-point types
Type make_signed(Type in_t)
{
    switch(in_t) {
        case bool_t:
            return bool_t;
        case i8_t:
            return int8_t;
        case i16_t:
            return int16_t;
        case i32_t:
            return int32_t;
        case i48_t:
            return int48_t;
        case fp16_t:
            return fp16_t;
        case bf16_t:
            return bf16_t;
        case fp32_t:
            return fp32_t;
    }
}

// Returns the usigned type of the given type
// Error to call this with anything but i8_t or i16_t

Type make_unsigned(Type in_t)
{
    ERROR_IF(in_t != i8_t && in_t != i16_t);
    switch(in_t) {
        case i8_t:
            return uint8_t;
        case i16_t:
            return uint16_t;
    }
}

out_t static_cast<out_t>(in_t value)
{
    // Operates similar to the c++ standard static_cast
    // Limited to simple numeric conversion for TOSA.
    // Sign extends signed integer input types if needed
    // Zero extends unsigned integer input types if needed
    // Truncates when converting to a smaller width data type
    // Conversion from integer to floating-point is exact if possible
    // If converting between signless integer types, treated as signed integer
}

out_t bitcast<out_t>(in_t value)
{
    // Treats the bits of value as if they were of type out_t
    // Only supported for integer types of the same bit width
}
----

==== Numeric Conversion Helpers

The following definitions are used in pseudocode to do numeric conversions.
Where the *float_t* type is used, it represents all of the floating-point data types supported by the given profile.
See <<Number formats>> for details on the floating-point formats.

[source,c++]
----
int round_to_nearest_int(float_t f)
  Converts the floating-point value to f, with rounding to the nearest integer value.
  For the required precision see the section: Main inference precision requirements.

float_t round_to_nearest_float(in_t f)
  Converts the input value into floating-point, rounding to the nearest representable value.
  For the required precision see the section: Main inference precision requirements.

out_t sign_extend<out_t>(in_t input)
  Floating point values are unchanged.
  For two's complement integer values where out_t has more bits than in_t, replicate the top bit of input for all bits between the top bit of input and the top bit of output.

out_t zero_extend<out_t>(in_t input)
  Floating point values are unchanged.
  For two's complement integer values where out_t has more bits than in_t, insert zero values for all bits between the top bit of input and the top bit of output.

out_t truncate(in_t input)
  output is the sizeof(out_t) least significant bits in input.
  Nop for floating-point types
----

The following definition is used to flatten a list of lists into a single list.

[source,c++]
----
in_t* flatten(in_t lists[]) {
    in_t output = [];
    for_each(list in lists) {
        for_each(element in list) {
            output.append(element);
        }
    }
}
----

Generic helper functions used to keep the pseudocode concise.

[source,c++]
----

bool_t is_floating_point(type) {
    if (type == fp16_t || type == fp32_t || type == bf16_t)
        return true;
    return false;
}

int32_t idiv(int32_t input1, int32_t input2) {
    return input1 / input2; // Integer divide that truncates towards zero
}

// Integer division that checks input1 is a multiple of input2

int32_t idiv_check(int32_t input1, int32_t input2) {
    ERROR_IF(input1 % input2 != 0); // input1 must be a multiple of input2
    return input1 / input2;         // exact quotient without rounding
}

// perform an integer division with rounding towards minus infinity

int32_t idiv_floor(int32_t input1, int32_t input2) {
    int32_t rval = input1 / input2;
    if (rval * input2 > input1) {
        rval--;
    }
    return rval;
}

int32_t length(in_t input)
    return number of elements in input list

int32_t rank(in_t input)
    return rank of an input tensor

int32_t sum(in_t input[])
    return the sum of values of an input list

bool isNaN(float input)
    return True if floating-point input value is NaN

float_t pi()
    returns value of pi

float_t sin(angle)
    return sine of angle given in radians

float_t cos(angle)
    return cosine of angle given in radians

bool power_of_two(int32_t value)
    return true if value is a power of two, false otherwise

in_out_t maximum_s<Type T>
    return the maximum value when interpreting type T as a signed value as returned by the make_signed helper.

in_out_t minimum_s<Type T>
    return the minimum value when interpreting type T as a signed value as returned by the make_signed helper.

in_out_t maximum_u<Type T>
    return the maximum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.

in_out_t minimum_u<Type T>
    return the minimum value when interpreting type T as an unsigned value as returned by the make_unsigned helper.
----