diff options
author | Dominic Symes <dominic.symes@arm.com> | 2023-03-06 17:14:15 +0000 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2023-03-20 10:39:56 +0000 |
commit | 67a6f7f26f092860a4e1b3f6ac0c2e5cde9bf685 (patch) | |
tree | 2be362915e3897118303ef34531db6b2b0200142 | |
parent | 18548921437ef60e5553a75517449918e5c42b1d (diff) | |
download | specification-67a6f7f26f092860a4e1b3f6ac0c2e5cde9bf685.tar.gz |
Check the output shape of broadcast operations
For an operation that performs broadcast the
output shape size must be the maximum of the input shape
sizes in each dimension.
Additionally, the input dimension size must be 1 whenever
an input shape does not match the output shape size
in a dimension.
Signed-off-by: Dominic Symes <dominic.symes@arm.com>
Change-Id: I89492f4ef22da76f84f12e720c79634ea42545bc
-rw-r--r-- | chapters/comparison.adoc | 3 | ||||
-rw-r--r-- | chapters/ewise_binary.adoc | 16 | ||||
-rw-r--r-- | chapters/ewise_ternary.adoc | 1 | ||||
-rw-r--r-- | chapters/pseudocode.adoc | 20 |
4 files changed, 39 insertions, 1 deletions
diff --git a/chapters/comparison.adoc b/chapters/comparison.adoc index f4da361..4ef52d6 100644 --- a/chapters/comparison.adoc +++ b/chapters/comparison.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -39,6 +40,7 @@ include::{generated}/operators/GREATER.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -61,6 +63,7 @@ include::{generated}/operators/GREATER_EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_binary.adoc b/chapters/ewise_binary.adoc index 963d712..4af347a 100644 --- a/chapters/ewise_binary.adoc +++ b/chapters/ewise_binary.adoc @@ -18,6 +18,7 @@ include::{generated}/operators/ADD.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -37,6 +38,7 @@ include::{generated}/operators/ARITHMETIC_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -66,6 +68,7 @@ include::{generated}/operators/BITWISE_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -85,6 +88,7 @@ include::{generated}/operators/BITWISE_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -104,6 +108,7 @@ include::{generated}/operators/BITWISE_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -126,6 +131,7 @@ include::{generated}/operators/INTDIV.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -149,6 +155,7 @@ include::{generated}/operators/LOGICAL_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -168,6 +175,7 @@ include::{generated}/operators/LOGICAL_LEFT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -188,6 +196,7 @@ include::{generated}/operators/LOGICAL_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -208,6 +217,7 @@ include::{generated}/operators/LOGICAL_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -227,6 +237,7 @@ include::{generated}/operators/LOGICAL_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -246,6 +257,7 @@ include::{generated}/operators/MAXIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -265,6 +277,7 @@ include::{generated}/operators/MINIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -284,6 +297,7 @@ include::{generated}/operators/MUL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); ERROR_IF(in_t != int32_t && shift > 0); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); @@ -313,6 +327,7 @@ include::{generated}/operators/POW.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -332,6 +347,7 @@ include::{generated}/operators/SUB.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_ternary.adoc b/chapters/ewise_ternary.adoc index 0b8097d..57cf599 100644 --- a/chapters/ewise_ternary.adoc +++ b/chapters/ewise_ternary.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/SELECT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(broadcast_shape(shape1, shape2), shape3)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc index db699d1..42f123b 100644 --- a/chapters/pseudocode.adoc +++ b/chapters/pseudocode.adoc @@ -125,7 +125,25 @@ void tensor_write<type>(<type> *address, dim_t shape, dim_t index, <type> value) } ---- -==== Broadcast Helper +==== Broadcast Helpers + +The following function derives the broadcast output shape from the input shapes. + +[source,c++] +---- +dim_t broadcast_shape(dim_t shape1, dim_t shape2) { + ERROR_IF(rank(shape1) != rank(shape2)); + dim_t shape = shape1; + for (int32_t i = 0; i < rank(shape); i++) { + if (shape[i] == 1) { + shape[i] = shape2[i]; + } else { + ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]); + } + } + return shape; +} +---- The following function maps an index in the output tensor to an index in the input tensor. |