aboutsummaryrefslogtreecommitdiff
path: root/chapters/tensor_ops.adoc
diff options
context:
space:
mode:
Diffstat (limited to 'chapters/tensor_ops.adoc')
-rw-r--r--chapters/tensor_ops.adoc40
1 files changed, 25 insertions, 15 deletions
diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc
index 291751f..b2c220e 100644
--- a/chapters/tensor_ops.adoc
+++ b/chapters/tensor_ops.adoc
@@ -18,9 +18,9 @@ This returns the index with the largest value across the given axis of the input
|===
|Argument|Type|Name|Shape|Description
-|Input|in_t*|input|shape1|Input tensor dimension k \<=4
-|Attribute|int|axis|-|Axis in range 0 to k-1
-|Output|out_t*|output|shape|Output tensor dimension k-1
+|Input|in_t*|input|shape1|Input tensor with rank from 1 to 4
+|Attribute|int|axis|-|Axis in range from 0 to rank(shape1)-1
+|Output|out_t*|output|shape|Output tensor, with rank = rank(shape1)-1
|===
*Quantization Parameters:*
@@ -31,20 +31,30 @@ None
[source,c++]
----
-assert(axis >= 0 && axis < k && k <=4);
-left_shape = shape1[0:axis-1];
-right_shape = shape1[axis+1:k-1];
+assert(axis >= 0 && axis < rank(shape1) && rank(shape1) <= 4);
+if (axis == 0) {
+ left_shape = [];
+} else {
+ left_shape = shape1[0:axis - 1];
+}
+if (axis == rank(shape1)-1) {
+ right_shape = [];
+} else {
+ right_shape = shape1[axis+1:rank(shape1) - 1];
+}
assert(flatten(left_shape, right_shape) == shape);
-for_each(left_index in left_shape, right_index in right_shape )
- in_t max_value = minimum_value<in_t>;
- int32_t max_index = 0;
- for (i = 0; i < shape[axis]; i++) {
- index = flatten(left_index, [i], right_index);
- in_t value = tensor_read<in_t>(input, shape1, index);
- if (value > max_value) { max_value = value; max_index=i; }
+for_each(left_index in left_shape) {
+ for_each(right_index in right_shape) {
+ in_t max_value = minimum_value<in_t>;
+ int32_t max_index = 0;
+ for (i = 0; i < shape[axis]; i++) {
+ index = flatten(left_index, [i], right_index);
+ in_t value = tensor_read<in_t>(input, shape1, index);
+ if (value > max_value) { max_value = value; max_index = i; }
+ }
+ index = flatten(left_index, right_index);
+ tensor_write<int32_t>(output, shape, index, max_index);
}
- index = flatten(left_index, right_index);
- tensor_write<int32_t>(output, shape, index, max_index);
}
----