aboutsummaryrefslogtreecommitdiff
path: root/chapters/tensor_ops.adoc
diff options
context:
space:
mode:
Diffstat (limited to 'chapters/tensor_ops.adoc')
-rw-r--r--chapters/tensor_ops.adoc30
1 files changed, 15 insertions, 15 deletions
diff --git a/chapters/tensor_ops.adoc b/chapters/tensor_ops.adoc
index b2c220e..6780b1c 100644
--- a/chapters/tensor_ops.adoc
+++ b/chapters/tensor_ops.adoc
@@ -31,7 +31,7 @@ None
[source,c++]
----
-assert(axis >= 0 && axis < rank(shape1) && rank(shape1) <= 4);
+REQUIRE(axis >= 0 && axis < rank(shape1) && rank(shape1) <= 4);
if (axis == 0) {
left_shape = [];
} else {
@@ -42,7 +42,7 @@ if (axis == rank(shape1)-1) {
} else {
right_shape = shape1[axis+1:rank(shape1) - 1];
}
-assert(flatten(left_shape, right_shape) == shape);
+REQUIRE(flatten(left_shape, right_shape) == shape);
for_each(left_index in left_shape) {
for_each(right_index in right_shape) {
in_t max_value = minimum_value<in_t>;
@@ -97,8 +97,8 @@ This performs an average pooling over the given input tensor. A sliding window o
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
-assert(in_t == int8_t || output_zp == 0); // Zero point only for int8_t
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
+REQUIRE(in_t == int8_t || output_zp == 0); // Zero point only for int8_t
pad = flatten([0,0], pad, [0,0]);
for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W, 0 <= c < C ) {
in_t output_val;
@@ -164,8 +164,8 @@ Performs a 2D convolution over the given tensor input, using the weight tensor.
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
-assert(weight_t == int8_t || weight_zp == 0);
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
+REQUIRE(weight_t == int8_t || weight_zp == 0);
pad = flatten([0,0], pad, [0,0]);
for_each(0 <= n < N, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) {
acc_t acc = 0;
@@ -225,8 +225,8 @@ Performs a 3D convolution over the given input tensor.
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
-assert(weight_t == int8_t || weight_zp == 0);
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
+REQUIRE(weight_t == int8_t || weight_zp == 0);
pad = flatten([0,0], pad, [0,0]);
for_each(0 <= n < N, 0 <= od < D, 0 <= oy < H, 0 <= ox < W; 0 <= oc < OC) {
acc_t acc = 0;
@@ -289,8 +289,8 @@ Performs 2D convolutions separately over each channel of the given tensor input,
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
-assert(weight_t == int8_t || weight_zp == 0);
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
+REQUIRE(weight_t == int8_t || weight_zp == 0);
pad = flatten([0,0], pad, [0,0]);
for_each(0 <= n<N, 0 <= oy < H, 0 <= ox < W; 0 <= c < (C * M), 0 <= m < M) {
acc_t acc = 0;
@@ -347,8 +347,8 @@ Performs a fully connected network.
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
-assert(weight_t == int8_t || weight_zp == 0);
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only for int8_t
+REQUIRE(weight_t == int8_t || weight_zp == 0);
for_each(0 <= n < N, 0 <= oc < OC) {
acc_t acc = 0;
for_each(0 <= ic < IC) {
@@ -398,7 +398,7 @@ Performs two dimensional matrix multiplications. This allows both inputs to be a
[source,c++]
----
-assert(in_t == int8_t || (A_zp == 0 && B_zp == 0)); // Zero point only for int8_t
+REQUIRE(in_t == int8_t || (A_zp == 0 && B_zp == 0)); // Zero point only for int8_t
for_each(0 <= n < N, 0 <= h < H, 0 <= w < W) {
acc_t acc = 0;
for_each(0 <= c < C) {
@@ -499,8 +499,8 @@ Performs a 2D transposed convolution over the given tensor input, using the weig
[source,c++]
----
-assert(in_t == int8_t || input_zp == 0); // Zero point only allowed for int8_t
-assert(weight_t == int8_t || weight_zp == 0);
+REQUIRE(in_t == int8_t || input_zp == 0); // Zero point only allowed for int8_t
+REQUIRE(weight_t == int8_t || weight_zp == 0);
for_each(index in out_shape) {
tensor_write<acc_t>(output, [N,OH,OW,OC], index, bias[index[3]])
}