aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h204
1 files changed, 162 insertions, 42 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 53dcf1a..91b1a9d 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#ifndef _TOSA_SERIALIZATION_HANDLER_H
#define _TOSA_SERIALIZATION_HANDLER_H
#include "attribute.h"
+#include "cfloat.h"
#include "flatbuffers/idl.h"
#include "flatbuffers/util.h"
#include "numpy_utils.h"
@@ -26,8 +27,8 @@
#include <vector>
// Keep version number in sync with the version default value with schema/tosa.fbs
-#define TOSA_VERSION_MAJOR 0
-#define TOSA_VERSION_MINOR 31
+#define TOSA_VERSION_MAJOR 1
+#define TOSA_VERSION_MINOR 0
#define TOSA_VERSION_PATCH 0
#define TOSA_VERSION_DRAFT true
#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
@@ -57,7 +58,7 @@ struct TosaVersion
enum class compat_t
{
COMPLETELY_COMPATIBLE,
- PARTIALLY_COMPATIBLE,
+ BACKWARD_COMPATIBLE,
NOT_COMPATIBLE
};
@@ -86,17 +87,53 @@ struct TosaVersion
return str;
}
- compat_t is_compatible(const TosaVersion& rhs) const
+ static bool less_than(const TosaVersion& version1, const TosaVersion& version2)
{
- if (rhs._major == _major && rhs._minor == _minor)
+ if (version1._major < version2._major)
{
- if (rhs._patch == _patch && rhs._draft == _draft)
+ return true;
+ }
+ else if (version1._major == version2._major)
+ {
+ if (version1._minor < version2._minor)
{
- return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
+ return true;
}
- else
+ else if (version1._minor == version2._minor)
{
- return TosaVersion::compat_t::PARTIALLY_COMPATIBLE;
+ if (version1._patch < version2._patch)
+ {
+ return true;
+ }
+ else if (version1._patch == version2._patch)
+ {
+ if (version1._draft == true && version2._draft == false)
+ {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ static TosaVersion::compat_t is_compatible(const TosaVersion& tosa_fb_version,
+ const TosaVersion& serializer_version)
+ {
+ bool major_match = (serializer_version._major == tosa_fb_version._major);
+ bool minor_match = (serializer_version._minor == tosa_fb_version._minor);
+ bool patch_match = (serializer_version._patch == tosa_fb_version._patch);
+ bool draft_match = (serializer_version._draft == tosa_fb_version._draft);
+
+ if (major_match && minor_match && patch_match && draft_match)
+ return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
+
+ // We currently support backward compatibility starting from 0.100.0
+ if ((tosa_fb_version._major == 0 && tosa_fb_version._minor >= 100) || (tosa_fb_version._major > 0))
+ {
+ if (less_than(tosa_fb_version, serializer_version))
+ {
+ return TosaVersion::compat_t::BACKWARD_COMPATIBLE;
}
}
return TosaVersion::compat_t::NOT_COMPATIBLE;
@@ -112,11 +149,17 @@ public:
TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data);
+ const flatbuffers::Vector<uint8_t>* data,
+ const bool variable = false,
+ const bool is_unranked = false,
+ const flatbuffers::String* variable_name = NULL);
TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data);
+ const std::vector<uint8_t>& data,
+ const bool variable = false,
+ const bool is_unranked = false,
+ const std::string& variable_name = "");
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -129,14 +172,26 @@ public:
{
return _shape;
}
- DType GetDtype()
+ DType GetDtype() const
{
return _dtype;
}
+ bool GetVariable() const
+ {
+ return _variable;
+ }
const std::vector<uint8_t>& GetData() const
{
return _data;
}
+ bool GetIsUnranked() const
+ {
+ return _is_unranked;
+ }
+ const std::string GetVariableName() const
+ {
+ return _variable_name;
+ }
// modifier
void SetDtype(DType dtype)
@@ -155,12 +210,28 @@ public:
{
_data = std::move(data);
}
+ void SetIsUnranked(const bool value)
+ {
+ _is_unranked = value;
+ }
+ void SetDimSize(size_t dim, uint32_t new_size)
+ {
+ if (dim >= _shape.size())
+ {
+ printf("dim is out of bound\n");
+ assert(0);
+ }
+ _shape[dim] = new_size;
+ }
private:
DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
std::vector<int32_t> _shape; /* shape of the tensor */
std::string _name; /* name of the tensor, used for solving dependency */
+ bool _variable; /* is this a variable tensor */
std::vector<uint8_t> _data; /* data array */
+ bool _is_unranked; /* whether this is an unranked tensor */
+ std::string _variable_name; /* name for variable tensors */
};
class TosaSerializationOperator
@@ -216,11 +287,13 @@ class TosaSerializationBasicBlock
public:
// constructor and destructor
TosaSerializationBasicBlock(const std::string& name,
+ const std::string& region_name,
const std::vector<TosaSerializationOperator*>& operators,
const std::vector<TosaSerializationTensor*>& tensors,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs);
TosaSerializationBasicBlock(std::string&& name,
+ std::string&& region_name,
std::vector<TosaSerializationOperator*>&& operators,
std::vector<TosaSerializationTensor*>&& tensors,
std::vector<std::string>&& inputs,
@@ -232,10 +305,15 @@ public:
{
return _name;
}
+ std::string GetRegionName() const
+ {
+ return _region_name;
+ }
std::vector<TosaSerializationOperator*>& GetOperators()
{
return _operators;
}
+
std::vector<TosaSerializationTensor*>& GetTensors()
{
return _tensors;
@@ -259,19 +337,59 @@ public:
{
return _inputs;
}
+
std::vector<std::string>& GetOutputs()
{
return _outputs;
}
private:
- std::string _name; /* name of basic block */
+ std::string _name; /* name of basic block */
+ std::string _region_name;
std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
std::vector<std::string> _inputs; /* array of string to specify block inputs */
std::vector<std::string> _outputs; /* array of string to specify block outputs */
};
+class TosaSerializationRegion
+{
+public:
+ // constructor and desctructor
+ TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
+ TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
+ ~TosaSerializationRegion();
+
+ // accessors
+ std::string GetName() const
+ {
+ return this->_name;
+ }
+
+ std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ {
+ return this->_blocks;
+ }
+
+ TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ {
+ TosaSerializationBasicBlock* result = nullptr;
+ for (auto block : GetBlocks())
+ {
+ if (block->GetName() == name)
+ {
+ result = block;
+ break;
+ }
+ }
+ return result;
+ }
+
+private:
+ std::string _name; /* name of basic block */
+ std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
+};
+
/*
* this is a helper class for writing/reading Tosa ISA
* supported format: .tosa (flatbuffer), .json
@@ -294,7 +412,12 @@ public:
tosa_err_t LoadFileSchema(const char* schema_filename);
// data format conversion. little-endian.
+ static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
@@ -302,7 +425,13 @@ public:
static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t
+ ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out);
static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toI64(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
@@ -310,6 +439,8 @@ public:
static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
+ static void ForceAlignTensorData(std::vector<uint8_t>& buf);
+
// version
const TosaVersion& GetVersion()
{
@@ -317,39 +448,29 @@ public:
}
// accessor
- std::vector<TosaSerializationBasicBlock*>& GetBlocks()
+ std::vector<TosaSerializationRegion*>& GetRegions()
{
- return _blocks;
+ return _regions;
}
- TosaSerializationBasicBlock* GetBlockByName(std::string name)
+ TosaSerializationRegion* GetMainRegion()
{
- TosaSerializationBasicBlock* result = nullptr;
- for (auto block : GetBlocks())
+ return _regions[0];
+ }
+
+ TosaSerializationRegion* GetRegionByName(std::string name)
+ {
+ TosaSerializationRegion* result = nullptr;
+ for (auto region : GetRegions())
{
- if (block->GetName() == name)
+ if (region->GetName() == name)
{
- result = block;
+ result = region;
break;
}
}
return result;
}
- TosaSerializationBasicBlock* GetMainBlock()
- {
- TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
- assert(main_block);
- return main_block;
- }
-
- std::vector<std::string>& GetInputs()
- {
- return GetMainBlock()->GetInputs();
- }
- std::vector<std::string>& GetOutputs()
- {
- return GetMainBlock()->GetOutputs();
- }
bool GetSchemaLoaded() const
{
@@ -360,14 +481,13 @@ protected:
tosa_err_t Clear();
tosa_err_t Deserialize(const uint8_t* buf);
tosa_err_t Serialize();
- TosaVersion ParseTosaSchemaVersion(std::string schema);
private:
- TosaVersion _version; /* version struct */
- flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
- flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
- std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
- bool _schemaLoaded; /* is the schema properly loaded? */
+ TosaVersion _version; /* version struct */
+ flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
+ flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
+ std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
+ bool _schemaLoaded; /* is the schema properly loaded? */
};
} // namespace tosa