aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/tensor.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-11-08 11:19:10 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-11-09 20:59:06 +0000
commit1c3c847a4368817e2c9e3af66d5deb4c67993cbc (patch)
tree913962b18e07dc5e0d837dc182b76066538ebc65 /reference_model/src/tensor.h
parent01c359de3ad4fc617ec7726fd7103749f6d3933f (diff)
downloadreference_model-1c3c847a4368817e2c9e3af66d5deb4c67993cbc.tar.gz
Check valid broadcastable shape for binary and ternary ops
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I9ed3d8971a133b4cbb2cf7d827f4e69d55dee246
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r--reference_model/src/tensor.h22
1 files changed, 22 insertions, 0 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 3fa23f9..5536583 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -148,6 +148,28 @@ public:
return 0;
}
+ const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const
+ {
+ if (matchRank(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ {
+ if (!broadcastOk ||
+ // For broadcasts, at least one operand must have size 1
+ // if they don't both match
+ (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ {
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+ }
+
// Sometimes we might want to match several semi-compatible types,
// so just check rank and size here
const int matchRankSize(const Tensor& ref) const