aboutsummaryrefslogtreecommitdiff
path: root/include/tosa_serialization_handler.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/tosa_serialization_handler.h')
-rw-r--r--include/tosa_serialization_handler.h29
1 files changed, 23 insertions, 6 deletions
diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h
index 398590d..db9481b 100644
--- a/include/tosa_serialization_handler.h
+++ b/include/tosa_serialization_handler.h
@@ -26,6 +26,8 @@
#include <string>
#include <vector>
+#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
+
namespace tosa
{
@@ -108,13 +110,13 @@ class TosaSerializationTensor
public:
// constructor and destructor
TosaSerializationTensor(const flatbuffers::String* name,
- const flatbuffers::Vector<int32_t>& shape,
+ const flatbuffers::Vector<int32_t>* shape,
DType dtype,
- const flatbuffers::String* npy_filename);
+ const flatbuffers::Vector<uint8_t>* data);
TosaSerializationTensor(std::string& name,
const std::vector<int32_t>& shape,
DType dtype,
- const std::string& npy_filename);
+ const std::vector<uint8_t>& data);
TosaSerializationTensor();
~TosaSerializationTensor();
@@ -131,9 +133,9 @@ public:
{
return _dtype;
}
- const std::string& GetNpyFilePtr() const
+ const std::vector<uint8_t>& GetData() const
{
- return _npy_filename;
+ return _data;
}
// modifier
@@ -150,7 +152,7 @@ private:
DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
std::vector<int32_t> _shape; /* shape of the tensor */
std::string _name; /* name of the tensor, used for solving dependency */
- std::string _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
+ std::vector<uint8_t> _data; /* data array */
};
class TosaSerializationOperator
@@ -283,6 +285,21 @@ public:
tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
tosa_err_t LoadFileSchema(const char* schema_filename);
+ // data format conversion. little-endian.
+ static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
+ static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
+
+ static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
+ static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
+ static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
+ static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
+ static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
+ static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
+
// version
const TosaVersion& GetTosaVersion() const
{