From 4762564da970eb1883a54aa66582e05c0dbd2b81 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 7 Sep 2023 20:49:09 +0000 Subject: [reference_model] Support StatefulOps and the tests for CallOnceOp Signed-off-by: Jerry Ge Change-Id: I03cb878736ccd7e1f5e1f780d7171949a19a9de2 --- examples/test_stateful_op/desc.json | 16 +++ examples/test_stateful_op/placeholder_0.npy | Bin 0 -> 132 bytes examples/test_stateful_op/test_variable_add.tosa | Bin 0 -> 616 bytes examples/test_stateful_op/variable_0.npy | Bin 0 -> 132 bytes reference_model/include/func_config.h | 40 +++--- reference_model/src/command_line_utils.h | 5 + reference_model/src/graph_node.cc | 14 +- reference_model/src/graph_node.h | 20 +++ reference_model/src/main.cpp | 152 ++++++++++++++++++++- reference_model/src/subgraph_traverser.cc | 130 ++++++++++++++++-- reference_model/src/subgraph_traverser.h | 9 +- reference_model/src/tensor.cc | 7 + reference_model/src/tensor.h | 27 ++-- verif/frameworks/test_builder.py | 38 +++++- verif/frameworks/tosa_verif_framework_generator.py | 33 ++++- 15 files changed, 442 insertions(+), 49 deletions(-) create mode 100644 examples/test_stateful_op/desc.json create mode 100644 examples/test_stateful_op/placeholder_0.npy create mode 100644 examples/test_stateful_op/test_variable_add.tosa create mode 100644 examples/test_stateful_op/variable_0.npy diff --git a/examples/test_stateful_op/desc.json b/examples/test_stateful_op/desc.json new file mode 100644 index 0000000..1f1459f --- /dev/null +++ b/examples/test_stateful_op/desc.json @@ -0,0 +1,16 @@ +{ + "tosa_file": "test_variable_add.tosa", + "ifm_name": [ + "TosaInput_0" + ], + "ifm_file": [ + "placeholder_0.npy" + ], + "variable_name": [ + "Variable_0" + ], + "variable_file": [ + "variable_0.npy" + ], + "expected_failure": false +} \ No newline at end of file diff --git a/examples/test_stateful_op/placeholder_0.npy b/examples/test_stateful_op/placeholder_0.npy new file mode 100644 index 0000000..f9688e6 Binary files /dev/null and b/examples/test_stateful_op/placeholder_0.npy differ diff --git a/examples/test_stateful_op/test_variable_add.tosa b/examples/test_stateful_op/test_variable_add.tosa new file mode 100644 index 0000000..2be3044 Binary files /dev/null and b/examples/test_stateful_op/test_variable_add.tosa differ diff --git a/examples/test_stateful_op/variable_0.npy b/examples/test_stateful_op/variable_0.npy new file mode 100644 index 0000000..c183a15 Binary files /dev/null and b/examples/test_stateful_op/variable_0.npy differ diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h index 1e93b89..860d8c6 100644 --- a/reference_model/include/func_config.h +++ b/reference_model/include/func_config.h @@ -35,24 +35,28 @@ struct tosa_level_t struct func_config_t { - std::string operator_fbs = "tosa.fbs"; - std::string test_desc = "desc.json"; - std::string flatbuffer_dir = ""; - std::string output_dir = ""; - std::string tosa_file = ""; - std::string ifm_name = ""; - std::string ifm_file = ""; - std::string ofm_name = ""; - std::string ofm_file = ""; - uint32_t eval = 1; - uint32_t validate_only = 0; - uint32_t output_tensors = 1; - uint32_t tosa_profile = 1; - uint32_t dump_intermediates = 0; - std::string fp_format = "0.5"; - uint32_t precise_mode = 0; - bool abs_mode = 0; // set in main as second run of precise_mode - bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() + std::string operator_fbs = "tosa.fbs"; + std::string test_desc = "desc.json"; + std::string flatbuffer_dir = ""; + std::string output_dir = ""; + std::string tosa_file = ""; + std::string ifm_name = ""; + std::string ifm_file = ""; + std::string ofm_name = ""; + std::string ofm_file = ""; + std::string variable_name = ""; + std::string variable_file = ""; + + uint32_t eval = 1; + uint32_t validate_only = 0; + uint32_t output_tensors = 1; + uint32_t tosa_profile = 1; + uint32_t dump_intermediates = 0; + uint32_t initialize_variable_tensor_from_numpy = 0; + std::string fp_format = "0.5"; + uint32_t precise_mode = 0; + bool abs_mode = 0; // set in main as second run of precise_mode + bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() tosa_level_t tosa_level; static constexpr tosa_level_t EIGHTK = { 6, 8192, 8192, 64 }; diff --git a/reference_model/src/command_line_utils.h b/reference_model/src/command_line_utils.h index dcb0564..f8031d9 100644 --- a/reference_model/src/command_line_utils.h +++ b/reference_model/src/command_line_utils.h @@ -48,6 +48,10 @@ int func_model_parse_cmd_line( cxxopts::value(func_config.ofm_name)) ("ofm_file", "Output tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", cxxopts::value(func_config.ofm_file)) + ("variable_name", "Region tensor name. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.variable_name)) + ("variable_file", "Region tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.variable_file)) ("eval", "Evaluate the network (0/1)", cxxopts::value(func_config.eval)) ("fp_format", "Floating-point number dump format string (printf-style format, e.g. 0.5)", cxxopts::value(func_config.fp_format)) @@ -60,6 +64,7 @@ int func_model_parse_cmd_line( cxxopts::value(func_config.tosa_level)) ("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value(func_config.dump_intermediates)) ("p,precise_mode", "Calculate floating point operations in FP64 (0/1)", cxxopts::value(func_config.precise_mode)) + ("initialize_variable_tensor_from_numpy", "Initialize variable tensors from flatbuffer (0, default) or numpy (1)", cxxopts::value(func_config.initialize_variable_tensor_from_numpy)) ("v,version", "print model version") ("i,input_tensor_file", "specify input tensor files", cxxopts::value>()) ("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value()) diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc index 1781e40..c8a0b29 100644 --- a/reference_model/src/graph_node.cc +++ b/reference_model/src/graph_node.cc @@ -31,6 +31,7 @@ GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const clearNodeMarked(); evalCount = 0; clearOnNextNodeList(); + clearEvaluated(); setRequiredOperands(-1, -1); setRequiredRank(-1); inMainBlock = false; @@ -102,6 +103,12 @@ int GraphNode::hasAllOutputsReady() const { if (!outputs[i]->getIsValid()) return false; + if (outputs[i]->getIsVariable()) + { + // when output is a variable tensor + // isValid is not reliable indicator of this node having been evaluated + return false; + } } return true; @@ -110,8 +117,8 @@ int GraphNode::hasAllOutputsReady() const int GraphNode::dumpNode(FILE* out) { int i; - fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType], - nodeId, evalCount, onNextNodeList, isMarked); + fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Evaluated: %d Is marked: %d\n", + EnumNamesOp()[nodeType], nodeId, evalCount, onNextNodeList, evaluated, isMarked); i = 0; for (Tensor* ins : inputs) @@ -135,7 +142,8 @@ int GraphNode::dumpNode(std::ostream& out) int i; out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount - << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl; + << " On next node list: " << onNextNodeList << "Evaluated: " << evaluated << " Is marked: " << isMarked + << std::endl; out << " Inputs:"; for (std::string& name : inputNames) diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index 900f4b8..b3fe8d6 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -254,6 +254,23 @@ public: return inMainBlock; } + int getEvaluated() const + { + return evaluated; + } + + int setEvaluated() + { + evaluated = true; + return 0; + } + + int clearEvaluated() + { + evaluated = false; + return 0; + } + // Helper functions. int idiv_check(int input1, int input2, int& result); @@ -317,6 +334,9 @@ protected: // next-node list. int onNextNodeList; + // Flag indicating that this node has been evaluated before + int evaluated; + // Required input/output tensor counts for node validation // -1 means any number is allowed int requiredInputCount; diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index cb7f0a2..62b8f6f 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -33,8 +33,11 @@ using namespace tosa; using json = nlohmann::json; int initTestDesc(json& test_desc); + int readInputTensors(SubgraphTraverser& gt, json& test_desc); int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& filename_prefix); +int readVariableTensors(SubgraphTraverser& gt, json test_desc); +int writeVariableTensors(SubgraphTraverser& gt, json test_desc); int loadGraph(TosaSerializationHandler& tsh, json& test_desc); void parse_value(const std::string& text, tosa_level_t& value); const std::string getResultFilenamePrefix(); @@ -131,6 +134,14 @@ int main(int argc, char** argv) goto done; } + if (g_func_config.initialize_variable_tensor_from_numpy) + { + if (readVariableTensors(main_gt, test_desc)) + { + FATAL_ERROR("Unable to read variable tensors"); + } + } + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. if (main_gt.evaluateAll()) { @@ -177,6 +188,11 @@ int main(int argc, char** argv) { WARNING("Errors encountered in saving output tensors"); } + + if (writeVariableTensors(main_gt, test_desc)) + { + WARNING("Errors encountered in writing variable tensors"); + } } } @@ -312,7 +328,7 @@ int readInputTensors(SubgraphTraverser& gt, json& test_desc) // Push ready consumers to the next node list for (auto gn : tensor->getConsumers()) { - if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) { gt.addToNextNodeList(gn); } @@ -395,6 +411,124 @@ int writeFinalTensors(SubgraphTraverser& gt, json& test_desc, const std::string& return 0; } +int readVariableTensors(SubgraphTraverser& gt, json test_desc) +{ + int tensorCount = gt.getNumVariableTensors(); + Tensor* tensor; + char filename[1024]; + + try + { + if ((tensorCount != (int)test_desc["variable_name"].size()) || + (tensorCount != (int)test_desc["variable_file"].size())) + { + WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld)in test descriptor.", tensorCount, + test_desc["variable_name"].size(), test_desc["variable_file"].size()); + return 1; + } + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get().c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.flatbuffer_dir.c_str(), + test_desc["variable_file"][i].get().c_str()); + + DEBUG_MED(GT, "Loading variable tensor %s from filename: %s", tensor->getName().c_str(), filename); + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readFromNpyFile(filename)) + { + WARNING("Unable to read variable tensor %s from filename: %s", tensor->getName().c_str(), filename); + tensor->dumpTensorParams(g_func_debug.func_debug_file); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) + { + gt.addToNextNodeList(gn); + } + } + } + } + catch (nlohmann::json::type_error& e) + { + WARNING("Fail accessing test descriptor: %s", e.what()); + return 1; + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + gt.dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int writeVariableTensors(SubgraphTraverser& gt, json test_desc) +{ + int tensorCount = gt.getNumVariableTensors(); + const Tensor* tensor; + char filename[1024]; + + try + { + if ((tensorCount != (int)test_desc["variable_name"].size()) || + (tensorCount != (int)test_desc["variable_file"].size())) + { + WARNING("Number of variable tensors(%d) doesn't match name(%ld)/file(%ld) in test descriptor.", tensorCount, + test_desc["variable_name"].size(), test_desc["variable_file"].size()); + return 1; + } + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getVariableTensorByName(test_desc["variable_name"][i].get()); + if (!tensor) + { + WARNING("Unable to find variable tensor %s", test_desc["variable_name"][i].get().c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.output_dir.c_str(), + test_desc["variable_file"][i].get().c_str()); + + DEBUG_MED(GT, "Writing variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is no longer allocated", tensor->getName().c_str()); + return 1; + } + if (tensor->writeToNpyFile(filename)) + { + WARNING("Unable to write variable tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), + filename); + return 1; + } + } + } + catch (nlohmann::json::type_error& e) + { + WARNING("Fail accessing test descriptor: %s", e.what()); + return 1; + } + + return 0; +} + // Read "foo,bar,..." and return std::vector({foo, bar, ...}) std::vector parseFromString(std::string raw_str) { @@ -489,6 +623,22 @@ int initTestDesc(json& test_desc) test_desc["ofm_file"] = ofm_file_vec; } + // Overwrite test_desc["variable_name"] if --variable_name= specified. + std::string variable_name_str(g_func_config.variable_name); + if (!variable_name_str.empty()) + { + std::vector variable_name_vec = parseFromString(variable_name_str); + test_desc["variable_name"] = variable_name_vec; + } + + // Overwrite test_desc["variable_file"] if --variable_file= specified. + std::string variable_file_str(g_func_config.variable_file); + if (!variable_file_str.empty()) + { + std::vector variable_file_vec = parseFromString(variable_file_str); + test_desc["variable_file"] = variable_file_vec; + } + return 0; } diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 186cb8b..745213e 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -64,6 +64,11 @@ SubgraphTraverser::~SubgraphTraverser() for (TosaReference::Tensor* t : tensors) { + if (t->getIsVariable() && parent_sgt) + { + // variable tensors are owned by top level sgt + continue; + } if (t->is_allocated()) { t->deallocate(); @@ -119,6 +124,51 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin return nullptr; } +int SubgraphTraverser::getNumVariableTensors() const +{ + return variableTensors.size(); +} + +TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const +{ + return variableTensors[idx]; +} + +// find variable tensor by name in top level sgt's @a variableTensors +TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const +{ + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->getVariableTensorByName(name); + } + + for (auto t : variableTensors) + { + if (t->getName() == name) + { + return t; + } + } + + return nullptr; +} + +// add variable tensor to top level sgt's @a variableTensors +int SubgraphTraverser::registerVariableTensor(Tensor* tensor) +{ + SUBGRAPH_ERROR_IF(!tensor->getIsVariable(), + "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable", + tensor->getName().c_str()); + // variable tensors are owned by top level sgt + if (parent_sgt) + { + return parent_sgt->registerVariableTensor(tensor); + } + variableTensors.push_back(tensor); + return 0; +} + int SubgraphTraverser::initializeGraph() { int idx = 0; @@ -321,19 +371,18 @@ int SubgraphTraverser::initializeGraph() non_const_node_vec.push_back(node); } + // Bug fix: add the ready node in main block for evaluation + if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated()) + { + addToNextNodeList(node); + } + idx++; } for (auto ts : block->GetTensors()) { - DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); - TosaReference::Tensor* tensor = - TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); - - addTensor(tensor); + addTensor(ts); } DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); @@ -406,6 +455,22 @@ int SubgraphTraverser::allocateInputTensors() this->allocateTensor(input_tensor_name); } + // allocate variable tensors if not already allocated + for (auto ts : block->GetTensors()) + { + if (ts->GetVariable()) + { + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.", + ts->GetName().c_str()); + if (!tensor->is_allocated()) + { + DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str()); + this->allocateTensor(ts->GetName()); + } + } + } + return 0; } @@ -447,6 +512,8 @@ int SubgraphTraverser::allocateTensor(std::string name) if (!ts->GetData().empty()) { + if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy) + return 0; DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str()); auto serialization_dtype = ts->GetDtype(); switch (serialization_dtype) @@ -549,8 +616,16 @@ int SubgraphTraverser::allocateTensor(std::string name) SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", EnumNameDType(ts->GetDtype())); } + tensor->setIsValid(); + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated()) + { + addToNextNodeList(gn); + } + } } - return 0; } @@ -619,6 +694,8 @@ int SubgraphTraverser::evaluateNextNode() return 1; } + currNode->setEvaluated(); + // free input tensor if all of its consumers have all of their outputs ready and it's not block's output for (auto tensor : currNode->getInputs()) { @@ -631,6 +708,12 @@ int SubgraphTraverser::evaluateNextNode() continue; } + if (tensor->getIsVariable()) + { + // if tensor is a Variable, we cannot free it + continue; + } + for (auto node : tensor->getConsumers()) { // If the node is inside a loop, the input tensor is still needed @@ -660,7 +743,7 @@ int SubgraphTraverser::evaluateNextNode() { for (GraphNode* node : tensor->getConsumers()) { - if (!node->getOnNextNodeList() && node->hasAllInputsReady()) + if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated()) { addToNextNodeList(node); } @@ -716,8 +799,31 @@ int SubgraphTraverser::clearAllNodeMarkings() return false; } -int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor) +int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts) { + TosaReference::Tensor* tensor = nullptr; + + // variable tensors are shared: make new tensor only if not found + if (ts->GetVariable()) + { + tensor = getVariableTensorByName(ts->GetName()); + } + + if (!tensor) + { + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); + + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", + ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); + + if (ts->GetVariable()) + { + tensor->setIsVariable(); + registerVariableTensor(tensor); + } + } + // Enforce no duplicate tensors/tensor names // O(N), but the number of tensors is small for (TosaReference::Tensor* currTensor : tensors) @@ -751,7 +857,7 @@ int SubgraphTraverser::addNode(GraphNode* newNode) { if (currNode == newNode) { - FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph"); + FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph"); return 1; } } diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index ef6ea42..d6b0e8d 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -74,10 +74,14 @@ public: int getNumOutputTensors() const; Tensor* getOutputTensor(const unsigned int idx) const; Tensor* getOutputTensorByName(const std::string name) const; + int getNumVariableTensors() const; + Tensor* getVariableTensor(const unsigned int idx) const; + Tensor* getVariableTensorByName(const std::string name) const; + int registerVariableTensor(Tensor* tensor); int addToNextNodeList(GraphNode*); private: - int addTensor(Tensor* ct); + int addTensor(const TosaSerializationTensor* ts); int addNode(GraphNode* cn); Tensor* findTensorByName(const std::string& name) const; @@ -103,6 +107,9 @@ private: // The subset of tensors that are also output tensors std::vector outputTensors; + // The subset of tensors that are also variable tensors + std::vector variableTensors; + // The definitive list of all nodes in the graph std::vector nodes; diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 1aabe5b..5fffa8a 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -36,6 +36,7 @@ TosaReference::Tensor::Tensor(const std::string tensorName_, isSubgraphInput = false; isSubgraphOutput = false; isParentGraphOutput = false; + isVariable = false; } TosaReference::Tensor::~Tensor() @@ -59,6 +60,12 @@ int TosaReference::Tensor::setIsSubgraphOutput() return 0; } +int TosaReference::Tensor::setIsVariable() +{ + isVariable = true; + return 0; +} + int TosaReference::Tensor::setProducer(GraphNode* node) { ASSERT_MSG(node, "Tensor::setProducer: no node passed in"); diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index fe7336d..203cfec 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -42,21 +42,27 @@ public: int setIsSubgraphOutput(); int setIsParentGraphOutput(); - int getIsParentGraphOutput() const + bool getIsParentGraphOutput() const { return isParentGraphOutput; } + int setIsVariable(); - int getIsSubgraphInput() const + bool getIsSubgraphInput() const { return isSubgraphInput; } - int getIsSubgraphOutput() const + bool getIsSubgraphOutput() const { return isSubgraphOutput; } + bool getIsVariable() const + { + return isVariable; + } + int setProducer(GraphNode* node); int addConsumer(GraphNode* node); @@ -269,18 +275,19 @@ public: return in ? true_str : false_str; } - virtual int allocate() = 0; - virtual int deallocate() = 0; - virtual bool is_allocated() = 0; + virtual int allocate() = 0; + virtual int deallocate() = 0; + virtual bool is_allocated() const = 0; protected: const std::string tensorName; const DType serializationDtype; std::vector shape; const TOSA_REF_TYPE tensorDtype; - int isValid; - int isSubgraphInput; - int isSubgraphOutput; + bool isValid; + bool isSubgraphInput; + bool isSubgraphOutput; + bool isVariable; bool isAllocated; bool isParentGraphOutput; @@ -332,7 +339,7 @@ public: return 0; } - virtual bool is_allocated() + virtual bool is_allocated() const { if (tensor) { diff --git a/verif/frameworks/test_builder.py b/verif/frameworks/test_builder.py index fcd72a3..3554e40 100644 --- a/verif/frameworks/test_builder.py +++ b/verif/frameworks/test_builder.py @@ -1175,7 +1175,7 @@ class TBuilder: return result[0] - class LSTM: + class LSTM(tf.Module): def __init__(self, name): self.result_name = name self.lstm = tf.keras.layers.LSTM( @@ -1191,6 +1191,23 @@ class TBuilder: def eval(self, a): return self.lstm(a) + class SLSTM(tf.Module): + def __init__(self, name): + self.result_name = name + self.lstm = tf.keras.layers.LSTM( + 2, + stateful=True, + activation="tanh", + unroll=False, + recurrent_activation="sigmoid", + use_bias=True, + recurrent_initializer="ones", + kernel_initializer="ones", + ) + + def eval(self, a): + return self.lstm(a) + class GRU: def __init__(self, name): self.result_name = name @@ -1256,3 +1273,22 @@ class TBuilder: def eval(self, a): return tf.broadcast_to(a, shape=self.shape, name=self.result_name) + + class CallOnce(tf.Module): + def __init__(self, name): + print(tf.__version__) + self.result_name = name + self.var = tf.Variable([1.0]) + + @tf.function( + input_signature=[ + tf.TensorSpec( + shape=[ + 1, + ], + dtype=tf.float32, + ) + ] + ) + def eval(self, a): + return self.var.assign([2.0]) diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py index ec009c6..ffe373b 100755 --- a/verif/frameworks/tosa_verif_framework_generator.py +++ b/verif/frameworks/tosa_verif_framework_generator.py @@ -28,6 +28,7 @@ from frameworks.test_gen_utils import ( # noqa: E402 get_tf_dtype, get_shape_str, ) # noqa: E402 + from tensorflow.lite.python.interpreter import OpResolverType # noqa: E402 # All of the supported frameworks @@ -829,6 +830,15 @@ TF_OP_LIST = { ] }, }, + "lstm_stateful": { + "operands": (1, 0), + "build_fcn": (TBuilder.SLSTM, TGen.tgRecurrent, ArgGen.agNone), + "types": { + "tflite": [ + tf.float32, + ] + }, + }, "gru": { "operands": (1, 0), "build_fcn": (TBuilder.GRU, TGen.tgRecurrent, ArgGen.agNone), @@ -848,6 +858,17 @@ TF_OP_LIST = { ] }, }, + "callonce": { + "operands": (1, 0), + "build_fcn": (TBuilder.CallOnce, TGen.tgBasic, ArgGen.agNone), + "types": { + "tflite": [tf.float32], + }, + "custom_shapes": { + "custom_shape_only": True, + "shape_list": [(1,)], + }, + }, "rfft2d": { "operands": (1, 0), "build_fcn": (TBuilder.RFFT2d, TGen.tgRFFT2d, ArgGen.agRFFT2d), @@ -1219,9 +1240,15 @@ def run_unit_test( if "tflite" not in excluded_framework_list: # Convert the model to TFLite flatbuffer module = tf.Module() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function], module - ) + + if op_name == "callonce" or op_name == "lstm_stateful": + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], fcn_node + ) + else: + converter = tf.lite.TFLiteConverter.from_concrete_functions( + [concrete_function], module + ) converter.experimental_new_converter = True -- cgit v1.2.1