aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h93
1 files changed, 65 insertions, 28 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 4cda830..0aa16df 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -216,11 +216,13 @@ class TosaSerializationBasicBlock
public:
// constructor and destructor
TosaSerializationBasicBlock(const std::string& name,
+ const std::string& region_name,
const std::vector<TosaSerializationOperator*>& operators,
const std::vector<TosaSerializationTensor*>& tensors,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs);
TosaSerializationBasicBlock(std::string&& name,
+ std::string&& region_name,
std::vector<TosaSerializationOperator*>&& operators,
std::vector<TosaSerializationTensor*>&& tensors,
std::vector<std::string>&& inputs,
@@ -232,10 +234,15 @@ public:
{
return _name;
}
+ std::string GetRegionName() const
+ {
+ return _region_name;
+ }
std::vector<TosaSerializationOperator*>& GetOperators()
{
return _operators;
}
+
std::vector<TosaSerializationTensor*>& GetTensors()
{
return _tensors;
@@ -259,19 +266,59 @@ public:
{
return _inputs;
}
+
std::vector<std::string>& GetOutputs()
{
return _outputs;
}
private:
- std::string _name; /* name of basic block */
+ std::string _name; /* name of basic block */
+ std::string _region_name;
std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
std::vector<std::string> _inputs; /* array of string to specify block inputs */
std::vector<std::string> _outputs; /* array of string to specify block outputs */
};
+class TosaSerializationRegion
+{
+public:
+ // constructor and desctructor
+ TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
+ TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
+ ~TosaSerializationRegion();
+
+ // accessors
+ std::string GetName() const
+ {
+ return this->_name;
+ }
+
+ std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ {
+ return this->_blocks;
+ }
+
+ TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ {
+ TosaSerializationBasicBlock* result = nullptr;
+ for (auto block : GetBlocks())
+ {
+ if (block->GetName() == name)
+ {
+ result = block;
+ break;
+ }
+ }
+ return result;
+ }
+
+private:
+ std::string _name; /* name of basic block */
+ std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
+};
+
/*
* this is a helper class for writing/reading Tosa ISA
* supported format: .tosa (flatbuffer), .json
@@ -319,39 +366,29 @@ public:
}
// accessor
- std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ std::vector<TosaSerializationRegion*>& GetRegions()
{
- return _blocks;
+ return _regions;
}
- TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ TosaSerializationRegion* GetMainRegion()
{
- TosaSerializationBasicBlock* result = nullptr;
- for (auto block : GetBlocks())
+ return _regions[0];
+ }
+
+ TosaSerializationRegion* GetRegionByName(std::string name)
+ {
+ TosaSerializationRegion* result = nullptr;
+ for (auto region : GetRegions())
{
- if (block->GetName() == name)
+ if (region->GetName() == name)
{
- result = block;
+ result = region;
break;
}
}
return result;
}
- TosaSerializationBasicBlock* GetMainBlock()
- {
- TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
- assert(main_block);
- return main_block;
- }
-
- std::vector<std::string>& GetInputs()
- {
- return GetMainBlock()->GetInputs();
- }
- std::vector<std::string>& GetOutputs()
- {
- return GetMainBlock()->GetOutputs();
- }
bool GetSchemaLoaded() const
{
@@ -365,11 +402,11 @@ protected:
TosaVersion ParseTosaSchemaVersion(std::string schema);
private:
- TosaVersion _version; /* version struct */
- flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
- flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
- std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
- bool _schemaLoaded; /* is the schema properly loaded? */
+ TosaVersion _version; /* version struct */
+ flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
+ flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
+ std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
+ bool _schemaLoaded; /* is the schema properly loaded? */
};
} // namespace tosa