From 67a6f7f26f092860a4e1b3f6ac0c2e5cde9bf685 Mon Sep 17 00:00:00 2001 From: Dominic Symes Date: Mon, 6 Mar 2023 17:14:15 +0000 Subject: Check the output shape of broadcast operations For an operation that performs broadcast the output shape size must be the maximum of the input shape sizes in each dimension. Additionally, the input dimension size must be 1 whenever an input shape does not match the output shape size in a dimension. Signed-off-by: Dominic Symes Change-Id: I89492f4ef22da76f84f12e720c79634ea42545bc --- chapters/comparison.adoc | 3 +++ chapters/ewise_binary.adoc | 16 ++++++++++++++++ chapters/ewise_ternary.adoc | 1 + chapters/pseudocode.adoc | 20 +++++++++++++++++++- 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/chapters/comparison.adoc b/chapters/comparison.adoc index f4da361..4ef52d6 100644 --- a/chapters/comparison.adoc +++ b/chapters/comparison.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -39,6 +40,7 @@ include::{generated}/operators/GREATER.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -61,6 +63,7 @@ include::{generated}/operators/GREATER_EQUAL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_binary.adoc b/chapters/ewise_binary.adoc index 963d712..4af347a 100644 --- a/chapters/ewise_binary.adoc +++ b/chapters/ewise_binary.adoc @@ -18,6 +18,7 @@ include::{generated}/operators/ADD.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -37,6 +38,7 @@ include::{generated}/operators/ARITHMETIC_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -66,6 +68,7 @@ include::{generated}/operators/BITWISE_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -85,6 +88,7 @@ include::{generated}/operators/BITWISE_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -104,6 +108,7 @@ include::{generated}/operators/BITWISE_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -126,6 +131,7 @@ include::{generated}/operators/INTDIV.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -149,6 +155,7 @@ include::{generated}/operators/LOGICAL_AND.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -168,6 +175,7 @@ include::{generated}/operators/LOGICAL_LEFT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -188,6 +196,7 @@ include::{generated}/operators/LOGICAL_RIGHT_SHIFT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -208,6 +217,7 @@ include::{generated}/operators/LOGICAL_OR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -227,6 +237,7 @@ include::{generated}/operators/LOGICAL_XOR.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -246,6 +257,7 @@ include::{generated}/operators/MAXIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -265,6 +277,7 @@ include::{generated}/operators/MINIMUM.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -284,6 +297,7 @@ include::{generated}/operators/MUL.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); ERROR_IF(in_t != int32_t && shift > 0); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); @@ -313,6 +327,7 @@ include::{generated}/operators/POW.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); @@ -332,6 +347,7 @@ include::{generated}/operators/SUB.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(shape1, shape2)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/ewise_ternary.adoc b/chapters/ewise_ternary.adoc index 0b8097d..57cf599 100644 --- a/chapters/ewise_ternary.adoc +++ b/chapters/ewise_ternary.adoc @@ -17,6 +17,7 @@ include::{generated}/operators/SELECT.adoc[] [source,c++] ---- +ERROR_IF(shape != broadcast_shape(broadcast_shape(shape1, shape2), shape3)); for_each(index in shape) { dim_t index1 = apply_broadcast(shape, shape1, index); dim_t index2 = apply_broadcast(shape, shape2, index); diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc index db699d1..42f123b 100644 --- a/chapters/pseudocode.adoc +++ b/chapters/pseudocode.adoc @@ -125,7 +125,25 @@ void tensor_write( *address, dim_t shape, dim_t index, value) } ---- -==== Broadcast Helper +==== Broadcast Helpers + +The following function derives the broadcast output shape from the input shapes. + +[source,c++] +---- +dim_t broadcast_shape(dim_t shape1, dim_t shape2) { + ERROR_IF(rank(shape1) != rank(shape2)); + dim_t shape = shape1; + for (int32_t i = 0; i < rank(shape); i++) { + if (shape[i] == 1) { + shape[i] = shape2[i]; + } else { + ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]); + } + } + return shape; +} +---- The following function maps an index in the output tensor to an index in the input tensor. -- cgit v1.2.1