aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-13 21:04:11 +0000
committerTai Ly <tai.ly@arm.com>2023-04-06 16:01:11 +0000
commitcfcb20d08c4c409bbcd2d2dde6ca5ecdac299454 (patch)
tree3765790fc7993ab95a910f9d9f17833f9243c2c7
parent3ef34fb300e7f62bdb397c605ab6c3bd30682cf8 (diff)
downloadserialization_lib-cfcb20d08c4c409bbcd2d2dde6ca5ecdac299454.tar.gz
Fix serialize/deserialize bug when there are two or more regions
Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Icd865b03765f672e3aa28ddaf6a290617ec3f530
-rw-r--r--src/tosa_serialization_handler.cpp53
1 files changed, 17 insertions, 36 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index d84e0ab..d782213 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -446,16 +446,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
auto fb_tosa_version = fb_tosa_graph->version();
auto fb_tosa_regions = fb_tosa_graph->regions();
- std::vector<std::string> operator_inputs_container;
- std::vector<std::string> operator_outputs_container;
-
- std::vector<TosaSerializationBasicBlock*> region_blocks_container;
-
- std::vector<TosaSerializationOperator*> block_operators_container;
- std::vector<TosaSerializationTensor*> block_tensors_container;
- std::vector<std::string> block_inputs_container;
- std::vector<std::string> block_outputs_container;
-
TosaAttributeBase* typed_attribute = NULL;
TosaSerializationOperator* new_operator = NULL;
TosaSerializationBasicBlock* new_block = NULL;
@@ -489,17 +479,21 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
auto region_name = curr_region->name()->str();
auto fb_tosa_blocks = curr_region->blocks();
- new_region = new TosaSerializationRegion(curr_region->name()->str(), region_blocks_container);
+ new_region = new TosaSerializationRegion(curr_region->name()->str(), {});
this->GetRegions().push_back(new_region);
for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
{
+ std::vector<TosaSerializationOperator*> block_operators_container;
+ std::vector<TosaSerializationTensor*> block_tensors_container;
+ std::vector<std::string> block_inputs_container;
+ std::vector<std::string> block_outputs_container;
+
auto curr_block = fb_tosa_blocks->Get(i);
auto block_name = curr_block->name()->str();
auto fb_tosa_operators = curr_block->operators();
- block_operators_container.clear();
for (size_t j = 0; j < fb_tosa_operators->size(); j++)
{
auto curr_operator = fb_tosa_operators->Get(j);
@@ -508,9 +502,11 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
auto attribute_type = curr_operator->attribute_type();
auto attribute = curr_operator->attribute();
+ std::vector<std::string> operator_inputs_container;
+ std::vector<std::string> operator_outputs_container;
+
// input tensors
auto operator_inputs = curr_operator->inputs();
- operator_inputs_container.clear();
if (operator_inputs)
{
for (size_t k = 0; k < operator_inputs->size(); k++)
@@ -522,7 +518,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
// output tensors
auto operator_outputs = curr_operator->outputs();
- operator_outputs_container.clear();
if (operator_outputs)
{
for (size_t k = 0; k < operator_outputs->size(); k++)
@@ -567,9 +562,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
auto block_inputs = curr_block->inputs();
auto block_outputs = curr_block->outputs();
- block_inputs_container.clear();
- block_outputs_container.clear();
-
for (size_t j = 0; j < block_inputs->size(); j++)
{
auto curr_block_input = block_inputs->Get(j);
@@ -582,7 +574,6 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
}
auto fb_tosa_tensors = curr_block->tensors();
- block_tensors_container.clear();
for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
{
auto curr_tensor = fb_tosa_tensors->Get(j);
@@ -607,7 +598,7 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
block_outputs_container);
if (new_block)
{
- this->GetRegions()[0]->GetBlocks().push_back(new_block);
+ new_region->GetBlocks().push_back(new_block);
}
else
{
@@ -637,26 +628,16 @@ tosa_err_t TosaSerializationHandler::Serialize()
// regions
std::vector<flatbuffers::Offset<TosaRegion>> fboffset_regions;
- // blocks
- std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
- std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
- std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
-
- // operators
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
-
// translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
for (auto region : GetRegions())
{
+ std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
for (auto block : region->GetBlocks())
{
- fboffset_block_operators.clear();
- fboffset_block_tensors.clear();
- fboffset_block_inputs.clear();
- fboffset_block_outputs.clear();
+ std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
+ std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
auto block_name = _builder.CreateString(block->GetName().c_str());
for (auto tensor_str : block->GetInputs())
{
@@ -672,8 +653,8 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs);
for (auto op : block->GetOperators())
{
- fboffset_operator_inputs.clear();
- fboffset_operator_outputs.clear();
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
auto operator_op = op->GetOp();
auto attribute_type = op->GetAttributeType();
for (auto tensor_str : op->GetInputTensorNames())