diff options
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r-- | include/tosa_serialization_handler.h | 93 |
1 files changed, 65 insertions, 28 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 4cda830..0aa16df 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -216,11 +216,13 @@ class TosaSerializationBasicBlock public: // constructor and destructor TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector<TosaSerializationOperator*>& operators, const std::vector<TosaSerializationTensor*>& tensors, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs); TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector<TosaSerializationOperator*>&& operators, std::vector<TosaSerializationTensor*>&& tensors, std::vector<std::string>&& inputs, @@ -232,10 +234,15 @@ public: { return _name; } + std::string GetRegionName() const + { + return _region_name; + } std::vector<TosaSerializationOperator*>& GetOperators() { return _operators; } + std::vector<TosaSerializationTensor*>& GetTensors() { return _tensors; @@ -259,19 +266,59 @@ public: { return _inputs; } + std::vector<std::string>& GetOutputs() { return _outputs; } private: - std::string _name; /* name of basic block */ + std::string _name; /* name of basic block */ + std::string _region_name; std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */ std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */ std::vector<std::string> _inputs; /* array of string to specify block inputs */ std::vector<std::string> _outputs; /* array of string to specify block outputs */ }; +class TosaSerializationRegion +{ +public: + // constructor and desctructor + TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks); + TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks); + ~TosaSerializationRegion(); + + // accessors + std::string GetName() const + { + return this->_name; + } + + std::vector<TosaSerializationBasicBlock*>& GetBlocks() + { + return this->_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + +private: + std::string _name; /* name of basic block */ + std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */ +}; + /* * this is a helper class for writing/reading Tosa ISA * supported format: .tosa (flatbuffer), .json @@ -319,39 +366,29 @@ public: } // accessor - std::vector<TosaSerializationBasicBlock*>& GetBlocks() + std::vector<TosaSerializationRegion*>& GetRegions() { - return _blocks; + return _regions; } - TosaSerializationBasicBlock* GetBlockByName(std::string name) + TosaSerializationRegion* GetMainRegion() { - TosaSerializationBasicBlock* result = nullptr; - for (auto block : GetBlocks()) + return _regions[0]; + } + + TosaSerializationRegion* GetRegionByName(std::string name) + { + TosaSerializationRegion* result = nullptr; + for (auto region : GetRegions()) { - if (block->GetName() == name) + if (region->GetName() == name) { - result = block; + result = region; break; } } return result; } - TosaSerializationBasicBlock* GetMainBlock() - { - TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); - assert(main_block); - return main_block; - } - - std::vector<std::string>& GetInputs() - { - return GetMainBlock()->GetInputs(); - } - std::vector<std::string>& GetOutputs() - { - return GetMainBlock()->GetOutputs(); - } bool GetSchemaLoaded() const { @@ -365,11 +402,11 @@ protected: TosaVersion ParseTosaSchemaVersion(std::string schema); private: - TosaVersion _version; /* version struct */ - flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ - flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ - std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */ - bool _schemaLoaded; /* is the schema properly loaded? */ + TosaVersion _version; /* version struct */ + flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ + flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ + std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */ + bool _schemaLoaded; /* is the schema properly loaded? */ }; } // namespace tosa |