aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-09-19 21:30:18 +0000
committerTai Ly <tai.ly@arm.com>2023-09-19 21:55:11 +0000
commitd0520b9b3a0eaf9dadc6cdb57ed42906e577d32e (patch)
tree1538ac688cbafa4e8e08b7aabe3e10bbe404c767
parent924f3094a745c1955d51fce18b488adfed5ee76b (diff)
downloadserialization_lib-d0520b9b3a0eaf9dadc6cdb57ed42906e577d32e.tar.gz
Add variable_name to tensors
Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia142c8b1a9e0869daefb3eef71100fd0c2a0effc
-rw-r--r--include/tosa_generated.h22
-rw-r--r--include/tosa_serialization_handler.h15
-rw-r--r--python/tosa/TosaTensor.py15
-rw-r--r--schema/tosa.fbs1
-rw-r--r--src/tosa_serialization_handler.cpp39
5 files changed, 68 insertions, 24 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index b07fa8f..a81ff9c 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -2230,7 +2230,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VT_TYPE = 8,
VT_DATA = 10,
VT_VARIABLE = 12,
- VT_IS_UNRANKED = 14
+ VT_IS_UNRANKED = 14,
+ VT_VARIABLE_NAME = 16
};
const ::flatbuffers::String *name() const {
return GetPointer<const ::flatbuffers::String *>(VT_NAME);
@@ -2250,6 +2251,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
bool is_unranked() const {
return GetField<uint8_t>(VT_IS_UNRANKED, 0) != 0;
}
+ const ::flatbuffers::String *variable_name() const {
+ return GetPointer<const ::flatbuffers::String *>(VT_VARIABLE_NAME);
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NAME) &&
@@ -2261,6 +2265,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
verifier.VerifyVector(data()) &&
VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) &&
VerifyField<uint8_t>(verifier, VT_IS_UNRANKED, 1) &&
+ VerifyOffset(verifier, VT_VARIABLE_NAME) &&
+ verifier.VerifyString(variable_name()) &&
verifier.EndTable();
}
};
@@ -2287,6 +2293,9 @@ struct TosaTensorBuilder {
void add_is_unranked(bool is_unranked) {
fbb_.AddElement<uint8_t>(TosaTensor::VT_IS_UNRANKED, static_cast<uint8_t>(is_unranked), 0);
}
+ void add_variable_name(::flatbuffers::Offset<::flatbuffers::String> variable_name) {
+ fbb_.AddOffset(TosaTensor::VT_VARIABLE_NAME, variable_name);
+ }
explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2305,8 +2314,10 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensor(
tosa::DType type = tosa::DType_UNKNOWN,
::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0,
bool variable = false,
- bool is_unranked = false) {
+ bool is_unranked = false,
+ ::flatbuffers::Offset<::flatbuffers::String> variable_name = 0) {
TosaTensorBuilder builder_(_fbb);
+ builder_.add_variable_name(variable_name);
builder_.add_data(data);
builder_.add_type(type);
builder_.add_shape(shape);
@@ -2323,11 +2334,13 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
tosa::DType type = tosa::DType_UNKNOWN,
const std::vector<uint8_t> *data = nullptr,
bool variable = false,
- bool is_unranked = false) {
+ bool is_unranked = false,
+ const char *variable_name = nullptr) {
auto name__ = name ? _fbb.CreateString(name) : 0;
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); }
auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0;
+ auto variable_name__ = variable_name ? _fbb.CreateString(variable_name) : 0;
return tosa::CreateTosaTensor(
_fbb,
name__,
@@ -2335,7 +2348,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
type,
data__,
variable,
- is_unranked);
+ is_unranked,
+ variable_name__);
}
struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 814395b..e5448bc 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -150,14 +150,16 @@ public:
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- const bool variable = false,
- const bool is_unranked = false);
+ const bool variable = false,
+ const bool is_unranked = false,
+ const flatbuffers::String* variable_name = NULL);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- const bool variable = false,
- const bool is_unranked = false);
+ const bool variable = false,
+ const bool is_unranked = false,
+ const std::string& variable_name = "");
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -186,6 +188,10 @@ public:
{
return _is_unranked;
}
+ const std::string GetVariableName() const
+ {
+ return _variable_name;
+ }
// modifier
void SetDtype(DType dtype)
@@ -225,6 +231,7 @@ private:
bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
bool _is_unranked; /* whether this is an unranked tensor */
+ std::string _variable_name; /* name for variable tensors */
};
class TosaSerializationOperator
diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py
index 6613796..3fb9f86 100644
--- a/python/tosa/TosaTensor.py
+++ b/python/tosa/TosaTensor.py
@@ -110,8 +110,15 @@ class TosaTensor(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
+ # TosaTensor
+ def VariableName(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
+ return self._tab.String(o + self._tab.Pos)
+ return None
+
def TosaTensorStart(builder):
- builder.StartObject(6)
+ builder.StartObject(7)
def Start(builder):
TosaTensorStart(builder)
@@ -164,6 +171,12 @@ def TosaTensorAddIsUnranked(builder, isUnranked):
def AddIsUnranked(builder, isUnranked):
TosaTensorAddIsUnranked(builder, isUnranked)
+def TosaTensorAddVariableName(builder, variableName):
+ builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(variableName), 0)
+
+def AddVariableName(builder, variableName):
+ TosaTensorAddVariableName(builder, variableName)
+
def TosaTensorEnd(builder):
return builder.EndObject()
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index c2f834f..431efb4 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -285,6 +285,7 @@ table TosaTensor {
data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor.
variable: bool; // is this a variable tensor
is_unranked: bool; // whether this is an unranked tensor
+ variable_name:string; // name for variable attribute
}
table TosaOperator {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 7e96313..f96ff60 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -24,7 +24,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
const bool variable,
- const bool is_unranked)
+ const bool is_unranked,
+ const flatbuffers::String* variable_name)
{
_dtype = dtype;
_variable = variable;
@@ -41,6 +42,11 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
std::copy(data->begin(), data->end(), std::back_inserter(_data));
}
_is_unranked = is_unranked;
+
+ if (variable_name)
+ {
+ _variable_name = variable_name->str();
+ }
}
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
@@ -48,14 +54,16 @@ TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
DType dtype,
const std::vector<uint8_t>& data,
const bool variable,
- const bool is_unranked)
+ const bool is_unranked,
+ const std::string& variable_name)
{
- _dtype = dtype;
- _variable = variable;
- _shape = shape;
- _name = name;
- _data = data;
- _is_unranked = is_unranked;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
+ _is_unranked = is_unranked;
+ _variable_name = variable_name;
}
TosaSerializationTensor::TosaSerializationTensor()
@@ -697,14 +705,15 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
for (auto tensor : block->GetTensors())
{
- auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
- auto tensor_shape = _builder.CreateVector(tensor->GetShape());
- auto tensor_dtype = tensor->GetDtype();
- bool tensor_variable = tensor->GetVariable();
- auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
+ auto tensor_shape = _builder.CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
+ auto tensor_data = _builder.CreateVector(tensor->GetData());
+ auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str());
auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data,
- tensor_variable, tensor_is_unranked);
+ tensor_variable, tensor_is_unranked, tensor_variable_name);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);