diff options
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 20 |
1 files changed, 6 insertions, 14 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 023158c..6808604 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -60,26 +60,16 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() return 1; } - // In some ops, only rank of input and output tensor needs to match - if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL) - { - if (inputs[0]->matchRank(*outputs[0])) - { - std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; - printNodeValidationError(err.c_str()); - return 1; - } - } - // Otherwise both rand/type of input and output must match - else if (inputs[0]->matchRankType(*outputs[0])) + if (inputs[0]->matchRank(*outputs[0])) { std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match"; + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; printNodeValidationError(err.c_str()); return 1; } + ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match"); + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); @@ -532,6 +522,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8"); return 1; } + ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8"); } else if (inputs[0]->getDtype() == DType_INT16) { @@ -540,6 +531,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16"); return 1; } + ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32"); } in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); |