aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2022-09-09 13:38:56 -0700
committerEric Kunze <eric.kunze@arm.com>2023-08-22 17:30:05 +0000
commit442261bf67fa2ec4d86ed3e431a6373787b3e35a (patch)
treed6a802db3c208cb9783cabe58d17aa1ffe12ee52 /include/tosa_serialization_handler.h
parent780ffb5f034a4fd6581a44cd9c3b1cf119f33589 (diff)
downloadserialization_lib-442261bf67fa2ec4d86ed3e431a6373787b3e35a.tar.gz
[Serialization_lib] Support StatefulOps for TOSA
- Add variable in TosaTensor to schema file - Update TosaSerializationTensor regarding variable change - Rename internal zero_pad() and expose interface as ForceAlignTensorData() Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I36fa64eb0802cb5b8d3564ea7233460ef8c9f539
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h17
1 files changed, 13 insertions, 4 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index cae6a27..bf44c11 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -112,11 +112,13 @@ public:
TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data);
+ const flatbuffers::Vector<uint8_t>* data,
+ bool variable = false);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data);
+ const std::vector<uint8_t>& data,
+ bool variable = false);
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -129,10 +131,14 @@ public:
{
return _shape;
}
- DType GetDtype()
+ DType GetDtype() const
{
return _dtype;
}
+ bool GetVariable() const
+ {
+ return _variable;
+ }
const std::vector<uint8_t>& GetData() const
{
return _data;
@@ -169,6 +175,7 @@ private:
DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
std::vector<int32_t> _shape; /* shape of the tensor */
std::string _name; /* name of the tensor, used for solving dependency */
+ bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
};
@@ -368,6 +375,8 @@ public:
static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
+ static void ForceAlignTensorData(std::vector<uint8_t>& buf);
+
// version
const TosaVersion& GetVersion()
{