aboutsummaryrefslogtreecommitdiff
path: root/src/tosa_serialization_handler.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tosa_serialization_handler.cpp')
-rw-r--r--src/tosa_serialization_handler.cpp794
1 files changed, 502 insertions, 292 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp
index 3a0ce43..0ce6211 100644
--- a/src/tosa_serialization_handler.cpp
+++ b/src/tosa_serialization_handler.cpp
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -14,18 +14,28 @@
// limitations under the License.
#include "tosa_serialization_handler.h"
+#include "half.hpp"
#include <iostream>
using namespace tosa;
+using fp8e4m3 = ct::cfloat<int8_t, 4, true, true, false>;
+using fp8e5m2 = ct::cfloat<int8_t, 5, true, true, true>;
+
TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::Vector<uint8_t>* data)
+ const flatbuffers::Vector<uint8_t>* data,
+ const bool variable,
+ const bool is_unranked,
+ const flatbuffers::String* variable_name)
{
- _dtype = dtype;
-
- std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
+ _dtype = dtype;
+ _variable = variable;
+ if (shape)
+ {
+ std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
+ }
assert(name);
_name = name->str();
@@ -34,23 +44,37 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name
{
std::copy(data->begin(), data->end(), std::back_inserter(_data));
}
+ _is_unranked = is_unranked;
+
+ if (variable_name)
+ {
+ _variable_name = variable_name->str();
+ }
}
TosaSerializationTensor::TosaSerializationTensor(const std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::vector<uint8_t>& data)
+ const std::vector<uint8_t>& data,
+ const bool variable,
+ const bool is_unranked,
+ const std::string& variable_name)
{
- _dtype = dtype;
- _shape = shape;
- _name = name;
- _data = data;
+ _dtype = dtype;
+ _variable = variable;
+ _shape = shape;
+ _name = name;
+ _data = data;
+ _is_unranked = is_unranked;
+ _variable_name = variable_name;
}
TosaSerializationTensor::TosaSerializationTensor()
{
- _dtype = DType_UNKNOWN;
- _name = "UNKNOWN";
+ _dtype = DType_UNKNOWN;
+ _variable = false;
+ _name = "UNKNOWN";
+ _is_unranked = false;
}
TosaSerializationTensor::~TosaSerializationTensor()
@@ -112,29 +136,33 @@ TosaSerializationOperator::~TosaSerializationOperator()
}
TosaSerializationBasicBlock::TosaSerializationBasicBlock(const std::string& name,
+ const std::string& region_name,
const std::vector<TosaSerializationOperator*>& operators,
const std::vector<TosaSerializationTensor*>& tensors,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
{
- _name = name;
- _operators = operators;
- _tensors = tensors;
- _inputs = inputs;
- _outputs = outputs;
+ _name = name;
+ _region_name = region_name;
+ _operators = operators;
+ _tensors = tensors;
+ _inputs = inputs;
+ _outputs = outputs;
}
TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string&& name,
+ std::string&& region_name,
std::vector<TosaSerializationOperator*>&& operators,
std::vector<TosaSerializationTensor*>&& tensors,
std::vector<std::string>&& inputs,
std::vector<std::string>&& outputs)
{
- _name = std::move(name);
- _operators = std::move(operators);
- _tensors = std::move(tensors);
- _inputs = std::move(inputs);
- _outputs = std::move(outputs);
+ _name = std::move(name);
+ _region_name = std::move(region_name);
+ _operators = std::move(operators);
+ _tensors = std::move(tensors);
+ _inputs = std::move(inputs);
+ _outputs = std::move(outputs);
}
TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
@@ -152,65 +180,38 @@ TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
}
}
-TosaSerializationHandler::TosaSerializationHandler()
+TosaSerializationRegion::TosaSerializationRegion(const std::string& name,
+ const std::vector<TosaSerializationBasicBlock*>& blocks)
{
- _schemaLoaded = false;
- _version = TosaVersion(TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT);
+ _name = name;
+ _blocks = blocks;
}
-TosaSerializationHandler::~TosaSerializationHandler()
+TosaSerializationRegion::TosaSerializationRegion(const std::string&& name,
+ const std::vector<TosaSerializationBasicBlock*>&& blocks)
{
- Clear(); // deallocate all basic blocks
+ _name = std::move(name);
+ _blocks = std::move(blocks);
}
-TosaVersion TosaSerializationHandler::ParseTosaSchemaVersion(std::string schema)
+TosaSerializationRegion::~TosaSerializationRegion()
{
- // Parse all 4 version fields in schema file
- static const char* keywords[4] = { "major: int32 = ", "minor: int32 = ", "patch: int32 = ", "draft: bool = " };
- string keyword_str[4];
- size_t search_pos = 0;
- size_t keyword_pos;
- size_t semicolon_pos;
- // parse integer field first
- for (int32_t i = 0; i < 4; i++)
- {
- keyword_pos = schema.find(keywords[i], search_pos);
- if (keyword_pos == std::string::npos)
- {
- printf("ERROR: can't find keyword \"%s\" in schema\n", keywords[i]);
- assert(0);
- }
- semicolon_pos = schema.find(';', keyword_pos);
- if (keyword_pos == std::string::npos)
- {
- printf("ERROR: can't find ';' in schema\n");
- assert(0);
- }
- keyword_str[i] =
- schema.substr(keyword_pos + strlen(keywords[i]), semicolon_pos - keyword_pos - strlen(keywords[i]));
- search_pos = semicolon_pos;
- }
-
- int32_t schema_major = 0;
- int32_t schema_minor = 0;
- int32_t schema_patch = 0;
- bool schema_draft = false;
- try
- {
- schema_major = stoi(keyword_str[0]);
- schema_minor = stoi(keyword_str[1]);
- schema_patch = stoi(keyword_str[2]);
- schema_draft = (keyword_str[3] == "true") ? true : false;
- }
- catch (std::invalid_argument& e)
+ // deallocate all blocks
+ for (auto block : GetBlocks())
{
- printf("ERROR: fail at stoi(): %s\n", e.what());
- assert(0);
+ delete block; // ~TosaSerializationBasicBlock()
}
+}
- TosaVersion schema_version(schema_major, schema_minor, schema_patch, schema_draft);
+TosaSerializationHandler::TosaSerializationHandler()
+{
+ _schemaLoaded = false;
+ _version = TosaVersion(TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT);
+}
- return schema_version;
+TosaSerializationHandler::~TosaSerializationHandler()
+{
+ Clear(); // deallocate all basic blocks
}
tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
@@ -227,23 +228,6 @@ tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
ok = _parser.Parse(schema.c_str());
- TosaVersion schema_version = ParseTosaSchemaVersion(schema);
-
- TosaVersion::compat_t is_compat = schema_version.is_compatible(GetVersion());
- switch (is_compat)
- {
- case TosaVersion::compat_t::COMPLETELY_COMPATIBLE:
- break;
- case TosaVersion::compat_t::PARTIALLY_COMPATIBLE:
- printf("WARNING: Schema flatbuffer version %s is partially compatible with serializer version %s\n",
- schema_version.to_string().c_str(), GetVersion().to_string().c_str());
- break;
- case TosaVersion::compat_t::NOT_COMPATIBLE:
- printf("ERROR: Schema flatbuffer version %s is not compatible with serializer version %s\n",
- schema_version.to_string().c_str(), GetVersion().to_string().c_str());
- return TOSA_VERSION_MISMATCH;
- }
-
if (!ok)
{
printf("Error parsing ISA schema file: %s\n", schema_filename);
@@ -308,7 +292,7 @@ tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
uint8_t* buf = _builder.GetBufferPointer();
- if (!GenerateText(_parser, buf, &jsongen))
+ if (GenText(_parser, buf, &jsongen))
{
printf("Couldn't serialize parsed data to JSON!\n");
return TOSA_FILE_ERROR;
@@ -399,11 +383,11 @@ tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename
tosa_err_t TosaSerializationHandler::Clear()
{
// deallocate all basic blocks
- for (auto bb : GetBlocks())
+ for (auto region : GetRegions())
{
- delete bb;
+ delete region;
}
- _blocks.clear();
+ _regions.clear();
return TOSA_OK;
}
@@ -416,20 +400,13 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
}
auto fb_tosa_graph = GetTosaGraph(buf);
auto fb_tosa_version = fb_tosa_graph->version();
- auto fb_tosa_blocks = fb_tosa_graph->blocks();
-
- std::vector<std::string> operator_inputs_container;
- std::vector<std::string> operator_outputs_container;
-
- std::vector<TosaSerializationOperator*> block_operators_container;
- std::vector<TosaSerializationTensor*> block_tensors_container;
- std::vector<std::string> block_inputs_container;
- std::vector<std::string> block_outputs_container;
+ auto fb_tosa_regions = fb_tosa_graph->regions();
TosaAttributeBase* typed_attribute = NULL;
TosaSerializationOperator* new_operator = NULL;
TosaSerializationBasicBlock* new_block = NULL;
TosaSerializationTensor* new_tensor = NULL;
+ TosaSerializationRegion* new_region = NULL;
// erase container
Clear();
@@ -437,226 +414,241 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf)
TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
fb_tosa_version->_draft());
- TosaVersion::compat_t is_compat = read_version.is_compatible(GetVersion());
+ TosaVersion::compat_t is_compat = TosaVersion::is_compatible(read_version, GetVersion());
switch (is_compat)
{
case TosaVersion::compat_t::COMPLETELY_COMPATIBLE:
break;
- case TosaVersion::compat_t::PARTIALLY_COMPATIBLE:
- printf("WARNING: Read flatbuffer version %s is partially compatible with serializer version %s\n",
+ case TosaVersion::compat_t::BACKWARD_COMPATIBLE:
+ printf("WARNING: Different Tosa flatbuffer and serializer versions detected. Read Tosa flatbuffer version "
+ "%s is backward "
+ "compatible with serializer version %s\n",
read_version.to_string().c_str(), GetVersion().to_string().c_str());
break;
case TosaVersion::compat_t::NOT_COMPATIBLE:
- printf("ERROR: Read flatbuffer version %s is not compatible with serializer version %s\n",
+ printf("ERROR: Read Tosa flatbuffer version %s is not compatible with serializer version %s\n",
read_version.to_string().c_str(), GetVersion().to_string().c_str());
return TOSA_VERSION_MISMATCH;
}
- for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
+ for (size_t i = 0; i < fb_tosa_regions->size(); i++)
{
- auto curr_block = fb_tosa_blocks->Get(i);
+ auto curr_region = fb_tosa_regions->Get(i);
+ auto region_name = curr_region->name()->str();
+ auto fb_tosa_blocks = curr_region->blocks();
- auto block_name = curr_block->name()->str();
+ new_region = new TosaSerializationRegion(curr_region->name()->str(), {});
+ this->GetRegions().push_back(new_region);
- auto fb_tosa_operators = curr_block->operators();
- block_operators_container.clear();
- for (size_t j = 0; j < fb_tosa_operators->size(); j++)
+ for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
{
- auto curr_operator = fb_tosa_operators->Get(j);
+ std::vector<TosaSerializationOperator*> block_operators_container;
+ std::vector<TosaSerializationTensor*> block_tensors_container;
+ std::vector<std::string> block_inputs_container;
+ std::vector<std::string> block_outputs_container;
- auto operator_op = curr_operator->op();
- auto attribute_type = curr_operator->attribute_type();
- auto attribute = curr_operator->attribute();
+ auto curr_block = fb_tosa_blocks->Get(i);
- // input tensors
- auto operator_inputs = curr_operator->inputs();
- operator_inputs_container.clear();
- if (operator_inputs)
+ auto block_name = curr_block->name()->str();
+
+ auto fb_tosa_operators = curr_block->operators();
+ for (size_t j = 0; j < fb_tosa_operators->size(); j++)
{
- for (size_t k = 0; k < operator_inputs->size(); k++)
+ auto curr_operator = fb_tosa_operators->Get(j);
+
+ auto operator_op = curr_operator->op();
+ auto attribute_type = curr_operator->attribute_type();
+ auto attribute = curr_operator->attribute();
+
+ std::vector<std::string> operator_inputs_container;
+ std::vector<std::string> operator_outputs_container;
+
+ // input tensors
+ auto operator_inputs = curr_operator->inputs();
+ if (operator_inputs)
{
- auto curr_input = operator_inputs->Get(k);
- operator_inputs_container.push_back(curr_input->str());
+ for (size_t k = 0; k < operator_inputs->size(); k++)
+ {
+ auto curr_input = operator_inputs->Get(k);
+ operator_inputs_container.push_back(curr_input->str());
+ }
}
- }
- // output tensors
- auto operator_outputs = curr_operator->outputs();
- operator_outputs_container.clear();
- if (operator_outputs)
- {
- for (size_t k = 0; k < operator_outputs->size(); k++)
+ // output tensors
+ auto operator_outputs = curr_operator->outputs();
+ if (operator_outputs)
{
- auto curr_output = operator_outputs->Get(k);
- operator_outputs_container.push_back(curr_output->str());
+ for (size_t k = 0; k < operator_outputs->size(); k++)
+ {
+ auto curr_output = operator_outputs->Get(k);
+ operator_outputs_container.push_back(curr_output->str());
+ }
}
- }
- switch (attribute_type)
- {
- case Attribute_NONE:
- typed_attribute = new TosaNoneAttribute();
- break;
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ typed_attribute = new TosaNoneAttribute();
+ break;
#define DEF_ATTRIBUTE(NAME, ...) \
case Attribute_##NAME##Attribute: \
typed_attribute = new Tosa##NAME##Attribute(attribute); \
break;
#include "attribute.def"
#undef DEF_ATTRIBUTE
- default:
- printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n",
- EnumNamesAttribute()[attribute_type]);
- return TOSA_INTERNAL_ERROR;
+ default:
+ printf("TosaSerializationHandler::Deserialize(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+
+ new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute,
+ operator_inputs_container, operator_outputs_container);
+ if (new_operator)
+ {
+ block_operators_container.push_back(new_operator);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+
+ if (typed_attribute)
+ delete typed_attribute;
}
- new_operator = new TosaSerializationOperator(operator_op, attribute_type, typed_attribute,
- operator_inputs_container, operator_outputs_container);
- if (new_operator)
+ auto block_inputs = curr_block->inputs();
+ auto block_outputs = curr_block->outputs();
+
+ for (size_t j = 0; j < block_inputs->size(); j++)
{
- block_operators_container.push_back(new_operator);
+ auto curr_block_input = block_inputs->Get(j);
+ block_inputs_container.push_back(curr_block_input->str());
}
- else
+ for (size_t j = 0; j < block_outputs->size(); j++)
{
- return TOSA_MEMORY_ERROR;
+ auto curr_block_output = block_outputs->Get(j);
+ block_outputs_container.push_back(curr_block_output->str());
}
- if (typed_attribute)
- delete typed_attribute;
- }
-
- auto fb_tosa_tensors = curr_block->tensors();
- block_tensors_container.clear();
- for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
- {
- auto curr_tensor = fb_tosa_tensors->Get(j);
-
- auto tensor_name = curr_tensor->name();
- auto tensor_shape = curr_tensor->shape();
- auto tensor_type = curr_tensor->type();
- auto tensor_data = curr_tensor->data();
-
- new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data);
- if (new_tensor)
+ auto fb_tosa_tensors = curr_block->tensors();
+ for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
+ {
+ auto curr_tensor = fb_tosa_tensors->Get(j);
+
+ auto tensor_name = curr_tensor->name();
+ auto tensor_shape = curr_tensor->shape();
+ auto tensor_type = curr_tensor->type();
+ auto tensor_variable = curr_tensor->variable();
+ auto tensor_data = curr_tensor->data();
+ auto tensor_is_unranked = curr_tensor->is_unranked();
+ auto tensor_variable_name = curr_tensor->variable_name();
+
+ new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data,
+ tensor_variable, tensor_is_unranked, tensor_variable_name);
+ if (new_tensor)
+ {
+ block_tensors_container.push_back(new_tensor);
+ }
+ else
+ {
+ return TOSA_MEMORY_ERROR;
+ }
+ }
+ new_block = new TosaSerializationBasicBlock(block_name, region_name, block_operators_container,
+ block_tensors_container, block_inputs_container,
+ block_outputs_container);
+ if (new_block)
{
- block_tensors_container.push_back(new_tensor);
+ new_region->GetBlocks().push_back(new_block);
}
else
{
return TOSA_MEMORY_ERROR;
}
- }
-
- auto block_inputs = curr_block->inputs();
- auto block_outputs = curr_block->outputs();
-
- block_inputs_container.clear();
- block_outputs_container.clear();
-
- for (size_t j = 0; j < block_inputs->size(); j++)
- {
- auto curr_block_input = block_inputs->Get(j);
- block_inputs_container.push_back(curr_block_input->str());
- }
- for (size_t j = 0; j < block_outputs->size(); j++)
- {
- auto curr_block_output = block_outputs->Get(j);
- block_outputs_container.push_back(curr_block_output->str());
- }
-
- new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container,
- block_inputs_container, block_outputs_container);
- if (new_block)
- {
- this->GetBlocks().push_back(new_block);
- }
- else
- {
- return TOSA_MEMORY_ERROR;
- }
+ } // end block for_loop
}
return TOSA_OK;
}
-tosa_err_t TosaSerializationHandler::Serialize()
+std::vector<uint8_t> float_to_u8_helper(float f_in)
{
- std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
-
- std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
- std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
+ // Push back a single float value to the buffer with *NO PADDING*
+ // Therefore ConvertF32toU8 function not used
+ std::vector<uint8_t> u8_out;
+ uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&f_in);
+ u8_out.push_back(*val_u32 & 0xFF);
+ u8_out.push_back((*val_u32 >> 8) & 0xFF);
+ u8_out.push_back((*val_u32 >> 16) & 0xFF);
+ u8_out.push_back((*val_u32 >> 24) & 0xFF);
+ return u8_out;
+}
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
- std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
+tosa_err_t TosaSerializationHandler::Serialize()
+{
+ // regions
+ std::vector<flatbuffers::Offset<TosaRegion>> fboffset_regions;
// translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
- for (auto block : GetBlocks())
+ for (auto region : GetRegions())
{
- fboffset_block_operators.clear();
- fboffset_block_tensors.clear();
- fboffset_block_inputs.clear();
- fboffset_block_outputs.clear();
-
- auto block_name = _builder.CreateString(block->GetName().c_str());
-
- for (auto tensor_str : block->GetInputs())
- {
- auto tensor_name = _builder.CreateString(tensor_str.c_str());
- fboffset_block_inputs.push_back(tensor_name);
- }
-
- for (auto tensor_str : block->GetOutputs())
- {
- auto tensor_name = _builder.CreateString(tensor_str.c_str());
- fboffset_block_outputs.push_back(tensor_name);
- }
-
- auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs);
- auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs);
-
- for (auto op : block->GetOperators())
+ std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
+ for (auto block : region->GetBlocks())
{
- fboffset_operator_inputs.clear();
- fboffset_operator_outputs.clear();
-
- auto operator_op = op->GetOp();
- auto attribute_type = op->GetAttributeType();
-
- for (auto tensor_str : op->GetInputTensorNames())
+ std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
+ std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
+ auto block_name = _builder.CreateString(block->GetName().c_str());
+ for (auto tensor_str : block->GetInputs())
{
auto tensor_name = _builder.CreateString(tensor_str.c_str());
- fboffset_operator_inputs.push_back(tensor_name);
+ fboffset_block_inputs.push_back(tensor_name);
}
-
- for (auto tensor_str : op->GetOutputTensorNames())
+ for (auto tensor_str : block->GetOutputs())
{
auto tensor_name = _builder.CreateString(tensor_str.c_str());
- fboffset_operator_outputs.push_back(tensor_name);
+ fboffset_block_outputs.push_back(tensor_name);
}
-
- auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs);
- auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs);
-
- flatbuffers::Offset<void> fb_attribute;
- switch (attribute_type)
+ auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs);
+ auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs);
+ for (auto op : block->GetOperators())
{
- case Attribute_NONE:
- fb_attribute = 0;
- break;
-
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
+ std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
+ auto operator_op = op->GetOp();
+ auto attribute_type = op->GetAttributeType();
+ for (auto tensor_str : op->GetInputTensorNames())
+ {
+ auto tensor_name = _builder.CreateString(tensor_str.c_str());
+ fboffset_operator_inputs.push_back(tensor_name);
+ }
+ for (auto tensor_str : op->GetOutputTensorNames())
+ {
+ auto tensor_name = _builder.CreateString(tensor_str.c_str());
+ fboffset_operator_outputs.push_back(tensor_name);
+ }
+ auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs);
+ auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs);
+ flatbuffers::Offset<void> fb_attribute;
+ switch (attribute_type)
+ {
+ case Attribute_NONE:
+ fb_attribute = 0;
+ break;
#define DEF_ARGS_S_STR(NAME, V) , _builder.CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str())
+#define DEF_ARGS_S_FP_as_U8(NAME, V) \
+ , _builder.CreateVector<uint8_t>(float_to_u8_helper(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()))
#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()
-
#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
-#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_FP_as_U8(NAME, V)
#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
+#define DEF_ARGS_S_DType(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V)
-
#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V)
#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())
-
#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
@@ -672,11 +664,20 @@ tosa_err_t TosaSerializationHandler::Serialize()
#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
+#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7)
+#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
+ V7, T8, F8, V8) \
+ DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
+ DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
+ DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8)
#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
case Attribute_##NAME##Attribute: \
fb_attribute = Create##NAME##Attribute(_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \
break;
-
#include "attribute.def"
#undef DEF_ATTRIBUTE
#undef DEF_ARGS_1
@@ -692,53 +693,56 @@ tosa_err_t TosaSerializationHandler::Serialize()
#undef DEF_ARGS_S_float
#undef DEF_ARGS_S_bool
#undef DEF_ARGS_S_ResizeMode
+#undef DEF_ARGS_S_DType
#undef DEF_ARGS_S_string
#undef DEF_ARGS_S_STR
#undef DEF_ARGS_S_DEFAULT
- default:
- printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n",
- EnumNamesAttribute()[attribute_type]);
- return TOSA_INTERNAL_ERROR;
+ default:
+ printf("TosaSerializationHandler::Serialize(): Attribute %s not implemented yet\n",
+ EnumNamesAttribute()[attribute_type]);
+ return TOSA_INTERNAL_ERROR;
+ }
+ auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute,
+ fb_operator_inputs, fb_operator_outputs);
+ fboffset_block_operators.push_back(fboffset_operator);
}
+ auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
+ for (auto tensor : block->GetTensors())
+ {
+ auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
+ auto tensor_shape = _builder.CreateVector(tensor->GetShape());
+ auto tensor_dtype = tensor->GetDtype();
+ bool tensor_variable = tensor->GetVariable();
+ auto tensor_data = _builder.CreateVector(tensor->GetData());
+ auto tensor_is_unranked = tensor->GetIsUnranked();
+ auto tensor_variable_name = _builder.CreateString(tensor->GetVariableName().c_str());
+ auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data,
+ tensor_variable, tensor_is_unranked, tensor_variable_name);
+ fboffset_block_tensors.push_back(fboffset_tensor);
+ }
+ auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
+ auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors,
+ fb_block_inputs, fb_block_outputs);
+ fboffset_blocks.push_back(fboffset_block);
+ } // end block for_loop
+ auto fb_blocks = _builder.CreateVector(fboffset_blocks);
- auto fboffset_operator = CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute,
- fb_operator_inputs, fb_operator_outputs);
- fboffset_block_operators.push_back(fboffset_operator);
- }
-
- auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
-
- for (auto tensor : block->GetTensors())
- {
-
- auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
- auto tensor_shape = _builder.CreateVector(tensor->GetShape());
- auto tensor_dtype = tensor->GetDtype();
- auto tensor_data = _builder.CreateVector(tensor->GetData());
-
- auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data);
- fboffset_block_tensors.push_back(fboffset_tensor);
- }
-
- auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
-
- auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors,
- fb_block_inputs, fb_block_outputs);
- fboffset_blocks.push_back(fboffset_block);
- }
+ auto region_name = _builder.CreateString(region->GetName().c_str());
+ auto fboffset_region = CreateTosaRegion(_builder, region_name, fb_blocks);
+ fboffset_regions.push_back(fboffset_region);
+ } // end region for_loop
- auto fb_blocks = _builder.CreateVector(fboffset_blocks);
+ auto fb_regions = _builder.CreateVector(fboffset_regions);
auto fb_version =
CreateVersion(_builder, TOSA_VERSION_MAJOR, TOSA_VERSION_MINOR, TOSA_VERSION_PATCH, TOSA_VERSION_DRAFT);
-
- auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks);
+ auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_regions);
_builder.Finish(fb_graph, TosaGraphIdentifier());
return TOSA_OK;
}
-void zero_pad(std::vector<uint8_t>& buf)
+void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf)
{
while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0)
{
@@ -746,6 +750,66 @@ void zero_pad(std::vector<uint8_t>& buf)
}
}
+tosa_err_t TosaSerializationHandler::ConvertBF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->bf16 by ignoring the least significant 16 bits
+ out.clear();
+ for (auto val : in)
+ {
+ uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
+ uint8_t f32_byte2 = (*val_u32 >> 16) & 0xFF;
+ uint8_t f32_byte3 = (*val_u32 >> 24) & 0xFF;
+ // little endian: byte2 followed by byte3
+ out.push_back(f32_byte2);
+ out.push_back(f32_byte3);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertFP8E4M3toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->FP8E4M3 before converting to unint8_t
+ out.clear();
+ for (auto val : in)
+ {
+ auto f8 = static_cast<fp8e4m3>(val);
+ uint8_t b8 = f8.bits();
+ out.push_back(b8);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertFP8E5M2toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->FP8E5M2 before converting to uint8_t
+ out.clear();
+ for (auto val : in)
+ {
+ auto f8 = static_cast<fp8e5m2>(val);
+ uint8_t b8 = f8.bits();
+ out.push_back(b8);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
+{
+ // Note: Converts fp32->fp16 before converting to uint8_t
+ out.clear();
+ for (auto val : in)
+ {
+ half_float::half val_f16 = half_float::half_cast<half_float::half, float>(val);
+ uint16_t* val_u16 = reinterpret_cast<uint16_t*>(&val_f16);
+ out.push_back(*val_u16 & 0xFF);
+ out.push_back((*val_u16 >> 8) & 0xFF);
+ }
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
{
out.clear();
@@ -757,7 +821,26 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out)
+{
+ out.clear();
+ for (auto val : in)
+ {
+ uint64_t* val_u64 = reinterpret_cast<uint64_t*>(&val);
+ out.push_back(*val_u64 & 0xFF);
+ out.push_back((*val_u64 >> 8) & 0xFF);
+ out.push_back((*val_u64 >> 16) & 0xFF);
+ out.push_back((*val_u64 >> 24) & 0xFF);
+ out.push_back((*val_u64 >> 32) & 0xFF);
+ out.push_back((*val_u64 >> 40) & 0xFF);
+ out.push_back((*val_u64 >> 48) & 0xFF);
+ out.push_back((*val_u64 >> 56) & 0xFF);
+ }
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -774,7 +857,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>&
out.push_back((*val_u64 >> 32) & 0xFF);
out.push_back((*val_u64 >> 40) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -789,7 +872,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>&
out.push_back((*val_u32 >> 16) & 0xFF);
out.push_back((*val_u32 >> 24) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -802,7 +885,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>&
out.push_back(*val_u16 & 0xFF);
out.push_back((*val_u16 >> 8) & 0xFF);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -814,7 +897,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in
uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
out.push_back(*val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -845,7 +928,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in
uint8_t val_u8 = static_cast<uint8_t>(val_packed);
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
return TOSA_OK;
}
@@ -857,7 +940,105 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in
uint8_t val_u8 = val;
out.push_back(val_u8);
}
- zero_pad(out);
+ ForceAlignTensorData(out);
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toBF16(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: bf16 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int16_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toBF16(): uint8 buffer size %ld must >= target size %ld\n",
+ in.size(), out_size * sizeof(int16_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ uint32_t f32_byte2 = in[i * sizeof(int16_t)];
+ uint32_t f32_byte3 = in[i * sizeof(int16_t) + 1];
+ uint32_t val_u32 = (f32_byte2 << 16) + (f32_byte3 << 24);
+
+ // Reinterpret u32 bytes as fp32
+ float val_f32 = *(float*)&val_u32;
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toFP8E4M3(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: FP8E4M3 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int8_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int8_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e4m3::from_bits(bits);
+ float val_f32 = static_cast<float>(f8);
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toFP8E5M2(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<float>& out)
+{
+ // Note: FP8E5M2 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int8_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int8_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ int8_t bits = static_cast<int8_t>(in[i * sizeof(int8_t)]);
+ auto f8 = fp8e5m2::from_bits(bits);
+ float val_f32 = static_cast<float>(f8);
+ out.push_back(val_f32);
+ }
+ return TOSA_OK;
+}
+
+tosa_err_t TosaSerializationHandler::ConvertU8toF16(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<half_float::half>& out)
+{
+ // Note: fp16 values returned in fp32 type
+ out.clear();
+ if (in.size() < out_size * sizeof(int16_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toF16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int16_t));
+ return TOSA_USER_ERROR;
+ }
+
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ uint16_t f16_byte0 = in[i * sizeof(int16_t)];
+ uint16_t f16_byte1 = in[i * sizeof(int16_t) + 1];
+ uint16_t val_u16 = f16_byte0 + (f16_byte1 << 8);
+
+ // Reinterpret u16 byte as fp16 then convert to fp32
+ half_float::half val_f16 = *(half_float::half*)&val_u16;
+ out.push_back(val_f16);
+ }
return TOSA_OK;
}
@@ -884,6 +1065,35 @@ tosa_err_t
return TOSA_OK;
}
+tosa_err_t TosaSerializationHandler::ConvertU8toI64(const std::vector<uint8_t>& in,
+ uint32_t out_size,
+ std::vector<int64_t>& out)
+{
+ out.clear();
+ if (in.size() < out_size * sizeof(int64_t))
+ {
+ printf("TosaSerializationHandler::ConvertU8toI64(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
+ out_size * sizeof(int64_t));
+ return TOSA_USER_ERROR;
+ }
+ for (uint32_t i = 0; i < out_size; i++)
+ {
+ uint64_t byte0 = in[i * sizeof(int64_t)];
+ uint64_t byte1 = in[i * sizeof(int64_t) + 1];
+ uint64_t byte2 = in[i * sizeof(int64_t) + 2];
+ uint64_t byte3 = in[i * sizeof(int64_t) + 3];
+ uint64_t byte4 = in[i * sizeof(int64_t) + 4];
+ uint64_t byte5 = in[i * sizeof(int64_t) + 5];
+ uint64_t byte6 = in[i * sizeof(int64_t) + 6];
+ uint64_t byte7 = in[i * sizeof(int64_t) + 7];
+ uint64_t val_u64 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24) + (byte4 << 32) + (byte5 << 40) +
+ (byte6 << 48) + (byte7 << 56);
+ int64_t* val_i64 = reinterpret_cast<int64_t*>(&val_u64);
+ out.push_back(*val_i64);
+ }
+ return TOSA_OK;
+}
+
tosa_err_t TosaSerializationHandler::ConvertU8toI48(const std::vector<uint8_t>& in,
uint32_t out_size,
std::vector<int64_t>& out)