diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/tosa_generated.h | 25 | ||||
-rw-r--r-- | include/tosa_serialization_handler.h | 29 |
2 files changed, 36 insertions, 18 deletions
diff --git a/include/tosa_generated.h b/include/tosa_generated.h index 735aca8..0f73819 100644 --- a/include/tosa_generated.h +++ b/include/tosa_generated.h @@ -1951,7 +1951,7 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_NAME = 4, VT_SHAPE = 6, VT_TYPE = 8, - VT_NPY_FILENAME = 10 + VT_DATA = 10 }; const flatbuffers::String *name() const { return GetPointer<const flatbuffers::String *>(VT_NAME); @@ -1962,8 +1962,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { tosa::DType type() const { return static_cast<tosa::DType>(GetField<uint32_t>(VT_TYPE, 0)); } - const flatbuffers::String *npy_filename() const { - return GetPointer<const flatbuffers::String *>(VT_NPY_FILENAME); + const flatbuffers::Vector<uint8_t> *data() const { + return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_DATA); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1972,8 +1972,8 @@ struct TosaTensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_SHAPE) && verifier.VerifyVector(shape()) && VerifyField<uint32_t>(verifier, VT_TYPE) && - VerifyOffset(verifier, VT_NPY_FILENAME) && - verifier.VerifyString(npy_filename()) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && verifier.EndTable(); } }; @@ -1991,8 +1991,8 @@ struct TosaTensorBuilder { void add_type(tosa::DType type) { fbb_.AddElement<uint32_t>(TosaTensor::VT_TYPE, static_cast<uint32_t>(type), 0); } - void add_npy_filename(flatbuffers::Offset<flatbuffers::String> npy_filename) { - fbb_.AddOffset(TosaTensor::VT_NPY_FILENAME, npy_filename); + void add_data(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data) { + fbb_.AddOffset(TosaTensor::VT_DATA, data); } explicit TosaTensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { @@ -2011,9 +2011,9 @@ inline flatbuffers::Offset<TosaTensor> CreateTosaTensor( flatbuffers::Offset<flatbuffers::String> name = 0, flatbuffers::Offset<flatbuffers::Vector<int32_t>> shape = 0, tosa::DType type = tosa::DType_UNKNOWN, - flatbuffers::Offset<flatbuffers::String> npy_filename = 0) { + flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data = 0) { TosaTensorBuilder builder_(_fbb); - builder_.add_npy_filename(npy_filename); + builder_.add_data(data); builder_.add_type(type); builder_.add_shape(shape); builder_.add_name(name); @@ -2025,16 +2025,17 @@ inline flatbuffers::Offset<TosaTensor> CreateTosaTensorDirect( const char *name = nullptr, const std::vector<int32_t> *shape = nullptr, tosa::DType type = tosa::DType_UNKNOWN, - const char *npy_filename = nullptr) { + const std::vector<uint8_t> *data = nullptr) { auto name__ = name ? _fbb.CreateString(name) : 0; auto shape__ = shape ? _fbb.CreateVector<int32_t>(*shape) : 0; - auto npy_filename__ = npy_filename ? _fbb.CreateString(npy_filename) : 0; + if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 8); } + auto data__ = data ? _fbb.CreateVector<uint8_t>(*data) : 0; return tosa::CreateTosaTensor( _fbb, name__, shape__, type, - npy_filename__); + data__); } struct TosaOperator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { diff --git a/include/tosa_serialization_handler.h b/include/tosa_serialization_handler.h index 398590d..db9481b 100644 --- a/include/tosa_serialization_handler.h +++ b/include/tosa_serialization_handler.h @@ -26,6 +26,8 @@ #include <string> #include <vector> +#define TENSOR_BUFFER_FORCE_ALIGNMENT 8 + namespace tosa { @@ -108,13 +110,13 @@ class TosaSerializationTensor public: // constructor and destructor TosaSerializationTensor(const flatbuffers::String* name, - const flatbuffers::Vector<int32_t>& shape, + const flatbuffers::Vector<int32_t>* shape, DType dtype, - const flatbuffers::String* npy_filename); + const flatbuffers::Vector<uint8_t>* data); TosaSerializationTensor(std::string& name, const std::vector<int32_t>& shape, DType dtype, - const std::string& npy_filename); + const std::vector<uint8_t>& data); TosaSerializationTensor(); ~TosaSerializationTensor(); @@ -131,9 +133,9 @@ public: { return _dtype; } - const std::string& GetNpyFilePtr() const + const std::vector<uint8_t>& GetData() const { - return _npy_filename; + return _data; } // modifier @@ -150,7 +152,7 @@ private: DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ std::vector<int32_t> _shape; /* shape of the tensor */ std::string _name; /* name of the tensor, used for solving dependency */ - std::string _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ + std::vector<uint8_t> _data; /* data array */ }; class TosaSerializationOperator @@ -283,6 +285,21 @@ public: tosa_err_t SaveFileTosaFlatbuffer(const char* filename); tosa_err_t LoadFileSchema(const char* schema_filename); + // data format conversion. little-endian. + static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out); + static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out); + + static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out); + static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out); + static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out); + static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out); + static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out); + static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out); + // version const TosaVersion& GetTosaVersion() const { |