From 13c78a67a6a3d743352f0b6e349c52bf36e84468 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 4 Oct 2022 20:32:39 -0700 Subject: [region] Add TosaSerializationRegion to serialization_lib - Rationale: add this fix to support constants access between multiple blocks by another layer of abstraction called Region - Changes: - flatbuffers schema update, regenerate header files - add TosaSerializationRegion for the handler - other relevant fixes Signed-off-by: Jerry Ge Change-Id: I4bb72503abfd629ae017d2f905184efbab244aa8 --- include/tosa_generated.h | 95 +++++++-- include/tosa_serialization_handler.h | 93 ++++++--- python/tosa/TosaGraph.py | 22 +- python/tosa/TosaRegion.py | 77 +++++++ schema/tosa.fbs | 11 +- src/tosa_serialization_handler.cpp | 395 +++++++++++++++++++---------------- 6 files changed, 455 insertions(+), 238 deletions(-) create mode 100644 python/tosa/TosaRegion.py diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 123c4f6..4d231b0 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -80,6 +80,9 @@ struct TosaOperatorBuilder; struct TosaBasicBlock; struct TosaBasicBlockBuilder; +struct TosaRegion; +struct TosaRegionBuilder; + struct TosaGraph; struct TosaGraphBuilder; @@ -2502,25 +2505,91 @@ inline flatbuffers::Offset CreateTosaBasicBlockDirect( outputs__); } +struct TosaRegion FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TosaRegionBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_BLOCKS = 6 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *blocks() const { + return GetPointer> *>(VT_BLOCKS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_BLOCKS) && + verifier.VerifyVector(blocks()) && + verifier.VerifyVectorOfTables(blocks()) && + verifier.EndTable(); + } +}; + +struct TosaRegionBuilder { + typedef TosaRegion Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(TosaRegion::VT_NAME, name); + } + void add_blocks(flatbuffers::Offset>> blocks) { + fbb_.AddOffset(TosaRegion::VT_BLOCKS, blocks); + } + explicit TosaRegionBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTosaRegion( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> blocks = 0) { + TosaRegionBuilder builder_(_fbb); + builder_.add_blocks(blocks); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTosaRegionDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector> *blocks = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto blocks__ = blocks ? _fbb.CreateVector>(*blocks) : 0; + return tosa::CreateTosaRegion( + _fbb, + name__, + blocks__); +} + struct TosaGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TosaGraphBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_VERSION = 4, - VT_BLOCKS = 6 + VT_REGIONS = 6 }; const tosa::Version *version() const { return GetPointer(VT_VERSION); } - const flatbuffers::Vector> *blocks() const { - return GetPointer> *>(VT_BLOCKS); + const flatbuffers::Vector> *regions() const { + return GetPointer> *>(VT_REGIONS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VERSION) && verifier.VerifyTable(version()) && - VerifyOffset(verifier, VT_BLOCKS) && - verifier.VerifyVector(blocks()) && - verifier.VerifyVectorOfTables(blocks()) && + VerifyOffset(verifier, VT_REGIONS) && + verifier.VerifyVector(regions()) && + verifier.VerifyVectorOfTables(regions()) && verifier.EndTable(); } }; @@ -2532,8 +2601,8 @@ struct TosaGraphBuilder { void add_version(flatbuffers::Offset version) { fbb_.AddOffset(TosaGraph::VT_VERSION, version); } - void add_blocks(flatbuffers::Offset>> blocks) { - fbb_.AddOffset(TosaGraph::VT_BLOCKS, blocks); + void add_regions(flatbuffers::Offset>> regions) { + fbb_.AddOffset(TosaGraph::VT_REGIONS, regions); } explicit TosaGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -2549,9 +2618,9 @@ struct TosaGraphBuilder { inline flatbuffers::Offset CreateTosaGraph( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset version = 0, - flatbuffers::Offset>> blocks = 0) { + flatbuffers::Offset>> regions = 0) { TosaGraphBuilder builder_(_fbb); - builder_.add_blocks(blocks); + builder_.add_regions(regions); builder_.add_version(version); return builder_.Finish(); } @@ -2559,12 +2628,12 @@ inline flatbuffers::Offset CreateTosaGraph( inline flatbuffers::Offset CreateTosaGraphDirect( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset version = 0, - const std::vector> *blocks = nullptr) { - auto blocks__ = blocks ? _fbb.CreateVector>(*blocks) : 0; + const std::vector> *regions = nullptr) { + auto regions__ = regions ? _fbb.CreateVector>(*regions) : 0; return tosa::CreateTosaGraph( _fbb, version, - blocks__); + regions__); } inline bool VerifyAttribute(flatbuffers::Verifier &verifier, const void *obj, Attribute type) { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 4cda830..0aa16df 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -216,11 +216,13 @@ class TosaSerializationBasicBlock public: // constructor and destructor TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector& operators, const std::vector& tensors, const std::vector& inputs, const std::vector& outputs); TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector&& operators, std::vector&& tensors, std::vector&& inputs, @@ -232,10 +234,15 @@ public: { return _name; } + std::string GetRegionName() const + { + return _region_name; + } std::vector& GetOperators() { return _operators; } + std::vector& GetTensors() { return _tensors; @@ -259,19 +266,59 @@ public: { return _inputs; } + std::vector& GetOutputs() { return _outputs; } private: - std::string _name; /* name of basic block */ + std::string _name; /* name of basic block */ + std::string _region_name; std::vector _operators; /* TosaSerializationOperator list */ std::vector _tensors; /* TosaSerializationTensor list */ std::vector _inputs; /* array of string to specify block inputs */ std::vector _outputs; /* array of string to specify block outputs */ }; +class TosaSerializationRegion +{ +public: + // constructor and desctructor + TosaSerializationRegion(const std::string& name, const std::vector& blocks); + TosaSerializationRegion(const std::string&& name, const std::vector&& blocks); + ~TosaSerializationRegion(); + + // accessors + std::string GetName() const + { + return this->_name; + } + + std::vector& GetBlocks() + { + return this->_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + +private: + std::string _name; /* name of basic block */ + std::vector _blocks; /* TosaSerializationBasicBlock list */ +}; + /* * this is a helper class for writing/reading Tosa ISA * supported format: .tosa (flatbuffer), .json @@ -319,39 +366,29 @@ public: } // accessor - std::vector& GetBlocks() + std::vector& GetRegions() { - return _blocks; + return _regions; } - TosaSerializationBasicBlock* GetBlockByName(std::string name) + TosaSerializationRegion* GetMainRegion() { - TosaSerializationBasicBlock* result = nullptr; - for (auto block : GetBlocks()) + return _regions[0]; + } + + TosaSerializationRegion* GetRegionByName(std::string name) + { + TosaSerializationRegion* result = nullptr; + for (auto region : GetRegions()) { - if (block->GetName() == name) + if (region->GetName() == name) { - result = block; + result = region; break; } } return result; } - TosaSerializationBasicBlock* GetMainBlock() - { - TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); - assert(main_block); - return main_block; - } - - std::vector& GetInputs() - { - return GetMainBlock()->GetInputs(); - } - std::vector& GetOutputs() - { - return GetMainBlock()->GetOutputs(); - } bool GetSchemaLoaded() const { @@ -365,11 +402,11 @@ protected: TosaVersion ParseTosaSchemaVersion(std::string schema); private: - TosaVersion _version; /* version struct */ - flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ - flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ - std::vector _blocks; /* array structure to store all TosaSerializationBasicBlock */ - bool _schemaLoaded; /* is the schema properly loaded? */ + TosaVersion _version; /* version struct */ + flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ + flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ + std::vector _regions; /* array structure to store all TosaSerializationRegion */ + bool _schemaLoaded; /* is the schema properly loaded? */ }; } // namespace tosa diff --git a/python/tosa/TosaGraph.py b/python/tosa/TosaGraph.py index 164cef2..5ee3304 100644 --- a/python/tosa/TosaGraph.py +++ b/python/tosa/TosaGraph.py @@ -40,27 +40,27 @@ class TosaGraph(object): return None # TosaGraph - def Blocks(self, j): + def Regions(self, j): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: x = self._tab.Vector(o) x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 x = self._tab.Indirect(x) - from tosa.TosaBasicBlock import TosaBasicBlock - obj = TosaBasicBlock() + from tosa.TosaRegion import TosaRegion + obj = TosaRegion() obj.Init(self._tab.Bytes, x) return obj return None # TosaGraph - def BlocksLength(self): + def RegionsLength(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) if o != 0: return self._tab.VectorLen(o) return 0 # TosaGraph - def BlocksIsNone(self): + def RegionsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 @@ -70,12 +70,12 @@ def Start(builder): def TosaGraphAddVersion(builder, version): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(version), 0) def AddVersion(builder, version): return TosaGraphAddVersion(builder, version) -def TosaGraphAddBlocks(builder, blocks): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(blocks), 0) -def AddBlocks(builder, blocks): - return TosaGraphAddBlocks(builder, blocks) -def TosaGraphStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4) -def StartBlocksVector(builder, numElems): - return TosaGraphStartBlocksVector(builder, numElems) +def TosaGraphAddRegions(builder, regions): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(regions), 0) +def AddRegions(builder, regions): + return TosaGraphAddRegions(builder, regions) +def TosaGraphStartRegionsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def StartRegionsVector(builder, numElems): + return TosaGraphStartRegionsVector(builder, numElems) def TosaGraphEnd(builder): return builder.EndObject() def End(builder): return TosaGraphEnd(builder) \ No newline at end of file diff --git a/python/tosa/TosaRegion.py b/python/tosa/TosaRegion.py new file mode 100644 index 0000000..e16ee0e --- /dev/null +++ b/python/tosa/TosaRegion.py @@ -0,0 +1,77 @@ +# automatically generated by the FlatBuffers compiler, do not modify + +# namespace: tosa + +import flatbuffers +from flatbuffers.compat import import_numpy +np = import_numpy() + +class TosaRegion(object): + __slots__ = ['_tab'] + + @classmethod + def GetRootAs(cls, buf, offset=0): + n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) + x = TosaRegion() + x.Init(buf, n + offset) + return x + + @classmethod + def GetRootAsTosaRegion(cls, buf, offset=0): + """This method is deprecated. Please switch to GetRootAs.""" + return cls.GetRootAs(buf, offset) + @classmethod + def TosaRegionBufferHasIdentifier(cls, buf, offset, size_prefixed=False): + return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x54\x4F\x53\x41", size_prefixed=size_prefixed) + + # TosaRegion + def Init(self, buf, pos): + self._tab = flatbuffers.table.Table(buf, pos) + + # TosaRegion + def Name(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + + # TosaRegion + def Blocks(self, j): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + x = self._tab.Vector(o) + x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 + x = self._tab.Indirect(x) + from tosa.TosaBasicBlock import TosaBasicBlock + obj = TosaBasicBlock() + obj.Init(self._tab.Bytes, x) + return obj + return None + + # TosaRegion + def BlocksLength(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + if o != 0: + return self._tab.VectorLen(o) + return 0 + + # TosaRegion + def BlocksIsNone(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +def TosaRegionStart(builder): builder.StartObject(2) +def Start(builder): + return TosaRegionStart(builder) +def TosaRegionAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) +def AddName(builder, name): + return TosaRegionAddName(builder, name) +def TosaRegionAddBlocks(builder, blocks): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(blocks), 0) +def AddBlocks(builder, blocks): + return TosaRegionAddBlocks(builder, blocks) +def TosaRegionStartBlocksVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def StartBlocksVector(builder, numElems): + return TosaRegionStartBlocksVector(builder, numElems) +def TosaRegionEnd(builder): return builder.EndObject() +def End(builder): + return TosaRegionEnd(builder) \ No newline at end of file diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 4955139..4d5c611 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -297,7 +297,7 @@ table TosaTensor { table TosaOperator { op:Op; // operator enum - attribute: Attribute; // union structure. operator attribute + attribute:Attribute; // union structure. operator attribute inputs:[string]; // list of input tensor names outputs:[string]; // list of output tensor names } @@ -310,9 +310,14 @@ table TosaBasicBlock { outputs:[string]; // name of graph outputs } +table TosaRegion { + name:string; // name of region + blocks:[TosaBasicBlock]; // basic blocks array +} + table TosaGraph { - version: Version; - blocks:[TosaBasicBlock]; // basic blocks array + version:Version; + regions:[TosaRegion]; // regions array } root_type TosaGraph; diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 170b313..a4410f2 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -113,29 +113,33 @@ TosaSerializationOperator::~TosaSerializationOperator() } TosaSerializationBasicBlock::TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector& operators, const std::vector& tensors, const std::vector& inputs, const std::vector& outputs) { - _name = name; - _operators = operators; - _tensors = tensors; - _inputs = inputs; - _outputs = outputs; + _name = name; + _region_name = region_name; + _operators = operators; + _tensors = tensors; + _inputs = inputs; + _outputs = outputs; } TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector&& operators, std::vector&& tensors, std::vector&& inputs, std::vector&& outputs) { - _name = std::move(name); - _operators = std::move(operators); - _tensors = std::move(tensors); - _inputs = std::move(inputs); - _outputs = std::move(outputs); + _name = std::move(name); + _region_name = std::move(region_name); + _operators = std::move(operators); + _tensors = std::move(tensors); + _inputs = std::move(inputs); + _outputs = std::move(outputs); } TosaSerializationBasicBlock::~TosaSerializationBasicBlock() @@ -153,6 +157,29 @@ TosaSerializationBasicBlock::~TosaSerializationBasicBlock() } } +TosaSerializationRegion::TosaSerializationRegion(const std::string& name, + const std::vector& blocks) +{ + _name = name; + _blocks = blocks; +} + +TosaSerializationRegion::TosaSerializationRegion(const std::string&& name, + const std::vector&& blocks) +{ + _name = std::move(name); + _blocks = std::move(blocks); +} + +TosaSerializationRegion::~TosaSerializationRegion() +{ + // deallocate all blocks + for (auto block : GetBlocks()) + { + delete block; // ~TosaSerializationBasicBlock() + } +} + TosaSerializationHandler::TosaSerializationHandler() { _schemaLoaded = false; @@ -400,11 +427,11 @@ tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename tosa_err_t TosaSerializationHandler::Clear() { // deallocate all basic blocks - for (auto bb : GetBlocks()) + for (auto region : GetRegions()) { - delete bb; + delete region; } - _blocks.clear(); + _regions.clear(); return TOSA_OK; } @@ -417,11 +444,13 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) } auto fb_tosa_graph = GetTosaGraph(buf); auto fb_tosa_version = fb_tosa_graph->version(); - auto fb_tosa_blocks = fb_tosa_graph->blocks(); + auto fb_tosa_regions = fb_tosa_graph->regions(); std::vector operator_inputs_container; std::vector operator_outputs_container; + std::vector region_blocks_container; + std::vector block_operators_container; std::vector block_tensors_container; std::vector block_inputs_container; @@ -431,6 +460,7 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) TosaSerializationOperator* new_operator = NULL; TosaSerializationBasicBlock* new_block = NULL; TosaSerializationTensor* new_tensor = NULL; + TosaSerializationRegion* new_region = NULL; // erase container Clear(); @@ -453,127 +483,137 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) return TOSA_VERSION_MISMATCH; } - for (size_t i = 0; i < fb_tosa_blocks->size(); i++) + for (size_t i = 0; i < fb_tosa_regions->size(); i++) { - auto curr_block = fb_tosa_blocks->Get(i); + auto curr_region = fb_tosa_regions->Get(i); + auto region_name = curr_region->name()->str(); + auto fb_tosa_blocks = curr_region->blocks(); - auto block_name = curr_block->name()->str(); + new_region = new TosaSerializationRegion(curr_region->name()->str(), region_blocks_container); + this->GetRegions().push_back(new_region); - auto fb_tosa_operators = curr_block->operators(); - block_operators_container.clear(); - for (size_t j = 0; j < fb_tosa_operators->size(); j++) + for (size_t i = 0; i < fb_tosa_blocks->size(); i++) { - auto curr_operator = fb_tosa_operators->Get(j); + auto curr_block = fb_tosa_blocks->Get(i); - auto operator_op = curr_operator->op(); - auto attribute_type = curr_operator->attribute_type(); - auto attribute = curr_operator->attribute(); + auto block_name = curr_block->name()->str(); - // input tensors - auto operator_inputs = curr_operator->inputs(); - operator_inputs_container.clear(); - if (operator_inputs) + auto fb_tosa_operators = curr_block->operators(); + block_operators_container.clear(); + for (size_t j = 0; j < fb_tosa_operators->size(); j++) { - for (size_t k = 0; k < operator_inputs->size(); k++) + auto curr_operator = fb_tosa_operators->Get(j); + + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); + + // input tensors + auto operator_inputs = curr_operator->inputs(); + operator_inputs_container.clear(); + if (operator_inputs) { - auto curr_input = operator_inputs->Get(k); - operator_inputs_container.push_back(curr_input->str()); + for (size_t k = 0; k < operator_inputs->size(); k++) + { + auto curr_input = operator_inputs->Get(k); + operator_inputs_container.push_back(curr_input->str()); + } } - } - // output tensors - auto operator_outputs = curr_operator->outputs(); - operator_outputs_container.clear(); - if (operator_outputs) - { - for (size_t k = 0; k < operator_outputs->size(); k++) + // output tensors + auto operator_outputs = curr_operator->outputs(); + operator_outputs_container.clear(); + if (operator_outputs) { - auto curr_output = operator_outputs->Get(k); - operator_outputs_container.push_back(curr_output->str()); + for (size_t k = 0; k < operator_outputs->size(); k++) + { + auto curr_output = operator_outputs->Get(k); + operator_outputs_container.push_back(curr_output->str()); + } } - } - switch (attribute_type) - { - case Attribute_NONE: - typed_attribute = new TosaNoneAttribute(); - break; + switch (attribute_type) + { + case Attribute_NONE: + typed_attribute = new TosaNoneAttribute(); + break; #define DEF_ATTRIBUTE(NAME, ...) \ case Attribute_##NAME##Attribute: \ typed_attribute = new Tosa##NAME##Attribute(attribute); \ break; #include "attribute.def" #undef DEF_ATTRIBUTE - default: - printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; + default: + printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, + operator_inputs_container, operator_outputs_container); + if (new_operator) + { + block_operators_container.push_back(new_operator); + } + else + { + return TOSA_MEMORY_ERROR; + } + + if (typed_attribute) + delete typed_attribute; } - new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, - operator_inputs_container, operator_outputs_container); - if (new_operator) + auto block_inputs = curr_block->inputs(); + auto block_outputs = curr_block->outputs(); + + block_inputs_container.clear(); + block_outputs_container.clear(); + + for (size_t j = 0; j < block_inputs->size(); j++) { - block_operators_container.push_back(new_operator); + auto curr_block_input = block_inputs->Get(j); + block_inputs_container.push_back(curr_block_input->str()); } - else + for (size_t j = 0; j < block_outputs->size(); j++) { - return TOSA_MEMORY_ERROR; + auto curr_block_output = block_outputs->Get(j); + block_outputs_container.push_back(curr_block_output->str()); } - if (typed_attribute) - delete typed_attribute; - } - - auto fb_tosa_tensors = curr_block->tensors(); - block_tensors_container.clear(); - for (size_t j = 0; j < fb_tosa_tensors->size(); j++) - { - auto curr_tensor = fb_tosa_tensors->Get(j); + auto fb_tosa_tensors = curr_block->tensors(); + block_tensors_container.clear(); + for (size_t j = 0; j < fb_tosa_tensors->size(); j++) + { + auto curr_tensor = fb_tosa_tensors->Get(j); - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_data = curr_tensor->data(); + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_data = curr_tensor->data(); - new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data); - if (new_tensor) + new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data); + if (new_tensor) + { + block_tensors_container.push_back(new_tensor); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + new_block = new TosaSerializationBasicBlock(block_name, region_name, block_operators_container, + block_tensors_container, block_inputs_container, + block_outputs_container); + if (new_block) { - block_tensors_container.push_back(new_tensor); + this->GetRegions()[0]->GetBlocks().push_back(new_block); } else { return TOSA_MEMORY_ERROR; } - } - - auto block_inputs = curr_block->inputs(); - auto block_outputs = curr_block->outputs(); - - block_inputs_container.clear(); - block_outputs_container.clear(); - - for (size_t j = 0; j < block_inputs->size(); j++) - { - auto curr_block_input = block_inputs->Get(j); - block_inputs_container.push_back(curr_block_input->str()); - } - for (size_t j = 0; j < block_outputs->size(); j++) - { - auto curr_block_output = block_outputs->Get(j); - block_outputs_container.push_back(curr_block_output->str()); - } - - new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container, - block_inputs_container, block_outputs_container); - if (new_block) - { - this->GetBlocks().push_back(new_block); - } - else - { - return TOSA_MEMORY_ERROR; - } + } // end block for_loop } return TOSA_OK; @@ -581,84 +621,76 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) tosa_err_t TosaSerializationHandler::Serialize() { - std::vector> fboffset_blocks; + // regions + std::vector> fboffset_regions; + // blocks + std::vector> fboffset_blocks; std::vector> fboffset_block_operators; std::vector> fboffset_block_tensors; std::vector> fboffset_block_inputs; std::vector> fboffset_block_outputs; + // operators std::vector> fboffset_operator_inputs; std::vector> fboffset_operator_outputs; // translate TosaFlatbufferOperator to flatbuffers::Offset - for (auto block : GetBlocks()) + for (auto region : GetRegions()) { - fboffset_block_operators.clear(); - fboffset_block_tensors.clear(); - fboffset_block_inputs.clear(); - fboffset_block_outputs.clear(); - - auto block_name = _builder.CreateString(block->GetName().c_str()); - - for (auto tensor_str : block->GetInputs()) - { - auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_block_inputs.push_back(tensor_name); - } - - for (auto tensor_str : block->GetOutputs()) - { - auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_block_outputs.push_back(tensor_name); - } - - auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs); - auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs); - - for (auto op : block->GetOperators()) + for (auto block : region->GetBlocks()) { - fboffset_operator_inputs.clear(); - fboffset_operator_outputs.clear(); - - auto operator_op = op->GetOp(); - auto attribute_type = op->GetAttributeType(); - - for (auto tensor_str : op->GetInputTensorNames()) + fboffset_block_operators.clear(); + fboffset_block_tensors.clear(); + fboffset_block_inputs.clear(); + fboffset_block_outputs.clear(); + auto block_name = _builder.CreateString(block->GetName().c_str()); + for (auto tensor_str : block->GetInputs()) { auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_operator_inputs.push_back(tensor_name); + fboffset_block_inputs.push_back(tensor_name); } - - for (auto tensor_str : op->GetOutputTensorNames()) + for (auto tensor_str : block->GetOutputs()) { auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_operator_outputs.push_back(tensor_name); + fboffset_block_outputs.push_back(tensor_name); } - - auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs); - auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs); - - flatbuffers::Offset fb_attribute; - switch (attribute_type) + auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs); + auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs); + for (auto op : block->GetOperators()) { - case Attribute_NONE: - fb_attribute = 0; - break; - + fboffset_operator_inputs.clear(); + fboffset_operator_outputs.clear(); + auto operator_op = op->GetOp(); + auto attribute_type = op->GetAttributeType(); + for (auto tensor_str : op->GetInputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_inputs.push_back(tensor_name); + } + for (auto tensor_str : op->GetOutputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_outputs.push_back(tensor_name); + } + auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs); + auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs); + flatbuffers::Offset fb_attribute; + switch (attribute_type) + { + case Attribute_NONE: + fb_attribute = 0; + break; #define DEF_ARGS_S_STR(NAME, V) , _builder.CreateString(reinterpret_cast(op->GetAttribute())->V().c_str()) #define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast(op->GetAttribute())->V() - #define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_DType(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V) - #define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V) #define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector(reinterpret_cast(op->GetAttribute())->V()) - #define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) #define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) #define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ @@ -678,7 +710,6 @@ tosa_err_t TosaSerializationHandler::Serialize() case Attribute_##NAME##Attribute: \ fb_attribute = Create##NAME##Attribute(_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \ break; - #include "attribute.def" #undef DEF_ATTRIBUTE #undef DEF_ARGS_1 @@ -698,44 +729,42 @@ tosa_err_t TosaSerializationHandler::Serialize() #undef DEF_ARGS_S_string #undef DEF_ARGS_S_STR #undef DEF_ARGS_S_DEFAULT - default: - printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; + default: + printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, + fb_operator_inputs, fb_operator_outputs); + fboffset_block_operators.push_back(fboffset_operator); } + auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); + for (auto tensor : block->GetTensors()) + { + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data); + fboffset_block_tensors.push_back(fboffset_tensor); + } + auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); + auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors, + fb_block_inputs, fb_block_outputs); + fboffset_blocks.push_back(fboffset_block); + } // end block for_loop + auto fb_blocks = _builder.CreateVector(fboffset_blocks); - auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, - fb_operator_inputs, fb_operator_outputs); - fboffset_block_operators.push_back(fboffset_operator); - } - - auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); - - for (auto tensor : block->GetTensors()) - { - - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - - auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data); - fboffset_block_tensors.push_back(fboffset_tensor); - } - - auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); + auto region_name = _builder.CreateString(region->GetName().c_str()); + auto fboffset_region = CreateTosaRegion(_builder, region_name, fb_blocks); + fboffset_regions.push_back(fboffset_region); + } // end region for_loop - auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors, - fb_block_inputs, fb_block_outputs); - fboffset_blocks.push_back(fboffset_block); - } - - auto fb_blocks = _builder.CreateVector(fboffset_blocks); + auto fb_regions = _builder.CreateVector(fboffset_regions); auto fb_version = CreateVersion(_builder, TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT); - - auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks); + auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_regions); _builder.Finish(fb_graph, TosaGraphIdentifier()); return TOSA_OK; -- cgit v1.2.1