From d0520b9b3a0eaf9dadc6cdb57ed42906e577d32e Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 19 Sep 2023 21:30:18 +0000 Subject: Add variable_name to tensors Signed-off-by: Tai Ly Change-Id: Ia142c8b1a9e0869daefb3eef71100fd0c2a0effc --- src/tosa_serialization_handler.cpp | 39 +++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 7e96313..f96ff60 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -24,7 +24,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name DType dtype, const flatbuffers::Vector* data, const bool variable, - const bool is_unranked) + const bool is_unranked, + const flatbuffers::String* variable_name) { _dtype = dtype; _variable = variable; @@ -41,6 +42,11 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name std::copy(data->begin(), data->end(), std::back_inserter(_data)); } _is_unranked = is_unranked; + + if (variable_name) + { + _variable_name = variable_name->str(); + } } TosaSerializationTensor::TosaSerializationTensor(const std::string& name, @@ -48,14 +54,16 @@ TosaSerializationTensor::TosaSerializationTensor(const std::string& name, DType dtype, const std::vector& data, const bool variable, - const bool is_unranked) + const bool is_unranked, + const std::string& variable_name) { - _dtype = dtype; - _variable = variable; - _shape = shape; - _name = name; - _data = data; - _is_unranked = is_unranked; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; + _is_unranked = is_unranked; + _variable_name = variable_name; } TosaSerializationTensor::TosaSerializationTensor() @@ -697,14 +705,15 @@ tosa_err_t TosaSerializationHandler::Serialize() auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); for (auto tensor : block->GetTensors()) { - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - bool tensor_variable = tensor->GetVariable(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto tensor_is_unranked = tensor->GetIsUnranked(); + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto tensor_is_unranked = tensor->GetIsUnranked(); + auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str()); auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, - tensor_variable, tensor_is_unranked); + tensor_variable, tensor_is_unranked, tensor_variable_name); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); -- cgit v1.2.1