diff options
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r-- | include/tosa_serialization_handler.h | 204 |
1 files changed, 162 insertions, 42 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 53dcf1a..91b1a9d 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #ifndef _TOSA_SERIALIZATION_HANDLER_H #define _TOSA_SERIALIZATION_HANDLER_H #include "attribute.h" +#include "cfloat.h" #include "flatbuffers/idl.h" #include "flatbuffers/util.h" #include "numpy_utils.h" @@ -26,8 +27,8 @@ #include <vector> // Keep version number in sync with the version default value with schema/tosa.fbs -#define TOSA_VERSION_MAJOR 0 -#define TOSA_VERSION_MINOR 31 +#define TOSA_VERSION_MAJOR 1 +#define TOSA_VERSION_MINOR 0 #define TOSA_VERSION_PATCH 0 #define TOSA_VERSION_DRAFT true #define TENSOR_BUFFER_FORCE_ALIGNMENT 8 @@ -57,7 +58,7 @@ struct TosaVersion enum class compat_t { COMPLETELY_COMPATIBLE, - PARTIALLY_COMPATIBLE, + BACKWARD_COMPATIBLE, NOT_COMPATIBLE }; @@ -86,17 +87,53 @@ struct TosaVersion return str; } - compat_t is_compatible(const TosaVersion& rhs) const + static bool less_than(const TosaVersion& version1, const TosaVersion& version2) { - if (rhs._major == _major && rhs._minor == _minor) + if (version1._major < version2._major) { - if (rhs._patch == _patch && rhs._draft == _draft) + return true; + } + else if (version1._major == version2._major) + { + if (version1._minor < version2._minor) { - return TosaVersion::compat_t::COMPLETELY_COMPATIBLE; + return true; } - else + else if (version1._minor == version2._minor) { - return TosaVersion::compat_t::PARTIALLY_COMPATIBLE; + if (version1._patch < version2._patch) + { + return true; + } + else if (version1._patch == version2._patch) + { + if (version1._draft == true && version2._draft == false) + { + return true; + } + } + } + } + return false; + } + + static TosaVersion::compat_t is_compatible(const TosaVersion& tosa_fb_version, + const TosaVersion& serializer_version) + { + bool major_match = (serializer_version._major == tosa_fb_version._major); + bool minor_match = (serializer_version._minor == tosa_fb_version._minor); + bool patch_match = (serializer_version._patch == tosa_fb_version._patch); + bool draft_match = (serializer_version._draft == tosa_fb_version._draft); + + if (major_match && minor_match && patch_match && draft_match) + return TosaVersion::compat_t::COMPLETELY_COMPATIBLE; + + // We currently support backward compatibility starting from 0.100.0 + if ((tosa_fb_version._major == 0 && tosa_fb_version._minor >= 100) || (tosa_fb_version._major > 0)) + { + if (less_than(tosa_fb_version, serializer_version)) + { + return TosaVersion::compat_t::BACKWARD_COMPATIBLE; } } return TosaVersion::compat_t::NOT_COMPATIBLE; @@ -112,11 +149,17 @@ public: TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, - const flatbuffers::Vector<uint8_t>* data); + const flatbuffers::Vector<uint8_t>* data, + const bool variable = false, + const bool is_unranked = false, + const flatbuffers::String* variable_name = NULL); TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, - const std::vector<uint8_t>& data); + const std::vector<uint8_t>& data, + const bool variable = false, + const bool is_unranked = false, + const std::string& variable_name = ""); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -129,14 +172,26 @@ public: { return _shape; } - DType GetDtype() + DType GetDtype() const { return _dtype; } + bool GetVariable() const + { + return _variable; + } const std::vector<uint8_t>& GetData() const { return _data; } + bool GetIsUnranked() const + { + return _is_unranked; + } + const std::string GetVariableName() const + { + return _variable_name; + } // modifier void SetDtype(DType dtype) @@ -155,12 +210,28 @@ public: { _data = std::move(data); } + void SetIsUnranked(const bool value) + { + _is_unranked = value; + } + void SetDimSize(size_t dim, uint32_t new_size) + { + if (dim >= _shape.size()) + { + printf("dim is out of bound\n"); + assert(0); + } + _shape[dim] = new_size; + } private: DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ std::vector<int32_t> _shape; /* shape of the tensor */ std::string _name; /* name of the tensor, used for solving dependency */ + bool _variable; /* is this a variable tensor */ std::vector<uint8_t> _data; /* data array */ + bool _is_unranked; /* whether this is an unranked tensor */ + std::string _variable_name; /* name for variable tensors */ }; class TosaSerializationOperator @@ -216,11 +287,13 @@ class TosaSerializationBasicBlock public: // constructor and destructor TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector<TosaSerializationOperator*>& operators, const std::vector<TosaSerializationTensor*>& tensors, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs); TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector<TosaSerializationOperator*>&& operators, std::vector<TosaSerializationTensor*>&& tensors, std::vector<std::string>&& inputs, @@ -232,10 +305,15 @@ public: { return _name; } + std::string GetRegionName() const + { + return _region_name; + } std::vector<TosaSerializationOperator*>& GetOperators() { return _operators; } + std::vector<TosaSerializationTensor*>& GetTensors() { return _tensors; @@ -259,19 +337,59 @@ public: { return _inputs; } + std::vector<std::string>& GetOutputs() { return _outputs; } private: - std::string _name; /* name of basic block */ + std::string _name; /* name of basic block */ + std::string _region_name; std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */ std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */ std::vector<std::string> _inputs; /* array of string to specify block inputs */ std::vector<std::string> _outputs; /* array of string to specify block outputs */ }; +class TosaSerializationRegion +{ +public: + // constructor and desctructor + TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks); + TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks); + ~TosaSerializationRegion(); + + // accessors + std::string GetName() const + { + return this->_name; + } + + std::vector<TosaSerializationBasicBlock*>& GetBlocks() + { + return this->_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + +private: + std::string _name; /* name of basic block */ + std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */ +}; + /* * this is a helper class for writing/reading Tosa ISA * supported format: .tosa (flatbuffer), .json @@ -294,7 +412,12 @@ public: tosa_err_t LoadFileSchema(const char* schema_filename); // data format conversion. little-endian. + static tosa_err_t ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out); @@ -302,7 +425,13 @@ public: static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertU8toBF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t + ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<half_float::half>& out); static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toI64(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out); static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out); static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out); static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out); @@ -310,6 +439,8 @@ public: static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out); static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out); + static void ForceAlignTensorData(std::vector<uint8_t>& buf); + // version const TosaVersion& GetVersion() { @@ -317,39 +448,29 @@ public: } // accessor - std::vector<TosaSerializationBasicBlock*>& GetBlocks() + std::vector<TosaSerializationRegion*>& GetRegions() { - return _blocks; + return _regions; } - TosaSerializationBasicBlock* GetBlockByName(std::string name) + TosaSerializationRegion* GetMainRegion() { - TosaSerializationBasicBlock* result = nullptr; - for (auto block : GetBlocks()) + return _regions[0]; + } + + TosaSerializationRegion* GetRegionByName(std::string name) + { + TosaSerializationRegion* result = nullptr; + for (auto region : GetRegions()) { - if (block->GetName() == name) + if (region->GetName() == name) { - result = block; + result = region; break; } } return result; } - TosaSerializationBasicBlock* GetMainBlock() - { - TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); - assert(main_block); - return main_block; - } - - std::vector<std::string>& GetInputs() - { - return GetMainBlock()->GetInputs(); - } - std::vector<std::string>& GetOutputs() - { - return GetMainBlock()->GetOutputs(); - } bool GetSchemaLoaded() const { @@ -360,14 +481,13 @@ protected: tosa_err_t Clear(); tosa_err_t Deserialize(const uint8_t* buf); tosa_err_t Serialize(); - TosaVersion ParseTosaSchemaVersion(std::string schema); private: - TosaVersion _version; /* version struct */ - flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ - flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ - std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */ - bool _schemaLoaded; /* is the schema properly loaded? */ + TosaVersion _version; /* version struct */ + flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */ + flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */ + std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */ + bool _schemaLoaded; /* is the schema properly loaded? */ }; } // namespace tosa |