diff options
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r-- | reference_model/src/tensor.h | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 3fa23f9..5536583 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -148,6 +148,28 @@ public: return 0; } + const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const + { + if (matchRank(ref)) + return 1; + + for (size_t i = 0; i < shape.size(); i++) + { + if (shape[i] != ref.shape[i]) + { + if (!broadcastOk || + // For broadcasts, at least one operand must have size 1 + // if they don't both match + (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1))) + { + return 1; + } + } + } + + return 0; + } + // Sometimes we might want to match several semi-compatible types, // so just check rank and size here const int matchRankSize(const Tensor& ref) const |