From 442261bf67fa2ec4d86ed3e431a6373787b3e35a Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Fri, 9 Sep 2022 13:38:56 -0700 Subject: [Serialization_lib] Support StatefulOps for TOSA - Add variable in TosaTensor to schema file - Update TosaSerializationTensor regarding variable change - Rename internal zero_pad() and expose interface as ForceAlignTensorData() Signed-off-by: Jerry Ge Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539 --- include/tosa_generated.h | 20 +++++++++--- include/tosa_serialization_handler.h | 17 +++++++--- python/tosa/TosaTensor.py | 15 ++++++++- schema/tosa.fbs | 1 + src/tosa_serialization_handler.cpp | 60 ++++++++++++++++++++---------------- 5 files changed, 78 insertions(+), 35 deletions(-) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 22819f1..b2805a8 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -2204,7 +2204,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_NAME = 4, VT_SHAPE = 6, VT_TYPE = 8, - VT_DATA = 10 + VT_DATA = 10, + VT_VARIABLE = 12 }; const ::flatbuffers::String *name() const { return GetPointer(VT_NAME); @@ -2218,6 +2219,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector *data() const { return GetPointer *>(VT_DATA); } + bool variable() const { + return GetField(VT_VARIABLE, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && @@ -2227,6 +2231,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyField(verifier, VT_TYPE, 4) && VerifyOffset(verifier, VT_DATA) && verifier.VerifyVector(data()) && + VerifyField(verifier, VT_VARIABLE, 1) && verifier.EndTable(); } }; @@ -2247,6 +2252,9 @@ struct TosaTensorBuilder { void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { fbb_.AddOffset(TosaTensor::VT_DATA, data); } + void add_variable(bool variable) { + fbb_.AddElement(TosaTensor::VT_VARIABLE, static_cast(variable), 0); + } explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2263,12 +2271,14 @@ inline ::flatbuffers::Offset CreateTosaTensor( ::flatbuffers::Offset<::flatbuffers::String> name = 0, ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, tosa::DType type = tosa::DType_UNKNOWN, - ::flatbuffers::Offset<::flatbuffers::Vector> data = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0, + bool variable = false) { TosaTensorBuilder builder_(_fbb); builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); builder_.add_name(name); + builder_.add_variable(variable); return builder_.Finish(); } @@ -2277,7 +2287,8 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( const char *name = nullptr, const std::vector *shape = nullptr, tosa::DType type = tosa::DType_UNKNOWN, - const std::vector *data = nullptr) { + const std::vector *data = nullptr, + bool variable = false) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } @@ -2287,7 +2298,8 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( name__, shape__, type, - data__); + data__, + variable); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index cae6a27..bf44c11 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -112,11 +112,13 @@ public: TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, DType dtype, - const flatbuffers::Vector* data); + const flatbuffers::Vector* data, + bool variable = false); TosaSerializationTensor(const std::string& name, const std::vector& shape, DType dtype, - const std::vector& data); + const std::vector& data, + bool variable = false); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -129,10 +131,14 @@ public: { return _shape; } - DType GetDtype() + DType GetDtype() const { return _dtype; } + bool GetVariable() const + { + return _variable; + } const std::vector& GetData() const { return _data; @@ -169,6 +175,7 @@ private: DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ std::vector _shape; /* shape of the tensor */ std::string _name; /* name of the tensor, used for solving dependency */ + bool _variable; /* is this a variable tensor */ std::vector _data; /* data array */ }; @@ -368,6 +375,8 @@ public: static tosa_err_t ConvertU8toI4(const std::vector& in, uint32_t out_size, std::vector& out); static tosa_err_t ConvertU8toBool(const std::vector& in, uint32_t out_size, std::vector& out); + static void ForceAlignTensorData(std::vector& buf); + // version const TosaVersion& GetVersion() { diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py index 850ff8f..d8264f2 100644 --- a/python/tosa/TosaTensor.py +++ b/python/tosa/TosaTensor.py @@ -96,8 +96,15 @@ class TosaTensor(object): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) return o == 0 + # TosaTensor + def Variable(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def TosaTensorStart(builder): - builder.StartObject(4) + builder.StartObject(5) def Start(builder): TosaTensorStart(builder) @@ -138,6 +145,12 @@ def TosaTensorStartDataVector(builder, numElems): def StartDataVector(builder, numElems: int) -> int: return TosaTensorStartDataVector(builder, numElems) +def TosaTensorAddVariable(builder, variable): + builder.PrependBoolSlot(4, variable, 0) + +def AddVariable(builder, variable): + TosaTensorAddVariable(builder, variable) + def TosaTensorEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index f101fa3..0943f11 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -281,6 +281,7 @@ table TosaTensor { shape:[int32]; // shape of the tensor type:DType; // data type of the tensor data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor. + variable: bool; // is this a variable tensor } table TosaOperator { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index cbb862f..cb44f17 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,10 +22,11 @@ using namespace tosa; TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector* shape, DType dtype, - const flatbuffers::Vector* data) + const flatbuffers::Vector* data, + bool variable) { - _dtype = dtype; - + _dtype = dtype; + _variable = variable; if (shape) { std::copy(shape->begin(), shape->end(), std::back_inserter(_shape)); @@ -43,18 +44,21 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name TosaSerializationTensor::TosaSerializationTensor(const std::string& name, const std::vector& shape, DType dtype, - const std::vector& data) + const std::vector& data, + bool variable) { - _dtype = dtype; - _shape = shape; - _name = name; - _data = data; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; } TosaSerializationTensor::TosaSerializationTensor() { - _dtype = DType_UNKNOWN; - _name = "UNKNOWN"; + _dtype = DType_UNKNOWN; + _variable = false; + _name = "UNKNOWN"; } TosaSerializationTensor::~TosaSerializationTensor() @@ -514,12 +518,14 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_tensor = fb_tosa_tensors->Get(j); - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_data = curr_tensor->data(); + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_variable = curr_tensor->variable(); + auto tensor_data = curr_tensor->data(); - new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data); + new_tensor = + new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable); if (new_tensor) { block_tensors_container.push_back(new_tensor); @@ -676,8 +682,10 @@ tosa_err_t TosaSerializationHandler::Serialize() auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); auto tensor_shape = _builder.CreateVector(tensor->GetShape()); auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data); + auto fboffset_tensor = + CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); @@ -702,7 +710,7 @@ tosa_err_t TosaSerializationHandler::Serialize() return TOSA_OK; } -void zero_pad(std::vector& buf) +void TosaSerializationHandler::ForceAlignTensorData(std::vector& buf) { while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0) { @@ -721,7 +729,7 @@ tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector& in out.push_back(*val_u16 & 0xFF); out.push_back((*val_u16 >> 8) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -736,7 +744,7 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector& in out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -753,7 +761,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector& out.push_back((*val_u64 >> 32) & 0xFF); out.push_back((*val_u64 >> 40) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -768,7 +776,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector& out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -781,7 +789,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector& out.push_back(*val_u16 & 0xFF); out.push_back((*val_u16 >> 8) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -793,7 +801,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector& in uint8_t* val_u8 = reinterpret_cast(&val); out.push_back(*val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -824,7 +832,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector& in uint8_t val_u8 = static_cast(val_packed); out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -836,7 +844,7 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector& in uint8_t val_u8 = val; out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } -- cgit v1.2.1