diff options
author | Jerry Ge <jerry.ge@arm.com> | 2022-09-09 13:38:56 -0700 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-08-22 17:30:05 +0000 |
commit | 442261bf67fa2ec4d86ed3e431a6373787b3e35a (patch) | |
tree | d6a802db3c208cb9783cabe58d17aa1ffe12ee52 /include | |
parent | 780ffb5f034a4fd6581a44cd9c3b1cf119f33589 (diff) | |
download | serialization_lib-442261bf67fa2ec4d86ed3e431a6373787b3e35a.tar.gz |
[Serialization_lib] Support StatefulOps for TOSA
- Add variable in TosaTensor to schema file
- Update TosaSerializationTensor regarding variable change
- Rename internal zero_pad() and expose interface as ForceAlignTensorData()
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539
Diffstat (limited to 'include')
-rw-r--r-- | include/tosa_generated.h | 20 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 17 |
2 files changed, 29 insertions, 8 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 22819f1..b2805a8 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -2204,7 +2204,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_NAME = 4, VT_SHAPE = 6, VT_TYPE = 8, - VT_DATA = 10 + VT_DATA = 10, + VT_VARIABLE = 12 }; const ::flatbuffers::String *name() const { return GetPointer<const ::flatbuffers::String *>(VT_NAME); @@ -2218,6 +2219,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::Vector<uint8_t> *data() const { return GetPointer<const ::flatbuffers::Vector<uint8_t> *>(VT_DATA); } + bool variable() const { + return GetField<uint8_t>(VT_VARIABLE, 0) != 0; + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NAME) && @@ -2227,6 +2231,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyField<uint32_t>(verifier, VT_TYPE, 4) && VerifyOffset(verifier, VT_DATA) && verifier.VerifyVector(data()) && + VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) && verifier.EndTable(); } }; @@ -2247,6 +2252,9 @@ struct TosaTensorBuilder { void add_data(::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data) { fbb_.AddOffset(TosaTensor::VT_DATA, data); } + void add_variable(bool variable) { + fbb_.AddElement<uint8_t>(TosaTensor::VT_VARIABLE, static_cast<uint8_t>(variable), 0); + } explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2263,12 +2271,14 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensor( ::flatbuffers::Offset<::flatbuffers::String> name = 0, ::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0, tosa::DType type = tosa::DType_UNKNOWN, - ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0) { + ::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0, + bool variable = false) { TosaTensorBuilder builder_(_fbb); builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); builder_.add_name(name); + builder_.add_variable(variable); return builder_.Finish(); } @@ -2277,7 +2287,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( const char *name = nullptr, const std::vector<int32_t> *shape = nullptr, tosa::DType type = tosa::DType_UNKNOWN, - const std::vector<uint8_t> *data = nullptr) { + const std::vector<uint8_t> *data = nullptr, + bool variable = false) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } @@ -2287,7 +2298,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( name__, shape__, type, - data__); + data__, + variable); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index cae6a27..bf44c11 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -112,11 +112,13 @@ public: TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, - const flatbuffers::Vector<uint8_t>* data); + const flatbuffers::Vector<uint8_t>* data, + bool variable = false); TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, - const std::vector<uint8_t>& data); + const std::vector<uint8_t>& data, + bool variable = false); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -129,10 +131,14 @@ public: { return _shape; } - DType GetDtype() + DType GetDtype() const { return _dtype; } + bool GetVariable() const + { + return _variable; + } const std::vector<uint8_t>& GetData() const { return _data; @@ -169,6 +175,7 @@ private: DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ std::vector<int32_t> _shape; /* shape of the tensor */ std::string _name; /* name of the tensor, used for solving dependency */ + bool _variable; /* is this a variable tensor */ std::vector<uint8_t> _data; /* data array */ }; @@ -368,6 +375,8 @@ public: static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out); static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out); + static void ForceAlignTensorData(std::vector<uint8_t>& buf); + // version const TosaVersion& GetVersion() { |