aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chapters/comparison.adoc3
-rw-r--r--chapters/ewise_binary.adoc16
-rw-r--r--chapters/ewise_ternary.adoc1
-rw-r--r--chapters/pseudocode.adoc20
4 files changed, 39 insertions, 1 deletions
diff --git a/chapters/comparison.adoc b/chapters/comparison.adoc
index f4da361..4ef52d6 100644
--- a/chapters/comparison.adoc
+++ b/chapters/comparison.adoc
@@ -17,6 +17,7 @@ include::{generated}/operators/EQUAL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -39,6 +40,7 @@ include::{generated}/operators/GREATER.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -61,6 +63,7 @@ include::{generated}/operators/GREATER_EQUAL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/ewise_binary.adoc b/chapters/ewise_binary.adoc
index 963d712..4af347a 100644
--- a/chapters/ewise_binary.adoc
+++ b/chapters/ewise_binary.adoc
@@ -18,6 +18,7 @@ include::{generated}/operators/ADD.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -37,6 +38,7 @@ include::{generated}/operators/ARITHMETIC_RIGHT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -66,6 +68,7 @@ include::{generated}/operators/BITWISE_AND.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -85,6 +88,7 @@ include::{generated}/operators/BITWISE_OR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -104,6 +108,7 @@ include::{generated}/operators/BITWISE_XOR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -126,6 +131,7 @@ include::{generated}/operators/INTDIV.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -149,6 +155,7 @@ include::{generated}/operators/LOGICAL_AND.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -168,6 +175,7 @@ include::{generated}/operators/LOGICAL_LEFT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -188,6 +196,7 @@ include::{generated}/operators/LOGICAL_RIGHT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -208,6 +217,7 @@ include::{generated}/operators/LOGICAL_OR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -227,6 +237,7 @@ include::{generated}/operators/LOGICAL_XOR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -246,6 +257,7 @@ include::{generated}/operators/MAXIMUM.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -265,6 +277,7 @@ include::{generated}/operators/MINIMUM.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -284,6 +297,7 @@ include::{generated}/operators/MUL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
ERROR_IF(in_t != int32_t && shift > 0);
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
@@ -313,6 +327,7 @@ include::{generated}/operators/POW.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -332,6 +347,7 @@ include::{generated}/operators/SUB.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/ewise_ternary.adoc b/chapters/ewise_ternary.adoc
index 0b8097d..57cf599 100644
--- a/chapters/ewise_ternary.adoc
+++ b/chapters/ewise_ternary.adoc
@@ -17,6 +17,7 @@ include::{generated}/operators/SELECT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(broadcast_shape(shape1, shape2), shape3));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc
index db699d1..42f123b 100644
--- a/chapters/pseudocode.adoc
+++ b/chapters/pseudocode.adoc
@@ -125,7 +125,25 @@ void tensor_write<type>(<type> *address, dim_t shape, dim_t index, <type> value)
}
----
-==== Broadcast Helper
+==== Broadcast Helpers
+
+The following function derives the broadcast output shape from the input shapes.
+
+[source,c++]
+----
+dim_t broadcast_shape(dim_t shape1, dim_t shape2) {
+ ERROR_IF(rank(shape1) != rank(shape2));
+ dim_t shape = shape1;
+ for (int32_t i = 0; i < rank(shape); i++) {
+ if (shape[i] == 1) {
+ shape[i] = shape2[i];
+ } else {
+ ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]);
+ }
+ }
+ return shape;
+}
+----
The following function maps an index in the output tensor to an index in the input tensor.