aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominic Symes <dominic.symes@arm.com>2023-03-06 17:14:15 +0000
committerDominic Symes <dominic.symes@arm.com>2023-03-20 10:39:56 +0000
commit67a6f7f26f092860a4e1b3f6ac0c2e5cde9bf685 (patch)
tree2be362915e3897118303ef34531db6b2b0200142
parent18548921437ef60e5553a75517449918e5c42b1d (diff)
downloadspecification-67a6f7f26f092860a4e1b3f6ac0c2e5cde9bf685.tar.gz
Check the output shape of broadcast operations
For an operation that performs broadcast the output shape size must be the maximum of the input shape sizes in each dimension. Additionally, the input dimension size must be 1 whenever an input shape does not match the output shape size in a dimension. Signed-off-by: Dominic Symes <dominic.symes@arm.com> Change-Id: I89492f4ef22da76f84f12e720c79634ea42545bc
-rw-r--r--chapters/comparison.adoc3
-rw-r--r--chapters/ewise_binary.adoc16
-rw-r--r--chapters/ewise_ternary.adoc1
-rw-r--r--chapters/pseudocode.adoc20
4 files changed, 39 insertions, 1 deletions
diff --git a/chapters/comparison.adoc b/chapters/comparison.adoc
index f4da361..4ef52d6 100644
--- a/chapters/comparison.adoc
+++ b/chapters/comparison.adoc
@@ -17,6 +17,7 @@ include::{generated}/operators/EQUAL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -39,6 +40,7 @@ include::{generated}/operators/GREATER.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -61,6 +63,7 @@ include::{generated}/operators/GREATER_EQUAL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/ewise_binary.adoc b/chapters/ewise_binary.adoc
index 963d712..4af347a 100644
--- a/chapters/ewise_binary.adoc
+++ b/chapters/ewise_binary.adoc
@@ -18,6 +18,7 @@ include::{generated}/operators/ADD.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -37,6 +38,7 @@ include::{generated}/operators/ARITHMETIC_RIGHT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -66,6 +68,7 @@ include::{generated}/operators/BITWISE_AND.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -85,6 +88,7 @@ include::{generated}/operators/BITWISE_OR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -104,6 +108,7 @@ include::{generated}/operators/BITWISE_XOR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -126,6 +131,7 @@ include::{generated}/operators/INTDIV.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -149,6 +155,7 @@ include::{generated}/operators/LOGICAL_AND.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -168,6 +175,7 @@ include::{generated}/operators/LOGICAL_LEFT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -188,6 +196,7 @@ include::{generated}/operators/LOGICAL_RIGHT_SHIFT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -208,6 +217,7 @@ include::{generated}/operators/LOGICAL_OR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -227,6 +237,7 @@ include::{generated}/operators/LOGICAL_XOR.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -246,6 +257,7 @@ include::{generated}/operators/MAXIMUM.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -265,6 +277,7 @@ include::{generated}/operators/MINIMUM.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -284,6 +297,7 @@ include::{generated}/operators/MUL.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
ERROR_IF(in_t != int32_t && shift > 0);
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
@@ -313,6 +327,7 @@ include::{generated}/operators/POW.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
@@ -332,6 +347,7 @@ include::{generated}/operators/SUB.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(shape1, shape2));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/ewise_ternary.adoc b/chapters/ewise_ternary.adoc
index 0b8097d..57cf599 100644
--- a/chapters/ewise_ternary.adoc
+++ b/chapters/ewise_ternary.adoc
@@ -17,6 +17,7 @@ include::{generated}/operators/SELECT.adoc[]
[source,c++]
----
+ERROR_IF(shape != broadcast_shape(broadcast_shape(shape1, shape2), shape3));
for_each(index in shape) {
dim_t index1 = apply_broadcast(shape, shape1, index);
dim_t index2 = apply_broadcast(shape, shape2, index);
diff --git a/chapters/pseudocode.adoc b/chapters/pseudocode.adoc
index db699d1..42f123b 100644
--- a/chapters/pseudocode.adoc
+++ b/chapters/pseudocode.adoc
@@ -125,7 +125,25 @@ void tensor_write<type>(<type> *address, dim_t shape, dim_t index, <type> value)
}
----
-==== Broadcast Helper
+==== Broadcast Helpers
+
+The following function derives the broadcast output shape from the input shapes.
+
+[source,c++]
+----
+dim_t broadcast_shape(dim_t shape1, dim_t shape2) {
+ ERROR_IF(rank(shape1) != rank(shape2));
+ dim_t shape = shape1;
+ for (int32_t i = 0; i < rank(shape); i++) {
+ if (shape[i] == 1) {
+ shape[i] = shape2[i];
+ } else {
+ ERROR_IF(shape2[i] != 1 && shape2[i] != shape[i]);
+ }
+ }
+ return shape;
+}
+----
The following function maps an index in the output tensor to an index in the input tensor.