aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-14 17:09:57 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-10-18 18:50:08 +0000
commitcc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2 (patch)
tree2d664f87e3fdd75de8c6794f6f6c8d6364ece6bb /reference_model/src/ops/ewise_binary.cc
parente807aae606a78d923a2565052f7c2179e3050650 (diff)
downloadreference_model-cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2.tar.gz
More ERROR_IF supports
- Also delay tensor allocation after operator being validated ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE - Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc20
1 files changed, 6 insertions, 14 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 023158c..6808604 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,26 +60,16 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- // In some ops, only rank of input and output tensor needs to match
- if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
- {
- if (inputs[0]->matchRank(*outputs[0]))
- {
- std::string err =
- "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
- printNodeValidationError(err.c_str());
- return 1;
- }
- }
- // Otherwise both rand/type of input and output must match
- else if (inputs[0]->matchRankType(*outputs[0]))
+ if (inputs[0]->matchRank(*outputs[0]))
{
std::string err =
- "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
printNodeValidationError(err.c_str());
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
+
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -532,6 +522,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8");
}
else if (inputs[0]->getDtype() == DType_INT16)
{
@@ -540,6 +531,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32");
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);