diff options
author | Tai Ly <tai.ly@arm.com> | 2023-08-21 17:00:29 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-08-23 15:12:30 +0000 |
commit | c6939a4d269968a34b0ae0aa579f0f0736aaeccc (patch) | |
tree | 8925f10184fe6f4c8c179ad34d04a24d9e304bd7 /src | |
parent | 442261bf67fa2ec4d86ed3e431a6373787b3e35a (diff) | |
download | serialization_lib-c6939a4d269968a34b0ae0aa579f0f0736aaeccc.tar.gz |
Add is_unranked flag to Tensors
This adds a bool field, is_unranked to TosaTensor in tosa.fbs
to indicate whether a tensor with shape = {} is an unranked tensor
or an 0-D tensor.
For older tosa files without this field, the default value is false.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69
Diffstat (limited to 'src')
-rw-r--r-- | src/tosa_serialization_handler.cpp | 55 |
1 files changed, 31 insertions, 24 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index cb44f17..3620c16 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -23,7 +23,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name const flatbuffers::Vector<int32_t>* shape, DType dtype, const flatbuffers::Vector<uint8_t>* data, - bool variable) + const bool variable, + const bool is_unranked) { _dtype = dtype; _variable = variable; @@ -39,26 +40,30 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name { std::copy(data->begin(), data->end(), std::back_inserter(_data)); } + _is_unranked = is_unranked; } TosaSerializationTensor::TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, const std::vector<uint8_t>& data, - bool variable) + const bool variable, + const bool is_unranked) { - _dtype = dtype; - _variable = variable; - _shape = shape; - _name = name; - _data = data; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; + _is_unranked = is_unranked; } TosaSerializationTensor::TosaSerializationTensor() { - _dtype = DType_UNKNOWN; - _variable = false; - _name = "UNKNOWN"; + _dtype = DType_UNKNOWN; + _variable = false; + _name = "UNKNOWN"; + _is_unranked = false; } TosaSerializationTensor::~TosaSerializationTensor() @@ -518,14 +523,15 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_tensor = fb_tosa_tensors->Get(j); - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_variable = curr_tensor->variable(); - auto tensor_data = curr_tensor->data(); + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_variable = curr_tensor->variable(); + auto tensor_data = curr_tensor->data(); + auto tensor_is_unranked = curr_tensor->is_unranked(); - new_tensor = - new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable); + new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, + tensor_variable, tensor_is_unranked); if (new_tensor) { block_tensors_container.push_back(new_tensor); @@ -679,13 +685,14 @@ tosa_err_t TosaSerializationHandler::Serialize() auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); for (auto tensor : block->GetTensors()) { - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - bool tensor_variable = tensor->GetVariable(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto fboffset_tensor = - CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable); + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto tensor_is_unranked = tensor->GetIsUnranked(); + auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, + tensor_variable, tensor_is_unranked); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); |