diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/tosa_serialization_handler.cpp | 60 |
1 files changed, 34 insertions, 26 deletions
diff --git a/src/tosa_serialization_handler.cpp b/src/tosa_serialization_handler.cpp index cbb862f..cb44f17 100644 --- a/src/tosa_serialization_handler.cpp +++ b/src/tosa_serialization_handler.cpp @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,10 +22,11 @@ using namespace tosa; TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name, const flatbuffers::Vector<int32_t>* shape, DType dtype, - const flatbuffers::Vector<uint8_t>* data) + const flatbuffers::Vector<uint8_t>* data, + bool variable) { - _dtype = dtype; - + _dtype = dtype; + _variable = variable; if (shape) { std::copy(shape->begin(), shape->end(), std::back_inserter(_shape)); @@ -43,18 +44,21 @@ TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name TosaSerializationTensor::TosaSerializationTensor(const std::string& name, const std::vector<int32_t>& shape, DType dtype, - const std::vector<uint8_t>& data) + const std::vector<uint8_t>& data, + bool variable) { - _dtype = dtype; - _shape = shape; - _name = name; - _data = data; + _dtype = dtype; + _variable = variable; + _shape = shape; + _name = name; + _data = data; } TosaSerializationTensor::TosaSerializationTensor() { - _dtype = DType_UNKNOWN; - _name = "UNKNOWN"; + _dtype = DType_UNKNOWN; + _variable = false; + _name = "UNKNOWN"; } TosaSerializationTensor::~TosaSerializationTensor() @@ -514,12 +518,14 @@ tosa_err_t TosaSerializationHandler::Deserialize(const uint8_t* buf) { auto curr_tensor = fb_tosa_tensors->Get(j); - auto tensor_name = curr_tensor->name(); - auto tensor_shape = curr_tensor->shape(); - auto tensor_type = curr_tensor->type(); - auto tensor_data = curr_tensor->data(); + auto tensor_name = curr_tensor->name(); + auto tensor_shape = curr_tensor->shape(); + auto tensor_type = curr_tensor->type(); + auto tensor_variable = curr_tensor->variable(); + auto tensor_data = curr_tensor->data(); - new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data); + new_tensor = + new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data, tensor_variable); if (new_tensor) { block_tensors_container.push_back(new_tensor); @@ -676,8 +682,10 @@ tosa_err_t TosaSerializationHandler::Serialize() auto tensor_name = _builder.CreateString(tensor->GetName().c_str()); auto tensor_shape = _builder.CreateVector(tensor->GetShape()); auto tensor_dtype = tensor->GetDtype(); + bool tensor_variable = tensor->GetVariable(); auto tensor_data = _builder.CreateVector(tensor->GetData()); - auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data); + auto fboffset_tensor = + CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data, tensor_variable); fboffset_block_tensors.push_back(fboffset_tensor); } auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors); @@ -702,7 +710,7 @@ tosa_err_t TosaSerializationHandler::Serialize() return TOSA_OK; } -void zero_pad(std::vector<uint8_t>& buf) +void TosaSerializationHandler::ForceAlignTensorData(std::vector<uint8_t>& buf) { while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0) { @@ -721,7 +729,7 @@ tosa_err_t TosaSerializationHandler::ConvertF16toU8(const std::vector<float>& in out.push_back(*val_u16 & 0xFF); out.push_back((*val_u16 >> 8) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -736,7 +744,7 @@ tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -753,7 +761,7 @@ tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>& out.push_back((*val_u64 >> 32) & 0xFF); out.push_back((*val_u64 >> 40) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -768,7 +776,7 @@ tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>& out.push_back((*val_u32 >> 16) & 0xFF); out.push_back((*val_u32 >> 24) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -781,7 +789,7 @@ tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>& out.push_back(*val_u16 & 0xFF); out.push_back((*val_u16 >> 8) & 0xFF); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -793,7 +801,7 @@ tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val); out.push_back(*val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -824,7 +832,7 @@ tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in uint8_t val_u8 = static_cast<uint8_t>(val_packed); out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } @@ -836,7 +844,7 @@ tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in uint8_t val_u8 = val; out.push_back(val_u8); } - zero_pad(out); + ForceAlignTensorData(out); return TOSA_OK; } |