diff options
Diffstat (limited to 'reference_model/src')
39 files changed, 558 insertions, 373 deletions
diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h index 2d47462..26901cf 100644 --- a/reference_model/src/func_debug.h +++ b/reference_model/src/func_debug.h @@ -97,16 +97,31 @@ struct func_debug_t } #endif -#ifndef ASSERT_MSG_NODE -#define ASSERT_MSG_NODE(COND, fmt, ...) \ +#ifndef REQUIRE +#define REQUIRE(COND, fmt, ...) \ if (!(COND)) \ { \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("REQUIRE() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ + __func__, #COND); \ + fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + this->parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); \ + } +#endif + +#ifndef ERROR_IF +#define ERROR_IF(COND, fmt, ...) \ + if ((COND)) \ + { \ + if (this->parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \ + { \ + this->parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR); \ + } \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ __func__, #COND); \ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ this->dumpNode(g_func_debug.func_debug_file); \ func_print_backtrace(g_func_debug.func_debug_file); \ - assert(COND); \ + return 1; \ } #endif @@ -130,14 +145,6 @@ struct func_debug_t abort(); #endif -#ifndef FATAL_ERROR_NODE -#define FATAL_ERROR_NODE(fmt, ...) \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ - fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - this->dumpNode(g_func_debug.func_debug_file); \ - func_print_backtrace(g_func_debug.func_debug_file); \ - abort(); -#endif #ifndef SIMPLE_FATAL_ERROR #define SIMPLE_FATAL_ERROR(fmt, ...) \ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc index b57b9dd..f765700 100644 --- a/reference_model/src/graph_node.cc +++ b/reference_model/src/graph_node.cc @@ -19,10 +19,11 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_) +GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const uint64_t id_) { - nodeType = nodeType_; - nodeId = id_; + parent_sgt = parent_sgt_; + nodeType = nodeType_; + nodeId = id_; inputs.clear(); outputs.clear(); inputNames.clear(); diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index bf80859..14a8acc 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -18,6 +18,7 @@ #include "attribute.h" #include "quant_info.h" +#include "subgraph_traverser.h" #include "tensor.h" #include "tosa_generated.h" #include <iostream> @@ -139,12 +140,14 @@ namespace TosaReference { +class SubgraphTraverser; + // Nodes in the graph (e.g., tosa operators) are defined with this base // class. class GraphNode { public: - GraphNode(const tosa::Op& nodeType, const uint64_t id_); + GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_); virtual ~GraphNode(); int addInputName(std::string& name); @@ -274,6 +277,9 @@ protected: int validateRequiredOperands(); int validateRequiredRank(const Tensor* t); + // Parent SubgraphTraverser + SubgraphTraverser* parent_sgt; + // Description of the node type (e.g., CONST, CONV2D, etc...) tosa::Op nodeType; diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 412894c..55a4848 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -64,12 +64,12 @@ int main(int argc, const char** argv) SIMPLE_FATAL_ERROR("Unable to load graph"); } - // load json first since it's easier debugging SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); if (main_gt.initializeGraph()) { - SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\""); + WARNING("Unable to initialize main graph traverser."); + goto done; } if (main_gt.linkTensorsAndNodes()) @@ -95,49 +95,76 @@ int main(int argc, const char** argv) if (g_func_config.eval) { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) { - SIMPLE_FATAL_ERROR("Error evaluating network. Giving up."); + ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID || + main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); } - // make sure output tensor is evaluated and show its value - int num_output_tensors = main_gt.getNumOutputTensors(); - bool all_output_valid = true; - for (int i = 0; i < num_output_tensors; i++) + // Only generate output tensor if graph is valid. + if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID) { - const Tensor* ct = main_gt.getOutputTensor(i); - ASSERT_MEM(ct); - if (!ct->getIsValid()) + // make sure output tensor is evaluated and show its value + int num_output_tensors = main_gt.getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) { - ct->dumpTensorParams(g_func_debug.func_debug_file); - if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + const Tensor* ct = main_gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) { - ct->dumpTensor(g_func_debug.func_debug_file); + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; } - all_output_valid = false; } - } - if (!all_output_valid) - { - main_gt.dumpGraph(g_func_debug.func_debug_file); - SIMPLE_FATAL_ERROR( - "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); - } + if (!all_output_valid) + { + main_gt.dumpGraph(g_func_debug.func_debug_file); + SIMPLE_FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } - if (g_func_config.output_tensors) - { - if (writeFinalTensors(main_gt, test_desc)) + if (g_func_config.output_tensors) { - WARNING("Errors encountered in saving output tensors"); + if (writeFinalTensors(main_gt, test_desc)) + { + WARNING("Errors encountered in saving output tensors"); + } } } } done: + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + fprintf(stderr, "Graph result: UNPREDICTABLE.\n"); + break; + case GraphStatus::TOSA_ERROR: + fprintf(stderr, "Graph result: ERROR.\n"); + break; + default: + fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus()); + } + func_fini_debug(&g_func_debug); func_model_config_cleanup(); - return 0; + return (int)main_gt.getGraphStatus(); } int loadGraph(TosaSerializationHandler& tsh, json test_desc) diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 3410ba9..440f4e1 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -44,7 +44,7 @@ int OpClamp<Rank, Dtype>::register_fcn() } break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -69,7 +69,7 @@ int OpReluN<Rank, Dtype>::register_fcn() } break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -84,7 +84,7 @@ int OpSigmoid<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -99,7 +99,7 @@ int OpTanh<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index b051b9d..c834b52 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -28,8 +28,8 @@ template <int Rank, DType Dtype> class OpClamp : public UnaryNode<Rank, Dtype> { public: - OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode<Rank, Dtype>(Op_CLAMP, id_) + OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); register_fcn(); @@ -48,8 +48,8 @@ template <int Rank, DType Dtype> class OpReluN : public UnaryNode<Rank, Dtype> { public: - OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode<Rank, Dtype>(Op_RELUN, id_) + OpReluN(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(sgt_, Op_RELUN, id_) { INIT_ATTRIBUTE(ReluN); register_fcn(); @@ -68,8 +68,8 @@ template <int Rank, DType Dtype> class OpSigmoid : public UnaryNode<Rank, Dtype> { public: - OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_) + OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(sgt_, Op_SIGMOID, id_) { register_fcn(); } @@ -84,8 +84,8 @@ template <int Rank, DType Dtype> class OpTanh : public UnaryNode<Rank, Dtype> { public: - OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : UnaryNode<Rank, Dtype>(Op_TANH, id_) + OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(sgt_, Op_TANH, id_) { register_fcn(); } diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index 402e152..ab89e24 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -32,7 +32,7 @@ int OpEqual<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -48,7 +48,7 @@ int OpGreater<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -64,7 +64,7 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index e75b1a6..5b4d0f1 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -28,8 +28,8 @@ template <int Rank, DType Dtype> class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_) + OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_) { register_fcn(); } @@ -42,8 +42,8 @@ template <int Rank, DType Dtype> class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_) + OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, qinfo_, id_) { register_fcn(); } @@ -56,8 +56,8 @@ template <int Rank, DType Dtype> class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL> { public: - OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_) + OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_) { register_fcn(); } diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 1a6a63a..0945056 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -20,8 +20,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_) - : GraphNode(op_, id_) +OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { tsh = tsh_; } @@ -148,8 +148,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, return 0; } -OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) - : OpControlFlow(tsh_, Op_COND_IF, id_) +OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_) { INIT_ATTRIBUTE(CondIf); } @@ -221,8 +221,11 @@ int OpCondIf::eval() return GraphNode::eval(); } -OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) - : OpControlFlow(tsh_, Op_WHILE_LOOP, id_) +OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_, + TosaSerializationHandler* tsh_, + TosaAttributeBase* attribute_, + uint64_t id_) + : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_) { INIT_ATTRIBUTE(WhileLoop); } diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h index 14c11bc..879cd6a 100644 --- a/reference_model/src/ops/control_flow.h +++ b/reference_model/src/ops/control_flow.h @@ -25,7 +25,7 @@ namespace TosaReference class OpControlFlow : public GraphNode { public: - OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_); + OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_); ~OpControlFlow(); virtual int evalBlock(TosaSerializationBasicBlock* block, @@ -39,7 +39,7 @@ protected: class OpCondIf : public OpControlFlow { public: - OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpCondIf(); virtual int checkTensorAttributes(); @@ -55,7 +55,7 @@ protected: class OpWhileLoop : public OpControlFlow { public: - OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + OpWhileLoop(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); virtual ~OpWhileLoop(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc index 5c4f29b..5fc36f3 100644 --- a/reference_model/src/ops/custom.cc +++ b/reference_model/src/ops/custom.cc @@ -19,8 +19,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpCustom::OpCustom(uint64_t id_) - : GraphNode(Op_CUSTOM, id_) +OpCustom::OpCustom(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CUSTOM, id_) {} OpCustom::~OpCustom() @@ -33,7 +33,7 @@ int OpCustom::checkTensorAttributes() int OpCustom::eval() { - FATAL_ERROR_NODE("not supported yet"); + FATAL_ERROR("not supported yet"); // Evaluation is trivial for constants return GraphNode::eval(); diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h index b1085a5..d14c809 100644 --- a/reference_model/src/ops/custom.h +++ b/reference_model/src/ops/custom.h @@ -26,7 +26,7 @@ namespace TosaReference class OpCustom : public GraphNode { public: - OpCustom(uint64_t id_); + OpCustom(SubgraphTraverser* sgt_, uint64_t id_); virtual ~OpCustom(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index c66d64e..86326f5 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -21,8 +21,11 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CONCAT, id_) +OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CONCAT, id_) { setRequiredOperands(-1, 1); setRequiredRank(1, 6); @@ -95,8 +98,11 @@ int OpConcat<Rank, Dtype>::eval() } template <int Rank, DType Dtype> -OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_PAD, id_) +OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -157,8 +163,11 @@ int OpPad<Rank, Dtype>::eval() } template <int InRank, int OutRank, DType Dtype> -OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESHAPE, id_) +OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESHAPE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -274,8 +283,11 @@ int OpReshape<InRank, OutRank, Dtype>::eval() } template <int Rank, DType Dtype> -OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_REVERSE, id_) +OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_REVERSE, id_) { setRequiredOperands(1, 1); setRequiredRank(1, 6); @@ -339,8 +351,11 @@ int OpReverse<Rank, Dtype>::eval() } template <int Rank, DType Dtype> -OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SLICE, id_) +OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -407,8 +422,11 @@ int OpSlice<Rank, Dtype>::eval() } template <int Rank, DType Dtype> -OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TILE, id_) +OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -466,7 +484,7 @@ template <int Rank, DType Dtype> int OpTile<Rank, Dtype>::eval() { // primary template shouldn't be called - FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); + FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); } template <DType Dtype> @@ -542,8 +560,11 @@ int OpTile<4, Dtype>::eval() } template <int Rank, DType Dtype> -OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TRANSPOSE, id_) +OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index b180b4f..c9c2602 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -27,7 +27,7 @@ template <int Rank, DType Dtype> class OpConcat : public GraphNode { public: - OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConcat(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <int Rank, DType Dtype> class OpPad : public GraphNode { public: - OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpPad(); virtual int checkTensorAttributes(); virtual int eval(); @@ -70,7 +70,7 @@ template <int InRank, int OutRank, DType Dtype> class OpReshape : public GraphNode { public: - OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReshape(); virtual int checkTensorAttributes(); @@ -94,7 +94,7 @@ template <int Rank, DType Dtype> class OpReverse : public GraphNode { public: - OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpReverse(); virtual int checkTensorAttributes(); @@ -116,7 +116,7 @@ template <int Rank, DType Dtype> class OpSlice : public GraphNode { public: - OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpSlice(); virtual int checkTensorAttributes(); @@ -139,7 +139,7 @@ template <int Rank, DType Dtype> class OpTileBase : public GraphNode { public: - OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTileBase(); virtual int checkTensorAttributes(); @@ -160,8 +160,8 @@ template <int Rank, DType Dtype> class OpTile : public OpTileBase<Rank, Dtype> { public: - OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_) + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpTileBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_) {} protected: @@ -174,8 +174,8 @@ protected: class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \ { \ public: \ - OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : OpTileBase<N, Dtype>(sgt_, attribute_, qinfo_, id_) \ {} \ \ protected: \ @@ -193,7 +193,7 @@ template <int Rank, DType Dtype> class OpTranspose : public GraphNode { public: - OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTranspose(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index baae019..ec4bc41 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -19,8 +19,8 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -OpConst::OpConst(uint64_t id_) - : GraphNode(Op_CONST, id_) +OpConst::OpConst(SubgraphTraverser* sgt_, uint64_t id_) + : GraphNode(sgt_, Op_CONST, id_) { setRequiredOperands(0, 1); } @@ -43,8 +43,11 @@ int OpConst::eval() } template <int Rank, DType Dtype> -OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_IDENTITY, id_) +OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_IDENTITY, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index a02d441..407cf0a 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -24,7 +24,7 @@ namespace TosaReference class OpConst : public GraphNode { public: - OpConst(uint64_t id_); + OpConst(SubgraphTraverser* sgt_, uint64_t id_); virtual ~OpConst(); virtual int checkTensorAttributes(); @@ -35,7 +35,7 @@ template <int Rank, DType Dtype> class OpIdentity : public GraphNode { public: - OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpIdentity(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 3379ffe..16c4901 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -23,8 +23,11 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType InDtype, DType OutDtype> -BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(op_, id_) +BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, + const Op& op_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -203,7 +206,7 @@ int OpAdd<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -226,12 +229,12 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { - ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", - (int32_t)b, num_bits); + REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]", + (int32_t)b, num_bits); InEigenType acc = a >> b; @@ -257,7 +260,7 @@ int OpBitwiseAnd<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -274,7 +277,7 @@ int OpBitwiseOr<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -291,7 +294,7 @@ int OpBitwiseXor<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -304,15 +307,15 @@ int OpDiv<Rank, Dtype>::register_fcn() { case DType_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { - ASSERT_MSG_NODE(b != 0, "OpDiv: divisor must be non-zero value"); + REQUIRE(b != 0, "OpDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast<int64_t>(a) / b; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); - ASSERT_MSG_NODE(a <= i32_max_in_64, "OpDiv: result not in i32 range"); + REQUIRE(a <= i32_max_in_64, "OpDiv: result not in i32 range"); return static_cast<InEigenType>(res_in_64); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -327,7 +330,7 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -344,7 +347,7 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -366,7 +369,7 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { @@ -386,7 +389,7 @@ int OpLogicalOr<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -401,7 +404,7 @@ int OpLogicalXor<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -417,7 +420,7 @@ int OpMaximum<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -433,7 +436,7 @@ int OpMinimum<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -443,8 +446,6 @@ template <int Rank, DType InDtype, DType OutDtype> int OpMul<Rank, InDtype, OutDtype>::register_fcn() { int32_t shift = attribute->shift(); - ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift, - EnumNamesDType()[InDtype]); switch (InDtype) { @@ -460,8 +461,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round; result = result >> shift; - ASSERT_MSG_NODE(result >= QMin && result <= QMax, - "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax); + REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]", + result, QMin, QMax); } else { @@ -482,7 +483,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; @@ -497,7 +498,7 @@ int OpPow<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -513,15 +514,18 @@ int OpSub<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return 0; } template <int Rank, DType InDtype> -OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TABLE, id_) +OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_TABLE, id_) { setRequiredOperands(2, 1); setRequiredRank(0, 6); @@ -607,7 +611,7 @@ int OpTable<Rank, InDtype>::eval() }); break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); } return GraphNode::eval(); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index a5b1059..86b2101 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -42,7 +42,7 @@ template <int Rank, DType InDtype, DType OutDtype> class BinaryNodeBase : public GraphNode { public: - BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); virtual ~BinaryNodeBase(); virtual int checkTensorAttributes() final; @@ -76,8 +76,8 @@ template <int Rank, DType InDtype, DType OutDtype> class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype> { public: - BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) {} virtual ~BinaryNode() {} @@ -95,8 +95,8 @@ template <DType InDtype, DType OutDtype> class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: - BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) - : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_) + BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_) {} virtual ~BinaryNode() {} @@ -109,8 +109,8 @@ public: class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, qinfo_, id_) \ { \ register_fcn(); \ } \ @@ -142,8 +142,11 @@ template <int Rank, DType Dtype> class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> { public: - OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) + OpArithmeticRightShift(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_) { INIT_ATTRIBUTE(ArithmeticRightShift); register_fcn(); @@ -160,8 +163,8 @@ template <int Rank, DType InDtype, DType OutDtype> class OpMul : public BinaryNode<Rank, InDtype, OutDtype> { public: - OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_) + OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, qinfo_, id_) { INIT_ATTRIBUTE(Mul); register_fcn(); @@ -180,7 +183,7 @@ template <int Rank, DType InDtype> class OpTable : public GraphNode { public: - OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTable(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index d4845f9..64c4412 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -20,8 +20,11 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SELECT, id_) +OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SELECT, id_) { setRequiredOperands(3, 1); setRequiredRank(0, 6); @@ -62,7 +65,7 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes() template <int Rank, DType Dtype> int OpSelectBase<Rank, Dtype>::eval() { - FATAL_ERROR_NODE("shouldn't be called"); + FATAL_ERROR("shouldn't be called"); } template <int Rank, DType Dtype> @@ -78,9 +81,9 @@ int OpSelect<Rank, Dtype>::broadcast() this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; - ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); - ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); - ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); } return 0; diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index b354247..b80fb23 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -33,7 +33,7 @@ template <int Rank, DType Dtype> class OpSelectBase : public GraphNode { public: - OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpSelectBase(); virtual int checkTensorAttributes(); @@ -59,8 +59,8 @@ template <int Rank, DType Dtype> class OpSelect : public OpSelectBase<Rank, Dtype> { public: - OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_) {} virtual int eval(); int broadcast(); @@ -73,8 +73,8 @@ template <DType Dtype> class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: - OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_) + OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 95a1102..041bbdb 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -23,8 +23,8 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_) - : GraphNode(op_, id_) +UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -80,7 +80,7 @@ int OpAbs<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -97,7 +97,7 @@ int OpBitwiseNot<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -112,7 +112,7 @@ int OpCeil<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -128,7 +128,7 @@ int OpClz<Rank, Dtype>::register_fcn() num_bits = 32; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -159,7 +159,7 @@ int OpExp<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -174,7 +174,7 @@ int OpFloor<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -189,7 +189,7 @@ int OpLog<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -204,7 +204,7 @@ int OpLogicalNot<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -213,6 +213,12 @@ int OpLogicalNot<Rank, Dtype>::register_fcn() template <int Rank, DType Dtype> int OpNegate<Rank, Dtype>::register_fcn() { + if (Dtype != DType_INT8 && this->qinfo) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + } + switch (Dtype) { case DType_FLOAT: @@ -229,7 +235,6 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; case DType_INT8: - ASSERT(this->qinfo); this->fcn = [this](InEigenType a) -> OutEigenType { InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); result = std::min(std::max(result, static_cast<InEigenType>(QMin)), static_cast<InEigenType>(QMax)); @@ -237,7 +242,7 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -252,7 +257,7 @@ int OpReciprocal<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; @@ -267,7 +272,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn() this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; default: - FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } return 0; diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 0db3cfb..374c8e4 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -26,7 +26,7 @@ template <int Rank, DType Dtype> class UnaryNode : public GraphNode { public: - UnaryNode(const Op& nodeType, const uint64_t id_); + UnaryNode(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_); virtual ~UnaryNode(); virtual int checkTensorAttributes() final; @@ -49,8 +49,8 @@ protected: class Op##Opname : public UnaryNode<Rank, Dtype> \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ @@ -66,8 +66,8 @@ protected: class Op##Opname : public UnaryNode<Rank, Dtype> \ { \ public: \ - Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ - : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ INIT_QINFO(Unary); \ register_fcn(); \ diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 829a6e0..f4decae 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -22,8 +22,11 @@ using namespace Eigen; using namespace tosa; template <DType InDtype, DType OutDtype> -OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESIZE, id_) +OpResize<InDtype, OutDtype>::OpResize(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESIZE, id_) { setRequiredOperands(1, 1); setRequiredRank(4, 4); @@ -102,10 +105,13 @@ int OpResize<InDtype, OutDtype>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]"); - ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift < 1 || shift > 11, "OpResize: attribute shift should be within [1, 11]"); + ERROR_IF(stride[0] <= 0 || stride[0] >= (16 << shift), "OpResize: invalid attribute stride_x"); + ERROR_IF(stride[1] <= 0 || stride[1] >= (16 << shift), "OpResize: invalid attribute stride_y"); + ERROR_IF(offset[0] <= (-16 << shift) || offset[0] >= (16 << shift), "OpResize: invalid attribute offset_x"); + ERROR_IF(offset[1] <= (-16 << shift) || offset[1] >= (16 << shift), "OpResize: invalid attribute offset_y"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -125,8 +131,8 @@ int OpResize<InDtype, OutDtype>::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) @@ -167,10 +173,10 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; - ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift"); - ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride"); - ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); - ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + ERROR_IF(shift != 0, "OpResize: float mode must have 0 shift"); + ERROR_IF(stride_fp[0] <= 0.0f || stride_fp[1] <= 0.0f, "OpResize: invalid attribute stride"); + ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); + ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch"); for (int b = 0; b < out_batch; b++) for (int c = 0; c < out_channels; c++) @@ -190,8 +196,8 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval() int32_t ix0 = MAX(ix, 0); int32_t ix1 = MIN(ix + 1, in_width - 1); - ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", - iy0, iy1, ix0, ix1); + REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, + iy1, ix0, ix1); OutEigenType acc; if (mode == ResizeMode_BILINEAR) diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 5dd14c8..095dc7d 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -27,7 +27,7 @@ template <DType InDtype, DType OutDtype> class OpResize : public GraphNode { public: - OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpResize(); virtual int checkTensorAttributes() final; virtual int eval(); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 726ab7c..2d9e428 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -32,7 +32,8 @@ using namespace TosaReference; using namespace tosa; -GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, +GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, + TosaSerializationHandler* tsh, Op opType, TosaAttributeBase* attribute, TosaQuantInfoBase* qinfo, @@ -349,7 +350,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // data_nodes case Op_CONST: - return new OpConst(id); + return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); @@ -398,13 +399,13 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // custom case Op_CUSTOM: - return new OpCustom(id); + return new OpCustom(sgt, id); // control_flow case Op_COND_IF: - return new OpCondIf(tsh, attribute, id); + return new OpCondIf(sgt, tsh, attribute, id); case Op_WHILE_LOOP: - return new OpWhileLoop(tsh, attribute, id); + return new OpWhileLoop(sgt, tsh, attribute, id); // Ops not recognized default: diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 0c116b6..eaa359c 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -24,55 +24,55 @@ #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id); + return new OP<RANK, DType_##DTYPE>(sgt, attribute, qinfo, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); + return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id); + return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, qinfo, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); + return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ { \ case 0: \ - return new OP<0>(attribute, qinfo, id); \ + return new OP<0>(sgt, attribute, qinfo, id); \ case 1: \ - return new OP<1>(attribute, qinfo, id); \ + return new OP<1>(sgt, attribute, qinfo, id); \ case 2: \ - return new OP<2>(attribute, qinfo, id); \ + return new OP<2>(sgt, attribute, qinfo, id); \ case 3: \ - return new OP<3>(attribute, qinfo, id); \ + return new OP<3>(sgt, attribute, qinfo, id); \ case 4: \ - return new OP<4>(attribute, qinfo, id); \ + return new OP<4>(sgt, attribute, qinfo, id); \ case 5: \ - return new OP<5>(attribute, qinfo, id); \ + return new OP<5>(sgt, attribute, qinfo, id); \ case 6: \ - return new OP<6>(attribute, qinfo, id); \ + return new OP<6>(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ if (inputDType == DType_##DTYPE) \ { \ - return new OP<DType_##DTYPE>(attribute, qinfo, id); \ + return new OP<DType_##DTYPE>(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ @@ -221,10 +221,14 @@ namespace TosaReference { +class SubgraphTraverser; +class GraphNode; + class OpFactory { public: - static GraphNode* newOp(tosa::TosaSerializationHandler* tsh, + static GraphNode* newOp(SubgraphTraverser* sgt, + tosa::TosaSerializationHandler* tsh, tosa::Op opType, tosa::TosaAttributeBase* attribute, tosa::TosaQuantInfoBase* qinfo, diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 97a7aa7..107c7a8 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -21,8 +21,8 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) - : GraphNode(op_, id_) +ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) + : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 4); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index cf75812..f4e29b9 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -27,7 +27,7 @@ template <int Rank, DType Dtype> class ReduceNode : public GraphNode { public: - ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); + ReduceNode(SubgraphTraverser* sgt_, const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); virtual ~ReduceNode(); virtual int checkTensorAttributes(); virtual int eval() = 0; @@ -48,8 +48,8 @@ template <int Rank, DType Dtype> class OpReduceAll : public ReduceNode<Rank, Dtype> { public: - OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_) + OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); }; @@ -58,8 +58,8 @@ template <int Rank, DType Dtype> class OpReduceAny : public ReduceNode<Rank, Dtype> { public: - OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_) + OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_) {} virtual int eval(); }; @@ -68,8 +68,8 @@ template <int Rank, DType Dtype> class OpReduceMax : public ReduceNode<Rank, Dtype> { public: - OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_) + OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MAX, attribute_, id_) {} virtual int eval(); }; @@ -78,8 +78,8 @@ template <int Rank, DType Dtype> class OpReduceMin : public ReduceNode<Rank, Dtype> { public: - OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_) + OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MIN, attribute_, id_) {} virtual int eval(); }; @@ -88,8 +88,8 @@ template <int Rank, DType Dtype> class OpReduceProduct : public ReduceNode<Rank, Dtype> { public: - OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_) + OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) {} virtual int eval(); }; @@ -98,8 +98,8 @@ template <int Rank, DType Dtype> class OpReduceSum : public ReduceNode<Rank, Dtype> { public: - OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_) + OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_) {} virtual int eval(); }; diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 478b776..02ec54f 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -21,8 +21,11 @@ using namespace Eigen; using namespace tosa; template <DType Dtype> -OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_GATHER, id_) +OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_GATHER, id_) { setRequiredOperands(2, 1); } @@ -102,7 +105,7 @@ int OpGather<Dtype>::eval() for (int32_t w = 0; w < W; w++) { int32_t k = this->indices->getTensor()(n, w); - ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + REQUIRE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); for (int32_t c = 0; c < C; c++) { EigenType value = this->values->getTensor()(n, k, c); @@ -115,8 +118,11 @@ int OpGather<Dtype>::eval() } template <DType Dtype> -OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_SCATTER, id_) +OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_SCATTER, id_) { setRequiredOperands(3, 1); } @@ -206,7 +212,7 @@ int OpScatter<Dtype>::eval() for (int w = 0; w < W; w++) { int32_t k = this->indices->getTensor()(n, w); - ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); for (int c = 0; c < C; c++) { EigenType value = this->input->getTensor()(n, w, c); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index 17ea723..66b584a 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -27,7 +27,7 @@ template <DType Dtype> class OpGather : public GraphNode { public: - OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpGather(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpScatter : public GraphNode { public: - OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpScatter(); virtual int checkTensorAttributes(); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 0007553..045c0a5 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -22,8 +22,11 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType Dtype> -OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_ARGMAX, id_) +OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -66,8 +69,11 @@ int OpArgMax<Rank, Dtype>::eval() } template <DType Dtype> -OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_AVG_POOL2D, id_) +OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_AVG_POOL2D, id_) { setRequiredOperands(1, 1); setRequiredRank(4); @@ -142,9 +148,6 @@ ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_ int32_t left_index = pad_left / stride; int32_t right_index = pad_right / stride; - // not handle ultra small activation yet - ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet"); - // minus the number of pad bit this index cover while (left_index >= 0) { @@ -176,7 +179,8 @@ int OpAvgPool2d<Dtype>::eval() int out_width = this->out->getShape()[2]; int out_channels = this->out->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -260,12 +264,19 @@ int OpAvgPool2d<Dtype>::eval() if (Dtype != DType_FLOAT) { - this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { - int32_t multiplier, shift; - TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); + try + { + this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { + int32_t multiplier, shift; + TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift); - return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false); - }); + return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false); + }); + } + catch (std::string desc) + { + REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str()); + } this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); @@ -279,8 +290,11 @@ int OpAvgPool2d<Dtype>::eval() } template <DType InDtype, DType WeightDtype> -OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CONV2D, id_) +OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -361,13 +375,12 @@ int OpConv2d<InDtype, WeightDtype>::eval() int out_width = this->output->getShape()[2]; int out_channels = this->output->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, - in_channels); - ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, - out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, + out_channels); + ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -469,10 +482,11 @@ int OpConv2d<InDtype, WeightDtype>::eval() } template <DType InDtype, DType WeightDtype> -OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_, +OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_DEPTHWISE_CONV2D, id_) + : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -553,14 +567,13 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() int out_width = this->output->getShape()[2]; int out_channels = this->output->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", - f_in_channels, in_channels); - ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels, - "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, - out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", - b_out_channels, out_channels); + ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", + in_channels * f_multiplier, out_channels); + ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, + out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -651,10 +664,11 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() } template <DType InDtype, DType WeightDtype> -OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_, +OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_FULLY_CONNECTED, id_) + : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); setRequiredRank(2); @@ -738,8 +752,11 @@ int OpFullyConnected<InDtype, WeightDtype>::eval() } template <DType Dtype> -OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_MATMUL, id_) +OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); setRequiredRank(3); @@ -866,8 +883,11 @@ int OpMatMul<Dtype>::eval() } template <DType Dtype> -OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_MAX_POOL2D, id_) +OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_MAX_POOL2D, id_) { setRequiredOperands(1, 1); setRequiredRank(4); @@ -936,7 +956,8 @@ int OpMaxPool2d<Dtype>::eval() int out_width = this->out->getShape()[2]; int out_channels = this->out->getShape()[3]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels); int padding_top = this->attribute->padding()[0]; int padding_bottom = this->attribute->padding()[1]; @@ -1004,10 +1025,11 @@ int OpMaxPool2d<Dtype>::eval() } template <DType InDtype, DType OutDtype> -OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_, +OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_TRANSPOSE_CONV2D, id_) + : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); setRequiredRank(4); @@ -1104,13 +1126,13 @@ int OpTransposeConv2d<InDtype, OutDtype>::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); - ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", - f_in_channels, in_channels); - ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", - f_out_channels, out_channels); - ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", - b_out_channels, out_channels); + ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", + f_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, + out_channels); DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 9aaa140..6ffc27d 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -28,7 +28,7 @@ template <int Rank, DType Dtype> class OpArgMax : public GraphNode { public: - OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpArgMax(); virtual int checkTensorAttributes(); @@ -49,7 +49,7 @@ template <DType Dtype> class OpAvgPool2d : public GraphNode { public: - OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpAvgPool2d(); virtual int checkTensorAttributes(); @@ -80,7 +80,7 @@ template <DType InDtype, DType WeightDtype> class OpConv2d : public GraphNode { public: - OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpConv2d(); virtual int checkTensorAttributes() final; @@ -112,7 +112,7 @@ template <DType InDtype, DType WeightDtype> class OpDepthwiseConv2d : public GraphNode { public: - OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpDepthwiseConv2d(); virtual int checkTensorAttributes() final; @@ -144,7 +144,7 @@ template <DType InDtype, DType WeightDtype> class OpFullyConnected : public GraphNode { public: - OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpFullyConnected(); virtual int checkTensorAttributes() final; @@ -174,7 +174,7 @@ template <DType Dtype> class OpMatMul : public GraphNode { public: - OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMatMul(); virtual int checkTensorAttributes() final; @@ -205,7 +205,7 @@ template <DType Dtype> class OpMaxPool2d : public GraphNode { public: - OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpMaxPool2d(); virtual int checkTensorAttributes(); @@ -226,7 +226,7 @@ template <DType InDtype, DType WeightDtype> class OpTransposeConv2d : public GraphNode { public: - OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpTransposeConv2d(); virtual int checkTensorAttributes() final; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index d988c57..657eebf 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -23,8 +23,11 @@ using namespace Eigen; using namespace tosa; template <int Rank, DType InDtype, DType OutDtype> -OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_RESCALE, id_) +OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_RESCALE, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); @@ -101,53 +104,68 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() int32_t channel_multiplier, channel_shift; Eigen::array<Eigen::Index, 2> begin, size; size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 }); - for (int32_t i = 0; i < shape_2d[1]; i++) + try { - begin = Eigen::array<Eigen::Index, 2>({ 0, i }); - curr_channel_slice_prescaled = input_reshaped.slice(begin, size); - channel_multiplier = multiplier[i]; - channel_shift = shift[i]; - curr_channel_slice_postscaled = - curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, - double_round, scale32](InEigenType in_val) -> OutEigenType { - InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, - channel_shift, double_round); - else - scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier, - channel_shift); - OutEigenType out_val = (OutEigenType)(scaled + output_zp); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); - return out_val; - }); - - for (int32_t j = 0; j < shape_2d[0]; j++) + for (int32_t i = 0; i < shape_2d[1]; i++) { - output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + begin = Eigen::array<Eigen::Index, 2>({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, + double_round, scale32](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier, + channel_shift, double_round); + else + scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier, + channel_shift); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); + + for (int32_t j = 0; j < shape_2d[0]; j++) + { + output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + } } } + catch (std::string desc) + { + REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str()); + } } else { int32_t tensor_multiplier = multiplier[0]; int32_t tensor_shift = shift[0]; - output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, - scale32](InEigenType in_val) -> OutEigenType { - InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; - int32_t scaled; - if (scale32) - scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, - double_round); - else - scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); - OutEigenType out_val = (OutEigenType)(scaled + output_zp); - out_val = std::max<OutEigenType>(out_val, QMin); - out_val = std::min<OutEigenType>(out_val, QMax); - return out_val; - }); + try + { + output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round, + scale32](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled; + if (scale32) + scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift, + double_round); + else + scaled = + TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); + } + catch (std::string desc) + { + REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str()); + } } // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] @@ -162,8 +180,11 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() } template <int Rank, DType InDtype, DType OutDtype> -OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) - : GraphNode(Op_CAST, id_) +OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CAST, id_) { setRequiredOperands(1, 1); setRequiredRank(0, 6); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 6ec4d6d..060e14e 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -26,7 +26,7 @@ template <int Rank, DType InDtype, DType OutDtype> class OpRescale : public GraphNode { public: - OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpRescale(); virtual int checkTensorAttributes() final; @@ -140,7 +140,7 @@ template <int Rank, DType InDtype, DType OutDtype> class OpCast : public GraphNode { public: - OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); virtual ~OpCast(); virtual int checkTensorAttributes() final; diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index c595869..4f6a525 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -44,9 +44,17 @@ public: static int32_t apply_scale_32(int32_t value, int32_t multiplier, int32_t shift, bool double_round = true) { - ASSERT_MSG(multiplier >= 0, "apply_scale_32() error: multiplier should >= 0 but is %d", multiplier); - ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_32() error: shift should be within [2, 62] but is %d", - shift); + if (multiplier < 0) + { + std::string desc = "apply_scale_32() error: multiplier should >= 0 but is " + std::to_string(multiplier); + throw desc; + } + if (shift < 2 || shift > 62) + { + std::string desc = + "apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift); + throw desc; + } int64_t round = 1L << (shift - 1); if (double_round) { @@ -57,21 +65,35 @@ public: } int64_t result = (int64_t)value * multiplier + round; result = result >> shift; - ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), - "apply_scale_32() error: scaled result exceed int32 numeric range"); + if (result < -(1L << 31) || result >= (1L << 31)) + { + std::string desc = "apply_scale_32() error: scaled result exceeds int32 numeric range"; + throw desc; + } return static_cast<int32_t>(result); } static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift) { - ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier); - ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_16() error: shift should be within [2, 62] but is %d", - shift); + if (multiplier < 0) + { + std::string desc = "apply_scale_16() error: multiplier should >= 0 but is " + std::to_string(multiplier); + throw desc; + } + if (shift < 2 || shift > 62) + { + std::string desc = + "apply_scale_16(): shift value should stay within [2, 62] but is " + std::to_string(shift); + throw desc; + } int64_t round = 1L << (shift - 1); int64_t result = value * (int64_t)multiplier + round; result = result >> shift; - ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), - "apply_scale_16() error: scaled result exceed int32 numeric range"); + if (result < -(1L << 31) || result >= (1L << 31)) + { + std::string desc = "apply_scale_16() error: scaled result exceeds int32 numeric range"; + throw desc; + } return static_cast<int32_t>(result); } }; diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index bdf6fbc..ef7bae6 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -21,6 +21,8 @@ using namespace tosa; SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) { + graph_status = GraphStatus::TOSA_VALID; + block = _block; tsh = _tsh; @@ -166,7 +168,7 @@ int SubgraphTraverser::initializeGraph() DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); - GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, + GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype, input_rank, output_dtype, output_rank, weight_dtype, weight_rank); if (!node) { @@ -221,16 +223,25 @@ int SubgraphTraverser::initializeGraph() for (auto ts : block->GetTensors()) { + // Bail out if any dimension is invalid. + for (auto& dim : ts->GetShape()) + { + if (dim <= 0) + { + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; + } + } DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + if (!ts->GetData().empty()) { if (tensor->allocate()) { - WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); - return 1; + SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); } switch (ts->GetDtype()) @@ -316,7 +327,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str()); } } @@ -332,7 +343,7 @@ int SubgraphTraverser::initializeGraph() } else { - FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str()); + FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str()); } } @@ -395,13 +406,14 @@ int SubgraphTraverser::evaluateNextNode() if (!tensor->is_allocated()) if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); + FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str()); } } if (currNode->eval()) { - FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID()); + WARNING("Failed to evaluate node: %lu", currNode->getID()); + return 1; } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 3f4eecf..4be6c1f 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -16,15 +16,22 @@ #ifndef SUBGRAPH_TRAVERSER_H #define SUBGRAPH_TRAVERSER_H -#include "model_common.h" - #include "graph_node.h" +#include "model_common.h" #include "ops/op_factory.h" +#include "tensor.h" #include "tosa_serialization_handler.h" namespace TosaReference { +enum class GraphStatus : int +{ + TOSA_VALID = 0, + TOSA_UNPREDICTABLE = 1, + TOSA_ERROR = 2, +}; + class SubgraphTraverser { public: @@ -36,6 +43,15 @@ public: int evaluateNextNode(); int evaluateAll(); + GraphStatus getGraphStatus() const + { + return graph_status; + } + void setGraphStatus(GraphStatus status) + { + graph_status = status; + } + int linkTensorsAndNodes(); int validateGraph(); @@ -59,6 +75,8 @@ private: GraphNode* getNextNode(); + GraphStatus graph_status; + // pointer to serialization library and corresponding basic block TosaSerializationBasicBlock* block; TosaSerializationHandler* tsh; diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 1efebe3..f2a3a98 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -20,15 +20,13 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -TosaReference::Tensor::Tensor(std::string tensorName_, - DType tensorDtype_, - std::vector<int> shape_) -{ - tensorName = std::string(tensorName_); - tensorDtype = tensorDtype_; - shape = std::vector<int>(shape_); - producer = nullptr; - isValid = false; +TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_) +{ + tensorName = std::string(tensorName_); + tensorDtype = tensorDtype_; + shape = std::vector<int>(shape_); + producer = nullptr; + isValid = false; consumers.clear(); isSubgraphInput = false; isSubgraphOutput = false; diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 6c0622e..e97554f 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -604,13 +604,6 @@ class TensorFactory public: static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank) { - // Bail out if any dimension is invalid. - for (auto& dim : shape_) - { - if (dim <= 0) - goto done; - } - switch (tensorDtype_) { case DType_FLOAT: @@ -697,7 +690,6 @@ public: break; } - done: std::string shape_str("["); for (auto& dim : shape_) { |