aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-19 21:30:18 +0000
committerTai Ly <tai.ly@arm.com>2023-09-19 21:55:11 +0000
commitd0520b9b3a0eaf9dadc6cdb57ed42906e577d32e (patch)
tree1538ac688cbafa4e8e08b7aabe3e10bbe404c767 /src
parent924f3094a745c1955d51fce18b488adfed5ee76b (diff)
downloadserialization_lib-d0520b9b3a0eaf9dadc6cdb57ed42906e577d32e.tar.gz
Add variable_name to tensors
Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia142c8b1a9e0869daefb3eef71100fd0c2a0effc
Diffstat (limited to 'src')
-rw-r--r--src/tosa_serialization_handler.cpp39
1 files changed, 24 insertions, 15 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 7e96313..f96ff60 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -24,7 +24,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
const bool variable,
- const bool is_unranked)
+ const bool is_unranked,
+ const flatbuffers::String* variable_name)
{
_dtype = dtype;
_variable = variable;
@@ -41,6 +42,11 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
std::copy(data->begin(), data->end(), std::back_inserter(_data));
}
_is_unranked = is_unranked;
+
+ if (variable_name)
+ {
+ _variable_name = variable_name->str();
+ }
}
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
@@ -48,14 +54,16 @@ TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
DType dtype,
const std::vector<uint8_t>& data,
const bool variable,
- const bool is_unranked)
+ const bool is_unranked,
+ const std::string& variable_name)
{
- _dtype = dtype;
- _variable = variable;
- _shape = shape;
- _name = name;
- _data = data;
- _is_unranked = is_unranked;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
+ _is_unranked = is_unranked;
+ _variable_name = variable_name;
}
TosaSerializationTensor::TosaSerializationTensor()
@@ -697,14 +705,15 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
for (auto tensor : block->GetTensors())
{
- auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
- auto tensor_shape = _builder.CreateVector(tensor->GetShape());
- auto tensor_dtype = tensor->GetDtype();
- bool tensor_variable = tensor->GetVariable();
- auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
+ auto tensor_shape = _builder.CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
+ auto tensor_data = _builder.CreateVector(tensor->GetData());
+ auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str());
auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data,
- tensor_variable, tensor_is_unranked);
+ tensor_variable, tensor_is_unranked, tensor_variable_name);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);