aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-14 11:53:39 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-10-19 09:06:57 -0700
commitb97cb1d46690321235f814e5a52cb8186380bce3 (patch)
tree137bc8121ee88771a00fdb197d54b113ed036276
parent6b078cac3ff2b33fd6d01c5e849424fbd9b2ac58 (diff)
downloadserialization_lib-b97cb1d46690321235f814e5a52cb8186380bce3.tar.gz
Both serializer and schema carry version info now
- version number now encoded in serializer as well - rename experimental to draft - rename internal function from FreezeBuilder/InitWithBuf to Serialize/Deserialize Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I70cb07832fdf66c6bde3d18aadf1f3646765887e
-rw-r--r--include/tosa_generated.h22
-rw-r--r--include/tosa_serialization_handler.h75
-rw-r--r--python/tosa/Version.py10
-rw-r--r--python/tosa_serializer.py41
-rw-r--r--schema/tosa.fbs4
-rw-r--r--src/tosa_serialization_handler.cpp77
6 files changed, 65 insertions, 164 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h
index 56ebbf2..51e33ce 100644
--- a/include/tosa_generated.h
+++ b/include/tosa_generated.h
@@ -1879,26 +1879,26 @@ struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT__MAJOR = 4,
VT__MINOR = 6,
VT__PATCH = 8,
- VT__EXPERIMENTAL = 10
+ VT__DRAFT = 10
};
int32_t _major() const {
return GetField<int32_t>(VT__MAJOR, 0);
}
int32_t _minor() const {
- return GetField<int32_t>(VT__MINOR, 22);
+ return GetField<int32_t>(VT__MINOR, 23);
}
int32_t _patch() const {
return GetField<int32_t>(VT__PATCH, 0);
}
- bool _experimental() const {
- return GetField<uint8_t>(VT__EXPERIMENTAL, 0) != 0;
+ bool _draft() const {
+ return GetField<uint8_t>(VT__DRAFT, 1) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT__MAJOR) &&
VerifyField<int32_t>(verifier, VT__MINOR) &&
VerifyField<int32_t>(verifier, VT__PATCH) &&
- VerifyField<uint8_t>(verifier, VT__EXPERIMENTAL) &&
+ VerifyField<uint8_t>(verifier, VT__DRAFT) &&
verifier.EndTable();
}
};
@@ -1911,13 +1911,13 @@ struct VersionBuilder {
fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0);
}
void add__minor(int32_t _minor) {
- fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 22);
+ fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 23);
}
void add__patch(int32_t _patch) {
fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0);
}
- void add__experimental(bool _experimental) {
- fbb_.AddElement<uint8_t>(Version::VT__EXPERIMENTAL, static_cast<uint8_t>(_experimental), 0);
+ void add__draft(bool _draft) {
+ fbb_.AddElement<uint8_t>(Version::VT__DRAFT, static_cast<uint8_t>(_draft), 1);
}
explicit VersionBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
@@ -1934,14 +1934,14 @@ struct VersionBuilder {
inline flatbuffers::Offset<Version> CreateVersion(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t _major = 0,
- int32_t _minor = 22,
+ int32_t _minor = 23,
int32_t _patch = 0,
- bool _experimental = false) {
+ bool _draft = true) {
VersionBuilder builder_(_fbb);
builder_.add__patch(_patch);
builder_.add__minor(_minor);
builder_.add__major(_major);
- builder_.add__experimental(_experimental);
+ builder_.add__draft(_draft);
return builder_.Finish();
}
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 73254dd..7fd8282 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -26,6 +26,10 @@
#include <string>
#include <vector>
+#define TOSA_VERSION_MAJOR 0
+#define TOSA_VERSION_MINOR 23
+#define TOSA_VERSION_PATCH 0
+#define TOSA_VERSION_DRAFT true
#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
namespace tosa
@@ -43,66 +47,6 @@ enum tosa_err_t
NUM_TOSA_ERROR
};
-struct TosaVersion
-{
- int32_t _major;
- int32_t _minor;
- int32_t _patch;
- bool _experimental;
- bool _valid;
-
- TosaVersion()
- {
- _valid = false;
- }
-
- TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
- {
- set_version(major, minor, patch, experimental);
- }
-
- void set_version(int32_t major, int32_t minor, int32_t patch, bool experimental)
- {
- _major = major;
- _minor = minor;
- _patch = patch;
- _experimental = experimental;
- _valid = true;
- }
-
- std::string to_string() const
- {
- std::string str;
- assert(_valid);
- str += std::to_string(_major) + ".";
- str += std::to_string(_minor) + ".";
- str += std::to_string(_patch);
- if (_experimental)
- str += "(experimental)";
- return str;
- };
-
- bool operator==(const TosaVersion& rhs)
- {
- assert(_valid);
- if (!_valid)
- return false;
- if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
- {
- return true;
- }
- return false;
- }
-
- bool operator!=(const TosaVersion& rhs)
- {
- assert(_valid);
- if (!_valid)
- return true;
- return !((*this) == rhs);
- }
-};
-
class TosaSerializationHandler;
class TosaSerializationTensor
@@ -303,7 +247,7 @@ public:
static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
// version
- const TosaVersion& GetTosaVersion() const
+ const std::string& GetVersionStr()
{
return _version;
}
@@ -350,13 +294,12 @@ public:
protected:
tosa_err_t Clear();
- tosa_err_t InitWithBuf(const uint8_t* buf);
- tosa_err_t FreezeBuilder();
- tosa_err_t SetTosaVersion();
- tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
+ tosa_err_t Deserialize(const uint8_t* buf);
+ tosa_err_t Serialize();
+ std::string VersionToStr(int32_t major, int32_t minor, int32_t patch, bool draft);
private:
- TosaVersion _version; /* tosa version */
+ std::string _version; /* version string */
flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
diff --git a/python/tosa/Version.py b/python/tosa/Version.py
index a63c482..8414c8c 100644
--- a/python/tosa/Version.py
+++ b/python/tosa/Version.py
@@ -51,7 +51,7 @@ class Version(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
- return 22
+ return 23
# Version
def _patch(self):
@@ -61,15 +61,15 @@ class Version(object):
return 0
# Version
- def _experimental(self):
+ def _draft(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
- return False
+ return True
def VersionStart(builder): builder.StartObject(4)
def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0)
-def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 22)
+def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 23)
def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0)
-def VersionAdd_experimental(builder, Experimental): builder.PrependBoolSlot(3, Experimental, 0)
+def VersionAdd_draft(builder, Draft): builder.PrependBoolSlot(3, Draft, 1)
def VersionEnd(builder): return builder.EndObject()
diff --git a/python/tosa_serializer.py b/python/tosa_serializer.py
index 04b2fc4..f0d7c63 100644
--- a/python/tosa_serializer.py
+++ b/python/tosa_serializer.py
@@ -35,6 +35,14 @@ from tosa_ref_run import TosaReturnCode
import tosa
+TOSA_VERSION_MAJOR = 0
+TOSA_VERSION_MINOR = 23
+TOSA_VERSION_PATCH = 0
+TOSA_VERSION_DRAFT = True
+TOSA_VERSION = [TOSA_VERSION_MAJOR,
+ TOSA_VERSION_MINOR,
+ TOSA_VERSION_PATCH,
+ TOSA_VERSION_DRAFT]
# With the way flatc generates its python types, there is no programatic way
# to get string names for the integer types. Manually maintain a string table
# here.
@@ -570,10 +578,6 @@ class TosaSerializer:
def __init__(self, pathPrefix):
# Get the global TOSA version if not already defined
- try:
- TOSA_VERSION
- except NameError:
- TosaSerializer.setTosaVersion()
self.builder = flatbuffers.Builder(0)
@@ -688,7 +692,7 @@ class TosaSerializer:
Version.VersionAdd_major(builder, TOSA_VERSION[0])
Version.VersionAdd_minor(builder, TOSA_VERSION[1])
Version.VersionAdd_patch(builder, TOSA_VERSION[2])
- Version.VersionAdd_experimental(builder, TOSA_VERSION[3])
+ Version.VersionAdd_draft(builder, TOSA_VERSION[3])
version = Version.VersionEnd(builder)
fbv_bb = TosaSerializer.serializeObjVec(
@@ -805,30 +809,3 @@ class TosaSerializer:
else:
return [val]
- @staticmethod
- def setTosaVersion():
- # Create a dummy flatbuffers file with the default version information
- # There does not appear to be a better way to get a constant from a
- # flatbuffer schema file
- builder = flatbuffers.Builder(0)
- Version.VersionStart(builder)
- ver = Version.VersionEnd(builder)
- TosaGraph.TosaGraphStart(builder)
- TosaGraph.TosaGraphAddVersion(builder, ver)
- gr = TosaGraph.TosaGraphEnd(builder)
- builder.Finish(gr)
-
- out = builder.Output()
-
- gr = TosaGraph.TosaGraph()
- root = gr.GetRootAsTosaGraph(out, 0)
-
- # Store the version as a global variable so that it only needs to be
- # generated once per process.
- global TOSA_VERSION
- TOSA_VERSION = [
- root.Version()._major(),
- root.Version()._minor(),
- root.Version()._patch(),
- root.Version()._experimental(),
- ]
diff --git a/schema/tosa.fbs b/schema/tosa.fbs
index 6e84b22..2e77fe5 100644
--- a/schema/tosa.fbs
+++ b/schema/tosa.fbs
@@ -270,9 +270,9 @@ table PadQuantInfo {
table Version {
_major: int32 = 0;
- _minor: int32 = 22;
+ _minor: int32 = 23;
_patch: int32 = 0;
- _experimental: bool = false;
+ _draft: bool = true;
}
table TosaTensor {
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 4d69396..fced242 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -148,8 +148,7 @@ TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
TosaSerializationHandler::TosaSerializationHandler()
{
_schemaLoaded = false;
-
- SetTosaVersion();
+ _version = VersionToStr(TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT);
}
TosaSerializationHandler::~TosaSerializationHandler()
@@ -157,26 +156,6 @@ TosaSerializationHandler::~TosaSerializationHandler()
Clear(); // deallocate all basic blocks
}
-tosa_err_t TosaSerializationHandler::SetTosaVersion()
-{
- // version is specified within .fbs
- // and it's encoded as defaulted value of CreateTosaVersion()
- // need to write out one object to read that value out
- // TODO: very costly now. is there any better way to encode constant in .fbs?
- auto fboffset_version = CreateVersion(_builder);
- auto fboffset_tosa_graph = CreateTosaGraphDirect(_builder, fboffset_version, nullptr);
- _builder.Finish(fboffset_tosa_graph);
- std::string jsongen;
- uint8_t* buf = _builder.GetBufferPointer();
- auto fb_tosa_graph = GetTosaGraph(buf);
- auto fb_tosa_version = fb_tosa_graph->version();
-
- _version.set_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
- fb_tosa_version->_experimental());
-
- return TOSA_OK;
-}
-
tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
{
std::string schema;
@@ -227,7 +206,7 @@ tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
uint8_t* buf = _parser.builder_.GetBufferPointer();
- err = InitWithBuf(buf);
+ err = Deserialize(buf);
if (err != TOSA_OK)
{
return err;
@@ -246,7 +225,7 @@ tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
return TOSA_SCHEMA_MISSING;
}
- err = FreezeBuilder();
+ err = Serialize();
if (err != TOSA_OK)
{
return err;
@@ -297,7 +276,7 @@ tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename
buf = (uint8_t*)read_buffer.data();
- err = InitWithBuf(buf);
+ err = Deserialize(buf);
if (err != TOSA_OK)
{
return err;
@@ -310,7 +289,7 @@ tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename
{
tosa_err_t err;
- err = FreezeBuilder();
+ err = Serialize();
if (err != TOSA_OK)
{
return err;
@@ -340,19 +319,18 @@ tosa_err_t TosaSerializationHandler::Clear()
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version)
+std::string TosaSerializationHandler::VersionToStr(int32_t major, int32_t minor, int32_t patch, bool draft)
{
- if (_version != read_version)
- {
- printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(),
- _version.to_string().c_str());
- return TOSA_VERSION_MISMATCH;
- }
-
- return TOSA_OK;
+ std::string str;
+ str += std::to_string(major) + ".";
+ str += std::to_string(minor) + ".";
+ str += std::to_string(patch);
+ if (draft)
+ str += "d";
+ return str;
}
-tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
+tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
{
auto fb_tosa_graph = GetTosaGraph(buf);
auto fb_tosa_version = fb_tosa_graph->version();
@@ -375,12 +353,15 @@ tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
// erase container
Clear();
- TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
- fb_tosa_version->_experimental());
- tosa_err_t err = CheckTosaVersion(read_version);
+ std::string read_version = VersionToStr(fb_tosa_version->_major(), fb_tosa_version->_minor(),
+ fb_tosa_version->_patch(), fb_tosa_version->_draft());
- if (err != TOSA_OK)
- return err;
+ if (read_version != GetVersionStr())
+ {
+ printf("Read flatbuffer version %s doesn't match serializer version %s\n", read_version.c_str(),
+ GetVersionStr().c_str());
+ return TOSA_VERSION_MISMATCH;
+ }
for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
{
@@ -436,7 +417,7 @@ tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
#include "attribute.def"
#undef DEF_ATTRIBUTE
default:
- printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n",
+ printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n",
EnumNamesAttribute()[attribute_type]);
return TOSA_INTERNAL_ERROR;
}
@@ -454,7 +435,7 @@ tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
#include "quant_info.def"
#undef DEF_QUANTIZATION_INFO
default:
- printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n",
+ printf("TosaSerializationHandler::Deserialize(): QuantInfo %s not implemented yet\n",
EnumNamesQuantInfo()[operator_qinfo_type]);
return TOSA_INTERNAL_ERROR;
}
@@ -531,7 +512,7 @@ tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::FreezeBuilder()
+tosa_err_t TosaSerializationHandler::Serialize()
{
std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
@@ -649,7 +630,7 @@ tosa_err_t TosaSerializationHandler::FreezeBuilder()
#undef DEF_ARGS_S_STR
#undef DEF_ARGS_S_DEFAULT
default:
- printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n",
EnumNamesAttribute()[attribute_type]);
return TOSA_INTERNAL_ERROR;
}
@@ -715,7 +696,7 @@ tosa_err_t TosaSerializationHandler::FreezeBuilder()
#undef DEF_ARGS_S
#undef DEF_ARGS_V
default:
- printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
+ printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n",
EnumNamesAttribute()[attribute_type]);
return TOSA_INTERNAL_ERROR;
}
@@ -749,8 +730,8 @@ tosa_err_t TosaSerializationHandler::FreezeBuilder()
auto fb_blocks = _builder.CreateVector(fboffset_blocks);
- auto fb_version = CreateVersion(_builder, GetTosaVersion()._major, GetTosaVersion()._minor, GetTosaVersion()._patch,
- GetTosaVersion()._experimental);
+ auto fb_version =
+ CreateVersion(_builder, TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT);
auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks);
_builder.Finish(fb_graph);