diff options
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r-- | reference_model/src/tensor.h | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 5536583..d857dc8 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -136,9 +136,11 @@ public: if (shape[i] != ref.shape[i]) { if (!broadcastOk || - // For broadcasts, at least one operand must have size 1 - // if they don't both match - (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1))) + // For broadcasts, the order of *this and ref matters. + // *this should be the source tensor. + // ref should be the target tensor. In most of the case, ref is expected to be the output tensor. + // this->shape must have size 1 if they don't match + (broadcastOk && (shape[i] != 1))) { return 1; } @@ -158,9 +160,11 @@ public: if (shape[i] != ref.shape[i]) { if (!broadcastOk || - // For broadcasts, at least one operand must have size 1 - // if they don't both match - (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1))) + // For broadcasts, the order of *this and ref matters. + // *this should be the source tensor. + // ref should be the target tensor. In most of the case, ref is expected to be the output tensor. + // this->shape must have size 1 if they don't match + (broadcastOk && (shape[i] != 1))) { return 1; } |