aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-21 17:00:29 +0000
committerTai Ly <tai.ly@arm.com>2023-08-23 15:12:30 +0000
commitc6939a4d269968a34b0ae0aa579f0f0736aaeccc (patch)
tree8925f10184fe6f4c8c179ad34d04a24d9e304bd7 /src
parent442261bf67fa2ec4d86ed3e431a6373787b3e35a (diff)
downloadserialization_lib-c6939a4d269968a34b0ae0aa579f0f0736aaeccc.tar.gz
Add is_unranked flag to Tensors
This adds a bool field, is_unranked to TosaTensor in tosa.fbs to indicate whether a tensor with shape = {} is an unranked tensor or an 0-D tensor. For older tosa files without this field, the default value is false. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69
Diffstat (limited to 'src')
-rw-r--r--src/tosa_serialization_handler.cpp55
1 files changed, 31 insertions, 24 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index cb44f17..3620c16 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -23,7 +23,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- bool variable)
+ const bool variable,
+ const bool is_unranked)
{
_dtype = dtype;
_variable = variable;
@@ -39,26 +40,30 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
{
std::copy(data->begin(), data->end(), std::back_inserter(_data));
}
+ _is_unranked = is_unranked;
}
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- bool variable)
+ const bool variable,
+ const bool is_unranked)
{
- _dtype = dtype;
- _variable = variable;
- _shape = shape;
- _name = name;
- _data = data;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
+ _is_unranked = is_unranked;
}
TosaSerializationTensor::TosaSerializationTensor()
{
- _dtype = DType_UNKNOWN;
- _variable = false;
- _name = "UNKNOWN";
+ _dtype = DType_UNKNOWN;
+ _variable = false;
+ _name = "UNKNOWN";
+ _is_unranked = false;
}
TosaSerializationTensor::~TosaSerializationTensor()
@@ -518,14 +523,15 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
{
auto curr_tensor = fb_tosa_tensors->Get(j);
- auto tensor_name = curr_tensor->name();
- auto tensor_shape = curr_tensor->shape();
- auto tensor_type = curr_tensor->type();
- auto tensor_variable = curr_tensor->variable();
- auto tensor_data = curr_tensor->data();
+ auto tensor_name = curr_tensor->name();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_variable = curr_tensor->variable();
+ auto tensor_data = curr_tensor->data();
+ auto tensor_is_unranked = curr_tensor->is_unranked();
- new_tensor =
- new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable);
+ new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data,
+ tensor_variable, tensor_is_unranked);
if (new_tensor)
{
block_tensors_container.push_back(new_tensor);
@@ -679,13 +685,14 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
for (auto tensor : block->GetTensors())
{
- auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
- auto tensor_shape = _builder.CreateVector(tensor->GetShape());
- auto tensor_dtype = tensor->GetDtype();
- bool tensor_variable = tensor->GetVariable();
- auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto fboffset_tensor =
- CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable);
+ auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
+ auto tensor_shape = _builder.CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
+ auto tensor_data = _builder.CreateVector(tensor->GetData());
+ auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data,
+ tensor_variable, tensor_is_unranked);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);