aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-07 20:49:09 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-28 16:38:53 -0800
commit4762564da970eb1883a54aa66582e05c0dbd2b81 (patch)
tree657e14aa711f1c9e55a5fde15fac3f7f9f77e536
parent09ae449db8a45ab7c48af4541b43cb3dc80f9a30 (diff)
downloadreference_model-4762564da970eb1883a54aa66582e05c0dbd2b81.tar.gz
[reference_model] Support StatefulOps and the tests for CallOnceOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2
-rw-r--r--examples/test_stateful_op/desc.json16
-rw-r--r--examples/test_stateful_op/placeholder_0.npybin0 -> 132 bytes
-rw-r--r--examples/test_stateful_op/test_variable_add.tosabin0 -> 616 bytes
-rw-r--r--examples/test_stateful_op/variable_0.npybin0 -> 132 bytes
-rw-r--r--reference_model/include/func_config.h40
-rw-r--r--reference_model/src/command_line_utils.h5
-rw-r--r--reference_model/src/graph_node.cc14
-rw-r--r--reference_model/src/graph_node.h20
-rw-r--r--reference_model/src/main.cpp152
-rw-r--r--reference_model/src/subgraph_traverser.cc130
-rw-r--r--reference_model/src/subgraph_traverser.h9
-rw-r--r--reference_model/src/tensor.cc7
-rw-r--r--reference_model/src/tensor.h27
-rw-r--r--verif/frameworks/test_builder.py38
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py33
15 files changed, 442 insertions, 49 deletions
diff --git a/examples/test_stateful_op/desc.json b/examples/test_stateful_op/desc.json
new file mode 100644
index 0000000..1f1459f
--- /dev/null
+++ b/examples/test_stateful_op/desc.json
@@ -0,0 +1,16 @@
+{
+ "tosa_file": "test_variable_add.tosa",
+ "ifm_name": [
+ "TosaInput_0"
+ ],
+ "ifm_file": [
+ "placeholder_0.npy"
+ ],
+ "variable_name": [
+ "Variable_0"
+ ],
+ "variable_file": [
+ "variable_0.npy"
+ ],
+ "expected_failure": false
+} \ No newline at end of file
diff --git a/examples/test_stateful_op/placeholder_0.npy b/examples/test_stateful_op/placeholder_0.npy
new file mode 100644
index 0000000..f9688e6
--- /dev/null
+++ b/examples/test_stateful_op/placeholder_0.npy
Binary files differ
diff --git a/examples/test_stateful_op/test_variable_add.tosa b/examples/test_stateful_op/test_variable_add.tosa
new file mode 100644
index 0000000..2be3044
--- /dev/null
+++ b/examples/test_stateful_op/test_variable_add.tosa
Binary files differ
diff --git a/examples/test_stateful_op/variable_0.npy b/examples/test_stateful_op/variable_0.npy
new file mode 100644
index 0000000..c183a15
--- /dev/null
+++ b/examples/test_stateful_op/variable_0.npy
Binary files differ
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index 1e93b89..860d8c6 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -35,24 +35,28 @@ struct tosa_level_t
struct func_config_t
{
- std::string operator_fbs = "tosa.fbs";
- std::string test_desc = "desc.json";
- std::string flatbuffer_dir = "";
- std::string output_dir = "";
- std::string tosa_file = "";
- std::string ifm_name = "";
- std::string ifm_file = "";
- std::string ofm_name = "";
- std::string ofm_file = "";
- uint32_t eval = 1;
- uint32_t validate_only = 0;
- uint32_t output_tensors = 1;
- uint32_t tosa_profile = 1;
- uint32_t dump_intermediates = 0;
- std::string fp_format = "0.5";
- uint32_t precise_mode = 0;
- bool abs_mode = 0; // set in main as second run of precise_mode
- bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
+ std::string operator_fbs = "tosa.fbs";
+ std::string test_desc = "desc.json";
+ std::string flatbuffer_dir = "";
+ std::string output_dir = "";
+ std::string tosa_file = "";
+ std::string ifm_name = "";
+ std::string ifm_file = "";
+ std::string ofm_name = "";
+ std::string ofm_file = "";
+ std::string variable_name = "";
+ std::string variable_file = "";
+
+ uint32_t eval = 1;
+ uint32_t validate_only = 0;
+ uint32_t output_tensors = 1;
+ uint32_t tosa_profile = 1;
+ uint32_t dump_intermediates = 0;
+ uint32_t initialize_variable_tensor_from_numpy = 0;
+ std::string fp_format = "0.5";
+ uint32_t precise_mode = 0;
+ bool abs_mode = 0; // set in main as second run of precise_mode
+ bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
tosa_level_t tosa_level;
static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 64 };
diff --git a/reference_model/src/command_line_utils.h b/reference_model/src/command_line_utils.h
index dcb0564..f8031d9 100644
--- a/reference_model/src/command_line_utils.h
+++ b/reference_model/src/command_line_utils.h
@@ -48,6 +48,10 @@ int func_model_parse_cmd_line(
cxxopts::value<std::string>(func_config.ofm_name))
("ofm_file", "Output tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.",
cxxopts::value<std::string>(func_config.ofm_file))
+ ("variable_name", "Region tensor name. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.",
+ cxxopts::value<std::string>(func_config.variable_name))
+ ("variable_file", "Region tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.",
+ cxxopts::value<std::string>(func_config.variable_file))
("eval", "Evaluate the network (0/1)", cxxopts::value<uint32_t>(func_config.eval))
("fp_format", "Floating-point number dump format string (printf-style format, e.g. 0.5)",
cxxopts::value<std::string>(func_config.fp_format))
@@ -60,6 +64,7 @@ int func_model_parse_cmd_line(
cxxopts::value<tosa_level_t>(func_config.tosa_level))
("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value<uint32_t>(func_config.dump_intermediates))
("p,precise_mode", "Calculate floating point operations in FP64 (0/1)", cxxopts::value<uint32_t>(func_config.precise_mode))
+ ("initialize_variable_tensor_from_numpy", "Initialize variable tensors from flatbuffer (0, default) or numpy (1)", cxxopts::value<uint32_t>(func_config.initialize_variable_tensor_from_numpy))
("v,version", "print model version")
("i,input_tensor_file", "specify input tensor files", cxxopts::value<std::vector<std::string>>())
("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value<std::string>())
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
index 1781e40..c8a0b29 100644
--- a/reference_model/src/graph_node.cc
+++ b/reference_model/src/graph_node.cc
@@ -31,6 +31,7 @@ GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const
clearNodeMarked();
evalCount = 0;
clearOnNextNodeList();
+ clearEvaluated();
setRequiredOperands(-1, -1);
setRequiredRank(-1);
inMainBlock = false;
@@ -102,6 +103,12 @@ int GraphNode::hasAllOutputsReady() const
{
if (!outputs[i]->getIsValid())
return false;
+ if (outputs[i]->getIsVariable())
+ {
+ // when output is a variable tensor
+ // isValid is not reliable indicator of this node having been evaluated
+ return false;
+ }
}
return true;
@@ -110,8 +117,8 @@ int GraphNode::hasAllOutputsReady() const
int GraphNode::dumpNode(FILE* out)
{
int i;
- fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
- nodeId, evalCount, onNextNodeList, isMarked);
+ fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Evaluated: %d Is marked: %d\n",
+ EnumNamesOp()[nodeType], nodeId, evalCount, onNextNodeList, evaluated, isMarked);
i = 0;
for (Tensor* ins : inputs)
@@ -135,7 +142,8 @@ int GraphNode::dumpNode(std::ostream& out)
int i;
out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
- << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
+ << " On next node list: " << onNextNodeList << "Evaluated: " << evaluated << " Is marked: " << isMarked
+ << std::endl;
out << " Inputs:";
for (std::string& name : inputNames)
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 900f4b8..b3fe8d6 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -254,6 +254,23 @@ public:
return inMainBlock;
}
+ int getEvaluated() const
+ {
+ return evaluated;
+ }
+
+ int setEvaluated()
+ {
+ evaluated = true;
+ return 0;
+ }
+
+ int clearEvaluated()
+ {
+ evaluated = false;
+ return 0;
+ }
+
// Helper functions.
int idiv_check(int input1, int input2, int& result);
@@ -317,6 +334,9 @@ protected:
// next-node list.
int onNextNodeList;
+ // Flag indicating that this node has been evaluated before
+ int evaluated;
+
// Required input/output tensor counts for node validation
// -1 means any number is allowed
int requiredInputCount;
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index cb7f0a2..62b8f6f 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -33,8 +33,11 @@ using namespace tosa;
using json = nlohmann::json;
int initTestDesc(json& test_desc);
+
int readInputTensors(SubgraphTraverser& gt, json& test_desc);
int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix);
+int readVariableTensors(SubgraphTraverser& gt, json test_desc);
+int writeVariableTensors(SubgraphTraverser& gt, json test_desc);
int loadGraph(TosaSerializationHandler& tsh, json& test_desc);
void parse_value(const std::string& text, tosa_level_t& value);
const std::string getResultFilenamePrefix();
@@ -131,6 +134,14 @@ int main(int argc, char** argv)
goto done;
}
+ if (g_func_config.initialize_variable_tensor_from_numpy)
+ {
+ if (readVariableTensors(main_gt, test_desc))
+ {
+ FATAL_ERROR("Unable to read variable tensors");
+ }
+ }
+
// evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
if (main_gt.evaluateAll())
{
@@ -177,6 +188,11 @@ int main(int argc, char** argv)
{
WARNING("Errors encountered in saving output tensors");
}
+
+ if (writeVariableTensors(main_gt, test_desc))
+ {
+ WARNING("Errors encountered in writing variable tensors");
+ }
}
}
@@ -312,7 +328,7 @@ int readInputTensors(SubgraphTraverser& gt, json& test_desc)
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{
- if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
{
gt.addToNextNodeList(gn);
}
@@ -395,6 +411,124 @@ int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string&
return 0;
}
+int readVariableTensors(SubgraphTraverser& gt, json test_desc)
+{
+ int tensorCount = gt.getNumVariableTensors();
+ Tensor* tensor;
+ char filename[1024];
+
+ try
+ {
+ if ((tensorCount != (int)test_desc["variable_name"].size()) ||
+ (tensorCount != (int)test_desc["variable_file"].size()))
+ {
+ WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount,
+ test_desc["variable_name"].size(), test_desc["variable_file"].size());
+ return 1;
+ }
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get<std::string>().c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir.c_str(),
+ test_desc["variable_file"][i].get<std::string>().c_str());
+
+ DEBUG_MED(GT, "Loading variable tensor %s from filename: %s", tensor->getName().c_str(), filename);
+
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readFromNpyFile(filename))
+ {
+ WARNING("Unable to read variable tensor %s from filename: %s", tensor->getName().c_str(), filename);
+ tensor->dumpTensorParams(g_func_debug.func_debug_file);
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+ }
+ catch (nlohmann::json::type_error& e)
+ {
+ WARNING("Fail accessing test descriptor: %s", e.what());
+ return 1;
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int writeVariableTensors(SubgraphTraverser& gt, json test_desc)
+{
+ int tensorCount = gt.getNumVariableTensors();
+ const Tensor* tensor;
+ char filename[1024];
+
+ try
+ {
+ if ((tensorCount != (int)test_desc["variable_name"].size()) ||
+ (tensorCount != (int)test_desc["variable_file"].size()))
+ {
+ WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount,
+ test_desc["variable_name"].size(), test_desc["variable_file"].size());
+ return 1;
+ }
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get<std::string>());
+ if (!tensor)
+ {
+ WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get<std::string>().c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(),
+ test_desc["variable_file"][i].get<std::string>().c_str());
+
+ DEBUG_MED(GT, "Writing variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is no longer allocated", tensor->getName().c_str());
+ return 1;
+ }
+ if (tensor->writeToNpyFile(filename))
+ {
+ WARNING("Unable to write variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(),
+ filename);
+ return 1;
+ }
+ }
+ }
+ catch (nlohmann::json::type_error& e)
+ {
+ WARNING("Fail accessing test descriptor: %s", e.what());
+ return 1;
+ }
+
+ return 0;
+}
+
// Read "foo,bar,..." and return std::vector({foo, bar, ...})
std::vector<std::string> parseFromString(std::string raw_str)
{
@@ -489,6 +623,22 @@ int initTestDesc(json& test_desc)
test_desc["ofm_file"] = ofm_file_vec;
}
+ // Overwrite test_desc["variable_name"] if --variable_name= specified.
+ std::string variable_name_str(g_func_config.variable_name);
+ if (!variable_name_str.empty())
+ {
+ std::vector<std::string> variable_name_vec = parseFromString(variable_name_str);
+ test_desc["variable_name"] = variable_name_vec;
+ }
+
+ // Overwrite test_desc["variable_file"] if --variable_file= specified.
+ std::string variable_file_str(g_func_config.variable_file);
+ if (!variable_file_str.empty())
+ {
+ std::vector<std::string> variable_file_vec = parseFromString(variable_file_str);
+ test_desc["variable_file"] = variable_file_vec;
+ }
+
return 0;
}
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 186cb8b..745213e 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -64,6 +64,11 @@ SubgraphTraverser::~SubgraphTraverser()
for (TosaReference::Tensor* t : tensors)
{
+ if (t->getIsVariable() && parent_sgt)
+ {
+ // variable tensors are owned by top level sgt
+ continue;
+ }
if (t->is_allocated())
{
t->deallocate();
@@ -119,6 +124,51 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin
return nullptr;
}
+int SubgraphTraverser::getNumVariableTensors() const
+{
+ return variableTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const
+{
+ return variableTensors[idx];
+}
+
+// find variable tensor by name in top level sgt's @a variableTensors
+TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const
+{
+ // variable tensors are owned by top level sgt
+ if (parent_sgt)
+ {
+ return parent_sgt->getVariableTensorByName(name);
+ }
+
+ for (auto t : variableTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+// add variable tensor to top level sgt's @a variableTensors
+int SubgraphTraverser::registerVariableTensor(Tensor* tensor)
+{
+ SUBGRAPH_ERROR_IF(!tensor->getIsVariable(),
+ "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable",
+ tensor->getName().c_str());
+ // variable tensors are owned by top level sgt
+ if (parent_sgt)
+ {
+ return parent_sgt->registerVariableTensor(tensor);
+ }
+ variableTensors.push_back(tensor);
+ return 0;
+}
+
int SubgraphTraverser::initializeGraph()
{
int idx = 0;
@@ -321,19 +371,18 @@ int SubgraphTraverser::initializeGraph()
non_const_node_vec.push_back(node);
}
+ // Bug fix: add the ready node in main block for evaluation
+ if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated())
+ {
+ addToNextNodeList(node);
+ }
+
idx++;
}
for (auto ts : block->GetTensors())
{
- DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
- TosaReference::Tensor* tensor =
- TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
-
- SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
- ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
-
- addTensor(tensor);
+ addTensor(ts);
}
DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
@@ -406,6 +455,22 @@ int SubgraphTraverser::allocateInputTensors()
this->allocateTensor(input_tensor_name);
}
+ // allocate variable tensors if not already allocated
+ for (auto ts : block->GetTensors())
+ {
+ if (ts->GetVariable())
+ {
+ TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.",
+ ts->GetName().c_str());
+ if (!tensor->is_allocated())
+ {
+ DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str());
+ this->allocateTensor(ts->GetName());
+ }
+ }
+ }
+
return 0;
}
@@ -447,6 +512,8 @@ int SubgraphTraverser::allocateTensor(std::string name)
if (!ts->GetData().empty())
{
+ if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
+ return 0;
DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
auto serialization_dtype = ts->GetDtype();
switch (serialization_dtype)
@@ -549,8 +616,16 @@ int SubgraphTraverser::allocateTensor(std::string name)
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
EnumNameDType(ts->GetDtype()));
}
+ tensor->setIsValid();
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
+ {
+ addToNextNodeList(gn);
+ }
+ }
}
-
return 0;
}
@@ -619,6 +694,8 @@ int SubgraphTraverser::evaluateNextNode()
return 1;
}
+ currNode->setEvaluated();
+
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
for (auto tensor : currNode->getInputs())
{
@@ -631,6 +708,12 @@ int SubgraphTraverser::evaluateNextNode()
continue;
}
+ if (tensor->getIsVariable())
+ {
+ // if tensor is a Variable, we cannot free it
+ continue;
+ }
+
for (auto node : tensor->getConsumers())
{
// If the node is inside a loop, the input tensor is still needed
@@ -660,7 +743,7 @@ int SubgraphTraverser::evaluateNextNode()
{
for (GraphNode* node : tensor->getConsumers())
{
- if (!node->getOnNextNodeList() && node->hasAllInputsReady())
+ if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated())
{
addToNextNodeList(node);
}
@@ -716,8 +799,31 @@ int SubgraphTraverser::clearAllNodeMarkings()
return false;
}
-int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
+int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts)
{
+ TosaReference::Tensor* tensor = nullptr;
+
+ // variable tensors are shared: make new tensor only if not found
+ if (ts->GetVariable())
+ {
+ tensor = getVariableTensorByName(ts->GetName());
+ }
+
+ if (!tensor)
+ {
+ DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+ tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
+
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
+ ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
+
+ if (ts->GetVariable())
+ {
+ tensor->setIsVariable();
+ registerVariableTensor(tensor);
+ }
+ }
+
// Enforce no duplicate tensors/tensor names
// O(N), but the number of tensors is small
for (TosaReference::Tensor* currTensor : tensors)
@@ -751,7 +857,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode)
{
if (currNode == newNode)
{
- FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph");
+ FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph");
return 1;
}
}
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index ef6ea42..d6b0e8d 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -74,10 +74,14 @@ public:
int getNumOutputTensors() const;
Tensor* getOutputTensor(const unsigned int idx) const;
Tensor* getOutputTensorByName(const std::string name) const;
+ int getNumVariableTensors() const;
+ Tensor* getVariableTensor(const unsigned int idx) const;
+ Tensor* getVariableTensorByName(const std::string name) const;
+ int registerVariableTensor(Tensor* tensor);
int addToNextNodeList(GraphNode*);
private:
- int addTensor(Tensor* ct);
+ int addTensor(const TosaSerializationTensor* ts);
int addNode(GraphNode* cn);
Tensor* findTensorByName(const std::string& name) const;
@@ -103,6 +107,9 @@ private:
// The subset of tensors that are also output tensors
std::vector<Tensor*> outputTensors;
+ // The subset of tensors that are also variable tensors
+ std::vector<Tensor*> variableTensors;
+
// The definitive list of all nodes in the graph
std::vector<GraphNode*> nodes;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 1aabe5b..5fffa8a 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -36,6 +36,7 @@ TosaReference::Tensor::Tensor(const std::string tensorName_,
isSubgraphInput = false;
isSubgraphOutput = false;
isParentGraphOutput = false;
+ isVariable = false;
}
TosaReference::Tensor::~Tensor()
@@ -59,6 +60,12 @@ int TosaReference::Tensor::setIsSubgraphOutput()
return 0;
}
+int TosaReference::Tensor::setIsVariable()
+{
+ isVariable = true;
+ return 0;
+}
+
int TosaReference::Tensor::setProducer(GraphNode* node)
{
ASSERT_MSG(node, "Tensor::setProducer: no node passed in");
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index fe7336d..203cfec 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -42,21 +42,27 @@ public:
int setIsSubgraphOutput();
int setIsParentGraphOutput();
- int getIsParentGraphOutput() const
+ bool getIsParentGraphOutput() const
{
return isParentGraphOutput;
}
+ int setIsVariable();
- int getIsSubgraphInput() const
+ bool getIsSubgraphInput() const
{
return isSubgraphInput;
}
- int getIsSubgraphOutput() const
+ bool getIsSubgraphOutput() const
{
return isSubgraphOutput;
}
+ bool getIsVariable() const
+ {
+ return isVariable;
+ }
+
int setProducer(GraphNode* node);
int addConsumer(GraphNode* node);
@@ -269,18 +275,19 @@ public:
return in ? true_str : false_str;
}
- virtual int allocate() = 0;
- virtual int deallocate() = 0;
- virtual bool is_allocated() = 0;
+ virtual int allocate() = 0;
+ virtual int deallocate() = 0;
+ virtual bool is_allocated() const = 0;
protected:
const std::string tensorName;
const DType serializationDtype;
std::vector<int> shape;
const TOSA_REF_TYPE tensorDtype;
- int isValid;
- int isSubgraphInput;
- int isSubgraphOutput;
+ bool isValid;
+ bool isSubgraphInput;
+ bool isSubgraphOutput;
+ bool isVariable;
bool isAllocated;
bool isParentGraphOutput;
@@ -332,7 +339,7 @@ public:
return 0;
}
- virtual bool is_allocated()
+ virtual bool is_allocated() const
{
if (tensor)
{
diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py
index fcd72a3..3554e40 100644
--- a/verif/frameworks/test_builder.py
+++ b/verif/frameworks/test_builder.py
@@ -1175,7 +1175,7 @@ class TBuilder:
return result[0]
- class LSTM:
+ class LSTM(tf.Module):
def __init__(self, name):
self.result_name = name
self.lstm = tf.keras.layers.LSTM(
@@ -1191,6 +1191,23 @@ class TBuilder:
def eval(self, a):
return self.lstm(a)
+ class SLSTM(tf.Module):
+ def __init__(self, name):
+ self.result_name = name
+ self.lstm = tf.keras.layers.LSTM(
+ 2,
+ stateful=True,
+ activation="tanh",
+ unroll=False,
+ recurrent_activation="sigmoid",
+ use_bias=True,
+ recurrent_initializer="ones",
+ kernel_initializer="ones",
+ )
+
+ def eval(self, a):
+ return self.lstm(a)
+
class GRU:
def __init__(self, name):
self.result_name = name
@@ -1256,3 +1273,22 @@ class TBuilder:
def eval(self, a):
return tf.broadcast_to(a, shape=self.shape, name=self.result_name)
+
+ class CallOnce(tf.Module):
+ def __init__(self, name):
+ print(tf.__version__)
+ self.result_name = name
+ self.var = tf.Variable([1.0])
+
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(
+ shape=[
+ 1,
+ ],
+ dtype=tf.float32,
+ )
+ ]
+ )
+ def eval(self, a):
+ return self.var.assign([2.0])
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index ec009c6..ffe373b 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402
get_tf_dtype,
get_shape_str,
) # noqa: E402
+
from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402
# All of the supported frameworks
@@ -829,6 +830,15 @@ TF_OP_LIST = {
]
},
},
+ "lstm_stateful": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone),
+ "types": {
+ "tflite": [
+ tf.float32,
+ ]
+ },
+ },
"gru": {
"operands": (1, 0),
"build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone),
@@ -848,6 +858,17 @@ TF_OP_LIST = {
]
},
},
+ "callonce": {
+ "operands": (1, 0),
+ "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone),
+ "types": {
+ "tflite": [tf.float32],
+ },
+ "custom_shapes": {
+ "custom_shape_only": True,
+ "shape_list": [(1,)],
+ },
+ },
"rfft2d": {
"operands": (1, 0),
"build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d),
@@ -1219,9 +1240,15 @@ def run_unit_test(
if "tflite" not in excluded_framework_list:
# Convert the model to TFLite flatbuffer
module = tf.Module()
- converter = tf.lite.TFLiteConverter.from_concrete_functions(
- [concrete_function], module
- )
+
+ if op_name == "callonce" or op_name == "lstm_stateful":
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], fcn_node
+ )
+ else:
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
+ [concrete_function], module
+ )
converter.experimental_new_converter = True