From 13c78a67a6a3d743352f0b6e349c52bf36e84468 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 4 Oct 2022 20:32:39 -0700 Subject: [region] Add TosaSerializationRegion to serialization_lib - Rationale: add this fix to support constants access between multiple blocks by another layer of abstraction called Region - Changes: - flatbuffers schema update, regenerate header files - add TosaSerializationRegion for the handler - other relevant fixes Signed-off-by: Jerry Ge Change-Id: I4bb72503abfd629ae017d2f905184efbab244aa8 --- include/tosa_generated.h | 95 +++++++++++++++++++++++++++++++----- include/tosa_serialization_handler.h | 93 ++++++++++++++++++++++++----------- 2 files changed, 147 insertions(+), 41 deletions(-) (limited to 'include') diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 123c4f6..4d231b0 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -80,6 +80,9 @@ struct TosaOperatorBuilder; struct TosaBasicBlock; struct TosaBasicBlockBuilder; +struct TosaRegion; +struct TosaRegionBuilder; + struct TosaGraph; struct TosaGraphBuilder; @@ -2502,25 +2505,91 @@ inline flatbuffers::Offset CreateTosaBasicBlockDirect( outputs__); } +struct TosaRegion FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaRegionBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_BLOCKS = 6 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *blocks() const { + return GetPointer> *>(VT_BLOCKS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_BLOCKS) && + verifier.VerifyVector(blocks()) && + verifier.VerifyVectorOfTables(blocks()) && + verifier.EndTable(); + } +}; + +struct TosaRegionBuilder { + typedef TosaRegion Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(TosaRegion::VT_NAME, name); + } + void add_blocks(flatbuffers::Offset>> blocks) { + fbb_.AddOffset(TosaRegion::VT_BLOCKS, blocks); + } + explicit TosaRegionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaRegion( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> blocks = 0) { + TosaRegionBuilder builder_(_fbb); + builder_.add_blocks(blocks); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaRegionDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector> *blocks = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto blocks__ = blocks ? _fbb.CreateVector>(*blocks) : 0; + return tosa::CreateTosaRegion( + _fbb, + name__, + blocks__); +} + struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TosaGraphBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_VERSION = 4, - VT_BLOCKS = 6 + VT_REGIONS = 6 }; const tosa::Version *version() const { return GetPointer(VT_VERSION); } - const flatbuffers::Vector> *blocks() const { - return GetPointer> *>(VT_BLOCKS); + const flatbuffers::Vector> *regions() const { + return GetPointer> *>(VT_REGIONS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VERSION) && verifier.VerifyTable(version()) && - VerifyOffset(verifier, VT_BLOCKS) && - verifier.VerifyVector(blocks()) && - verifier.VerifyVectorOfTables(blocks()) && + VerifyOffset(verifier, VT_REGIONS) && + verifier.VerifyVector(regions()) && + verifier.VerifyVectorOfTables(regions()) && verifier.EndTable(); } }; @@ -2532,8 +2601,8 @@ struct TosaGraphBuilder { void add_version(flatbuffers::Offset version) { fbb_.AddOffset(TosaGraph::VT_VERSION, version); } - void add_blocks(flatbuffers::Offset>> blocks) { - fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks); + void add_regions(flatbuffers::Offset>> regions) { + fbb_.AddOffset(TosaGraph::VT_REGIONS, regions); } explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -2549,9 +2618,9 @@ struct TosaGraphBuilder { inline flatbuffers::Offset CreateTosaGraph( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset version = 0, - flatbuffers::Offset>> blocks = 0) { + flatbuffers::Offset>> regions = 0) { TosaGraphBuilder builder_(_fbb); - builder_.add_blocks(blocks); + builder_.add_regions(regions); builder_.add_version(version); return builder_.Finish(); } @@ -2559,12 +2628,12 @@ inline flatbuffers::Offset CreateTosaGraph( inline flatbuffers::Offset CreateTosaGraphDirect( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset version = 0, - const std::vector> *blocks = nullptr) { - auto blocks__ = blocks ? _fbb.CreateVector>(*blocks) : 0; + const std::vector> *regions = nullptr) { + auto regions__ = regions ? _fbb.CreateVector>(*regions) : 0; return tosa::CreateTosaGraph( _fbb, version, - blocks__); + regions__); } inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 4cda830..0aa16df 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -216,11 +216,13 @@ class TosaSerializationBasicBlock public: // constructor and destructor TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector& operators, const std::vector& tensors, const std::vector& inputs, const std::vector& outputs); TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector&& operators, std::vector&& tensors, std::vector&& inputs, @@ -232,10 +234,15 @@ public: { return _name; } + std::string GetRegionName() const + { + return _region_name; + } std::vector& GetOperators() { return _operators; } + std::vector& GetTensors() { return _tensors; @@ -259,19 +266,59 @@ public: { return _inputs; } + std::vector& GetOutputs() { return _outputs; } private: - std::string _name; /* name of basic block */ + std::string _name; /* name of basic block */ + std::string _region_name; std::vector _operators; /* TosaSerializationOperator list */ std::vector _tensors; /* TosaSerializationTensor list */ std::vector _inputs; /* array of string to specify block inputs */ std::vector _outputs; /* array of string to specify block outputs */ }; +class TosaSerializationRegion +{ +public: + // constructor and desctructor + TosaSerializationRegion(const std::string& name, const std::vector& blocks); + TosaSerializationRegion(const std::string&& name, const std::vector&& blocks); + ~TosaSerializationRegion(); + + // accessors + std::string GetName() const + { + return this->_name; + } + + std::vector& GetBlocks() + { + return this->_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + +private: + std::string _name; /* name of basic block */ + std::vector _blocks; /* TosaSerializationBasicBlock list */ +}; + /* * this is a helper class for writing/reading Tosa ISA * supported format: .tosa (flatbuffer), .json @@ -319,39 +366,29 @@ public: } // accessor - std::vector& GetBlocks() + std::vector& GetRegions() { - return _blocks; + return _regions; } - TosaSerializationBasicBlock* GetBlockByName(std::string name) + TosaSerializationRegion* GetMainRegion() { - TosaSerializationBasicBlock* result = nullptr; - for (auto block : GetBlocks()) + return _regions[0]; + } + + TosaSerializationRegion* GetRegionByName(std::string name) + { + TosaSerializationRegion* result = nullptr; + for (auto region : GetRegions()) { - if (block->GetName() == name) + if (region->GetName() == name) { - result = block; + result = region; break; } } return result; } - TosaSerializationBasicBlock* GetMainBlock() - { - TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); - assert(main_block); - return main_block; - } - - std::vector& GetInputs() - { - return GetMainBlock()->GetInputs(); - } - std::vector& GetOutputs() - { - return GetMainBlock()->GetOutputs(); - } bool GetSchemaLoaded() const { @@ -365,11 +402,11 @@ protected: TosaVersion ParseTosaSchemaVersion(std::string schema); private: - TosaVersion _version; /* version struct */ - flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ - flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ - std::vector _blocks; /* array structure to store all TosaSerializationBasicBlock */ - bool _schemaLoaded; /* is the schema properly loaded? */ + TosaVersion _version; /* version struct */ + flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ + flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ + std::vector _regions; /* array structure to store all TosaSerializationRegion */ + bool _schemaLoaded; /* is the schema properly loaded? */ }; } // namespace tosa -- cgit v1.2.1