aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-29 15:32:19 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-08-20 18:07:06 +0100
commitacb550f4410ae861e53cae27a9feb4b11d45769f (patch)
treeae2f4ec558c2cdf1afa020b80a09d7ab4be5ef6d /reference_model
parent68e7aee65bda5ac03fa7def753b7dc7462554793 (diff)
downloadreference_model-acb550f4410ae861e53cae27a9feb4b11d45769f.tar.gz
Replace node level check ASSERT_MSG_NODE()/FATAL_ERROR_NODE() with REQUIRE() or ERROR_IF()
- Adding return code enum class: {VALID, UNPREDICTABLE, ERROR} - Runtime errors (e.g. memory allocation failure) will abort immediately, or will return one of the three return codes Part of the codes are re-written to pass REQUIRE() to the top-level (e.g. apply_scale_32/16()) - Update setExpectedFailure() to setExpectedReturnCode() on test generation script - Update test regression script to interface with reference model change Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ia063c936bcb2a54d6e379a5bb6801aa72d1186f1
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/func_debug.h31
-rw-r--r--reference_model/src/graph_node.cc7
-rw-r--r--reference_model/src/graph_node.h8
-rw-r--r--reference_model/src/main.cpp79
-rw-r--r--reference_model/src/ops/activation_funcs.cc8
-rw-r--r--reference_model/src/ops/activation_funcs.h16
-rw-r--r--reference_model/src/ops/comparison.cc6
-rw-r--r--reference_model/src/ops/comparison.h12
-rw-r--r--reference_model/src/ops/control_flow.cc15
-rw-r--r--reference_model/src/ops/control_flow.h6
-rw-r--r--reference_model/src/ops/custom.cc6
-rw-r--r--reference_model/src/ops/custom.h2
-rw-r--r--reference_model/src/ops/data_layout.cc51
-rw-r--r--reference_model/src/ops/data_layout.h22
-rw-r--r--reference_model/src/ops/data_nodes.cc11
-rw-r--r--reference_model/src/ops/data_nodes.h4
-rw-r--r--reference_model/src/ops/ewise_binary.cc62
-rw-r--r--reference_model/src/ops/ewise_binary.h27
-rw-r--r--reference_model/src/ops/ewise_ternary.cc15
-rw-r--r--reference_model/src/ops/ewise_ternary.h10
-rw-r--r--reference_model/src/ops/ewise_unary.cc33
-rw-r--r--reference_model/src/ops/ewise_unary.h10
-rw-r--r--reference_model/src/ops/image.cc34
-rw-r--r--reference_model/src/ops/image.h2
-rw-r--r--reference_model/src/ops/op_factory.cc11
-rw-r--r--reference_model/src/ops/op_factory.h34
-rw-r--r--reference_model/src/ops/reduction.cc4
-rw-r--r--reference_model/src/ops/reduction.h26
-rw-r--r--reference_model/src/ops/scatter_gather.cc18
-rw-r--r--reference_model/src/ops/scatter_gather.h4
-rw-r--r--reference_model/src/ops/tensor_ops.cc118
-rw-r--r--reference_model/src/ops/tensor_ops.h16
-rw-r--r--reference_model/src/ops/type_conversion.cc105
-rw-r--r--reference_model/src/ops/type_conversion.h4
-rw-r--r--reference_model/src/quant_util.h42
-rw-r--r--reference_model/src/subgraph_traverser.cc26
-rw-r--r--reference_model/src/subgraph_traverser.h22
-rw-r--r--reference_model/src/tensor.cc16
-rw-r--r--reference_model/src/tensor.h8
39 files changed, 558 insertions, 373 deletions
diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h
index 2d47462..26901cf 100644
--- a/reference_model/src/func_debug.h
+++ b/reference_model/src/func_debug.h
@@ -97,16 +97,31 @@ struct func_debug_t
}
#endif
-#ifndef ASSERT_MSG_NODE
-#define ASSERT_MSG_NODE(COND, fmt, ...) \
+#ifndef REQUIRE
+#define REQUIRE(COND, fmt, ...) \
if (!(COND)) \
{ \
- fprintf(g_func_debug.func_debug_file, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("REQUIRE() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \
+ __func__, #COND); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ this->parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); \
+ }
+#endif
+
+#ifndef ERROR_IF
+#define ERROR_IF(COND, fmt, ...) \
+ if ((COND)) \
+ { \
+ if (this->parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
+ { \
+ this->parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR); \
+ } \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \
__func__, #COND); \
fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
this->dumpNode(g_func_debug.func_debug_file); \
func_print_backtrace(g_func_debug.func_debug_file); \
- assert(COND); \
+ return 1; \
}
#endif
@@ -130,14 +145,6 @@ struct func_debug_t
abort();
#endif
-#ifndef FATAL_ERROR_NODE
-#define FATAL_ERROR_NODE(fmt, ...) \
- fprintf(g_func_debug.func_debug_file, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
- fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
- this->dumpNode(g_func_debug.func_debug_file); \
- func_print_backtrace(g_func_debug.func_debug_file); \
- abort();
-#endif
#ifndef SIMPLE_FATAL_ERROR
#define SIMPLE_FATAL_ERROR(fmt, ...) \
fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
index b57b9dd..f765700 100644
--- a/reference_model/src/graph_node.cc
+++ b/reference_model/src/graph_node.cc
@@ -19,10 +19,11 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_)
+GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const uint64_t id_)
{
- nodeType = nodeType_;
- nodeId = id_;
+ parent_sgt = parent_sgt_;
+ nodeType = nodeType_;
+ nodeId = id_;
inputs.clear();
outputs.clear();
inputNames.clear();
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index bf80859..14a8acc 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -18,6 +18,7 @@
#include "attribute.h"
#include "quant_info.h"
+#include "subgraph_traverser.h"
#include "tensor.h"
#include "tosa_generated.h"
#include <iostream>
@@ -139,12 +140,14 @@
namespace TosaReference
{
+class SubgraphTraverser;
+
// Nodes in the graph (e.g., tosa operators) are defined with this base
// class.
class GraphNode
{
public:
- GraphNode(const tosa::Op& nodeType, const uint64_t id_);
+ GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_);
virtual ~GraphNode();
int addInputName(std::string& name);
@@ -274,6 +277,9 @@ protected:
int validateRequiredOperands();
int validateRequiredRank(const Tensor* t);
+ // Parent SubgraphTraverser
+ SubgraphTraverser* parent_sgt;
+
// Description of the node type (e.g., CONST, CONV2D, etc...)
tosa::Op nodeType;
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 412894c..55a4848 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -64,12 +64,12 @@ int main(int argc, const char** argv)
SIMPLE_FATAL_ERROR("Unable to load graph");
}
- // load json first since it's easier debugging
SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
if (main_gt.initializeGraph())
{
- SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\"");
+ WARNING("Unable to initialize main graph traverser.");
+ goto done;
}
if (main_gt.linkTensorsAndNodes())
@@ -95,49 +95,76 @@ int main(int argc, const char** argv)
if (g_func_config.eval)
{
+ // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
{
- SIMPLE_FATAL_ERROR("Error evaluating network. Giving up.");
+ ASSERT_MSG(main_gt.getGraphStatus() != GraphStatus::TOSA_VALID,
+ "Upon evaluateAll() returning 1, graph can not be VALID.");
+ }
+ else
+ {
+ ASSERT_MSG(main_gt.getGraphStatus() == GraphStatus::TOSA_VALID ||
+ main_gt.getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
}
- // make sure output tensor is evaluated and show its value
- int num_output_tensors = main_gt.getNumOutputTensors();
- bool all_output_valid = true;
- for (int i = 0; i < num_output_tensors; i++)
+ // Only generate output tensor if graph is valid.
+ if (main_gt.getGraphStatus() == GraphStatus::TOSA_VALID)
{
- const Tensor* ct = main_gt.getOutputTensor(i);
- ASSERT_MEM(ct);
- if (!ct->getIsValid())
+ // make sure output tensor is evaluated and show its value
+ int num_output_tensors = main_gt.getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
{
- ct->dumpTensorParams(g_func_debug.func_debug_file);
- if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ const Tensor* ct = main_gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
{
- ct->dumpTensor(g_func_debug.func_debug_file);
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
}
- all_output_valid = false;
}
- }
- if (!all_output_valid)
- {
- main_gt.dumpGraph(g_func_debug.func_debug_file);
- SIMPLE_FATAL_ERROR(
- "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
- }
+ if (!all_output_valid)
+ {
+ main_gt.dumpGraph(g_func_debug.func_debug_file);
+ SIMPLE_FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
- if (g_func_config.output_tensors)
- {
- if (writeFinalTensors(main_gt, test_desc))
+ if (g_func_config.output_tensors)
{
- WARNING("Errors encountered in saving output tensors");
+ if (writeFinalTensors(main_gt, test_desc))
+ {
+ WARNING("Errors encountered in saving output tensors");
+ }
}
}
}
done:
+ switch (main_gt.getGraphStatus())
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ fprintf(stderr, "Graph result: UNPREDICTABLE.\n");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ fprintf(stderr, "Graph result: ERROR.\n");
+ break;
+ default:
+ fprintf(stderr, "Unknown graph status code=%d.\n", (int)main_gt.getGraphStatus());
+ }
+
func_fini_debug(&g_func_debug);
func_model_config_cleanup();
- return 0;
+ return (int)main_gt.getGraphStatus();
}
int loadGraph(TosaSerializationHandler& tsh, json test_desc)
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 3410ba9..440f4e1 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -44,7 +44,7 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -69,7 +69,7 @@ int OpReluN<Rank, Dtype>::register_fcn()
}
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -84,7 +84,7 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -99,7 +99,7 @@ int OpTanh<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index b051b9d..c834b52 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -28,8 +28,8 @@ template <int Rank, DType Dtype>
class OpClamp : public UnaryNode<Rank, Dtype>
{
public:
- OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : UnaryNode<Rank, Dtype>(Op_CLAMP, id_)
+ OpClamp(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_)
{
INIT_ATTRIBUTE(Clamp);
register_fcn();
@@ -48,8 +48,8 @@ template <int Rank, DType Dtype>
class OpReluN : public UnaryNode<Rank, Dtype>
{
public:
- OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : UnaryNode<Rank, Dtype>(Op_RELUN, id_)
+ OpReluN(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_RELUN, id_)
{
INIT_ATTRIBUTE(ReluN);
register_fcn();
@@ -68,8 +68,8 @@ template <int Rank, DType Dtype>
class OpSigmoid : public UnaryNode<Rank, Dtype>
{
public:
- OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_)
+ OpSigmoid(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_SIGMOID, id_)
{
register_fcn();
}
@@ -84,8 +84,8 @@ template <int Rank, DType Dtype>
class OpTanh : public UnaryNode<Rank, Dtype>
{
public:
- OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : UnaryNode<Rank, Dtype>(Op_TANH, id_)
+ OpTanh(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(sgt_, Op_TANH, id_)
{
register_fcn();
}
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index 402e152..ab89e24 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -32,7 +32,7 @@ int OpEqual<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -48,7 +48,7 @@ int OpGreater<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -64,7 +64,7 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
index e75b1a6..5b4d0f1 100644
--- a/reference_model/src/ops/comparison.h
+++ b/reference_model/src/ops/comparison.h
@@ -28,8 +28,8 @@ template <int Rank, DType Dtype>
class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_)
{
register_fcn();
}
@@ -42,8 +42,8 @@ template <int Rank, DType Dtype>
class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_)
+ OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, qinfo_, id_)
{
register_fcn();
}
@@ -56,8 +56,8 @@ template <int Rank, DType Dtype>
class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
{
public:
- OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, qinfo_, id_)
{
register_fcn();
}
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 1a6a63a..0945056 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -20,8 +20,8 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
- : GraphNode(op_, id_)
+OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
{
tsh = tsh_;
}
@@ -148,8 +148,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
return 0;
}
-OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
- : OpControlFlow(tsh_, Op_COND_IF, id_)
+OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_)
{
INIT_ATTRIBUTE(CondIf);
}
@@ -221,8 +221,11 @@ int OpCondIf::eval()
return GraphNode::eval();
}
-OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
- : OpControlFlow(tsh_, Op_WHILE_LOOP, id_)
+OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_,
+ TosaSerializationHandler* tsh_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_)
{
INIT_ATTRIBUTE(WhileLoop);
}
diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h
index 14c11bc..879cd6a 100644
--- a/reference_model/src/ops/control_flow.h
+++ b/reference_model/src/ops/control_flow.h
@@ -25,7 +25,7 @@ namespace TosaReference
class OpControlFlow : public GraphNode
{
public:
- OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
+ OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
~OpControlFlow();
virtual int evalBlock(TosaSerializationBasicBlock* block,
@@ -39,7 +39,7 @@ protected:
class OpCondIf : public OpControlFlow
{
public:
- OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpCondIf();
virtual int checkTensorAttributes();
@@ -55,7 +55,7 @@ protected:
class OpWhileLoop : public OpControlFlow
{
public:
- OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ OpWhileLoop(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
virtual ~OpWhileLoop();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc
index 5c4f29b..5fc36f3 100644
--- a/reference_model/src/ops/custom.cc
+++ b/reference_model/src/ops/custom.cc
@@ -19,8 +19,8 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-OpCustom::OpCustom(uint64_t id_)
- : GraphNode(Op_CUSTOM, id_)
+OpCustom::OpCustom(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CUSTOM, id_)
{}
OpCustom::~OpCustom()
@@ -33,7 +33,7 @@ int OpCustom::checkTensorAttributes()
int OpCustom::eval()
{
- FATAL_ERROR_NODE("not supported yet");
+ FATAL_ERROR("not supported yet");
// Evaluation is trivial for constants
return GraphNode::eval();
diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h
index b1085a5..d14c809 100644
--- a/reference_model/src/ops/custom.h
+++ b/reference_model/src/ops/custom.h
@@ -26,7 +26,7 @@ namespace TosaReference
class OpCustom : public GraphNode
{
public:
- OpCustom(uint64_t id_);
+ OpCustom(SubgraphTraverser* sgt_, uint64_t id_);
virtual ~OpCustom();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index c66d64e..86326f5 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -21,8 +21,11 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
-OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_CONCAT, id_)
+OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_CONCAT, id_)
{
setRequiredOperands(-1, 1);
setRequiredRank(1, 6);
@@ -95,8 +98,11 @@ int OpConcat<Rank, Dtype>::eval()
}
template <int Rank, DType Dtype>
-OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_PAD, id_)
+OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
@@ -157,8 +163,11 @@ int OpPad<Rank, Dtype>::eval()
}
template <int InRank, int OutRank, DType Dtype>
-OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_RESHAPE, id_)
+OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_RESHAPE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -274,8 +283,11 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
}
template <int Rank, DType Dtype>
-OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_REVERSE, id_)
+OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_REVERSE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(1, 6);
@@ -339,8 +351,11 @@ int OpReverse<Rank, Dtype>::eval()
}
template <int Rank, DType Dtype>
-OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_SLICE, id_)
+OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -407,8 +422,11 @@ int OpSlice<Rank, Dtype>::eval()
}
template <int Rank, DType Dtype>
-OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_TILE, id_)
+OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -466,7 +484,7 @@ template <int Rank, DType Dtype>
int OpTile<Rank, Dtype>::eval()
{
// primary template shouldn't be called
- FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+ FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
}
template <DType Dtype>
@@ -542,8 +560,11 @@ int OpTile<4, Dtype>::eval()
}
template <int Rank, DType Dtype>
-OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_TRANSPOSE, id_)
+OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index b180b4f..c9c2602 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -27,7 +27,7 @@ template <int Rank, DType Dtype>
class OpConcat : public GraphNode
{
public:
- OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpConcat();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <int Rank, DType Dtype>
class OpPad : public GraphNode
{
public:
- OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpPad();
virtual int checkTensorAttributes();
virtual int eval();
@@ -70,7 +70,7 @@ template <int InRank, int OutRank, DType Dtype>
class OpReshape : public GraphNode
{
public:
- OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpReshape();
virtual int checkTensorAttributes();
@@ -94,7 +94,7 @@ template <int Rank, DType Dtype>
class OpReverse : public GraphNode
{
public:
- OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpReverse();
virtual int checkTensorAttributes();
@@ -116,7 +116,7 @@ template <int Rank, DType Dtype>
class OpSlice : public GraphNode
{
public:
- OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpSlice();
virtual int checkTensorAttributes();
@@ -139,7 +139,7 @@ template <int Rank, DType Dtype>
class OpTileBase : public GraphNode
{
public:
- OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpTileBase();
virtual int checkTensorAttributes();
@@ -160,8 +160,8 @@ template <int Rank, DType Dtype>
class OpTile : public OpTileBase<Rank, Dtype>
{
public:
- OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpTileBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_)
{}
protected:
@@ -174,8 +174,8 @@ protected:
class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
{ \
public: \
- OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(sgt_, attribute_, qinfo_, id_) \
{} \
\
protected: \
@@ -193,7 +193,7 @@ template <int Rank, DType Dtype>
class OpTranspose : public GraphNode
{
public:
- OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpTranspose();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index baae019..ec4bc41 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -19,8 +19,8 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-OpConst::OpConst(uint64_t id_)
- : GraphNode(Op_CONST, id_)
+OpConst::OpConst(SubgraphTraverser* sgt_, uint64_t id_)
+ : GraphNode(sgt_, Op_CONST, id_)
{
setRequiredOperands(0, 1);
}
@@ -43,8 +43,11 @@ int OpConst::eval()
}
template <int Rank, DType Dtype>
-OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_IDENTITY, id_)
+OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_IDENTITY, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
index a02d441..407cf0a 100644
--- a/reference_model/src/ops/data_nodes.h
+++ b/reference_model/src/ops/data_nodes.h
@@ -24,7 +24,7 @@ namespace TosaReference
class OpConst : public GraphNode
{
public:
- OpConst(uint64_t id_);
+ OpConst(SubgraphTraverser* sgt_, uint64_t id_);
virtual ~OpConst();
virtual int checkTensorAttributes();
@@ -35,7 +35,7 @@ template <int Rank, DType Dtype>
class OpIdentity : public GraphNode
{
public:
- OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpIdentity();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 3379ffe..16c4901 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -23,8 +23,11 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType InDtype, DType OutDtype>
-BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(op_, id_)
+BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
+ const Op& op_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
@@ -203,7 +206,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
}
return 0;
@@ -226,12 +229,12 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
num_bits = 32;
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
- ASSERT_MSG_NODE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
- (int32_t)b, num_bits);
+ REQUIRE(b >= 0 && b < num_bits, "OpArithmeticRightShift: shift value %d is out of valid range [0, %d]",
+ (int32_t)b, num_bits);
InEigenType acc = a >> b;
@@ -257,7 +260,7 @@ int OpBitwiseAnd<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -274,7 +277,7 @@ int OpBitwiseOr<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -291,7 +294,7 @@ int OpBitwiseXor<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -304,15 +307,15 @@ int OpDiv<Rank, Dtype>::register_fcn()
{
case DType_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
- ASSERT_MSG_NODE(b != 0, "OpDiv: divisor must be non-zero value");
+ REQUIRE(b != 0, "OpDiv: divisor must be non-zero value");
int64_t res_in_64 = static_cast<int64_t>(a) / b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
- ASSERT_MSG_NODE(a <= i32_max_in_64, "OpDiv: result not in i32 range");
+ REQUIRE(a <= i32_max_in_64, "OpDiv: result not in i32 range");
return static_cast<InEigenType>(res_in_64);
};
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
}
return 0;
@@ -327,7 +330,7 @@ int OpLogicalAnd<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -344,7 +347,7 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -366,7 +369,7 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn()
num_bits = 32;
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
@@ -386,7 +389,7 @@ int OpLogicalOr<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -401,7 +404,7 @@ int OpLogicalXor<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -417,7 +420,7 @@ int OpMaximum<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -433,7 +436,7 @@ int OpMinimum<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -443,8 +446,6 @@ template <int Rank, DType InDtype, DType OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
int32_t shift = attribute->shift();
- ASSERT_MSG_NODE(InDtype == DType_INT32 || shift == 0, "OpMul: shift needs to be 0 but is %d if input is %s", shift,
- EnumNamesDType()[InDtype]);
switch (InDtype)
{
@@ -460,8 +461,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
result = result >> shift;
- ASSERT_MSG_NODE(result >= QMin && result <= QMax,
- "OpMul: result %ld exceeds valid range [%ld, %ld]", result, QMin, QMax);
+ REQUIRE(result >= QMin && result <= QMax, "OpMul: result %ld exceeds valid range [%ld, %ld]",
+ result, QMin, QMax);
}
else
{
@@ -482,7 +483,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
};
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
}
return 0;
@@ -497,7 +498,7 @@ int OpPow<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -513,15 +514,18 @@ int OpSub<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
}
return 0;
}
template <int Rank, DType InDtype>
-OpTable<Rank, InDtype>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_TABLE, id_)
+OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_TABLE, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
@@ -607,7 +611,7 @@ int OpTable<Rank, InDtype>::eval()
});
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
}
return GraphNode::eval();
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index a5b1059..86b2101 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -42,7 +42,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
- BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
+ BinaryNodeBase(SubgraphTraverser* sgt_, const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
virtual ~BinaryNodeBase();
virtual int checkTensorAttributes() final;
@@ -76,8 +76,8 @@ template <int Rank, DType InDtype, DType OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
- BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
- : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_)
+ BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<Rank, InDtype, OutDtype>(sgt_, op_, qinfo_, id_)
{}
virtual ~BinaryNode()
{}
@@ -95,8 +95,8 @@ template <DType InDtype, DType OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
- BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
- : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_)
+ BinaryNode(SubgraphTraverser* sgt_, const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<0, InDtype, OutDtype>(sgt_, op_, qinfo_, id_)
{}
virtual ~BinaryNode()
{}
@@ -109,8 +109,8 @@ public:
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, qinfo_, id_) \
{ \
register_fcn(); \
} \
@@ -142,8 +142,11 @@ template <int Rank, DType Dtype>
class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
{
public:
- OpArithmeticRightShift(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, Dtype, Dtype>(Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
+ OpArithmeticRightShift(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_ARITHMETIC_RIGHT_SHIFT, qinfo_, id_)
{
INIT_ATTRIBUTE(ArithmeticRightShift);
register_fcn();
@@ -160,8 +163,8 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
- OpMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : BinaryNode<Rank, InDtype, OutDtype>(Op_MUL, qinfo_, id_)
+ OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, qinfo_, id_)
{
INIT_ATTRIBUTE(Mul);
register_fcn();
@@ -180,7 +183,7 @@ template <int Rank, DType InDtype>
class OpTable : public GraphNode
{
public:
- OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpTable();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index d4845f9..64c4412 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -20,8 +20,11 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
-OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_SELECT, id_)
+OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_SELECT, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(0, 6);
@@ -62,7 +65,7 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
template <int Rank, DType Dtype>
int OpSelectBase<Rank, Dtype>::eval()
{
- FATAL_ERROR_NODE("shouldn't be called");
+ FATAL_ERROR("shouldn't be called");
}
template <int Rank, DType Dtype>
@@ -78,9 +81,9 @@ int OpSelect<Rank, Dtype>::broadcast()
this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
- ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
- ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
- ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
+ ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
+ ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
}
return 0;
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index b354247..b80fb23 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -33,7 +33,7 @@ template <int Rank, DType Dtype>
class OpSelectBase : public GraphNode
{
public:
- OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpSelectBase();
virtual int checkTensorAttributes();
@@ -59,8 +59,8 @@ template <int Rank, DType Dtype>
class OpSelect : public OpSelectBase<Rank, Dtype>
{
public:
- OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<Rank, Dtype>(sgt_, attribute_, qinfo_, id_)
{}
virtual int eval();
int broadcast();
@@ -73,8 +73,8 @@ template <DType Dtype>
class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
{
public:
- OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_)
+ OpSelect(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<0, Dtype>(sgt_, attribute_, qinfo_, id_)
{}
virtual int eval();
};
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 95a1102..041bbdb 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -23,8 +23,8 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
-UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
- : GraphNode(op_, id_)
+UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -80,7 +80,7 @@ int OpAbs<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -97,7 +97,7 @@ int OpBitwiseNot<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -112,7 +112,7 @@ int OpCeil<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -128,7 +128,7 @@ int OpClz<Rank, Dtype>::register_fcn()
num_bits = 32;
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
this->fcn = [num_bits](int32_t a) -> int32_t {
@@ -159,7 +159,7 @@ int OpExp<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -174,7 +174,7 @@ int OpFloor<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -189,7 +189,7 @@ int OpLog<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -204,7 +204,7 @@ int OpLogicalNot<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -213,6 +213,12 @@ int OpLogicalNot<Rank, Dtype>::register_fcn()
template <int Rank, DType Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
+ if (Dtype != DType_INT8 && this->qinfo)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(this->qinfo->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ }
+
switch (Dtype)
{
case DType_FLOAT:
@@ -229,7 +235,6 @@ int OpNegate<Rank, Dtype>::register_fcn()
};
break;
case DType_INT8:
- ASSERT(this->qinfo);
this->fcn = [this](InEigenType a) -> OutEigenType {
InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
result = std::min(std::max(result, static_cast<InEigenType>(QMin)), static_cast<InEigenType>(QMax));
@@ -237,7 +242,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
};
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -252,7 +257,7 @@ int OpReciprocal<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
@@ -267,7 +272,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
break;
default:
- FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
return 0;
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 0db3cfb..374c8e4 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -26,7 +26,7 @@ template <int Rank, DType Dtype>
class UnaryNode : public GraphNode
{
public:
- UnaryNode(const Op& nodeType, const uint64_t id_);
+ UnaryNode(SubgraphTraverser* sgt_, const Op& nodeType, const uint64_t id_);
virtual ~UnaryNode();
virtual int checkTensorAttributes() final;
@@ -49,8 +49,8 @@ protected:
class Op##Opname : public UnaryNode<Rank, Dtype> \
{ \
public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
} \
@@ -66,8 +66,8 @@ protected:
class Op##Opname : public UnaryNode<Rank, Dtype> \
{ \
public: \
- Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
- : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
INIT_QINFO(Unary); \
register_fcn(); \
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 829a6e0..f4decae 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -22,8 +22,11 @@ using namespace Eigen;
using namespace tosa;
template <DType InDtype, DType OutDtype>
-OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_RESIZE, id_)
+OpResize<InDtype, OutDtype>::OpResize(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_RESIZE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(4, 4);
@@ -102,10 +105,13 @@ int OpResize<InDtype, OutDtype>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
- ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]");
- ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride");
- ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
- ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+ ERROR_IF(shift < 1 || shift > 11, "OpResize: attribute shift should be within [1, 11]");
+ ERROR_IF(stride[0] <= 0 || stride[0] >= (16 << shift), "OpResize: invalid attribute stride_x");
+ ERROR_IF(stride[1] <= 0 || stride[1] >= (16 << shift), "OpResize: invalid attribute stride_y");
+ ERROR_IF(offset[0] <= (-16 << shift) || offset[0] >= (16 << shift), "OpResize: invalid attribute offset_x");
+ ERROR_IF(offset[1] <= (-16 << shift) || offset[1] >= (16 << shift), "OpResize: invalid attribute offset_y");
+ ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch");
+ ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch");
for (int b = 0; b < out_batch; b++)
for (int c = 0; c < out_channels; c++)
@@ -125,8 +131,8 @@ int OpResize<InDtype, OutDtype>::eval()
int32_t ix0 = MAX(ix, 0);
int32_t ix1 = MIN(ix + 1, in_width - 1);
- ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
- iy0, iy1, ix0, ix1);
+ REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0,
+ iy1, ix0, ix1);
OutEigenType acc;
if (mode == ResizeMode_BILINEAR)
@@ -167,10 +173,10 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
- ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift");
- ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride");
- ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
- ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+ ERROR_IF(shift != 0, "OpResize: float mode must have 0 shift");
+ ERROR_IF(stride_fp[0] <= 0.0f || stride_fp[1] <= 0.0f, "OpResize: invalid attribute stride");
+ ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch");
+ ERROR_IF(in_channels != out_channels, "OpResize: output tensor channel mismatch");
for (int b = 0; b < out_batch; b++)
for (int c = 0; c < out_channels; c++)
@@ -190,8 +196,8 @@ int OpResize<DType_FLOAT, DType_FLOAT>::eval()
int32_t ix0 = MAX(ix, 0);
int32_t ix1 = MIN(ix + 1, in_width - 1);
- ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
- iy0, iy1, ix0, ix1);
+ REQUIRE(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0,
+ iy1, ix0, ix1);
OutEigenType acc;
if (mode == ResizeMode_BILINEAR)
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 5dd14c8..095dc7d 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -27,7 +27,7 @@ template <DType InDtype, DType OutDtype>
class OpResize : public GraphNode
{
public:
- OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpResize();
virtual int checkTensorAttributes() final;
virtual int eval();
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 726ab7c..2d9e428 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -32,7 +32,8 @@
using namespace TosaReference;
using namespace tosa;
-GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
+GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
+ TosaSerializationHandler* tsh,
Op opType,
TosaAttributeBase* attribute,
TosaQuantInfoBase* qinfo,
@@ -349,7 +350,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// data_nodes
case Op_CONST:
- return new OpConst(id);
+ return new OpConst(sgt, id);
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
@@ -398,13 +399,13 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// custom
case Op_CUSTOM:
- return new OpCustom(id);
+ return new OpCustom(sgt, id);
// control_flow
case Op_COND_IF:
- return new OpCondIf(tsh, attribute, id);
+ return new OpCondIf(sgt, tsh, attribute, id);
case Op_WHILE_LOOP:
- return new OpWhileLoop(tsh, attribute, id);
+ return new OpWhileLoop(sgt, tsh, attribute, id);
// Ops not recognized
default:
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 0c116b6..eaa359c 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -24,55 +24,55 @@
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id);
+ return new OP<RANK, DType_##DTYPE>(sgt, attribute, qinfo, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+ return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id);
+ return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, qinfo, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+ return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
{ \
case 0: \
- return new OP<0>(attribute, qinfo, id); \
+ return new OP<0>(sgt, attribute, qinfo, id); \
case 1: \
- return new OP<1>(attribute, qinfo, id); \
+ return new OP<1>(sgt, attribute, qinfo, id); \
case 2: \
- return new OP<2>(attribute, qinfo, id); \
+ return new OP<2>(sgt, attribute, qinfo, id); \
case 3: \
- return new OP<3>(attribute, qinfo, id); \
+ return new OP<3>(sgt, attribute, qinfo, id); \
case 4: \
- return new OP<4>(attribute, qinfo, id); \
+ return new OP<4>(sgt, attribute, qinfo, id); \
case 5: \
- return new OP<5>(attribute, qinfo, id); \
+ return new OP<5>(sgt, attribute, qinfo, id); \
case 6: \
- return new OP<6>(attribute, qinfo, id); \
+ return new OP<6>(sgt, attribute, qinfo, id); \
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
if (inputDType == DType_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(attribute, qinfo, id); \
+ return new OP<DType_##DTYPE>(sgt, attribute, qinfo, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, qinfo, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
@@ -221,10 +221,14 @@
namespace TosaReference
{
+class SubgraphTraverser;
+class GraphNode;
+
class OpFactory
{
public:
- static GraphNode* newOp(tosa::TosaSerializationHandler* tsh,
+ static GraphNode* newOp(SubgraphTraverser* sgt,
+ tosa::TosaSerializationHandler* tsh,
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
tosa::TosaQuantInfoBase* qinfo,
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 97a7aa7..107c7a8 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -21,8 +21,8 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
-ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
- : GraphNode(op_, id_)
+ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
+ : GraphNode(sgt_, op_, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 4);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
index cf75812..f4e29b9 100644
--- a/reference_model/src/ops/reduction.h
+++ b/reference_model/src/ops/reduction.h
@@ -27,7 +27,7 @@ template <int Rank, DType Dtype>
class ReduceNode : public GraphNode
{
public:
- ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
+ ReduceNode(SubgraphTraverser* sgt_, const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
virtual ~ReduceNode();
virtual int checkTensorAttributes();
virtual int eval() = 0;
@@ -48,8 +48,8 @@ template <int Rank, DType Dtype>
class OpReduceAll : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ OpReduceAll(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
{}
virtual int eval();
};
@@ -58,8 +58,8 @@ template <int Rank, DType Dtype>
class OpReduceAny : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ OpReduceAny(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_ALL, attribute_, id_)
{}
virtual int eval();
};
@@ -68,8 +68,8 @@ template <int Rank, DType Dtype>
class OpReduceMax : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_)
+ OpReduceMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MAX, attribute_, id_)
{}
virtual int eval();
};
@@ -78,8 +78,8 @@ template <int Rank, DType Dtype>
class OpReduceMin : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_)
+ OpReduceMin(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_MIN, attribute_, id_)
{}
virtual int eval();
};
@@ -88,8 +88,8 @@ template <int Rank, DType Dtype>
class OpReduceProduct : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_)
+ OpReduceProduct(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
{}
virtual int eval();
};
@@ -98,8 +98,8 @@ template <int Rank, DType Dtype>
class OpReduceSum : public ReduceNode<Rank, Dtype>
{
public:
- OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_)
+ OpReduceSum(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
{}
virtual int eval();
};
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index 478b776..02ec54f 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -21,8 +21,11 @@ using namespace Eigen;
using namespace tosa;
template <DType Dtype>
-OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_GATHER, id_)
+OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_GATHER, id_)
{
setRequiredOperands(2, 1);
}
@@ -102,7 +105,7 @@ int OpGather<Dtype>::eval()
for (int32_t w = 0; w < W; w++)
{
int32_t k = this->indices->getTensor()(n, w);
- ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ REQUIRE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
for (int32_t c = 0; c < C; c++)
{
EigenType value = this->values->getTensor()(n, k, c);
@@ -115,8 +118,11 @@ int OpGather<Dtype>::eval()
}
template <DType Dtype>
-OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_SCATTER, id_)
+OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_SCATTER, id_)
{
setRequiredOperands(3, 1);
}
@@ -206,7 +212,7 @@ int OpScatter<Dtype>::eval()
for (int w = 0; w < W; w++)
{
int32_t k = this->indices->getTensor()(n, w);
- ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
for (int c = 0; c < C; c++)
{
EigenType value = this->input->getTensor()(n, w, c);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index 17ea723..66b584a 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -27,7 +27,7 @@ template <DType Dtype>
class OpGather : public GraphNode
{
public:
- OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpGather();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <DType Dtype>
class OpScatter : public GraphNode
{
public:
- OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpScatter();
virtual int checkTensorAttributes();
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 0007553..045c0a5 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -22,8 +22,11 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType Dtype>
-OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_ARGMAX, id_)
+OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -66,8 +69,11 @@ int OpArgMax<Rank, Dtype>::eval()
}
template <DType Dtype>
-OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_AVG_POOL2D, id_)
+OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(4);
@@ -142,9 +148,6 @@ ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_
int32_t left_index = pad_left / stride;
int32_t right_index = pad_right / stride;
- // not handle ultra small activation yet
- ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet");
-
// minus the number of pad bit this index cover
while (left_index >= 0)
{
@@ -176,7 +179,8 @@ int OpAvgPool2d<Dtype>::eval()
int out_width = this->out->getShape()[2];
int out_channels = this->out->getShape()[3];
- ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(in_batch != out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(in_channels != out_channels, "OpAvgPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
int padding_top = this->attribute->padding()[0];
int padding_bottom = this->attribute->padding()[1];
@@ -260,12 +264,19 @@ int OpAvgPool2d<Dtype>::eval()
if (Dtype != DType_FLOAT)
{
- this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
- int32_t multiplier, shift;
- TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
+ try
+ {
+ this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
+ int32_t multiplier, shift;
+ TosaReference::QuantUtil::reciprocal_scale(div, multiplier, shift);
- return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
- });
+ return (OutEigenType)TosaReference::QuantUtil::apply_scale_32(value, multiplier, shift, false);
+ });
+ }
+ catch (std::string desc)
+ {
+ REQUIRE(false, "OpAvgPool2d apply_scale_32() fails: %s.", desc.c_str());
+ }
this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
@@ -279,8 +290,11 @@ int OpAvgPool2d<Dtype>::eval()
}
template <DType InDtype, DType WeightDtype>
-OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_CONV2D, id_)
+OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_CONV2D, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(4);
@@ -361,13 +375,12 @@ int OpConv2d<InDtype, WeightDtype>::eval()
int out_width = this->output->getShape()[2];
int out_channels = this->output->getShape()[3];
- ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
- ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
- in_channels);
- ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
- out_channels);
- ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels,
- out_channels);
+ ERROR_IF(in_batch != out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(f_in_channels != in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
+ out_channels);
+ ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
int padding_top = this->attribute->padding()[0];
int padding_bottom = this->attribute->padding()[1];
@@ -469,10 +482,11 @@ int OpConv2d<InDtype, WeightDtype>::eval()
}
template <DType InDtype, DType WeightDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_,
+OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
TosaQuantInfoBase* qinfo_,
uint64_t id_)
- : GraphNode(Op_DEPTHWISE_CONV2D, id_)
+ : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(4);
@@ -553,14 +567,13 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
int out_width = this->output->getShape()[2];
int out_channels = this->output->getShape()[3];
- ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
- ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d",
- f_in_channels, in_channels);
- ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels,
- "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier,
- out_channels);
- ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
- b_out_channels, out_channels);
+ ERROR_IF(in_batch != out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(f_in_channels != in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
+ in_channels * f_multiplier, out_channels);
+ ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
+ out_channels);
int padding_top = this->attribute->padding()[0];
int padding_bottom = this->attribute->padding()[1];
@@ -651,10 +664,11 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
}
template <DType InDtype, DType WeightDtype>
-OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_,
+OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
TosaQuantInfoBase* qinfo_,
uint64_t id_)
- : GraphNode(Op_FULLY_CONNECTED, id_)
+ : GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(2);
@@ -738,8 +752,11 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
}
template <DType Dtype>
-OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_MATMUL, id_)
+OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
setRequiredRank(3);
@@ -866,8 +883,11 @@ int OpMatMul<Dtype>::eval()
}
template <DType Dtype>
-OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_MAX_POOL2D, id_)
+OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(4);
@@ -936,7 +956,8 @@ int OpMaxPool2d<Dtype>::eval()
int out_width = this->out->getShape()[2];
int out_channels = this->out->getShape()[3];
- ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(in_batch != out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(in_channels != out_channels, "OpMaxPool2d: tensor channel mismatch %d != %d", in_channels, out_channels);
int padding_top = this->attribute->padding()[0];
int padding_bottom = this->attribute->padding()[1];
@@ -1004,10 +1025,11 @@ int OpMaxPool2d<Dtype>::eval()
}
template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_,
+OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
TosaQuantInfoBase* qinfo_,
uint64_t id_)
- : GraphNode(Op_TRANSPOSE_CONV2D, id_)
+ : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
setRequiredRank(4);
@@ -1104,13 +1126,13 @@ int OpTransposeConv2d<InDtype, OutDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
- ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
- ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d",
- f_in_channels, in_channels);
- ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
- f_out_channels, out_channels);
- ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
- b_out_channels, out_channels);
+ ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
+ f_out_channels, out_channels);
+ ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
+ out_channels);
DEBUG_INFO(OP,
"perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 9aaa140..6ffc27d 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -28,7 +28,7 @@ template <int Rank, DType Dtype>
class OpArgMax : public GraphNode
{
public:
- OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpArgMax();
virtual int checkTensorAttributes();
@@ -49,7 +49,7 @@ template <DType Dtype>
class OpAvgPool2d : public GraphNode
{
public:
- OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpAvgPool2d();
virtual int checkTensorAttributes();
@@ -80,7 +80,7 @@ template <DType InDtype, DType WeightDtype>
class OpConv2d : public GraphNode
{
public:
- OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpConv2d();
virtual int checkTensorAttributes() final;
@@ -112,7 +112,7 @@ template <DType InDtype, DType WeightDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
- OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpDepthwiseConv2d();
virtual int checkTensorAttributes() final;
@@ -144,7 +144,7 @@ template <DType InDtype, DType WeightDtype>
class OpFullyConnected : public GraphNode
{
public:
- OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpFullyConnected();
virtual int checkTensorAttributes() final;
@@ -174,7 +174,7 @@ template <DType Dtype>
class OpMatMul : public GraphNode
{
public:
- OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpMatMul();
virtual int checkTensorAttributes() final;
@@ -205,7 +205,7 @@ template <DType Dtype>
class OpMaxPool2d : public GraphNode
{
public:
- OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpMaxPool2d();
virtual int checkTensorAttributes();
@@ -226,7 +226,7 @@ template <DType InDtype, DType WeightDtype>
class OpTransposeConv2d : public GraphNode
{
public:
- OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpTransposeConv2d();
virtual int checkTensorAttributes() final;
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index d988c57..657eebf 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -23,8 +23,11 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, DType InDtype, DType OutDtype>
-OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_RESCALE, id_)
+OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_RESCALE, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
@@ -101,53 +104,68 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
int32_t channel_multiplier, channel_shift;
Eigen::array<Eigen::Index, 2> begin, size;
size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
- for (int32_t i = 0; i < shape_2d[1]; i++)
+ try
{
- begin = Eigen::array<Eigen::Index, 2>({ 0, i });
- curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
- channel_multiplier = multiplier[i];
- channel_shift = shift[i];
- curr_channel_slice_postscaled =
- curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
- double_round, scale32](InEigenType in_val) -> OutEigenType {
- InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled;
- if (scale32)
- scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
- channel_shift, double_round);
- else
- scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
- channel_shift);
- OutEigenType out_val = (OutEigenType)(scaled + output_zp);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
- return out_val;
- });
-
- for (int32_t j = 0; j < shape_2d[0]; j++)
+ for (int32_t i = 0; i < shape_2d[1]; i++)
{
- output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
+ double_round, scale32](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, channel_multiplier,
+ channel_shift, double_round);
+ else
+ scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, channel_multiplier,
+ channel_shift);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+
+ for (int32_t j = 0; j < shape_2d[0]; j++)
+ {
+ output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
+ }
}
}
+ catch (std::string desc)
+ {
+ REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
+ }
}
else
{
int32_t tensor_multiplier = multiplier[0];
int32_t tensor_shift = shift[0];
- output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
- scale32](InEigenType in_val) -> OutEigenType {
- InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
- int32_t scaled;
- if (scale32)
- scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
- double_round);
- else
- scaled = TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
- OutEigenType out_val = (OutEigenType)(scaled + output_zp);
- out_val = std::max<OutEigenType>(out_val, QMin);
- out_val = std::min<OutEigenType>(out_val, QMax);
- return out_val;
- });
+ try
+ {
+ output_2d = input_reshaped.unaryExpr([input_zp, output_zp, tensor_multiplier, tensor_shift, double_round,
+ scale32](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled;
+ if (scale32)
+ scaled = TosaReference::QuantUtil::apply_scale_32(input_zp_shifted, tensor_multiplier, tensor_shift,
+ double_round);
+ else
+ scaled =
+ TosaReference::QuantUtil::apply_scale_16(input_zp_shifted, tensor_multiplier, tensor_shift);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+ }
+ catch (std::string desc)
+ {
+ REQUIRE(false, "OpRescale apply_scale_32/16() fails: %s.", desc.c_str());
+ }
}
// reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
@@ -162,8 +180,11 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
}
template <int Rank, DType InDtype, DType OutDtype>
-OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
- : GraphNode(Op_CAST, id_)
+OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_CAST, id_)
{
setRequiredOperands(1, 1);
setRequiredRank(0, 6);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 6ec4d6d..060e14e 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -26,7 +26,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpRescale : public GraphNode
{
public:
- OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpRescale();
virtual int checkTensorAttributes() final;
@@ -140,7 +140,7 @@ template <int Rank, DType InDtype, DType OutDtype>
class OpCast : public GraphNode
{
public:
- OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
virtual ~OpCast();
virtual int checkTensorAttributes() final;
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index c595869..4f6a525 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -44,9 +44,17 @@ public:
static int32_t apply_scale_32(int32_t value, int32_t multiplier, int32_t shift, bool double_round = true)
{
- ASSERT_MSG(multiplier >= 0, "apply_scale_32() error: multiplier should >= 0 but is %d", multiplier);
- ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_32() error: shift should be within [2, 62] but is %d",
- shift);
+ if (multiplier < 0)
+ {
+ std::string desc = "apply_scale_32() error: multiplier should >= 0 but is " + std::to_string(multiplier);
+ throw desc;
+ }
+ if (shift < 2 || shift > 62)
+ {
+ std::string desc =
+ "apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift);
+ throw desc;
+ }
int64_t round = 1L << (shift - 1);
if (double_round)
{
@@ -57,21 +65,35 @@ public:
}
int64_t result = (int64_t)value * multiplier + round;
result = result >> shift;
- ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
- "apply_scale_32() error: scaled result exceed int32 numeric range");
+ if (result < -(1L << 31) || result >= (1L << 31))
+ {
+ std::string desc = "apply_scale_32() error: scaled result exceeds int32 numeric range";
+ throw desc;
+ }
return static_cast<int32_t>(result);
}
static int32_t apply_scale_16(int64_t value, int16_t multiplier, int32_t shift)
{
- ASSERT_MSG(multiplier >= 0, "apply_scale_16() error: multiplier should >= 0 but is %d", multiplier);
- ASSERT_MSG(shift >= 2 && shift <= 62, "apply_scale_16() error: shift should be within [2, 62] but is %d",
- shift);
+ if (multiplier < 0)
+ {
+ std::string desc = "apply_scale_16() error: multiplier should >= 0 but is " + std::to_string(multiplier);
+ throw desc;
+ }
+ if (shift < 2 || shift > 62)
+ {
+ std::string desc =
+ "apply_scale_16(): shift value should stay within [2, 62] but is " + std::to_string(shift);
+ throw desc;
+ }
int64_t round = 1L << (shift - 1);
int64_t result = value * (int64_t)multiplier + round;
result = result >> shift;
- ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
- "apply_scale_16() error: scaled result exceed int32 numeric range");
+ if (result < -(1L << 31) || result >= (1L << 31))
+ {
+ std::string desc = "apply_scale_16() error: scaled result exceeds int32 numeric range";
+ throw desc;
+ }
return static_cast<int32_t>(result);
}
};
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index bdf6fbc..ef7bae6 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -21,6 +21,8 @@ using namespace tosa;
SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
{
+ graph_status = GraphStatus::TOSA_VALID;
+
block = _block;
tsh = _tsh;
@@ -166,7 +168,7 @@ int SubgraphTraverser::initializeGraph()
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
- GraphNode* node = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
+ GraphNode* node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, input_dtype,
input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
if (!node)
{
@@ -221,16 +223,25 @@ int SubgraphTraverser::initializeGraph()
for (auto ts : block->GetTensors())
{
+ // Bail out if any dimension is invalid.
+ for (auto& dim : ts->GetShape())
+ {
+ if (dim <= 0)
+ {
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
+ }
+ }
DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
TosaReference::Tensor* tensor =
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
if (!ts->GetData().empty())
{
if (tensor->allocate())
{
- WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
- return 1;
+ SIMPLE_FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
switch (ts->GetDtype())
@@ -316,7 +327,7 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str());
+ FATAL_ERROR("loadGraphJson: Failed to find input tensor by name %s", input_name.c_str());
}
}
@@ -332,7 +343,7 @@ int SubgraphTraverser::initializeGraph()
}
else
{
- FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str());
+ FATAL_ERROR("loadGraphJson: Failed to find output tensor by name %s", output_name.c_str());
}
}
@@ -395,13 +406,14 @@ int SubgraphTraverser::evaluateNextNode()
if (!tensor->is_allocated())
if (tensor->allocate())
{
- FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str());
+ FATAL_ERROR("Failed to allocate Eigen tensor %s", tensor->getName().c_str());
}
}
if (currNode->eval())
{
- FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID());
+ WARNING("Failed to evaluate node: %lu", currNode->getID());
+ return 1;
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 3f4eecf..4be6c1f 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -16,15 +16,22 @@
#ifndef SUBGRAPH_TRAVERSER_H
#define SUBGRAPH_TRAVERSER_H
-#include "model_common.h"
-
#include "graph_node.h"
+#include "model_common.h"
#include "ops/op_factory.h"
+#include "tensor.h"
#include "tosa_serialization_handler.h"
namespace TosaReference
{
+enum class GraphStatus : int
+{
+ TOSA_VALID = 0,
+ TOSA_UNPREDICTABLE = 1,
+ TOSA_ERROR = 2,
+};
+
class SubgraphTraverser
{
public:
@@ -36,6 +43,15 @@ public:
int evaluateNextNode();
int evaluateAll();
+ GraphStatus getGraphStatus() const
+ {
+ return graph_status;
+ }
+ void setGraphStatus(GraphStatus status)
+ {
+ graph_status = status;
+ }
+
int linkTensorsAndNodes();
int validateGraph();
@@ -59,6 +75,8 @@ private:
GraphNode* getNextNode();
+ GraphStatus graph_status;
+
// pointer to serialization library and corresponding basic block
TosaSerializationBasicBlock* block;
TosaSerializationHandler* tsh;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 1efebe3..f2a3a98 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -20,15 +20,13 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-TosaReference::Tensor::Tensor(std::string tensorName_,
- DType tensorDtype_,
- std::vector<int> shape_)
-{
- tensorName = std::string(tensorName_);
- tensorDtype = tensorDtype_;
- shape = std::vector<int>(shape_);
- producer = nullptr;
- isValid = false;
+TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
+{
+ tensorName = std::string(tensorName_);
+ tensorDtype = tensorDtype_;
+ shape = std::vector<int>(shape_);
+ producer = nullptr;
+ isValid = false;
consumers.clear();
isSubgraphInput = false;
isSubgraphOutput = false;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 6c0622e..e97554f 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -604,13 +604,6 @@ class TensorFactory
public:
static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
{
- // Bail out if any dimension is invalid.
- for (auto& dim : shape_)
- {
- if (dim <= 0)
- goto done;
- }
-
switch (tensorDtype_)
{
case DType_FLOAT:
@@ -697,7 +690,6 @@ public:
break;
}
- done:
std::string shape_str("[");
for (auto& dim : shape_)
{