diff options
Diffstat (limited to 'chapters/tensor_ops.adoc')
-rw-r--r-- | chapters/tensor_ops.adoc | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc index 291751f..b2c220e 100644 --- a/chapters/tensor_ops.adoc +++ b/chapters/tensor_ops.adoc @@ -18,9 +18,9 @@ This returns the index with the largest value across the given axis of the input |=== |Argument|Type|Name|Shape|Description -|Input|in_t*|input|shape1|Input tensor dimension k \<=4 -|Attribute|int|axis|-|Axis in range 0 to k-1 -|Output|out_t*|output|shape|Output tensor dimension k-1 +|Input|in_t*|input|shape1|Input tensor with rank from 1 to 4 +|Attribute|int|axis|-|Axis in range from 0 to rank(shape1)-1 +|Output|out_t*|output|shape|Output tensor, with rank = rank(shape1)-1 |=== *Quantization Parameters:* @@ -31,20 +31,30 @@ None [source,c++] ---- -assert(axis >= 0 && axis < k && k <=4); -left_shape = shape1[0:axis-1]; -right_shape = shape1[axis+1:k-1]; +assert(axis >= 0 && axis < rank(shape1) && rank(shape1) <= 4); +if (axis == 0) { + left_shape = []; +} else { + left_shape = shape1[0:axis - 1]; +} +if (axis == rank(shape1)-1) { + right_shape = []; +} else { + right_shape = shape1[axis+1:rank(shape1) - 1]; +} assert(flatten(left_shape, right_shape) == shape); -for_each(left_index in left_shape, right_index in right_shape ) - in_t max_value = minimum_value<in_t>; - int32_t max_index = 0; - for (i = 0; i < shape[axis]; i++) { - index = flatten(left_index, [i], right_index); - in_t value = tensor_read<in_t>(input, shape1, index); - if (value > max_value) { max_value = value; max_index=i; } +for_each(left_index in left_shape) { + for_each(right_index in right_shape) { + in_t max_value = minimum_value<in_t>; + int32_t max_index = 0; + for (i = 0; i < shape[axis]; i++) { + index = flatten(left_index, [i], right_index); + in_t value = tensor_read<in_t>(input, shape1, index); + if (value > max_value) { max_value = value; max_index = i; } + } + index = flatten(left_index, right_index); + tensor_write<int32_t>(output, shape, index, max_index); } - index = flatten(left_index, right_index); - tensor_write<int32_t>(output, shape, index, max_index); } ---- |