aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-21 17:00:29 +0000
committerTai Ly <tai.ly@arm.com>2023-08-23 15:12:30 +0000
commitc6939a4d269968a34b0ae0aa579f0f0736aaeccc (patch)
tree8925f10184fe6f4c8c179ad34d04a24d9e304bd7
parent442261bf67fa2ec4d86ed3e431a6373787b3e35a (diff)
downloadserialization_lib-c6939a4d269968a34b0ae0aa579f0f0736aaeccc.tar.gz
Add is_unranked flag to Tensors
This adds a bool field, is_unranked to TosaTensor in tosa.fbs to indicate whether a tensor with shape = {} is an unranked tensor or an 0-D tensor. For older tosa files without this field, the default value is false. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I86950050b522565509863c483cd3a3c1c50f8f69
-rw-r--r--include/tosa_generated.h20
-rw-r--r--include/tosa_serialization_handler.h15
-rw-r--r--python/tosa/TosaTensor.py15
-rw-r--r--schema/tosa.fbs1
-rw-r--r--src/tosa_serialization_handler.cpp55
5 files changed, 75 insertions, 31 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index b2805a8..2995c3a 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -2205,7 +2205,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VT_SHAPE = 6,
VT_TYPE = 8,
VT_DATA = 10,
- VT_VARIABLE = 12
+ VT_VARIABLE = 12,
+ VT_IS_UNRANKED = 14
};
const ::flatbuffers::String *name() const {
return GetPointer<const ::flatbuffers::String *>(VT_NAME);
@@ -2222,6 +2223,9 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
bool variable() const {
return GetField<uint8_t>(VT_VARIABLE, 0) != 0;
}
+ bool is_unranked() const {
+ return GetField<uint8_t>(VT_IS_UNRANKED, 0) != 0;
+ }
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NAME) &&
@@ -2232,6 +2236,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VerifyOffset(verifier, VT_DATA) &&
verifier.VerifyVector(data()) &&
VerifyField<uint8_t>(verifier, VT_VARIABLE, 1) &&
+ VerifyField<uint8_t>(verifier, VT_IS_UNRANKED, 1) &&
verifier.EndTable();
}
};
@@ -2255,6 +2260,9 @@ struct TosaTensorBuilder {
void add_variable(bool variable) {
fbb_.AddElement<uint8_t>(TosaTensor::VT_VARIABLE, static_cast<uint8_t>(variable), 0);
}
+ void add_is_unranked(bool is_unranked) {
+ fbb_.AddElement<uint8_t>(TosaTensor::VT_IS_UNRANKED, static_cast<uint8_t>(is_unranked), 0);
+ }
explicit TosaTensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2272,12 +2280,14 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensor(
::flatbuffers::Offset<::flatbuffers::Vector<int32_t>> shape = 0,
tosa::DType type = tosa::DType_UNKNOWN,
::flatbuffers::Offset<::flatbuffers::Vector<uint8_t>> data = 0,
- bool variable = false) {
+ bool variable = false,
+ bool is_unranked = false) {
TosaTensorBuilder builder_(_fbb);
builder_.add_data(data);
builder_.add_type(type);
builder_.add_shape(shape);
builder_.add_name(name);
+ builder_.add_is_unranked(is_unranked);
builder_.add_variable(variable);
return builder_.Finish();
}
@@ -2288,7 +2298,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
const std::vector<int32_t> *shape = nullptr,
tosa::DType type = tosa::DType_UNKNOWN,
const std::vector<uint8_t> *data = nullptr,
- bool variable = false) {
+ bool variable = false,
+ bool is_unranked = false) {
auto name__ = name ? _fbb.CreateString(name) : 0;
auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0;
if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); }
@@ -2299,7 +2310,8 @@ inline ::flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect(
shape__,
type,
data__,
- variable);
+ variable,
+ is_unranked);
}
struct TosaOperator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index bf44c11..24c77a6 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -113,12 +113,14 @@ public:
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- bool variable = false);
+ const bool variable = false,
+ const bool is_unranked = false);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- bool variable = false);
+ const bool variable = false,
+ const bool is_unranked = false);
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -143,6 +145,10 @@ public:
{
return _data;
}
+ const bool GetIsUnranked() const
+ {
+ return _is_unranked;
+ }
// modifier
void SetDtype(DType dtype)
@@ -161,6 +167,10 @@ public:
{
_data = std::move(data);
}
+ void SetIsUnranked(const bool value)
+ {
+ _is_unranked = value;
+ }
void SetDimSize(size_t dim, uint32_t new_size)
{
if (dim < 0 || dim >= _shape.size())
@@ -177,6 +187,7 @@ private:
std::string _name; /* name of the tensor, used for solving dependency */
bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
+ bool _is_unranked; /* whether this is an unranked tensor */
};
class TosaSerializationOperator
diff --git a/python/tosa/TosaTensor.py b/python/tosa/TosaTensor.py
index d8264f2..6613796 100644
--- a/python/tosa/TosaTensor.py
+++ b/python/tosa/TosaTensor.py
@@ -103,8 +103,15 @@ class TosaTensor(object):
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
+ # TosaTensor
+ def IsUnranked(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
+ return False
+
def TosaTensorStart(builder):
- builder.StartObject(5)
+ builder.StartObject(6)
def Start(builder):
TosaTensorStart(builder)
@@ -151,6 +158,12 @@ def TosaTensorAddVariable(builder, variable):
def AddVariable(builder, variable):
TosaTensorAddVariable(builder, variable)
+def TosaTensorAddIsUnranked(builder, isUnranked):
+ builder.PrependBoolSlot(5, isUnranked, 0)
+
+def AddIsUnranked(builder, isUnranked):
+ TosaTensorAddIsUnranked(builder, isUnranked)
+
def TosaTensorEnd(builder):
return builder.EndObject()
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 0943f11..9033351 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -282,6 +282,7 @@ table TosaTensor {
type:DType; // data type of the tensor
data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor.
variable: bool; // is this a variable tensor
+ is_unranked: bool; // whether this is an unranked tensor
}
table TosaOperator {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index cb44f17..3620c16 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -23,7 +23,8 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
const flatbuffers::Vector<uint8_t>* data,
- bool variable)
+ const bool variable,
+ const bool is_unranked)
{
_dtype = dtype;
_variable = variable;
@@ -39,26 +40,30 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
{
std::copy(data->begin(), data->end(), std::back_inserter(_data));
}
+ _is_unranked = is_unranked;
}
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
const std::vector<uint8_t>& data,
- bool variable)
+ const bool variable,
+ const bool is_unranked)
{
- _dtype = dtype;
- _variable = variable;
- _shape = shape;
- _name = name;
- _data = data;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
+ _is_unranked = is_unranked;
}
TosaSerializationTensor::TosaSerializationTensor()
{
- _dtype = DType_UNKNOWN;
- _variable = false;
- _name = "UNKNOWN";
+ _dtype = DType_UNKNOWN;
+ _variable = false;
+ _name = "UNKNOWN";
+ _is_unranked = false;
}
TosaSerializationTensor::~TosaSerializationTensor()
@@ -518,14 +523,15 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
{
auto curr_tensor = fb_tosa_tensors->Get(j);
- auto tensor_name = curr_tensor->name();
- auto tensor_shape = curr_tensor->shape();
- auto tensor_type = curr_tensor->type();
- auto tensor_variable = curr_tensor->variable();
- auto tensor_data = curr_tensor->data();
+ auto tensor_name = curr_tensor->name();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_variable = curr_tensor->variable();
+ auto tensor_data = curr_tensor->data();
+ auto tensor_is_unranked = curr_tensor->is_unranked();
- new_tensor =
- new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable);
+ new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data,
+ tensor_variable, tensor_is_unranked);
if (new_tensor)
{
block_tensors_container.push_back(new_tensor);
@@ -679,13 +685,14 @@ tosa_err_t TosaSerializationHandler::Serialize()
auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
for (auto tensor : block->GetTensors())
{
- auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
- auto tensor_shape = _builder.CreateVector(tensor->GetShape());
- auto tensor_dtype = tensor->GetDtype();
- bool tensor_variable = tensor->GetVariable();
- auto tensor_data = _builder.CreateVector(tensor->GetData());
- auto fboffset_tensor =
- CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable);
+ auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
+ auto tensor_shape = _builder.CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
+ auto tensor_data = _builder.CreateVector(tensor->GetData());
+ auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data,
+ tensor_variable, tensor_is_unranked);
fboffset_block_tensors.push_back(fboffset_tensor);
}
auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);