aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-21 17:00:29 +0000
committerTai Ly <tai.ly@arm.com>2023-08-23 15:12:30 +0000
commitc6939a4d269968a34b0ae0aa579f0f0736aaeccc (patch)
tree8925f10184fe6f4c8c179ad34d04a24d9e304bd7 /include/tosa_serialization_handler.h
parent442261bf67fa2ec4d86ed3e431a6373787b3e35a (diff)
downloadserialization_lib-c6939a4d269968a34b0ae0aa579f0f0736aaeccc.tar.gz
Add is_unranked flag to Tensors
This adds a bool field, is_unranked to TosaTensor in tosa.fbs to indicate whether a tensor with shape = {} is an unranked tensor or an 0-D tensor. For older tosa files without this field, the default value is false. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h15
1 files changed, 13 insertions, 2 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index bf44c11..24c77a6 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -113,12 +113,14 @@ public:
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- bool variable = false);
+ const bool variable = false,
+ const bool is_unranked = false);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- bool variable = false);
+ const bool variable = false,
+ const bool is_unranked = false);
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -143,6 +145,10 @@ public:
{
return _data;
}
+ const bool GetIsUnranked() const
+ {
+ return _is_unranked;
+ }
// modifier
void SetDtype(DType dtype)
@@ -161,6 +167,10 @@ public:
{
_data = std::move(data);
}
+ void SetIsUnranked(const bool value)
+ {
+ _is_unranked = value;
+ }
void SetDimSize(size_t dim, uint32_t new_size)
{
if (dim < 0 || dim >= _shape.size())
@@ -177,6 +187,7 @@ private:
std::string _name; /* name of the tensor, used for solving dependency */
bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
+ bool _is_unranked; /* whether this is an unranked tensor */
};
class TosaSerializationOperator