aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-11-11 19:35:30 +0000
committerKevin Cheng <kevin.cheng@arm.com>2021-11-11 17:00:53 -0800
commit2131a4d2653c0cd8eeddd934f96da8f85717e2d9 (patch)
tree829b565a0ba01d06b30cb1206ee98cdec334b00a
parent9fe172483b77dcaa0bfe7e97af4a934d6ef01a16 (diff)
downloadreference_model-2131a4d2653c0cd8eeddd934f96da8f85717e2d9.tar.gz
Fix broadcast bug
- test like [1] + [2] = [1] should be treated as invalid test - modify matchRankShape() function so it allows size 1 only on the source tensor but not target tensor Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I6bbb6a63dc1143712e7eef736a991cac419b009e
-rw-r--r--reference_model/src/tensor.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 5536583..d857dc8 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -136,9 +136,11 @@ public:
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}
@@ -158,9 +160,11 @@ public:
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}