diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/tosa_generated.h | 20 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 15 |
2 files changed, 29 insertions, 6 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h index b2805a8..2995c3a 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -2205,7 +2205,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_SHAPE = 6, VT_TYPE = 8, VT_DATA = 10, - VT_VARIABLE = 12 + VT_VARIABLE = 12, + VT_IS_UNRANKED = 14 }; const ::flatbuffers::String *name() const { return GetPointer<const ::flatbuffers::String *>(VT_NAME); @@ -2222,6 +2223,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { bool variable() const { return GetField<uint8_t>(VT_VARIABLE, 0) != 0; } + bool is_unranked() const { + return GetField<uint8_t>(VT_IS_UNRANKED, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && @@ -2232,6 +2236,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyOffset(verifier, VT_DATA) && verifier.VerifyVector(data()) && VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) && + VerifyField<uint8_t>(verifier, VT_IS_UNRANKED, 1) && verifier.EndTable(); } }; @@ -2255,6 +2260,9 @@ struct TosaTensorBuilder { void add_variable(bool variable) { fbb_.AddElement<uint8_t>(TosaTensor::VT_VARIABLE, static_cast<uint8_t>(variable), 0); } + void add_is_unranked(bool is_unranked) { + fbb_.AddElement<uint8_t>(TosaTensor::VT_IS_UNRANKED, static_cast<uint8_t>(is_unranked), 0); + } explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2272,12 +2280,14 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensor( ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0, tosa::DType type = tosa::DType_UNKNOWN, ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0, - bool variable = false) { + bool variable = false, + bool is_unranked = false) { TosaTensorBuilder builder_(_fbb); builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); builder_.add_name(name); + builder_.add_is_unranked(is_unranked); builder_.add_variable(variable); return builder_.Finish(); } @@ -2288,7 +2298,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( const std::vector<int32_t> *shape = nullptr, tosa::DType type = tosa::DType_UNKNOWN, const std::vector<uint8_t> *data = nullptr, - bool variable = false) { + bool variable = false, + bool is_unranked = false) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } @@ -2299,7 +2310,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( shape__, type, data__, - variable); + variable, + is_unranked); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index bf44c11..24c77a6 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -113,12 +113,14 @@ public: const flatbuffers::Vector<int32_t>* shape, DType dtype, const flatbuffers::Vector<uint8_t>* data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, const std::vector<uint8_t>& data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -143,6 +145,10 @@ public: { return _data; } + const bool GetIsUnranked() const + { + return _is_unranked; + } // modifier void SetDtype(DType dtype) @@ -161,6 +167,10 @@ public: { _data = std::move(data); } + void SetIsUnranked(const bool value) + { + _is_unranked = value; + } void SetDimSize(size_t dim, uint32_t new_size) { if (dim < 0 || dim >= _shape.size()) @@ -177,6 +187,7 @@ private: std::string _name; /* name of the tensor, used for solving dependency */ bool _variable; /* is this a variable tensor */ std::vector<uint8_t> _data; /* data array */ + bool _is_unranked; /* whether this is an unranked tensor */ }; class TosaSerializationOperator |