aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/tensor.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 5536583..d857dc8 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -136,9 +136,11 @@ public:
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}
@@ -158,9 +160,11 @@ public:
if (shape[i] != ref.shape[i])
{
if (!broadcastOk ||
- // For broadcasts, at least one operand must have size 1
- // if they don't both match
- (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ // For broadcasts, the order of *this and ref matters.
+ // *this should be the source tensor.
+ // ref should be the target tensor. In most of the case, ref is expected to be the output tensor.
+ // this->shape must have size 1 if they don't match
+ (broadcastOk && (shape[i] != 1)))
{
return 1;
}