diff options
author | Tai Ly <tai.ly@arm.com> | 2023-08-21 17:00:29 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-08-23 15:12:30 +0000 |
commit | c6939a4d269968a34b0ae0aa579f0f0736aaeccc (patch) | |
tree | 8925f10184fe6f4c8c179ad34d04a24d9e304bd7 /include/tosa_serialization_handler.h | |
parent | 442261bf67fa2ec4d86ed3e431a6373787b3e35a (diff) | |
download | serialization_lib-c6939a4d269968a34b0ae0aa579f0f0736aaeccc.tar.gz |
Add is_unranked flag to Tensors
This adds a bool field, is_unranked to TosaTensor in tosa.fbs
to indicate whether a tensor with shape = {} is an unranked tensor
or an 0-D tensor.
For older tosa files without this field, the default value is false.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r-- | include/tosa_serialization_handler.h | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index bf44c11..24c77a6 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -113,12 +113,14 @@ public: const flatbuffers::Vector<int32_t>* shape, DType dtype, const flatbuffers::Vector<uint8_t>* data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, const std::vector<uint8_t>& data, - bool variable = false); + const bool variable = false, + const bool is_unranked = false); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -143,6 +145,10 @@ public: { return _data; } + const bool GetIsUnranked() const + { + return _is_unranked; + } // modifier void SetDtype(DType dtype) @@ -161,6 +167,10 @@ public: { _data = std::move(data); } + void SetIsUnranked(const bool value) + { + _is_unranked = value; + } void SetDimSize(size_t dim, uint32_t new_size) { if (dim < 0 || dim >= _shape.size()) @@ -177,6 +187,7 @@ private: std::string _name; /* name of the tensor, used for solving dependency */ bool _variable; /* is this a variable tensor */ std::vector<uint8_t> _data; /* data array */ + bool _is_unranked; /* whether this is an unranked tensor */ }; class TosaSerializationOperator |