aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-09-09 13:38:56 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-22 17:30:05 +0000
commit442261bf67fa2ec4d86ed3e431a6373787b3e35a (patch)
treed6a802db3c208cb9783cabe58d17aa1ffe12ee52
parent780ffb5f034a4fd6581a44cd9c3b1cf119f33589 (diff)
downloadserialization_lib-442261bf67fa2ec4d86ed3e431a6373787b3e35a.tar.gz
[Serialization_lib] Support StatefulOps for TOSA
- Add variable in TosaTensor to schema file - Update TosaSerializationTensor regarding variable change - Rename internal zero_pad() and expose interface as ForceAlignTensorData() Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539
-rw-r--r--include/tosa_generated.h20
-rw-r--r--include/tosa_serialization_handler.h17
-rw-r--r--python/tosa/TosaTensor.py15
-rw-r--r--schema/tosa.fbs1
-rw-r--r--src/tosa_serialization_handler.cpp60
5 files changed, 78 insertions, 35 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 22819f1..b2805a8 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -2204,7 +2204,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VT_NAME = 4,
VT_SHAPE = 6,
VT_TYPE = 8,
- VT_DATA = 10
+ VT_DATA = 10,
+ VT_VARIABLE = 12
};
const ::flatbuffers::String *name() const {
return GetPointer<const ::flatbuffers::String *>(VT_NAME);
@@ -2218,6 +2219,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
const ::flatbuffers::Vector<uint8_t> *data() const {
return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA);
}
+ bool variable() const {
+ return GetField<uint8_t>(VT_VARIABLE, 0) != 0;
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NAME) &&
@@ -2227,6 +2231,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VerifyField<uint32_t>(verifier, VT_TYPE, 4) &&
VerifyOffset(verifier, VT_DATA) &&
verifier.VerifyVector(data()) &&
+ VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) &&
verifier.EndTable();
}
};
@@ -2247,6 +2252,9 @@ struct TosaTensorBuilder {
void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data) {
fbb_.AddOffset(TosaTensor::VT_DATA, data);
}
+ void add_variable(bool variable) {
+ fbb_.AddElement<uint8_t>(TosaTensor::VT_VARIABLE, static_cast<uint8_t>(variable), 0);
+ }
explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2263,12 +2271,14 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensor(
::flatbuffers::Offset<::flatbuffers::String> name = 0,
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0,
tosa::DType type = tosa::DType_UNKNOWN,
- ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0) {
+ ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0,
+ bool variable = false) {
TosaTensorBuilder builder_(_fbb);
builder_.add_data(data);
builder_.add_type(type);
builder_.add_shape(shape);
builder_.add_name(name);
+ builder_.add_variable(variable);
return builder_.Finish();
}
@@ -2277,7 +2287,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
const char *name = nullptr,
const std::vector<int32_t> *shape = nullptr,
tosa::DType type = tosa::DType_UNKNOWN,
- const std::vector<uint8_t> *data = nullptr) {
+ const std::vector<uint8_t> *data = nullptr,
+ bool variable = false) {
auto name__ = name ? _fbb.CreateString(name) : 0;
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); }
@@ -2287,7 +2298,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
name__,
shape__,
type,
- data__);
+ data__,
+ variable);
}
struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index cae6a27..bf44c11 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -112,11 +112,13 @@ public:
TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data);
+ const flatbuffers::Vector<uint8_t>* data,
+ bool variable = false);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data);
+ const std::vector<uint8_t>& data,
+ bool variable = false);
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -129,10 +131,14 @@ public:
{
return _shape;
}
- DType GetDtype()
+ DType GetDtype() const
{
return _dtype;
}
+ bool GetVariable() const
+ {
+ return _variable;
+ }
const std::vector<uint8_t>& GetData() const
{
return _data;
@@ -169,6 +175,7 @@ private:
DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
std::vector<int32_t> _shape; /* shape of the tensor */
std::string _name; /* name of the tensor, used for solving dependency */
+ bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
};
@@ -368,6 +375,8 @@ public:
static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
+ static void ForceAlignTensorData(std::vector<uint8_t>& buf);
+
// version
const TosaVersion& GetVersion()
{
diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py
index 850ff8f..d8264f2 100644
--- a/python/tosa/TosaTensor.py
+++ b/python/tosa/TosaTensor.py
@@ -96,8 +96,15 @@ class TosaTensor(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
return o == 0
+ # TosaTensor
+ def Variable(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
def TosaTensorStart(builder):
- builder.StartObject(4)
+ builder.StartObject(5)
def Start(builder):
TosaTensorStart(builder)
@@ -138,6 +145,12 @@ def TosaTensorStartDataVector(builder, numElems):
def StartDataVector(builder, numElems: int) -> int:
return TosaTensorStartDataVector(builder, numElems)
+def TosaTensorAddVariable(builder, variable):
+ builder.PrependBoolSlot(4, variable, 0)
+
+def AddVariable(builder, variable):
+ TosaTensorAddVariable(builder, variable)
+
def TosaTensorEnd(builder):
return builder.EndObject()
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index f101fa3..0943f11 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -281,6 +281,7 @@ table TosaTensor {
shape:[int32]; // shape of the tensor
type:DType; // data type of the tensor
data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor.
+ variable: bool; // is this a variable tensor
}
table TosaOperator {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index cbb862f..cb44f17 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,10 +22,11 @@ using namespace tosa;
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data)
+ const flatbuffers::Vector<uint8_t>* data,
+ bool variable)
{
- _dtype = dtype;
-
+ _dtype = dtype;
+ _variable = variable;
if (shape)
{
std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
@@ -43,18 +44,21 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data)
+ const std::vector<uint8_t>& data,
+ bool variable)
{
- _dtype = dtype;
- _shape = shape;
- _name = name;
- _data = data;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
}
TosaSerializationTensor::TosaSerializationTensor()
{
- _dtype = DType_UNKNOWN;
- _name = "UNKNOWN";
+ _dtype = DType_UNKNOWN;
+ _variable = false;
+ _name = "UNKNOWN";
}
TosaSerializationTensor::~TosaSerializationTensor()
@@ -514,12 +518,14 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
{
auto curr_tensor = fb_tosa_tensors->Get(j);
- auto tensor_name = curr_tensor->name();
- auto tensor_shape = curr_tensor->shape();
- auto tensor_type = curr_tensor->type();
- auto tensor_data = curr_tensor->data();
+ auto tensor_name = curr_tensor->name();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_variable = curr_tensor->variable();
+ auto tensor_data = curr_tensor->data();
- new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data);
+ new_tensor =
+ new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable);
if (new_tensor)
{
block_tensors_container.push_back(new_tensor);
@@ -676,8 +682,10 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
auto tensor_shape = _builder.CreateVector(tensor->GetShape());
auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data);
+ auto fboffset_tensor =
+ CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
@@ -702,7 +710,7 @@ tosa_err_t TosaSerializationHandler::Serialize()
return TOSA_OK;
}
-void zero_pad(std::vector<uint8_t>& buf)
+void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
{
while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0)
{
@@ -721,7 +729,7 @@ tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in
out.push_back(*val_u16 & 0xFF);
out.push_back((*val_u16 >> 8) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -736,7 +744,7 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -753,7 +761,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>&
out.push_back((*val_u64 >> 32) & 0xFF);
out.push_back((*val_u64 >> 40) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -768,7 +776,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>&
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -781,7 +789,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>&
out.push_back(*val_u16 & 0xFF);
out.push_back((*val_u16 >> 8) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -793,7 +801,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in
uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
out.push_back(*val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -824,7 +832,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in
uint8_t val_u8 = static_cast<uint8_t>(val_packed);
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -836,7 +844,7 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
uint8_t val_u8 = val;
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}