From c6939a4d269968a34b0ae0aa579f0f0736aaeccc Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Mon, 21 Aug 2023 17:00:29 +0000 Subject: Add is_unranked flag to Tensors This adds a bool field, is_unranked to TosaTensor in tosa.fbs to indicate whether a tensor with shape = {} is an unranked tensor or an 0-D tensor. For older tosa files without this field, the default value is false. Signed-off-by: Tai Ly Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69 --- include/tosa_generated.h | 20 ++++++++++--- include/tosa_serialization_handler.h | 15 ++++++++-- python/tosa/TosaTensor.py | 15 +++++++++- schema/tosa.fbs | 1 + src/tosa_serialization_handler.cpp | 55 ++++++++++++++++++++---------------- 5 files changed, 75 insertions(+), 31 deletions(-) diff --git a/include/tosa_generated.h b/include/tosa_generated.h index b2805a8..2995c3a 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -2205,7 +2205,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_SHAPE = 6, VT_TYPE = 8, VT_DATA = 10, - VT_VARIABLE = 12 + VT_VARIABLE = 12, + VT_IS_UNRANKED = 14 }; const ::flatbuffers::String *name() const { return GetPointer(VT_NAME); @@ -2222,6 +2223,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool variable() const { return GetField(VT_VARIABLE, 0) != 0; } + bool is_unranked() const { + return GetField(VT_IS_UNRANKED, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && @@ -2232,6 +2236,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_DATA) && verifier.VerifyVector(data()) && VerifyField(verifier, VT_VARIABLE, 1) && + VerifyField(verifier, VT_IS_UNRANKED, 1) && verifier.EndTable(); } }; @@ -2255,6 +2260,9 @@ struct TosaTensorBuilder { void add_variable(bool variable) { fbb_.AddElement(TosaTensor::VT_VARIABLE, static_cast(variable), 0); } + void add_is_unranked(bool is_unranked) { + fbb_.AddElement(TosaTensor::VT_IS_UNRANKED, static_cast(is_unranked), 0); + } explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2272,12 +2280,14 @@ inline ::flatbuffers::Offset CreateTosaTensor( ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, tosa::DType type = tosa::DType_UNKNOWN, ::flatbuffers::Offset<::flatbuffers::Vector> data = 0, - bool variable = false) { + bool variable = false, + bool is_unranked = false) { TosaTensorBuilder builder_(_fbb); builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); builder_.add_name(name); + builder_.add_is_unranked(is_unranked); builder_.add_variable(variable); return builder_.Finish(); } @@ -2288,7 +2298,8 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( const std::vector *shape = nullptr, tosa::DType type = tosa::DType_UNKNOWN, const std::vector *data = nullptr, - bool variable = false) { + bool variable = false, + bool is_unranked = false) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } @@ -2299,7 +2310,8 @@ inline ::flatbuffers::Offset CreateTosaTensorDirect( shape__, type, data__, - variable); + variable, + is_unranked); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index bf44c11..24c77a6 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -113,12 +113,14 @@ public: const flatbuffers::Vector* shape, DType dtype, const flatbuffers::Vector* data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(const std::string& name, const std::vector& shape, DType dtype, const std::vector& data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -143,6 +145,10 @@ public: { return _data; } + const bool GetIsUnranked() const + { + return _is_unranked; + } // modifier void SetDtype(DType dtype) @@ -161,6 +167,10 @@ public: { _data = std::move(data); } + void SetIsUnranked(const bool value) + { + _is_unranked = value; + } void SetDimSize(size_t dim, uint32_t new_size) { if (dim < 0 || dim >= _shape.size()) @@ -177,6 +187,7 @@ private: std::string _name; /* name of the tensor, used for solving dependency */ bool _variable; /* is this a variable tensor */ std::vector _data; /* data array */ + bool _is_unranked; /* whether this is an unranked tensor */ }; class TosaSerializationOperator diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py index d8264f2..6613796 100644 --- a/python/tosa/TosaTensor.py +++ b/python/tosa/TosaTensor.py @@ -103,8 +103,15 @@ class TosaTensor(object): return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) return False + # TosaTensor + def IsUnranked(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + def TosaTensorStart(builder): - builder.StartObject(5) + builder.StartObject(6) def Start(builder): TosaTensorStart(builder) @@ -151,6 +158,12 @@ def TosaTensorAddVariable(builder, variable): def AddVariable(builder, variable): TosaTensorAddVariable(builder, variable) +def TosaTensorAddIsUnranked(builder, isUnranked): + builder.PrependBoolSlot(5, isUnranked, 0) + +def AddIsUnranked(builder, isUnranked): + TosaTensorAddIsUnranked(builder, isUnranked) + def TosaTensorEnd(builder): return builder.EndObject() diff --git a/schema/tosa.fbs b/schema/tosa.fbs index 0943f11..9033351 100644 --- a/schema/tosa.fbs +++ b/schema/tosa.fbs @@ -282,6 +282,7 @@ table TosaTensor { type:DType; // data type of the tensor data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor. variable: bool; // is this a variable tensor + is_unranked: bool; // whether this is an unranked tensor } table TosaOperator { diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index cb44f17..3620c16 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -23,7 +23,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name const flatbuffers::Vector* shape, DType dtype, const flatbuffers::Vector* data, - bool variable) + const bool variable, + const bool is_unranked) { _dtype = dtype; _variable = variable; @@ -39,26 +40,30 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name { std::copy(data->begin(), data->end(), std::back_inserter(_data)); } + _is_unranked = is_unranked; } TosaSerializationTensor::TosaSerializationTensor(const std::string& name, const std::vector& shape, DType dtype, const std::vector& data, - bool variable) + const bool variable, + const bool is_unranked) { - _dtype = dtype; - _variable = variable; - _shape = shape; - _name = name; - _data = data; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; + _is_unranked = is_unranked; } TosaSerializationTensor::TosaSerializationTensor() { - _dtype = DType_UNKNOWN; - _variable = false; - _name = "UNKNOWN"; + _dtype = DType_UNKNOWN; + _variable = false; + _name = "UNKNOWN"; + _is_unranked = false; } TosaSerializationTensor::~TosaSerializationTensor() @@ -518,14 +523,15 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_tensor = fb_tosa_tensors->Get(j); - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_variable = curr_tensor->variable(); - auto tensor_data = curr_tensor->data(); + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_variable = curr_tensor->variable(); + auto tensor_data = curr_tensor->data(); + auto tensor_is_unranked = curr_tensor->is_unranked(); - new_tensor = - new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable); + new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, + tensor_variable, tensor_is_unranked); if (new_tensor) { block_tensors_container.push_back(new_tensor); @@ -679,13 +685,14 @@ tosa_err_t TosaSerializationHandler::Serialize() auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); for (auto tensor : block->GetTensors()) { - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - bool tensor_variable = tensor->GetVariable(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto fboffset_tensor = - CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable); + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto tensor_is_unranked = tensor->GetIsUnranked(); + auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, + tensor_variable, tensor_is_unranked); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); -- cgit v1.2.1