aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc41
1 files changed, 16 insertions, 25 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index b199f69..287ad92 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,10 +60,18 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (inputs[0]->matchRank(*outputs[0]))
+ if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */))
{
std::string err =
- "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match";
+ printNodeValidationError(err.c_str());
+ return 1;
+ }
+
+ if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */))
+ {
+ std::string err =
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match";
printNodeValidationError(err.c_str());
return 1;
}
@@ -82,31 +90,14 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
template <int Rank, DType InDtype, DType OutDtype>
int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
{
- auto output_shape = result->getTensor().dimensions();
-
- std::vector<int> a_shape, b_shape;
-
- a_shape = a->getShape();
- b_shape = b->getShape();
+ const std::vector<int>& a_shape = a->getShape();
+ const std::vector<int>& b_shape = b->getShape();
+ const std::vector<int>& output_shape = result->getShape();
- for (int i = 0; i < (int)a_shape.size(); i++)
+ for (int i = 0; i < Rank; i++)
{
- if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
- {
- bcast_a[i] = output_shape[i];
- }
- else
- {
- bcast_a[i] = 1;
- }
- if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
- {
- bcast_b[i] = output_shape[i];
- }
- else
- {
- bcast_b[i] = 1;
- }
+ bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
+ bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
}
return 0;