diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-11-08 11:19:10 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-11-09 20:59:06 +0000 |
commit | 1c3c847a4368817e2c9e3af66d5deb4c67993cbc (patch) | |
tree | 913962b18e07dc5e0d837dc182b76066538ebc65 /reference_model/src/ops/ewise_binary.cc | |
parent | 01c359de3ad4fc617ec7726fd7103749f6d3933f (diff) | |
download | reference_model-1c3c847a4368817e2c9e3af66d5deb4c67993cbc.tar.gz |
Check valid broadcastable shape for binary and ternary ops
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I9ed3d8971a133b4cbb2cf7d827f4e69d55dee246
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 41 |
1 files changed, 16 insertions, 25 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index b199f69..287ad92 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -60,10 +60,18 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() return 1; } - if (inputs[0]->matchRank(*outputs[0])) + if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */)) { std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match"; + printNodeValidationError(err.c_str()); + return 1; + } + + if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */)) + { + std::string err = + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match"; printNodeValidationError(err.c_str()); return 1; } @@ -82,31 +90,14 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() template <int Rank, DType InDtype, DType OutDtype> int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() { - auto output_shape = result->getTensor().dimensions(); - - std::vector<int> a_shape, b_shape; - - a_shape = a->getShape(); - b_shape = b->getShape(); + const std::vector<int>& a_shape = a->getShape(); + const std::vector<int>& b_shape = b->getShape(); + const std::vector<int>& output_shape = result->getShape(); - for (int i = 0; i < (int)a_shape.size(); i++) + for (int i = 0; i < Rank; i++) { - if (a_shape[i] != output_shape[i] && a_shape[i] == 1) - { - bcast_a[i] = output_shape[i]; - } - else - { - bcast_a[i] = 1; - } - if (b_shape[i] != output_shape[i] && b_shape[i] == 1) - { - bcast_b[i] = output_shape[i]; - } - else - { - bcast_b[i] = 1; - } + bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; + bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } return 0; |