diff options
Diffstat (limited to 'serialization/tosa_serialization_handler.h')
-rw-r--r-- | serialization/tosa_serialization_handler.h | 423 |
1 files changed, 0 insertions, 423 deletions
diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h deleted file mode 100644 index 124b8e0..0000000 --- a/serialization/tosa_serialization_handler.h +++ /dev/null @@ -1,423 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef _TOSA_SERIALIZATION_HANDLER_H -#define _TOSA_SERIALIZATION_HANDLER_H -#include "attribute.h" -#include "flatbuffers/idl.h" -#include "flatbuffers/util.h" -#include "quant_info.h" -#include "tosa_generated.h" -#include <cstdint> -#include <memory> -#include <string> -#include <vector> - -namespace tosa -{ - -enum tosa_err_t -{ - TOSA_OK, - TOSA_USER_ERROR, - TOSA_FILE_ERROR, - TOSA_MEMORY_ERROR, - TOSA_SCHEMA_MISSING, - TOSA_INTERNAL_ERROR, - TOSA_VERSION_MISMATCH, - NUM_TOSA_ERROR -}; - -struct TosaVersion -{ - int32_t _major; - int32_t _minor; - int32_t _patch; - bool _experimental; - - TosaVersion() = delete; - TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental) - { - _major = major; - _minor = minor; - _patch = patch; - _experimental = experimental; - } - - std::string to_string() const - { - std::string str; - str += std::to_string(_major) + "."; - str += std::to_string(_minor) + "."; - str += std::to_string(_patch); - if (_experimental) - str += "(experimental)"; - return str; - }; - - bool operator==(const TosaVersion& rhs) - { - if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental) - { - return true; - } - return false; - } - - bool operator!=(const TosaVersion& rhs) - { - return !((*this) == rhs); - } -}; - -class TosaSerializationHandler; - -class TosaSerializationTensor -{ -public: - // constructor and destructor - TosaSerializationTensor(const flatbuffers::String* name, - const flatbuffers::Vector<uint32_t>& usage, - const flatbuffers::Vector<int32_t>& shape, - DType dtype, - const flatbuffers::Vector<uint32_t>& format, - const flatbuffers::String* npy_filename); - TosaSerializationTensor(std::string name, - const std::vector<Usage>& usage, - const std::vector<int32_t>& shape, - DType dtype, - const std::vector<Format>& format, - const std::string* npy_filename); - TosaSerializationTensor(); - ~TosaSerializationTensor(); - - // copy constructor/assignment - TosaSerializationTensor(const TosaSerializationTensor& rhs); - TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs); - - // move constructor/assignment - TosaSerializationTensor(TosaSerializationTensor&& rhs); - TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs); - - // accessor - std::string GetName() const - { - return *_name; - } - const std::vector<int32_t>& GetShape() const - { - return *_shape; - } - DType GetDtype() - { - return _dtype; - } - bool HasFormat(Format format) - { - for (Format us : *_format) - { - if (us == format) - return true; - } - return false; - } - std::vector<Format>& GetFormat() - { - return *_format; - } - bool HasUsage(Usage usage) - { - for (Usage us : *_usage) - { - if (us == usage) - return true; - } - return false; - } - std::vector<Usage>& GetUsage() - { - return *_usage; - } - std::string* GetNpyFilePtr() const - { - return _npy_filename; - } - - // modifier - void SetDtype(DType dtype) - { - _dtype = dtype; - } - void SetName(std::string name) - { - *_name = name; - } - -private: - DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ - std::vector<Format>* _format; /* list of possible tensor format */ - std::vector<Usage>* _usage; /* list of possible tensor usage */ - std::vector<int32_t>* _shape; /* shape of the tensor */ - std::string* _name; /* name of the tensor, used for solving dependency */ - std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ -}; - -class TosaSerializationOperator -{ -public: - // use default copy, void constructor - // constructor and destructor - TosaSerializationOperator(Op op_name, - Attribute attribute_type, - const TosaAttributeBase* attribute, - QuantInfo qinfo_type, - const TosaQuantInfoBase* qinfo, - std::vector<std::string> input_tensor_names, - std::vector<std::string> output_tensor_names); - ~TosaSerializationOperator(); - - // accessor - Op GetOp() const - { - return _op; - } - Attribute GetAttributeType() const - { - return _attribute_type; - } - TosaAttributeBase* GetAttribute() const - { - return _attribute; - } - QuantInfo GetQInfoType() const - { - return _qinfo_type; - } - TosaQuantInfoBase* GetQInfo() const - { - return _qinfo; - } - std::vector<std::string>& GetInputTensorNames() const - { - return *_input_tensor_names; - } - std::vector<std::string>& GetOutputTensorNames() const - { - return *_output_tensor_names; - } - std::vector<TosaSerializationTensor*>& GetInputTensors() const - { - return *_input_tensors; - } - std::vector<TosaSerializationTensor*>& GetOutputTensors() const - { - return *_output_tensors; - } - -private: - Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ - Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ - TosaAttributeBase* _attribute; /* real attribute class goes here */ - QuantInfo _qinfo_type; /* QuantInfo enum */ - TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ - std::vector<std::string>* _input_tensor_names; /* array of input tensor names */ - std::vector<std::string>* _output_tensor_names; /* array of output tensor names */ - - std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */ - std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */ -}; - -class TosaSerializationBasicBlock -{ -public: - // constructor and destructor - TosaSerializationBasicBlock(std::string name, - std::vector<TosaSerializationOperator*> operators, - std::vector<TosaSerializationTensor*> tensors, - std::vector<std::string> inputs, - std::vector<std::string> outputs); - ~TosaSerializationBasicBlock(); - - // accessor - std::string GetName() const - { - return *_name; - } - std::vector<TosaSerializationOperator*>& GetOperators() - { - return *_operators; - } - std::vector<TosaSerializationTensor*>& GetTensors() - { - return *_tensors; - } - - TosaSerializationTensor* GetTensorByName(std::string name) - { - TosaSerializationTensor* result = nullptr; - for (auto tensor : GetTensors()) - { - if (tensor->GetName() == name) - { - result = tensor; - break; - } - } - return result; - } - - std::vector<std::string>& GetInputs() - { - return *_inputs; - } - std::vector<std::string>& GetOutputs() - { - return *_outputs; - } - -private: - std::string* _name; /* name of basic block */ - std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */ - std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */ - std::vector<std::string>* _inputs; /* array of string to specify block inputs */ - std::vector<std::string>* _outputs; /* array of string to specify block outputs */ -}; - -/* - * this is a helper class for writing/reading Tosa ISA - * supported format: .tosa (flatbuffer), .json - * and provide high-level std::vector-like interface - * to access internal data structure - */ -class TosaSerializationHandler -{ -public: - // constructor and destructor - TosaSerializationHandler(); - ~TosaSerializationHandler(); - - // file io - tosa_err_t LoadFileJson(const char* filename); - tosa_err_t LoadFileTosaFlatbuffer(const char* filename); - tosa_err_t SaveFileJson(const char* filename); - tosa_err_t SaveFileTosaFlatbuffer(const char* filename); - tosa_err_t LoadFileSchema(const char* filename); - - // version - TosaVersion* GetTosaVersion() const - { - return _version; - } - - // accessor - std::vector<TosaSerializationBasicBlock*>& GetBlocks() - { - return *_blocks; - } - - TosaSerializationBasicBlock* GetBlockByName(std::string name) - { - TosaSerializationBasicBlock* result = nullptr; - for (auto block : GetBlocks()) - { - if (block->GetName() == name) - { - result = block; - break; - } - } - return result; - } - TosaSerializationBasicBlock* GetMainBlock() - { - TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); - assert(main_block); - return main_block; - } - - std::vector<std::string>& GetInputs() - { - return GetMainBlock()->GetInputs(); - } - std::vector<std::string>& GetOutputs() - { - return GetMainBlock()->GetOutputs(); - } - - bool GetSchemaLoaded() const - { - return _schemaLoaded; - } - -protected: - tosa_err_t Clear(); - tosa_err_t InitWithBuf(const uint8_t* buf); - tosa_err_t FreezeBuilder(); - tosa_err_t SetTosaVersion(); - tosa_err_t CheckTosaVersion(const TosaVersion& read_version); - -private: - TosaVersion* _version; /* tosa version */ - flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */ - flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */ - std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */ - bool _schemaLoaded; /* is the schema properly loaded? */ -}; - -class NumpyUtilities -{ -public: - enum NPError - { - NO_ERROR = 0, - FILE_NOT_FOUND, - FILE_IO_ERROR, - FILE_TYPE_MISMATCH, - HEADER_PARSE_ERROR, - BUFFER_SIZE_MISMATCH, - }; - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf); - - static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf); - - static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf); - - static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf); - -private: - static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); - static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str); -}; - -} // namespace tosa - -#endif // _TOSA_SERIALIZATION_HANDLER_H |