From acb550f4410ae861e53cae27a9feb4b11d45769f Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 29 Jun 2021 15:32:19 -0700 Subject: Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF() - Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1 --- reference_model/src/ops/activation_funcs.cc | 8 +- reference_model/src/ops/activation_funcs.h | 16 ++-- reference_model/src/ops/comparison.cc | 6 +- reference_model/src/ops/comparison.h | 12 +-- reference_model/src/ops/control_flow.cc | 15 ++-- reference_model/src/ops/control_flow.h | 6 +- reference_model/src/ops/custom.cc | 6 +- reference_model/src/ops/custom.h | 2 +- reference_model/src/ops/data_layout.cc | 51 ++++++++---- reference_model/src/ops/data_layout.h | 22 +++--- reference_model/src/ops/data_nodes.cc | 11 ++- reference_model/src/ops/data_nodes.h | 4 +- reference_model/src/ops/ewise_binary.cc | 62 ++++++++------- reference_model/src/ops/ewise_binary.h | 27 ++++--- reference_model/src/ops/ewise_ternary.cc | 15 ++-- reference_model/src/ops/ewise_ternary.h | 10 +-- reference_model/src/ops/ewise_unary.cc | 33 ++++---- reference_model/src/ops/ewise_unary.h | 10 +-- reference_model/src/ops/image.cc | 34 ++++---- reference_model/src/ops/image.h | 2 +- reference_model/src/ops/op_factory.cc | 11 +-- reference_model/src/ops/op_factory.h | 34 ++++---- reference_model/src/ops/reduction.cc | 4 +- reference_model/src/ops/reduction.h | 26 +++--- reference_model/src/ops/scatter_gather.cc | 18 +++-- reference_model/src/ops/scatter_gather.h | 4 +- reference_model/src/ops/tensor_ops.cc | 118 +++++++++++++++++----------- reference_model/src/ops/tensor_ops.h | 16 ++-- reference_model/src/ops/type_conversion.cc | 105 +++++++++++++++---------- reference_model/src/ops/type_conversion.h | 4 +- 30 files changed, 397 insertions(+), 295 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 3410ba9..440f4e1 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -44,7 +44,7 @@ int OpClamp::register_fcn() } break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -69,7 +69,7 @@ int OpReluN::register_fcn() } break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -84,7 +84,7 @@ int OpSigmoid::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -99,7 +99,7 @@ int OpTanh::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index b051b9d..c834b52 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -28,8 +28,8 @@ template class OpClamp : public UnaryNode { public: - OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode(Op_CLAMP, id_) + OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); register_fcn(); @@ -48,8 +48,8 @@ template class OpReluN : public UnaryNode { public: - OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode(Op_RELUN, id_) + OpReluN(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(sgt_, Op_RELUN, id_) { INIT_ATTRIBUTE(ReluN); register_fcn(); @@ -68,8 +68,8 @@ template class OpSigmoid : public UnaryNode { public: - OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode(Op_SIGMOID, id_) + OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(sgt_, Op_SIGMOID, id_) { register_fcn(); } @@ -84,8 +84,8 @@ template class OpTanh : public UnaryNode { public: - OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode(Op_TANH, id_) + OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode(sgt_, Op_TANH, id_) { register_fcn(); } diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index 402e152..ab89e24 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -32,7 +32,7 @@ int OpEqual::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -48,7 +48,7 @@ int OpGreater::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -64,7 +64,7 @@ int OpGreaterEqual::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index e75b1a6..5b4d0f1 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -28,8 +28,8 @@ template class OpEqual : public BinaryNode { public: - OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(Op_EQUAL, qinfo_, id_) + OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(sgt_, Op_EQUAL, qinfo_, id_) { register_fcn(); } @@ -42,8 +42,8 @@ template class OpGreater : public BinaryNode { public: - OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(Op_GREATER, qinfo_, id_) + OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(sgt_, Op_GREATER, qinfo_, id_) { register_fcn(); } @@ -56,8 +56,8 @@ template class OpGreaterEqual : public BinaryNode { public: - OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(Op_EQUAL, qinfo_, id_) + OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(sgt_, Op_EQUAL, qinfo_, id_) { register_fcn(); } diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 1a6a63a..0945056 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -20,8 +20,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_) - : GraphNode(op_, id_) +OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { tsh = tsh_; } @@ -148,8 +148,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, return 0; } -OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) - : OpControlFlow(tsh_, Op_COND_IF, id_) +OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_) { INIT_ATTRIBUTE(CondIf); } @@ -221,8 +221,11 @@ int OpCondIf::eval() return GraphNode::eval(); } -OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) - : OpControlFlow(tsh_, Op_WHILE_LOOP, id_) +OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_, + TosaSerializationHandler* tsh_, + TosaAttributeBase* attribute_, + uint64_t id_) + : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_) { INIT_ATTRIBUTE(WhileLoop); } diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h index 14c11bc..879cd6a 100644 --- a/reference_model/src/ops/control_flow.h +++ b/reference_model/src/ops/control_flow.h @@ -25,7 +25,7 @@ namespace TosaReference class OpControlFlow : public GraphNode { public: - OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_); + OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_); ~OpControlFlow(); virtual int evalBlock(TosaSerializationBasicBlock* block, @@ -39,7 +39,7 @@ protected: class OpCondIf : public OpControlFlow { public: - OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpCondIf(); virtual int checkTensorAttributes(); @@ -55,7 +55,7 @@ protected: class OpWhileLoop : public OpControlFlow { public: - OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + OpWhileLoop(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpWhileLoop(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc index 5c4f29b..5fc36f3 100644 --- a/reference_model/src/ops/custom.cc +++ b/reference_model/src/ops/custom.cc @@ -19,8 +19,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpCustom::OpCustom(uint64_t id_) - : GraphNode(Op_CUSTOM, id_) +OpCustom::OpCustom(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CUSTOM, id_) {} OpCustom::~OpCustom() @@ -33,7 +33,7 @@ int OpCustom::checkTensorAttributes() int OpCustom::eval() { - FATAL_ERROR_NODE("not supported yet"); + FATAL_ERROR("not supported yet"); // Evaluation is trivial for constants return GraphNode::eval(); diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h index b1085a5..d14c809 100644 --- a/reference_model/src/ops/custom.h +++ b/reference_model/src/ops/custom.h @@ -26,7 +26,7 @@ namespace TosaReference class OpCustom : public GraphNode { public: - OpCustom(uint64_t id_); + OpCustom(SubgraphTraverser* sgt_, uint64_t id_); virtual ~OpCustom(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index c66d64e..86326f5 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -21,8 +21,11 @@ using namespace Eigen; using namespace tosa; template -OpConcat::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CONCAT, id_) +OpConcat::OpConcat(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CONCAT, id_) { setRequiredOperands(-1, 1); setRequiredRank(1, 6); @@ -95,8 +98,11 @@ int OpConcat::eval() } template -OpPad::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_PAD, id_) +OpPad::OpPad(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -157,8 +163,11 @@ int OpPad::eval() } template -OpReshape::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESHAPE, id_) +OpReshape::OpReshape(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESHAPE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -274,8 +283,11 @@ int OpReshape::eval() } template -OpReverse::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_REVERSE, id_) +OpReverse::OpReverse(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_REVERSE, id_) { setRequiredOperands(1, 1); setRequiredRank(1, 6); @@ -339,8 +351,11 @@ int OpReverse::eval() } template -OpSlice::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SLICE, id_) +OpSlice::OpSlice(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -407,8 +422,11 @@ int OpSlice::eval() } template -OpTileBase::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TILE, id_) +OpTileBase::OpTileBase(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -466,7 +484,7 @@ template int OpTile::eval() { // primary template shouldn't be called - FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); + FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); } template @@ -542,8 +560,11 @@ int OpTile<4, Dtype>::eval() } template -OpTranspose::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TRANSPOSE, id_) +OpTranspose::OpTranspose(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index b180b4f..c9c2602 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -27,7 +27,7 @@ template class OpConcat : public GraphNode { public: - OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConcat(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpPad : public GraphNode { public: - OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpPad(); virtual int checkTensorAttributes(); virtual int eval(); @@ -70,7 +70,7 @@ template class OpReshape : public GraphNode { public: - OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReshape(); virtual int checkTensorAttributes(); @@ -94,7 +94,7 @@ template class OpReverse : public GraphNode { public: - OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReverse(); virtual int checkTensorAttributes(); @@ -116,7 +116,7 @@ template class OpSlice : public GraphNode { public: - OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpSlice(); virtual int checkTensorAttributes(); @@ -139,7 +139,7 @@ template class OpTileBase : public GraphNode { public: - OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTileBase(); virtual int checkTensorAttributes(); @@ -160,8 +160,8 @@ template class OpTile : public OpTileBase { public: - OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpTileBase(attribute_, qinfo_, id_) + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpTileBase(sgt_, attribute_, qinfo_, id_) {} protected: @@ -174,8 +174,8 @@ protected: class OpTile : public OpTileBase \ { \ public: \ - OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : OpTileBase(attribute_, qinfo_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : OpTileBase(sgt_, attribute_, qinfo_, id_) \ {} \ \ protected: \ @@ -193,7 +193,7 @@ template class OpTranspose : public GraphNode { public: - OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTranspose(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index baae019..ec4bc41 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -19,8 +19,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpConst::OpConst(uint64_t id_) - : GraphNode(Op_CONST, id_) +OpConst::OpConst(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CONST, id_) { setRequiredOperands(0, 1); } @@ -43,8 +43,11 @@ int OpConst::eval() } template -OpIdentity::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_IDENTITY, id_) +OpIdentity::OpIdentity(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_IDENTITY, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index a02d441..407cf0a 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -24,7 +24,7 @@ namespace TosaReference class OpConst : public GraphNode { public: - OpConst(uint64_t id_); + OpConst(SubgraphTraverser* sgt_, uint64_t id_); virtual ~OpConst(); virtual int checkTensorAttributes(); @@ -35,7 +35,7 @@ template class OpIdentity : public GraphNode { public: - OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpIdentity(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 3379ffe..16c4901 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -23,8 +23,11 @@ using namespace Eigen; using namespace tosa; template -BinaryNodeBase::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(op_, id_) +BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, + const Op& op_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -203,7 +206,7 @@ int OpAdd::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -226,12 +229,12 @@ int OpArithmeticRightShift::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", - (int32_t)b, num_bits); + REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", + (int32_t)b, num_bits); InEigenType acc = a >> b; @@ -257,7 +260,7 @@ int OpBitwiseAnd::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -274,7 +277,7 @@ int OpBitwiseOr::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -291,7 +294,7 @@ int OpBitwiseXor::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -304,15 +307,15 @@ int OpDiv::register_fcn() { case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { - ASSERT_MSG_NODE(b != 0, "OpDiv: divisor must be non-zero value"); + REQUIRE(b != 0, "OpDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast(a) / b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); - ASSERT_MSG_NODE(a <= i32_max_in_64, "OpDiv: result not in i32 range"); + REQUIRE(a <= i32_max_in_64, "OpDiv: result not in i32 range"); return static_cast(res_in_64); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -327,7 +330,7 @@ int OpLogicalAnd::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -344,7 +347,7 @@ int OpLogicalLeftShift::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -366,7 +369,7 @@ int OpLogicalRightShift::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { @@ -386,7 +389,7 @@ int OpLogicalOr::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -401,7 +404,7 @@ int OpLogicalXor::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -417,7 +420,7 @@ int OpMaximum::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -433,7 +436,7 @@ int OpMinimum::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -443,8 +446,6 @@ template int OpMul::register_fcn() { int32_t shift = attribute->shift(); - ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift, - EnumNamesDType()[InDtype]); switch (InDtype) { @@ -460,8 +461,8 @@ int OpMul::register_fcn() result = static_cast(a) * static_cast(b) + round; result = result >> shift; - ASSERT_MSG_NODE(result >= QMin && result <= QMax, - "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); + REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]", + result, QMin, QMax); } else { @@ -482,7 +483,7 @@ int OpMul::register_fcn() }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -497,7 +498,7 @@ int OpPow::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -513,15 +514,18 @@ int OpSub::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template -OpTable::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TABLE, id_) +OpTable::OpTable(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -607,7 +611,7 @@ int OpTable::eval() }); break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return GraphNode::eval(); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index a5b1059..86b2101 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -42,7 +42,7 @@ template class BinaryNodeBase : public GraphNode { public: - BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; @@ -76,8 +76,8 @@ template class BinaryNode : public BinaryNodeBase { public: - BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase(op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase(sgt_, op_, qinfo_, id_) {} virtual ~BinaryNode() {} @@ -95,8 +95,8 @@ template class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: - BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) {} virtual ~BinaryNode() {} @@ -109,8 +109,8 @@ public: class Op##Opname : public BinaryNode \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode(Op_##OPNAME, qinfo_, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode(sgt_, Op_##OPNAME, qinfo_, id_) \ { \ register_fcn(); \ } \ @@ -142,8 +142,11 @@ template class OpArithmeticRightShift : public BinaryNode { public: - OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + OpArithmeticRightShift(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : BinaryNode(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) { INIT_ATTRIBUTE(ArithmeticRightShift); register_fcn(); @@ -160,8 +163,8 @@ template class OpMul : public BinaryNode { public: - OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode(Op_MUL, qinfo_, id_) + OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode(sgt_, Op_MUL, qinfo_, id_) { INIT_ATTRIBUTE(Mul); register_fcn(); @@ -180,7 +183,7 @@ template class OpTable : public GraphNode { public: - OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index d4845f9..64c4412 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -20,8 +20,11 @@ using namespace Eigen; using namespace tosa; template -OpSelectBase::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SELECT, id_) +OpSelectBase::OpSelectBase(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SELECT, id_) { setRequiredOperands(3, 1); setRequiredRank(0, 6); @@ -62,7 +65,7 @@ int OpSelectBase::checkTensorAttributes() template int OpSelectBase::eval() { - FATAL_ERROR_NODE("shouldn't be called"); + FATAL_ERROR("shouldn't be called"); } template @@ -78,9 +81,9 @@ int OpSelect::broadcast() this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; - ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); - ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); - ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); } return 0; diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index b354247..b80fb23 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -33,7 +33,7 @@ template class OpSelectBase : public GraphNode { public: - OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpSelectBase(); virtual int checkTensorAttributes(); @@ -59,8 +59,8 @@ template class OpSelect : public OpSelectBase { public: - OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase(attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase(sgt_, attribute_, qinfo_, id_) {} virtual int eval(); int broadcast(); @@ -73,8 +73,8 @@ template class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: - OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 95a1102..041bbdb 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -23,8 +23,8 @@ using namespace Eigen; using namespace tosa; template -UnaryNode::UnaryNode(const Op& op_, uint64_t id_) - : GraphNode(op_, id_) +UnaryNode::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -80,7 +80,7 @@ int OpAbs::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -97,7 +97,7 @@ int OpBitwiseNot::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -112,7 +112,7 @@ int OpCeil::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -128,7 +128,7 @@ int OpClz::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -159,7 +159,7 @@ int OpExp::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -174,7 +174,7 @@ int OpFloor::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -189,7 +189,7 @@ int OpLog::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -204,7 +204,7 @@ int OpLogicalNot::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -213,6 +213,12 @@ int OpLogicalNot::register_fcn() template int OpNegate::register_fcn() { + if (Dtype != DType_INT8 && this->qinfo) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + } + switch (Dtype) { case DType_FLOAT: @@ -229,7 +235,6 @@ int OpNegate::register_fcn() }; break; case DType_INT8: - ASSERT(this->qinfo); this->fcn = [this](InEigenType a) -> OutEigenType { InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); result = std::min(std::max(result, static_cast(QMin)), static_cast(QMax)); @@ -237,7 +242,7 @@ int OpNegate::register_fcn() }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -252,7 +257,7 @@ int OpReciprocal::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -267,7 +272,7 @@ int OpRsqrt::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 0db3cfb..374c8e4 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -26,7 +26,7 @@ template class UnaryNode : public GraphNode { public: - UnaryNode(const Op& nodeType, const uint64_t id_); + UnaryNode(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_); virtual ~UnaryNode(); virtual int checkTensorAttributes() final; @@ -49,8 +49,8 @@ protected: class Op##Opname : public UnaryNode \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode(Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ @@ -66,8 +66,8 @@ protected: class Op##Opname : public UnaryNode \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode(Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode(sgt_, Op_##OPNAME, id_) \ { \ INIT_QINFO(Unary); \ register_fcn(); \ diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 829a6e0..f4decae 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -22,8 +22,11 @@ using namespace Eigen; using namespace tosa; template -OpResize::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESIZE, id_) +OpResize::OpResize(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESIZE, id_) { setRequiredOperands(1, 1); setRequiredRank(4, 4); @@ -102,10 +105,13 @@ int OpResize::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]"); - ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift < 1 || shift > 11, "OpResize: attribute shift should be within [1, 11]"); + ERROR_IF(stride[0] <= 0 || stride[0] >= (16 << shift), "OpResize: invalid attribute stride_x"); + ERROR_IF(stride[1] <= 0 || stride[1] >= (16 << shift), "OpResize: invalid attribute stride_y"); + ERROR_IF(offset[0] <= (-16 << shift) || offset[0] >= (16 << shift), "OpResize: invalid attribute offset_x"); + ERROR_IF(offset[1] <= (-16 << shift) || offset[1] >= (16 << shift), "OpResize: invalid attribute offset_y"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -125,8 +131,8 @@ int OpResize::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) @@ -167,10 +173,10 @@ int OpResize::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift"); - ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift != 0, "OpResize: float mode must have 0 shift"); + ERROR_IF(stride_fp[0] <= 0.0f || stride_fp[1] <= 0.0f, "OpResize: invalid attribute stride"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -190,8 +196,8 @@ int OpResize::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 5dd14c8..095dc7d 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -27,7 +27,7 @@ template class OpResize : public GraphNode { public: - OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpResize(); virtual int checkTensorAttributes() final; virtual int eval(); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 726ab7c..2d9e428 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -32,7 +32,8 @@ using namespace TosaReference; using namespace tosa; -GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, +GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, + TosaSerializationHandler* tsh, Op opType, TosaAttributeBase* attribute, TosaQuantInfoBase* qinfo, @@ -349,7 +350,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // data_nodes case Op_CONST: - return new OpConst(id); + return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); @@ -398,13 +399,13 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // custom case Op_CUSTOM: - return new OpCustom(id); + return new OpCustom(sgt, id); // control_flow case Op_COND_IF: - return new OpCondIf(tsh, attribute, id); + return new OpCondIf(sgt, tsh, attribute, id); case Op_WHILE_LOOP: - return new OpWhileLoop(tsh, attribute, id); + return new OpWhileLoop(sgt, tsh, attribute, id); // Ops not recognized default: diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 0c116b6..eaa359c 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -24,55 +24,55 @@ #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP(attribute, qinfo, id); + return new OP(sgt, attribute, qinfo, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP(attribute, qinfo, id); + return new OP(sgt, attribute, qinfo, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP(attribute, qinfo, id); + return new OP(sgt, attribute, qinfo, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP(attribute, qinfo, id); + return new OP(sgt, attribute, qinfo, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ { \ case 0: \ - return new OP<0>(attribute, qinfo, id); \ + return new OP<0>(sgt, attribute, qinfo, id); \ case 1: \ - return new OP<1>(attribute, qinfo, id); \ + return new OP<1>(sgt, attribute, qinfo, id); \ case 2: \ - return new OP<2>(attribute, qinfo, id); \ + return new OP<2>(sgt, attribute, qinfo, id); \ case 3: \ - return new OP<3>(attribute, qinfo, id); \ + return new OP<3>(sgt, attribute, qinfo, id); \ case 4: \ - return new OP<4>(attribute, qinfo, id); \ + return new OP<4>(sgt, attribute, qinfo, id); \ case 5: \ - return new OP<5>(attribute, qinfo, id); \ + return new OP<5>(sgt, attribute, qinfo, id); \ case 6: \ - return new OP<6>(attribute, qinfo, id); \ + return new OP<6>(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ if (inputDType == DType_##DTYPE) \ { \ - return new OP(attribute, qinfo, id); \ + return new OP(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ - return new OP(attribute, qinfo, id); \ + return new OP(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP(attribute, qinfo, id); \ + return new OP(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ @@ -221,10 +221,14 @@ namespace TosaReference { +class SubgraphTraverser; +class GraphNode; + class OpFactory { public: - static GraphNode* newOp(tosa::TosaSerializationHandler* tsh, + static GraphNode* newOp(SubgraphTraverser* sgt, + tosa::TosaSerializationHandler* tsh, tosa::Op opType, tosa::TosaAttributeBase* attribute, tosa::TosaQuantInfoBase* qinfo, diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 97a7aa7..107c7a8 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -21,8 +21,8 @@ using namespace Eigen; using namespace tosa; template -ReduceNode::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) - : GraphNode(op_, id_) +ReduceNode::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 4); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index cf75812..f4e29b9 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -27,7 +27,7 @@ template class ReduceNode : public GraphNode { public: - ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); + ReduceNode(SubgraphTraverser* sgt_, const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); virtual ~ReduceNode(); virtual int checkTensorAttributes(); virtual int eval() = 0; @@ -48,8 +48,8 @@ template class OpReduceAll : public ReduceNode { public: - OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_ALL, attribute_, id_) + OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); }; @@ -58,8 +58,8 @@ template class OpReduceAny : public ReduceNode { public: - OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_ALL, attribute_, id_) + OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); }; @@ -68,8 +68,8 @@ template class OpReduceMax : public ReduceNode { public: - OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_MAX, attribute_, id_) + OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_MAX, attribute_, id_) {} virtual int eval(); }; @@ -78,8 +78,8 @@ template class OpReduceMin : public ReduceNode { public: - OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_MIN, attribute_, id_) + OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_MIN, attribute_, id_) {} virtual int eval(); }; @@ -88,8 +88,8 @@ template class OpReduceProduct : public ReduceNode { public: - OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_PRODUCT, attribute_, id_) + OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) {} virtual int eval(); }; @@ -98,8 +98,8 @@ template class OpReduceSum : public ReduceNode { public: - OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode(Op_REDUCE_SUM, attribute_, id_) + OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 478b776..02ec54f 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -21,8 +21,11 @@ using namespace Eigen; using namespace tosa; template -OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_GATHER, id_) +OpGather::OpGather(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_GATHER, id_) { setRequiredOperands(2, 1); } @@ -102,7 +105,7 @@ int OpGather::eval() for (int32_t w = 0; w < W; w++) { int32_t k = this->indices->getTensor()(n, w); - ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + REQUIRE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); for (int32_t c = 0; c < C; c++) { EigenType value = this->values->getTensor()(n, k, c); @@ -115,8 +118,11 @@ int OpGather::eval() } template -OpScatter::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SCATTER, id_) +OpScatter::OpScatter(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SCATTER, id_) { setRequiredOperands(3, 1); } @@ -206,7 +212,7 @@ int OpScatter::eval() for (int w = 0; w < W; w++) { int32_t k = this->indices->getTensor()(n, w); - ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); for (int c = 0; c < C; c++) { EigenType value = this->input->getTensor()(n, w, c); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index 17ea723..66b584a 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -27,7 +27,7 @@ template class OpGather : public GraphNode { public: - OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpGather(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpScatter : public GraphNode { public: - OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpScatter(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 0007553..045c0a5 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -22,8 +22,11 @@ using namespace Eigen; using namespace tosa; template -OpArgMax::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_ARGMAX, id_) +OpArgMax::OpArgMax(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -66,8 +69,11 @@ int OpArgMax::eval() } template -OpAvgPool2d::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_AVG_POOL2D, id_) +OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_AVG_POOL2D, id_) { setRequiredOperands(1, 1); setRequiredRank(4); @@ -142,9 +148,6 @@ ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_ int32_t left_index = pad_left / stride; int32_t right_index = pad_right / stride; - // not handle ultra small activation yet - ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet"); - // minus the number of pad bit this index cover while (left_index >= 0) { @@ -176,7 +179,8 @@ int OpAvgPool2d::eval() int out_width = this->out->getShape()[2]; int out_channels = this->out->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -260,12 +264,19 @@ int OpAvgPool2d::eval() if (Dtype != DType_FLOAT) { - this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { - int32_t multiplier, shift; - TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); + try + { + this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { + int32_t multiplier, shift; + TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); - return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false); - }); + return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false); + }); + } + catch (std::string desc) + { + REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str()); + } this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); @@ -279,8 +290,11 @@ int OpAvgPool2d::eval() } template -OpConv2d::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CONV2D, id_) +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -361,13 +375,12 @@ int OpConv2d::eval() int out_width = this->output->getShape()[2]; int out_channels = this->output->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, - in_channels); - ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, - out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, + out_channels); + ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -469,10 +482,11 @@ int OpConv2d::eval() } template -OpDepthwiseConv2d::OpDepthwiseConv2d(TosaAttributeBase* attribute_, +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_DEPTHWISE_CONV2D, id_) + : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -553,14 +567,13 @@ int OpDepthwiseConv2d::eval() int out_width = this->output->getShape()[2]; int out_channels = this->output->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", - f_in_channels, in_channels); - ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels, - "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, - out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", - b_out_channels, out_channels); + ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", + in_channels * f_multiplier, out_channels); + ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, + out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -651,10 +664,11 @@ int OpDepthwiseConv2d::eval() } template -OpFullyConnected::OpFullyConnected(TosaAttributeBase* attribute_, +OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_FULLY_CONNECTED, id_) + : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); setRequiredRank(2); @@ -738,8 +752,11 @@ int OpFullyConnected::eval() } template -OpMatMul::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_MATMUL, id_) +OpMatMul::OpMatMul(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); setRequiredRank(3); @@ -866,8 +883,11 @@ int OpMatMul::eval() } template -OpMaxPool2d::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_MAX_POOL2D, id_) +OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_MAX_POOL2D, id_) { setRequiredOperands(1, 1); setRequiredRank(4); @@ -936,7 +956,8 @@ int OpMaxPool2d::eval() int out_width = this->out->getShape()[2]; int out_channels = this->out->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -1004,10 +1025,11 @@ int OpMaxPool2d::eval() } template -OpTransposeConv2d::OpTransposeConv2d(TosaAttributeBase* attribute_, +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TRANSPOSE_CONV2D, id_) + : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -1104,13 +1126,13 @@ int OpTransposeConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", - f_in_channels, in_channels); - ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", - f_out_channels, out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", - b_out_channels, out_channels); + ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", + f_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, + out_channels); DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 9aaa140..6ffc27d 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template class OpArgMax : public GraphNode { public: - OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -80,7 +80,7 @@ template class OpConv2d : public GraphNode { public: - OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -112,7 +112,7 @@ template class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -144,7 +144,7 @@ template class OpFullyConnected : public GraphNode { public: - OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -174,7 +174,7 @@ template class OpMatMul : public GraphNode { public: - OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -205,7 +205,7 @@ template class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -226,7 +226,7 @@ template class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index d988c57..657eebf 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -23,8 +23,11 @@ using namespace Eigen; using namespace tosa; template -OpRescale::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESCALE, id_) +OpRescale::OpRescale(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESCALE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -101,53 +104,68 @@ int OpRescale::eval() int32_t channel_multiplier, channel_shift; Eigen::array begin, size; size = Eigen::array({ shape_2d[0], 1 }); - for (int32_t i = 0; i < shape_2d[1]; i++) + try { - begin = Eigen::array({ 0, i }); - curr_channel_slice_prescaled = input_reshaped.slice(begin, size); - channel_multiplier = multiplier[i]; - channel_shift = shift[i]; - curr_channel_slice_postscaled = - curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, - double_round, scale32](InEigenType in_val) -> OutEigenType { - InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, - channel_shift, double_round); - else - scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier, - channel_shift); - OutEigenType out_val = (OutEigenType)(scaled + output_zp); - out_val = std::max(out_val, QMin); - out_val = std::min(out_val, QMax); - return out_val; - }); - - for (int32_t j = 0; j < shape_2d[0]; j++) + for (int32_t i = 0; i < shape_2d[1]; i++) { - output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + begin = Eigen::array({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, + double_round, scale32](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, + channel_shift, double_round); + else + scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier, + channel_shift); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max(out_val, QMin); + out_val = std::min(out_val, QMax); + return out_val; + }); + + for (int32_t j = 0; j < shape_2d[0]; j++) + { + output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + } } } + catch (std::string desc) + { + REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str()); + } } else { int32_t tensor_multiplier = multiplier[0]; int32_t tensor_shift = shift[0]; - output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, - scale32](InEigenType in_val) -> OutEigenType { - InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, - double_round); - else - scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); - OutEigenType out_val = (OutEigenType)(scaled + output_zp); - out_val = std::max(out_val, QMin); - out_val = std::min(out_val, QMax); - return out_val; - }); + try + { + output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, + scale32](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, + double_round); + else + scaled = + TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max(out_val, QMin); + out_val = std::min(out_val, QMax); + return out_val; + }); + } + catch (std::string desc) + { + REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str()); + } } // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] @@ -162,8 +180,11 @@ int OpRescale::eval() } template -OpCast::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CAST, id_) +OpCast::OpCast(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CAST, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 6ec4d6d..060e14e 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -26,7 +26,7 @@ template class OpRescale : public GraphNode { public: - OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpRescale(); virtual int checkTensorAttributes() final; @@ -140,7 +140,7 @@ template class OpCast : public GraphNode { public: - OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpCast(); virtual int checkTensorAttributes() final; -- cgit v1.2.1