aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-10-04 20:32:39 -0700
committerJerry Ge <jerry.ge@arm.com>2022-12-15 23:23:20 +0000
commit13c78a67a6a3d743352f0b6e349c52bf36e84468 (patch)
tree11e2966816fded27b37618ce08cc03b7f4ef2fa2 /include/tosa_serialization_handler.h
parent6388a097de4350cc70472921c272074190fd7c93 (diff)
downloadserialization_lib-13c78a67a6a3d743352f0b6e349c52bf36e84468.tar.gz
[region] Add TosaSerializationRegion to serialization_lib
- Rationale: add this fix to support constants access between multiple blocks by another layer of abstraction called Region - Changes: - flatbuffers schema update, regenerate header files - add TosaSerializationRegion for the handler - other relevant fixes Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I4bb72503abfd629ae017d2f905184efbab244aa8
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h93
1 files changed, 65 insertions, 28 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 4cda830..0aa16df 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -216,11 +216,13 @@ class TosaSerializationBasicBlock
public:
// constructor and destructor
TosaSerializationBasicBlock(const std::string& name,
+ const std::string& region_name,
const std::vector<TosaSerializationOperator*>& operators,
const std::vector<TosaSerializationTensor*>& tensors,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs);
TosaSerializationBasicBlock(std::string&& name,
+ std::string&& region_name,
std::vector<TosaSerializationOperator*>&& operators,
std::vector<TosaSerializationTensor*>&& tensors,
std::vector<std::string>&& inputs,
@@ -232,10 +234,15 @@ public:
{
return _name;
}
+ std::string GetRegionName() const
+ {
+ return _region_name;
+ }
std::vector<TosaSerializationOperator*>& GetOperators()
{
return _operators;
}
+
std::vector<TosaSerializationTensor*>& GetTensors()
{
return _tensors;
@@ -259,19 +266,59 @@ public:
{
return _inputs;
}
+
std::vector<std::string>& GetOutputs()
{
return _outputs;
}
private:
- std::string _name; /* name of basic block */
+ std::string _name; /* name of basic block */
+ std::string _region_name;
std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
std::vector<std::string> _inputs; /* array of string to specify block inputs */
std::vector<std::string> _outputs; /* array of string to specify block outputs */
};
+class TosaSerializationRegion
+{
+public:
+ // constructor and desctructor
+ TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
+ TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
+ ~TosaSerializationRegion();
+
+ // accessors
+ std::string GetName() const
+ {
+ return this->_name;
+ }
+
+ std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ {
+ return this->_blocks;
+ }
+
+ TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ {
+ TosaSerializationBasicBlock* result = nullptr;
+ for (auto block : GetBlocks())
+ {
+ if (block->GetName() == name)
+ {
+ result = block;
+ break;
+ }
+ }
+ return result;
+ }
+
+private:
+ std::string _name; /* name of basic block */
+ std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
+};
+
/*
* this is a helper class for writing/reading Tosa ISA
* supported format: .tosa (flatbuffer), .json
@@ -319,39 +366,29 @@ public:
}
// accessor
- std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ std::vector<TosaSerializationRegion*>& GetRegions()
{
- return _blocks;
+ return _regions;
}
- TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ TosaSerializationRegion* GetMainRegion()
{
- TosaSerializationBasicBlock* result = nullptr;
- for (auto block : GetBlocks())
+ return _regions[0];
+ }
+
+ TosaSerializationRegion* GetRegionByName(std::string name)
+ {
+ TosaSerializationRegion* result = nullptr;
+ for (auto region : GetRegions())
{
- if (block->GetName() == name)
+ if (region->GetName() == name)
{
- result = block;
+ result = region;
break;
}
}
return result;
}
- TosaSerializationBasicBlock* GetMainBlock()
- {
- TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
- assert(main_block);
- return main_block;
- }
-
- std::vector<std::string>& GetInputs()
- {
- return GetMainBlock()->GetInputs();
- }
- std::vector<std::string>& GetOutputs()
- {
- return GetMainBlock()->GetOutputs();
- }
bool GetSchemaLoaded() const
{
@@ -365,11 +402,11 @@ protected:
TosaVersion ParseTosaSchemaVersion(std::string schema);
private:
- TosaVersion _version; /* version struct */
- flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
- flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
- std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
- bool _schemaLoaded; /* is the schema properly loaded? */
+ TosaVersion _version; /* version struct */
+ flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
+ flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
+ std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
+ bool _schemaLoaded; /* is the schema properly loaded? */
};
} // namespace tosa