diff options
Diffstat (limited to 'pseudocode/operators/ARGMAX.tosac')
-rw-r--r-- | pseudocode/operators/ARGMAX.tosac | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/pseudocode/operators/ARGMAX.tosac b/pseudocode/operators/ARGMAX.tosac new file mode 100644 index 0000000..bbe301f --- /dev/null +++ b/pseudocode/operators/ARGMAX.tosac @@ -0,0 +1,37 @@ +// +// This confidential and proprietary software may be used only as +// authorised by a licensing agreement from ARM Limited +// (C) COPYRIGHT 2020-2024 ARM Limited +// ALL RIGHTS RESERVED +// The entire notice above must be reproduced on all authorised +// copies and copies may only be made to the extent permitted +// by a licensing agreement from ARM Limited. + +ERROR_IF(axis < 0 || axis >= rank(shape1)); +if (axis == 0) { + left_shape = []; +} else { + left_shape = shape1[0:axis - 1]; +} +if (axis == rank(shape1)-1) { + right_shape = []; +} else { + right_shape = shape1[axis+1:rank(shape1) - 1]; +} +ERROR_IF(flatten(left_shape, right_shape) != shape); +for_each(left_index in left_shape) { + for_each(right_index in right_shape) { + in_t max_value = minimum_s<in_t>; + out_t max_index = 0; + for (i = 0; i < shape[axis]; i++) { + dim_t index = flatten(left_index, [i], right_index); + in_t value = tensor_read<in_t>(input, shape1, index); + if (apply_max_s<in_t>(value, max_value) != max_value) { + max_value = value; + max_index = i; + } + } + dim_t index = flatten(left_index, right_index); + tensor_write<out_t>(output, shape, index, max_index); + } +} |