From d0520b9b3a0eaf9dadc6cdb57ed42906e577d32e Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 19 Sep 2023 21:30:18 +0000 Subject: Add variable_name to tensors Signed-off-by: Tai Ly Change-Id: Ia142c8b1a9e0869daefb3eef71100fd0c2a0effc --- include/tosa_generated.h | 22 ++++++++++++++++---- include/tosa_serialization_handler.h | 15 ++++++++++---- python/tosa/TosaTensor.py | 15 +++++++++++++- schema/tosa.fbs | 1 + src/tosa_serialization_handler.cpp | 39 ++++++++++++++++++++++-------------- 5 files changed, 68 insertions(+), 24 deletions(-) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index b07fa8f..a81ff9c 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -2230,7 +2230,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_TYPE = 8, VT_DATA = 10, VT_VARIABLE = 12, - VT_IS_UNRANKED = 14 + VT_IS_UNRANKED = 14, + VT_VARIABLE_NAME = 16 }; const ::flatbuffers::String *name() const { return GetPointer(VT_NAME); @@ -2250,6 +2251,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool is_unranked() const { return GetField(VT_IS_UNRANKED, 0) != 0; } + const ::flatbuffers::String *variable_name() const { + return GetPointer(VT_VARIABLE_NAME); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && @@ -2261,6 +2265,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { verifier.VerifyVector(data()) && VerifyField(verifier, VT_VARIABLE, 1) && VerifyField(verifier, VT_IS_UNRANKED, 1) && + VerifyOffset(verifier, VT_VARIABLE_NAME) && + verifier.VerifyString(variable_name()) && verifier.EndTable(); } }; @@ -2287,6 +2293,9 @@ struct TosaTensorBuilder { void add_is_unranked(bool is_unranked) { fbb_.AddElement(TosaTensor::VT_IS_UNRANKED, static_cast(is_unranked), 0); } + void add_variable_name(::flatbuffers::Offset<::flatbuffers::String> variable_name) { + fbb_.AddOffset(TosaTensor::VT_VARIABLE_NAME, variable_name); + } explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2305,8 +2314,10 @@ inline ::flatbuffers::Offset CreateTosaTensor( tosa::DType type = tosa::DType_UNKNOWN, ::flatbuffers::Offset<::flatbuffers::Vector> data = 0, bool variable = false, - bool is_unranked = false) { + bool is_unranked = false, + ::flatbuffers::Offset<::flatbuffers::String> variable_name = 0) { TosaTensorBuilder builder_(_fbb); + builder_.add_variable_name(variable_name); builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); @@ -2323,11 +2334,13 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( tosa::DType type = tosa::DType_UNKNOWN, const std::vector *data = nullptr, bool variable = false, - bool is_unranked = false) { + bool is_unranked = false, + const char *variable_name = nullptr) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } auto data__ = data ? _fbb.CreateVector(*data) : 0; + auto variable_name__ = variable_name ? _fbb.CreateString(variable_name) : 0; return tosa::CreateTosaTensor( _fbb, name__, @@ -2335,7 +2348,8 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( type, data__, variable, - is_unranked); + is_unranked, + variable_name__); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 814395b..e5448bc 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -150,14 +150,16 @@ public: const flatbuffers::Vector* shape, DType dtype, const flatbuffers::Vector* data, - const bool variable = false, - const bool is_unranked = false); + const bool variable = false, + const bool is_unranked = false, + const flatbuffers::String* variable_name = NULL); TosaSerializationTensor(const std::string& name, const std::vector& shape, DType dtype, const std::vector& data, - const bool variable = false, - const bool is_unranked = false); + const bool variable = false, + const bool is_unranked = false, + const std::string& variable_name = ""); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -186,6 +188,10 @@ public: { return _is_unranked; } + const std::string GetVariableName() const + { + return _variable_name; + } // modifier void SetDtype(DType dtype) @@ -225,6 +231,7 @@ private: bool _variable; /* is this a variable tensor */ std::vector _data; /* data array */ bool _is_unranked; /* whether this is an unranked tensor */ + std::string _variable_name; /* name for variable tensors */ }; class TosaSerializationOperator diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py index 6613796..3fb9f86 100644 --- a/python/tosa/TosaTensor.py +++ b/python/tosa/TosaTensor.py @@ -110,8 +110,15 @@ class TosaTensor(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # TosaTensor + def VariableName(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16)) + if o != 0: + return self._tab.String(o + self._tab.Pos) + return None + def TosaTensorStart(builder): - builder.StartObject(6) + builder.StartObject(7) def Start(builder): TosaTensorStart(builder) @@ -164,6 +171,12 @@ def TosaTensorAddIsUnranked(builder, isUnranked): def AddIsUnranked(builder, isUnranked): TosaTensorAddIsUnranked(builder, isUnranked) +def TosaTensorAddVariableName(builder, variableName): + builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(variableName), 0) + +def AddVariableName(builder, variableName): + TosaTensorAddVariableName(builder, variableName) + def TosaTensorEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index c2f834f..431efb4 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -285,6 +285,7 @@ table TosaTensor { data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor. variable: bool; // is this a variable tensor is_unranked: bool; // whether this is an unranked tensor + variable_name:string; // name for variable attribute } table TosaOperator { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 7e96313..f96ff60 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -24,7 +24,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name DType dtype, const flatbuffers::Vector* data, const bool variable, - const bool is_unranked) + const bool is_unranked, + const flatbuffers::String* variable_name) { _dtype = dtype; _variable = variable; @@ -41,6 +42,11 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name std::copy(data->begin(), data->end(), std::back_inserter(_data)); } _is_unranked = is_unranked; + + if (variable_name) + { + _variable_name = variable_name->str(); + } } TosaSerializationTensor::TosaSerializationTensor(const std::string& name, @@ -48,14 +54,16 @@ TosaSerializationTensor::TosaSerializationTensor(const std::string& name, DType dtype, const std::vector& data, const bool variable, - const bool is_unranked) + const bool is_unranked, + const std::string& variable_name) { - _dtype = dtype; - _variable = variable; - _shape = shape; - _name = name; - _data = data; - _is_unranked = is_unranked; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; + _is_unranked = is_unranked; + _variable_name = variable_name; } TosaSerializationTensor::TosaSerializationTensor() @@ -697,14 +705,15 @@ tosa_err_t TosaSerializationHandler::Serialize() auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); for (auto tensor : block->GetTensors()) { - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - bool tensor_variable = tensor->GetVariable(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto tensor_is_unranked = tensor->GetIsUnranked(); + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto tensor_is_unranked = tensor->GetIsUnranked(); + auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str()); auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, - tensor_variable, tensor_is_unranked); + tensor_variable, tensor_is_unranked, tensor_variable_name); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); -- cgit v1.2.1