From a4d748b08accce06fab93e2d2b96e499b35ae89b Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 28 Mar 2023 22:06:56 +0000 Subject: [reference model] Add precise mode This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly Change-Id: I156055216ad61710096497a8fa1a653be2a602a3 --- reference_model/src/ops/ewise_binary.cc | 210 ++++++++++++++++++-------------- 1 file changed, 116 insertions(+), 94 deletions(-) (limited to 'reference_model/src/ops/ewise_binary.cc') diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 6aa0c0f..c5801e7 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) @@ -37,11 +37,11 @@ BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } -template +template BinaryNodeBase::~BinaryNodeBase() {} -template +template int BinaryNodeBase::checkTensorAttributes() { // Check Tosa Level @@ -90,7 +90,7 @@ int BinaryNodeBase::checkTensorAttributes() return 0; } -template +template int BinaryNodeBase::broadcast() { const std::vector& a_shape = a->getShape(); @@ -106,7 +106,7 @@ int BinaryNodeBase::broadcast() return 0; } -template +template int BinaryNode::eval() { this->broadcast(); @@ -124,7 +124,7 @@ int BinaryNode::eval() } // still need to partial specialize this, or Eigen will throw static assertion -template +template int BinaryNode<0, InDtype, OutDtype>::eval() { this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); @@ -132,12 +132,12 @@ int BinaryNode<0, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template +template int OpAdd::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) + b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -146,36 +146,39 @@ int OpAdd::register_fcn() return static_cast(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a + b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template int OpArithmeticRightShift::register_fcn() { bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: num_bits = 8; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: num_bits = 16; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { @@ -195,69 +198,69 @@ int OpArithmeticRightShift::register_fcn() return 0; } -template +template OpArithmeticRightShift::~OpArithmeticRightShift() { if (attribute) delete attribute; } -template +template int OpBitwiseAnd::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpBitwiseOr::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpBitwiseXor::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpIntdiv::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast(a) / b; @@ -268,47 +271,47 @@ int OpIntdiv::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template int OpLogicalAnd::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalLeftShift::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -316,32 +319,32 @@ int OpLogicalLeftShift::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalRightShift::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -349,91 +352,96 @@ int OpLogicalRightShift::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalOr::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalXor::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMaximum::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMinimum::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMul::register_fcn() { int32_t shift = attribute->shift(); switch (InDtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a * b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; + case TOSA_REF_TYPE_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { int64_t result; if (shift > 0) @@ -457,8 +465,8 @@ int OpMul::register_fcn() return static_cast(result); }; break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; @@ -468,41 +476,44 @@ int OpMul::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template OpMul::~OpMul() { if (attribute) delete attribute; } -template +template int OpPow::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(powf(a, b)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpSub::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) - b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -511,19 +522,22 @@ int OpSub::register_fcn() return static_cast(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a - b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template OpTable::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -535,13 +549,13 @@ OpTable::OpTable(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Table); } -template +template OpTable::~OpTable() { if (attribute) delete attribute; } -template +template int OpTable::checkTensorAttributes() { // Check Tosa Level @@ -573,12 +587,12 @@ int OpTable::checkTensorAttributes() return 0; } -template +template int OpTable::eval() { switch (InDtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; @@ -587,7 +601,7 @@ int OpTable::eval() return value; }); break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { // 1. make sure input is int16 range int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); @@ -610,7 +624,7 @@ int OpTable::eval() }); break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return GraphNode::eval(); @@ -630,11 +644,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); @@ -672,11 +688,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); @@ -684,15 +702,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); @@ -703,3 +724,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL); -- cgit v1.2.1