aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h15
1 files changed, 11 insertions, 4 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 814395b..e5448bc 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -150,14 +150,16 @@ public:
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- const bool variable = false,
- const bool is_unranked = false);
+ const bool variable = false,
+ const bool is_unranked = false,
+ const flatbuffers::String* variable_name = NULL);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- const bool variable = false,
- const bool is_unranked = false);
+ const bool variable = false,
+ const bool is_unranked = false,
+ const std::string& variable_name = "");
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -186,6 +188,10 @@ public:
{
return _is_unranked;
}
+ const std::string GetVariableName() const
+ {
+ return _variable_name;
+ }
// modifier
void SetDtype(DType dtype)
@@ -225,6 +231,7 @@ private:
bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
bool _is_unranked; /* whether this is an unranked tensor */
+ std::string _variable_name; /* name for variable tensors */
};
class TosaSerializationOperator