aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-09-09 13:38:56 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-22 17:30:05 +0000
commit442261bf67fa2ec4d86ed3e431a6373787b3e35a (patch)
treed6a802db3c208cb9783cabe58d17aa1ffe12ee52 /src
parent780ffb5f034a4fd6581a44cd9c3b1cf119f33589 (diff)
downloadserialization_lib-442261bf67fa2ec4d86ed3e431a6373787b3e35a.tar.gz
[Serialization_lib] Support StatefulOps for TOSA
- Add variable in TosaTensor to schema file - Update TosaSerializationTensor regarding variable change - Rename internal zero_pad() and expose interface as ForceAlignTensorData() Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539
Diffstat (limited to 'src')
-rw-r--r--src/tosa_serialization_handler.cpp60
1 files changed, 34 insertions, 26 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index cbb862f..cb44f17 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,10 +22,11 @@ using namespace tosa;
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data)
+ const flatbuffers::Vector<uint8_t>* data,
+ bool variable)
{
- _dtype = dtype;
-
+ _dtype = dtype;
+ _variable = variable;
if (shape)
{
std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
@@ -43,18 +44,21 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data)
+ const std::vector<uint8_t>& data,
+ bool variable)
{
- _dtype = dtype;
- _shape = shape;
- _name = name;
- _data = data;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
}
TosaSerializationTensor::TosaSerializationTensor()
{
- _dtype = DType_UNKNOWN;
- _name = "UNKNOWN";
+ _dtype = DType_UNKNOWN;
+ _variable = false;
+ _name = "UNKNOWN";
}
TosaSerializationTensor::~TosaSerializationTensor()
@@ -514,12 +518,14 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
{
auto curr_tensor = fb_tosa_tensors->Get(j);
- auto tensor_name = curr_tensor->name();
- auto tensor_shape = curr_tensor->shape();
- auto tensor_type = curr_tensor->type();
- auto tensor_data = curr_tensor->data();
+ auto tensor_name = curr_tensor->name();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_variable = curr_tensor->variable();
+ auto tensor_data = curr_tensor->data();
- new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data);
+ new_tensor =
+ new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable);
if (new_tensor)
{
block_tensors_container.push_back(new_tensor);
@@ -676,8 +682,10 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
auto tensor_shape = _builder.CreateVector(tensor->GetShape());
auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data);
+ auto fboffset_tensor =
+ CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
@@ -702,7 +710,7 @@ tosa_err_t TosaSerializationHandler::Serialize()
return TOSA_OK;
}
-void zero_pad(std::vector<uint8_t>& buf)
+void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
{
while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0)
{
@@ -721,7 +729,7 @@ tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in
out.push_back(*val_u16 & 0xFF);
out.push_back((*val_u16 >> 8) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -736,7 +744,7 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -753,7 +761,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>&
out.push_back((*val_u64 >> 32) & 0xFF);
out.push_back((*val_u64 >> 40) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -768,7 +776,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>&
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -781,7 +789,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>&
out.push_back(*val_u16 & 0xFF);
out.push_back((*val_u16 >> 8) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -793,7 +801,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in
uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
out.push_back(*val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -824,7 +832,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in
uint8_t val_u8 = static_cast<uint8_t>(val_packed);
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -836,7 +844,7 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
uint8_t val_u8 = val;
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}