diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 101 |
1 files changed, 26 insertions, 75 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index c33f646..a11d855 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -32,10 +32,8 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); setRequiredRank(0, 6); - a_rank = b_rank = max_input_rank = -1; - a = b = nullptr; - a_rank0 = b_rank0 = nullptr; - result = nullptr; + a = b = nullptr; + result = nullptr; fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } @@ -55,54 +53,37 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() return 1; } - a_rank = inputs[0]->getRank(); - b_rank = inputs[1]->getRank(); - if (a_rank != 0 && b_rank != 0 && a_rank != b_rank) - { - printNodeValidationError("Binary operator input ranks must match"); - return 1; - } - - max_input_rank = a_rank >= b_rank ? a_rank : b_rank; - - // A & B must be the same types - if (inputs[0]->matchType(*inputs[1])) + // A & B must be the same rank and types + if (inputs[0]->matchRankType(*inputs[1])) { printNodeValidationError("Binary operator input types must match"); return 1; } - // Result's geometry must match, but the type may be wider - if (outputs[0]->getRank() != max_input_rank) - { - printNodeValidationError("Binary operator input and output genometry must match"); - return 1; - } - - if (a_rank == max_input_rank) - { - a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - } - else - { - a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]); - } - - if (b_rank == max_input_rank) + // Input and output rank must match + // If it's not MUL, type also needs to match as well. + if (nodeType != Op_MUL) { - b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Binary operators (except MUL) input and output rank and type must match"); + return 1; + } } else { - b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]); + if (inputs[0]->matchRank(*outputs[0])) + { + printNodeValidationError("MUL operator input and output rank must match"); + return 1; + } } + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - // either a or b can be rank0 - // a_rank0 and b_rank0 can't be valid at the same time. - // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0' - ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result); + ASSERT_MEM(a && b && result); return 0; } @@ -114,25 +95,10 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() std::vector<int> a_shape, b_shape; - if (a_rank == max_input_rank) - { - a_shape = a->getShape(); - } - else - { - a_shape.assign(max_input_rank, 1); - } + a_shape = a->getShape(); + b_shape = b->getShape(); - if (b_rank == max_input_rank) - { - b_shape = b->getShape(); - } - else - { - b_shape.assign(max_input_rank, 1); - } - - for (int i = 0; i < max_input_rank; i++) + for (int i = 0; i < (int)a_shape.size(); i++) { if (a_shape[i] != output_shape[i] && a_shape[i] == 1) { @@ -164,23 +130,8 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval() reshaper.fill(1); TIn ia, ib; - if (this->a_rank == this->max_input_rank) - { - ia = this->a->getTensor().broadcast(this->bcast_a); - } - else - { - ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a); - } - - if (this->b_rank == this->max_input_rank) - { - ib = this->b->getTensor().broadcast(this->bcast_b); - } - else - { - ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b); - } + ia = this->a->getTensor().broadcast(this->bcast_a); + ib = this->b->getTensor().broadcast(this->bcast_b); this->result->getTensor() = ia.binaryExpr(ib, this->fcn); @@ -475,7 +426,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() } else { - result = static_cast<int64_t>(a) * b; + result = static_cast<int64_t>(a) * b; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min()); REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); |