aboutsummaryrefslogtreecommitdiff
path: root/pseudocode/operators/ARGMAX.tosac
blob: bbe301fa813051f0045ce998d7de53eaf5c1d16d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
// (C) COPYRIGHT 2020-2024 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
// by a licensing agreement from ARM Limited.

ERROR_IF(axis < 0 || axis >= rank(shape1));
if (axis == 0) {
    left_shape = [];
} else {
    left_shape = shape1[0:axis - 1];
}
if (axis == rank(shape1)-1) {
    right_shape = [];
} else {
    right_shape = shape1[axis+1:rank(shape1) - 1];
}
ERROR_IF(flatten(left_shape, right_shape) != shape);
for_each(left_index in left_shape) {
    for_each(right_index in right_shape) {
        in_t max_value = minimum_s<in_t>;
        out_t max_index = 0;
        for (i = 0; i < shape[axis]; i++) {
            dim_t index = flatten(left_index, [i], right_index);
            in_t value = tensor_read<in_t>(input, shape1, index);
            if (apply_max_s<in_t>(value, max_value) != max_value) {
                max_value = value;
                max_index = i;
            }
        }
        dim_t index = flatten(left_index, right_index);
        tensor_write<out_t>(output, shape, index, max_index);
    }
}