diff options
Diffstat (limited to 'src/tosa_serialization_handler.cpp')
-rw-r--r-- | src/tosa_serialization_handler.cpp | 794 |
1 files changed, 502 insertions, 292 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index 3a0ce43..0ce6211 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,18 +14,28 @@ // limitations under the License. #include "tosa_serialization_handler.h" +#include "half.hpp" #include <iostream> using namespace tosa; +using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>; +using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>; + TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, - const flatbuffers::Vector<uint8_t>* data) + const flatbuffers::Vector<uint8_t>* data, + const bool variable, + const bool is_unranked, + const flatbuffers::String* variable_name) { - _dtype = dtype; - - std::copy(shape->begin(), shape->end(), std::back_inserter(_shape)); + _dtype = dtype; + _variable = variable; + if (shape) + { + std::copy(shape->begin(), shape->end(), std::back_inserter(_shape)); + } assert(name); _name = name->str(); @@ -34,23 +44,37 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name { std::copy(data->begin(), data->end(), std::back_inserter(_data)); } + _is_unranked = is_unranked; + + if (variable_name) + { + _variable_name = variable_name->str(); + } } TosaSerializationTensor::TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, - const std::vector<uint8_t>& data) + const std::vector<uint8_t>& data, + const bool variable, + const bool is_unranked, + const std::string& variable_name) { - _dtype = dtype; - _shape = shape; - _name = name; - _data = data; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; + _is_unranked = is_unranked; + _variable_name = variable_name; } TosaSerializationTensor::TosaSerializationTensor() { - _dtype = DType_UNKNOWN; - _name = "UNKNOWN"; + _dtype = DType_UNKNOWN; + _variable = false; + _name = "UNKNOWN"; + _is_unranked = false; } TosaSerializationTensor::~TosaSerializationTensor() @@ -112,29 +136,33 @@ TosaSerializationOperator::~TosaSerializationOperator() } TosaSerializationBasicBlock::TosaSerializationBasicBlock(const std::string& name, + const std::string& region_name, const std::vector<TosaSerializationOperator*>& operators, const std::vector<TosaSerializationTensor*>& tensors, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs) { - _name = name; - _operators = operators; - _tensors = tensors; - _inputs = inputs; - _outputs = outputs; + _name = name; + _region_name = region_name; + _operators = operators; + _tensors = tensors; + _inputs = inputs; + _outputs = outputs; } TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string&& name, + std::string&& region_name, std::vector<TosaSerializationOperator*>&& operators, std::vector<TosaSerializationTensor*>&& tensors, std::vector<std::string>&& inputs, std::vector<std::string>&& outputs) { - _name = std::move(name); - _operators = std::move(operators); - _tensors = std::move(tensors); - _inputs = std::move(inputs); - _outputs = std::move(outputs); + _name = std::move(name); + _region_name = std::move(region_name); + _operators = std::move(operators); + _tensors = std::move(tensors); + _inputs = std::move(inputs); + _outputs = std::move(outputs); } TosaSerializationBasicBlock::~TosaSerializationBasicBlock() @@ -152,65 +180,38 @@ TosaSerializationBasicBlock::~TosaSerializationBasicBlock() } } -TosaSerializationHandler::TosaSerializationHandler() +TosaSerializationRegion::TosaSerializationRegion(const std::string& name, + const std::vector<TosaSerializationBasicBlock*>& blocks) { - _schemaLoaded = false; - _version = TosaVersion(TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT); + _name = name; + _blocks = blocks; } -TosaSerializationHandler::~TosaSerializationHandler() +TosaSerializationRegion::TosaSerializationRegion(const std::string&& name, + const std::vector<TosaSerializationBasicBlock*>&& blocks) { - Clear(); // deallocate all basic blocks + _name = std::move(name); + _blocks = std::move(blocks); } -TosaVersion TosaSerializationHandler::ParseTosaSchemaVersion(std::string schema) +TosaSerializationRegion::~TosaSerializationRegion() { - // Parse all 4 version fields in schema file - static const char* keywords[4] = { "major: int32 = ", "minor: int32 = ", "patch: int32 = ", "draft: bool = " }; - string keyword_str[4]; - size_t search_pos = 0; - size_t keyword_pos; - size_t semicolon_pos; - // parse integer field first - for (int32_t i = 0; i < 4; i++) - { - keyword_pos = schema.find(keywords[i], search_pos); - if (keyword_pos == std::string::npos) - { - printf("ERROR: can't find keyword \"%s\" in schema\n", keywords[i]); - assert(0); - } - semicolon_pos = schema.find(';', keyword_pos); - if (keyword_pos == std::string::npos) - { - printf("ERROR: can't find ';' in schema\n"); - assert(0); - } - keyword_str[i] = - schema.substr(keyword_pos + strlen(keywords[i]), semicolon_pos - keyword_pos - strlen(keywords[i])); - search_pos = semicolon_pos; - } - - int32_t schema_major = 0; - int32_t schema_minor = 0; - int32_t schema_patch = 0; - bool schema_draft = false; - try - { - schema_major = stoi(keyword_str[0]); - schema_minor = stoi(keyword_str[1]); - schema_patch = stoi(keyword_str[2]); - schema_draft = (keyword_str[3] == "true") ? true : false; - } - catch (std::invalid_argument& e) + // deallocate all blocks + for (auto block : GetBlocks()) { - printf("ERROR: fail at stoi(): %s\n", e.what()); - assert(0); + delete block; // ~TosaSerializationBasicBlock() } +} - TosaVersion schema_version(schema_major, schema_minor, schema_patch, schema_draft); +TosaSerializationHandler::TosaSerializationHandler() +{ + _schemaLoaded = false; + _version = TosaVersion(TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT); +} - return schema_version; +TosaSerializationHandler::~TosaSerializationHandler() +{ + Clear(); // deallocate all basic blocks } tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename) @@ -227,23 +228,6 @@ tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename) ok = _parser.Parse(schema.c_str()); - TosaVersion schema_version = ParseTosaSchemaVersion(schema); - - TosaVersion::compat_t is_compat = schema_version.is_compatible(GetVersion()); - switch (is_compat) - { - case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: - break; - case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: - printf("WARNING: Schema flatbuffer version %s is partially compatible with serializer version %s\n", - schema_version.to_string().c_str(), GetVersion().to_string().c_str()); - break; - case TosaVersion::compat_t::NOT_COMPATIBLE: - printf("ERROR: Schema flatbuffer version %s is not compatible with serializer version %s\n", - schema_version.to_string().c_str(), GetVersion().to_string().c_str()); - return TOSA_VERSION_MISMATCH; - } - if (!ok) { printf("Error parsing ISA schema file: %s\n", schema_filename); @@ -308,7 +292,7 @@ tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename) uint8_t* buf = _builder.GetBufferPointer(); - if (!GenerateText(_parser, buf, &jsongen)) + if (GenText(_parser, buf, &jsongen)) { printf("Couldn't serialize parsed data to JSON!\n"); return TOSA_FILE_ERROR; @@ -399,11 +383,11 @@ tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename tosa_err_t TosaSerializationHandler::Clear() { // deallocate all basic blocks - for (auto bb : GetBlocks()) + for (auto region : GetRegions()) { - delete bb; + delete region; } - _blocks.clear(); + _regions.clear(); return TOSA_OK; } @@ -416,20 +400,13 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) } auto fb_tosa_graph = GetTosaGraph(buf); auto fb_tosa_version = fb_tosa_graph->version(); - auto fb_tosa_blocks = fb_tosa_graph->blocks(); - - std::vector<std::string> operator_inputs_container; - std::vector<std::string> operator_outputs_container; - - std::vector<TosaSerializationOperator*> block_operators_container; - std::vector<TosaSerializationTensor*> block_tensors_container; - std::vector<std::string> block_inputs_container; - std::vector<std::string> block_outputs_container; + auto fb_tosa_regions = fb_tosa_graph->regions(); TosaAttributeBase* typed_attribute = NULL; TosaSerializationOperator* new_operator = NULL; TosaSerializationBasicBlock* new_block = NULL; TosaSerializationTensor* new_tensor = NULL; + TosaSerializationRegion* new_region = NULL; // erase container Clear(); @@ -437,226 +414,241 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(), fb_tosa_version->_draft()); - TosaVersion::compat_t is_compat = read_version.is_compatible(GetVersion()); + TosaVersion::compat_t is_compat = TosaVersion::is_compatible(read_version, GetVersion()); switch (is_compat) { case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: break; - case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: - printf("WARNING: Read flatbuffer version %s is partially compatible with serializer version %s\n", + case TosaVersion::compat_t::BACKWARD_COMPATIBLE: + printf("WARNING: Different Tosa flatbuffer and serializer versions detected. Read Tosa flatbuffer version " + "%s is backward " + "compatible with serializer version %s\n", read_version.to_string().c_str(), GetVersion().to_string().c_str()); break; case TosaVersion::compat_t::NOT_COMPATIBLE: - printf("ERROR: Read flatbuffer version %s is not compatible with serializer version %s\n", + printf("ERROR: Read Tosa flatbuffer version %s is not compatible with serializer version %s\n", read_version.to_string().c_str(), GetVersion().to_string().c_str()); return TOSA_VERSION_MISMATCH; } - for (size_t i = 0; i < fb_tosa_blocks->size(); i++) + for (size_t i = 0; i < fb_tosa_regions->size(); i++) { - auto curr_block = fb_tosa_blocks->Get(i); + auto curr_region = fb_tosa_regions->Get(i); + auto region_name = curr_region->name()->str(); + auto fb_tosa_blocks = curr_region->blocks(); - auto block_name = curr_block->name()->str(); + new_region = new TosaSerializationRegion(curr_region->name()->str(), {}); + this->GetRegions().push_back(new_region); - auto fb_tosa_operators = curr_block->operators(); - block_operators_container.clear(); - for (size_t j = 0; j < fb_tosa_operators->size(); j++) + for (size_t i = 0; i < fb_tosa_blocks->size(); i++) { - auto curr_operator = fb_tosa_operators->Get(j); + std::vector<TosaSerializationOperator*> block_operators_container; + std::vector<TosaSerializationTensor*> block_tensors_container; + std::vector<std::string> block_inputs_container; + std::vector<std::string> block_outputs_container; - auto operator_op = curr_operator->op(); - auto attribute_type = curr_operator->attribute_type(); - auto attribute = curr_operator->attribute(); + auto curr_block = fb_tosa_blocks->Get(i); - // input tensors - auto operator_inputs = curr_operator->inputs(); - operator_inputs_container.clear(); - if (operator_inputs) + auto block_name = curr_block->name()->str(); + + auto fb_tosa_operators = curr_block->operators(); + for (size_t j = 0; j < fb_tosa_operators->size(); j++) { - for (size_t k = 0; k < operator_inputs->size(); k++) + auto curr_operator = fb_tosa_operators->Get(j); + + auto operator_op = curr_operator->op(); + auto attribute_type = curr_operator->attribute_type(); + auto attribute = curr_operator->attribute(); + + std::vector<std::string> operator_inputs_container; + std::vector<std::string> operator_outputs_container; + + // input tensors + auto operator_inputs = curr_operator->inputs(); + if (operator_inputs) { - auto curr_input = operator_inputs->Get(k); - operator_inputs_container.push_back(curr_input->str()); + for (size_t k = 0; k < operator_inputs->size(); k++) + { + auto curr_input = operator_inputs->Get(k); + operator_inputs_container.push_back(curr_input->str()); + } } - } - // output tensors - auto operator_outputs = curr_operator->outputs(); - operator_outputs_container.clear(); - if (operator_outputs) - { - for (size_t k = 0; k < operator_outputs->size(); k++) + // output tensors + auto operator_outputs = curr_operator->outputs(); + if (operator_outputs) { - auto curr_output = operator_outputs->Get(k); - operator_outputs_container.push_back(curr_output->str()); + for (size_t k = 0; k < operator_outputs->size(); k++) + { + auto curr_output = operator_outputs->Get(k); + operator_outputs_container.push_back(curr_output->str()); + } } - } - switch (attribute_type) - { - case Attribute_NONE: - typed_attribute = new TosaNoneAttribute(); - break; + switch (attribute_type) + { + case Attribute_NONE: + typed_attribute = new TosaNoneAttribute(); + break; #define DEF_ATTRIBUTE(NAME, ...) \ case Attribute_##NAME##Attribute: \ typed_attribute = new Tosa##NAME##Attribute(attribute); \ break; #include "attribute.def" #undef DEF_ATTRIBUTE - default: - printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; + default: + printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + + new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, + operator_inputs_container, operator_outputs_container); + if (new_operator) + { + block_operators_container.push_back(new_operator); + } + else + { + return TOSA_MEMORY_ERROR; + } + + if (typed_attribute) + delete typed_attribute; } - new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, - operator_inputs_container, operator_outputs_container); - if (new_operator) + auto block_inputs = curr_block->inputs(); + auto block_outputs = curr_block->outputs(); + + for (size_t j = 0; j < block_inputs->size(); j++) { - block_operators_container.push_back(new_operator); + auto curr_block_input = block_inputs->Get(j); + block_inputs_container.push_back(curr_block_input->str()); } - else + for (size_t j = 0; j < block_outputs->size(); j++) { - return TOSA_MEMORY_ERROR; + auto curr_block_output = block_outputs->Get(j); + block_outputs_container.push_back(curr_block_output->str()); } - if (typed_attribute) - delete typed_attribute; - } - - auto fb_tosa_tensors = curr_block->tensors(); - block_tensors_container.clear(); - for (size_t j = 0; j < fb_tosa_tensors->size(); j++) - { - auto curr_tensor = fb_tosa_tensors->Get(j); - - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_data = curr_tensor->data(); - - new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data); - if (new_tensor) + auto fb_tosa_tensors = curr_block->tensors(); + for (size_t j = 0; j < fb_tosa_tensors->size(); j++) + { + auto curr_tensor = fb_tosa_tensors->Get(j); + + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_variable = curr_tensor->variable(); + auto tensor_data = curr_tensor->data(); + auto tensor_is_unranked = curr_tensor->is_unranked(); + auto tensor_variable_name = curr_tensor->variable_name(); + + new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, + tensor_variable, tensor_is_unranked, tensor_variable_name); + if (new_tensor) + { + block_tensors_container.push_back(new_tensor); + } + else + { + return TOSA_MEMORY_ERROR; + } + } + new_block = new TosaSerializationBasicBlock(block_name, region_name, block_operators_container, + block_tensors_container, block_inputs_container, + block_outputs_container); + if (new_block) { - block_tensors_container.push_back(new_tensor); + new_region->GetBlocks().push_back(new_block); } else { return TOSA_MEMORY_ERROR; } - } - - auto block_inputs = curr_block->inputs(); - auto block_outputs = curr_block->outputs(); - - block_inputs_container.clear(); - block_outputs_container.clear(); - - for (size_t j = 0; j < block_inputs->size(); j++) - { - auto curr_block_input = block_inputs->Get(j); - block_inputs_container.push_back(curr_block_input->str()); - } - for (size_t j = 0; j < block_outputs->size(); j++) - { - auto curr_block_output = block_outputs->Get(j); - block_outputs_container.push_back(curr_block_output->str()); - } - - new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container, - block_inputs_container, block_outputs_container); - if (new_block) - { - this->GetBlocks().push_back(new_block); - } - else - { - return TOSA_MEMORY_ERROR; - } + } // end block for_loop } return TOSA_OK; } -tosa_err_t TosaSerializationHandler::Serialize() +std::vector<uint8_t> float_to_u8_helper(float f_in) { - std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks; - - std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators; - std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors; - std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs; - std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs; + // Push back a single float value to the buffer with *NO PADDING* + // Therefore ConvertF32toU8 function not used + std::vector<uint8_t> u8_out; + uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&f_in); + u8_out.push_back(*val_u32 & 0xFF); + u8_out.push_back((*val_u32 >> 8) & 0xFF); + u8_out.push_back((*val_u32 >> 16) & 0xFF); + u8_out.push_back((*val_u32 >> 24) & 0xFF); + return u8_out; +} - std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs; - std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs; +tosa_err_t TosaSerializationHandler::Serialize() +{ + // regions + std::vector<flatbuffers::Offset<TosaRegion>> fboffset_regions; // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator> - for (auto block : GetBlocks()) + for (auto region : GetRegions()) { - fboffset_block_operators.clear(); - fboffset_block_tensors.clear(); - fboffset_block_inputs.clear(); - fboffset_block_outputs.clear(); - - auto block_name = _builder.CreateString(block->GetName().c_str()); - - for (auto tensor_str : block->GetInputs()) - { - auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_block_inputs.push_back(tensor_name); - } - - for (auto tensor_str : block->GetOutputs()) - { - auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_block_outputs.push_back(tensor_name); - } - - auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs); - auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs); - - for (auto op : block->GetOperators()) + std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks; + for (auto block : region->GetBlocks()) { - fboffset_operator_inputs.clear(); - fboffset_operator_outputs.clear(); - - auto operator_op = op->GetOp(); - auto attribute_type = op->GetAttributeType(); - - for (auto tensor_str : op->GetInputTensorNames()) + std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators; + std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs; + auto block_name = _builder.CreateString(block->GetName().c_str()); + for (auto tensor_str : block->GetInputs()) { auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_operator_inputs.push_back(tensor_name); + fboffset_block_inputs.push_back(tensor_name); } - - for (auto tensor_str : op->GetOutputTensorNames()) + for (auto tensor_str : block->GetOutputs()) { auto tensor_name = _builder.CreateString(tensor_str.c_str()); - fboffset_operator_outputs.push_back(tensor_name); + fboffset_block_outputs.push_back(tensor_name); } - - auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs); - auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs); - - flatbuffers::Offset<void> fb_attribute; - switch (attribute_type) + auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs); + auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs); + for (auto op : block->GetOperators()) { - case Attribute_NONE: - fb_attribute = 0; - break; - + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs; + std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs; + auto operator_op = op->GetOp(); + auto attribute_type = op->GetAttributeType(); + for (auto tensor_str : op->GetInputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_inputs.push_back(tensor_name); + } + for (auto tensor_str : op->GetOutputTensorNames()) + { + auto tensor_name = _builder.CreateString(tensor_str.c_str()); + fboffset_operator_outputs.push_back(tensor_name); + } + auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs); + auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs); + flatbuffers::Offset<void> fb_attribute; + switch (attribute_type) + { + case Attribute_NONE: + fb_attribute = 0; + break; #define DEF_ARGS_S_STR(NAME, V) , _builder.CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str()) +#define DEF_ARGS_S_FP_as_U8(NAME, V) \ + , _builder.CreateVector<uint8_t>(float_to_u8_helper(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())) #define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V() - #define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) -#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_FP_as_U8(NAME, V) #define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) +#define DEF_ARGS_S_DType(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V) #define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V) - #define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V) #define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()) - #define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0) #define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) #define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \ @@ -672,11 +664,20 @@ tosa_err_t TosaSerializationHandler::Serialize() #define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) +#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) +#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \ + V7, T8, F8, V8) \ + DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \ + DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \ + DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) #define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \ case Attribute_##NAME##Attribute: \ fb_attribute = Create##NAME##Attribute(_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \ break; - #include "attribute.def" #undef DEF_ATTRIBUTE #undef DEF_ARGS_1 @@ -692,53 +693,56 @@ tosa_err_t TosaSerializationHandler::Serialize() #undef DEF_ARGS_S_float #undef DEF_ARGS_S_bool #undef DEF_ARGS_S_ResizeMode +#undef DEF_ARGS_S_DType #undef DEF_ARGS_S_string #undef DEF_ARGS_S_STR #undef DEF_ARGS_S_DEFAULT - default: - printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", - EnumNamesAttribute()[attribute_type]); - return TOSA_INTERNAL_ERROR; + default: + printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n", + EnumNamesAttribute()[attribute_type]); + return TOSA_INTERNAL_ERROR; + } + auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, + fb_operator_inputs, fb_operator_outputs); + fboffset_block_operators.push_back(fboffset_operator); } + auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); + for (auto tensor : block->GetTensors()) + { + auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); + auto tensor_shape = _builder.CreateVector(tensor->GetShape()); + auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); + auto tensor_data = _builder.CreateVector(tensor->GetData()); + auto tensor_is_unranked = tensor->GetIsUnranked(); + auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str()); + auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, + tensor_variable, tensor_is_unranked, tensor_variable_name); + fboffset_block_tensors.push_back(fboffset_tensor); + } + auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); + auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors, + fb_block_inputs, fb_block_outputs); + fboffset_blocks.push_back(fboffset_block); + } // end block for_loop + auto fb_blocks = _builder.CreateVector(fboffset_blocks); - auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, - fb_operator_inputs, fb_operator_outputs); - fboffset_block_operators.push_back(fboffset_operator); - } - - auto fb_block_operators = _builder.CreateVector(fboffset_block_operators); - - for (auto tensor : block->GetTensors()) - { - - auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); - auto tensor_shape = _builder.CreateVector(tensor->GetShape()); - auto tensor_dtype = tensor->GetDtype(); - auto tensor_data = _builder.CreateVector(tensor->GetData()); - - auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data); - fboffset_block_tensors.push_back(fboffset_tensor); - } - - auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); - - auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors, - fb_block_inputs, fb_block_outputs); - fboffset_blocks.push_back(fboffset_block); - } + auto region_name = _builder.CreateString(region->GetName().c_str()); + auto fboffset_region = CreateTosaRegion(_builder, region_name, fb_blocks); + fboffset_regions.push_back(fboffset_region); + } // end region for_loop - auto fb_blocks = _builder.CreateVector(fboffset_blocks); + auto fb_regions = _builder.CreateVector(fboffset_regions); auto fb_version = CreateVersion(_builder, TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT); - - auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks); + auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_regions); _builder.Finish(fb_graph, TosaGraphIdentifier()); return TOSA_OK; } -void zero_pad(std::vector<uint8_t>& buf) +void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf) { while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0) { @@ -746,6 +750,66 @@ void zero_pad(std::vector<uint8_t>& buf) } } +tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +{ + // Note: Converts fp32->bf16 by ignoring the least significant 16 bits + out.clear(); + for (auto val : in) + { + uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val); + uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF; + uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF; + // little endian: byte2 followed by byte3 + out.push_back(f32_byte2); + out.push_back(f32_byte3); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +{ + // Note: Converts fp32->FP8E4M3 before converting to unint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast<fp8e4m3>(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +{ + // Note: Converts fp32->FP8E5M2 before converting to uint8_t + out.clear(); + for (auto val : in) + { + auto f8 = static_cast<fp8e5m2>(val); + uint8_t b8 = f8.bits(); + out.push_back(b8); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out) +{ + // Note: Converts fp32->fp16 before converting to uint8_t + out.clear(); + for (auto val : in) + { + half_float::half val_f16 = half_float::half_cast<half_float::half, float>(val); + uint16_t* val_u16 = reinterpret_cast<uint16_t*>(&val_f16); + out.push_back(*val_u16 & 0xFF); + out.push_back((*val_u16 >> 8) & 0xFF); + } + ForceAlignTensorData(out); + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out) { out.clear(); @@ -757,7 +821,26 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out) +{ + out.clear(); + for (auto val : in) + { + uint64_t* val_u64 = reinterpret_cast<uint64_t*>(&val); + out.push_back(*val_u64 & 0xFF); + out.push_back((*val_u64 >> 8) & 0xFF); + out.push_back((*val_u64 >> 16) & 0xFF); + out.push_back((*val_u64 >> 24) & 0xFF); + out.push_back((*val_u64 >> 32) & 0xFF); + out.push_back((*val_u64 >> 40) & 0xFF); + out.push_back((*val_u64 >> 48) & 0xFF); + out.push_back((*val_u64 >> 56) & 0xFF); + } + ForceAlignTensorData(out); return TOSA_OK; } @@ -774,7 +857,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>& out.push_back((*val_u64 >> 32) & 0xFF); out.push_back((*val_u64 >> 40) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -789,7 +872,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>& out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -802,7 +885,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>& out.push_back(*val_u16 & 0xFF); out.push_back((*val_u16 >> 8) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -814,7 +897,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val); out.push_back(*val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -845,7 +928,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in uint8_t val_u8 = static_cast<uint8_t>(val_packed); out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -857,7 +940,105 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in uint8_t val_u8 = val; out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in, + uint32_t out_size, + std::vector<float>& out) +{ + // Note: bf16 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int16_t)) + { + printf("TosaSerializationHandler::ConvertU8toBF16(): uint8 buffer size %ld must >= target size %ld\n", + in.size(), out_size * sizeof(int16_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + uint32_t f32_byte2 = in[i * sizeof(int16_t)]; + uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1]; + uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24); + + // Reinterpret u32 bytes as fp32 + float val_f32 = *(float*)&val_u32; + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in, + uint32_t out_size, + std::vector<float>& out) +{ + // Note: FP8E4M3 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e4m3::from_bits(bits); + float val_f32 = static_cast<float>(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in, + uint32_t out_size, + std::vector<float>& out) +{ + // Note: FP8E5M2 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int8_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int8_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]); + auto f8 = fp8e5m2::from_bits(bits); + float val_f32 = static_cast<float>(f8); + out.push_back(val_f32); + } + return TOSA_OK; +} + +tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& in, + uint32_t out_size, + std::vector<half_float::half>& out) +{ + // Note: fp16 values returned in fp32 type + out.clear(); + if (in.size() < out_size * sizeof(int16_t)) + { + printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int16_t)); + return TOSA_USER_ERROR; + } + + for (uint32_t i = 0; i < out_size; i++) + { + uint16_t f16_byte0 = in[i * sizeof(int16_t)]; + uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1]; + uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8); + + // Reinterpret u16 byte as fp16 then convert to fp32 + half_float::half val_f16 = *(half_float::half*)&val_u16; + out.push_back(val_f16); + } return TOSA_OK; } @@ -884,6 +1065,35 @@ tosa_err_t return TOSA_OK; } +tosa_err_t TosaSerializationHandler::ConvertU8toI64(const std::vector<uint8_t>& in, + uint32_t out_size, + std::vector<int64_t>& out) +{ + out.clear(); + if (in.size() < out_size * sizeof(int64_t)) + { + printf("TosaSerializationHandler::ConvertU8toI64(): uint8 buffer size %ld must >= target size %ld\n", in.size(), + out_size * sizeof(int64_t)); + return TOSA_USER_ERROR; + } + for (uint32_t i = 0; i < out_size; i++) + { + uint64_t byte0 = in[i * sizeof(int64_t)]; + uint64_t byte1 = in[i * sizeof(int64_t) + 1]; + uint64_t byte2 = in[i * sizeof(int64_t) + 2]; + uint64_t byte3 = in[i * sizeof(int64_t) + 3]; + uint64_t byte4 = in[i * sizeof(int64_t) + 4]; + uint64_t byte5 = in[i * sizeof(int64_t) + 5]; + uint64_t byte6 = in[i * sizeof(int64_t) + 6]; + uint64_t byte7 = in[i * sizeof(int64_t) + 7]; + uint64_t val_u64 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24) + (byte4 << 32) + (byte5 << 40) + + (byte6 << 48) + (byte7 << 56); + int64_t* val_i64 = reinterpret_cast<int64_t*>(&val_u64); + out.push_back(*val_i64); + } + return TOSA_OK; +} + tosa_err_t TosaSerializationHandler::ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out) |