diff options
-rw-r--r-- | chapters/comparison.adoc | 3 | ||||
-rw-r--r-- | chapters/ewise_binary.adoc | 16 | ||||
-rw-r--r-- | chapters/ewise_ternary.adoc | 1 | ||||
-rw-r--r-- | chapters/pseudocode.adoc | 20 |
4 files changed, 39 insertions, 1 deletions
diff --git a/chapters/comparison.adoc b/chapters/comparison.adoc index f4da361..4ef52d6 100644 --- a/chapters/comparison.adoc +++ b/chapters/comparison.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -39,6 +40,7 @@ include::{generated}/operators/GREATER.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -61,6 +63,7 @@ include::{generated}/operators/GREATER_EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_binary.adoc b/chapters/ewise_binary.adoc index 963d712..4af347a 100644 --- a/chapters/ewise_binary.adoc +++ b/chapters/ewise_binary.adoc @@ -18,6 +18,7 @@ include::{generated}/operators/ADD.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -37,6 +38,7 @@ include::{generated}/operators/ARITHMETIC_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -66,6 +68,7 @@ include::{generated}/operators/BITWISE_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -85,6 +88,7 @@ include::{generated}/operators/BITWISE_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -104,6 +108,7 @@ include::{generated}/operators/BITWISE_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -126,6 +131,7 @@ include::{generated}/operators/INTDIV.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -149,6 +155,7 @@ include::{generated}/operators/LOGICAL_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -168,6 +175,7 @@ include::{generated}/operators/LOGICAL_LEFT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -188,6 +196,7 @@ include::{generated}/operators/LOGICAL_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -208,6 +217,7 @@ include::{generated}/operators/LOGICAL_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -227,6 +237,7 @@ include::{generated}/operators/LOGICAL_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -246,6 +257,7 @@ include::{generated}/operators/MAXIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -265,6 +277,7 @@ include::{generated}/operators/MINIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -284,6 +297,7 @@ include::{generated}/operators/MUL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); ERROR_IF(in_t != int32_t && shift > 0); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); @@ -313,6 +327,7 @@ include::{generated}/operators/POW.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -332,6 +347,7 @@ include::{generated}/operators/SUB.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_ternary.adoc b/chapters/ewise_ternary.adoc index 0b8097d..57cf599 100644 --- a/chapters/ewise_ternary.adoc +++ b/chapters/ewise_ternary.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/SELECT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(broadcast_shape(shape1, shape2), shape3)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc index db699d1..42f123b 100644 --- a/chapters/pseudocode.adoc +++ b/chapters/pseudocode.adoc @@ -125,7 +125,25 @@ void tensor_write<type>(<type> *address, dim_t shape, dim_t index, <type> value) } ---- -==== Broadcast Helper +==== Broadcast Helpers + +The following function derives the broadcast output shape from the input shapes. + +[source,c++] +---- +dim_t broadcast_shape(dim_t shape1, dim_t shape2) { + ERROR_IF(rank(shape1) != rank(shape2)); + dim_t shape = shape1; + for (int32_t i = 0; i < rank(shape); i++) { + if (shape[i] == 1) { + shape[i] = shape2[i]; + } else { + ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]); + } + } + return shape; +} +---- The following function maps an index in the output tensor to an index in the input tensor. |