diff options
author | Eric Kunze <eric.kunze@arm.com> | 2020-10-13 16:11:07 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2020-10-14 11:11:43 -0700 |
commit | e5e2676409a936431f87d31fb74d825257b20804 (patch) | |
tree | 304d93d993ef6417b02a515025f9030367682774 /serialization/tosa_serialization_handler.h | |
parent | 88b7860f180f91b5b66764c61cfd97d8bc53cece (diff) | |
download | reference_model-e5e2676409a936431f87d31fb74d825257b20804.tar.gz |
Initial checkin of TOSA reference_model and tests
Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6
Signed-off-by: Eric Kunze <eric.kunze@arm.com>
Diffstat (limited to 'serialization/tosa_serialization_handler.h')
-rw-r--r-- | serialization/tosa_serialization_handler.h | 423 |
1 files changed, 423 insertions, 0 deletions
diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h new file mode 100644 index 0000000..124b8e0 --- /dev/null +++ b/serialization/tosa_serialization_handler.h @@ -0,0 +1,423 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_HANDLER_H +#define _TOSA_SERIALIZATION_HANDLER_H +#include "attribute.h" +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "quant_info.h" +#include "tosa_generated.h" +#include <cstdint> +#include <memory> +#include <string> +#include <vector> + +namespace tosa +{ + +enum tosa_err_t +{ + TOSA_OK, + TOSA_USER_ERROR, + TOSA_FILE_ERROR, + TOSA_MEMORY_ERROR, + TOSA_SCHEMA_MISSING, + TOSA_INTERNAL_ERROR, + TOSA_VERSION_MISMATCH, + NUM_TOSA_ERROR +}; + +struct TosaVersion +{ + int32_t _major; + int32_t _minor; + int32_t _patch; + bool _experimental; + + TosaVersion() = delete; + TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental) + { + _major = major; + _minor = minor; + _patch = patch; + _experimental = experimental; + } + + std::string to_string() const + { + std::string str; + str += std::to_string(_major) + "."; + str += std::to_string(_minor) + "."; + str += std::to_string(_patch); + if (_experimental) + str += "(experimental)"; + return str; + }; + + bool operator==(const TosaVersion& rhs) + { + if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental) + { + return true; + } + return false; + } + + bool operator!=(const TosaVersion& rhs) + { + return !((*this) == rhs); + } +}; + +class TosaSerializationHandler; + +class TosaSerializationTensor +{ +public: + // constructor and destructor + TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector<uint32_t>& usage, + const flatbuffers::Vector<int32_t>& shape, + DType dtype, + const flatbuffers::Vector<uint32_t>& format, + const flatbuffers::String* npy_filename); + TosaSerializationTensor(std::string name, + const std::vector<Usage>& usage, + const std::vector<int32_t>& shape, + DType dtype, + const std::vector<Format>& format, + const std::string* npy_filename); + TosaSerializationTensor(); + ~TosaSerializationTensor(); + + // copy constructor/assignment + TosaSerializationTensor(const TosaSerializationTensor& rhs); + TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs); + + // move constructor/assignment + TosaSerializationTensor(TosaSerializationTensor&& rhs); + TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs); + + // accessor + std::string GetName() const + { + return *_name; + } + const std::vector<int32_t>& GetShape() const + { + return *_shape; + } + DType GetDtype() + { + return _dtype; + } + bool HasFormat(Format format) + { + for (Format us : *_format) + { + if (us == format) + return true; + } + return false; + } + std::vector<Format>& GetFormat() + { + return *_format; + } + bool HasUsage(Usage usage) + { + for (Usage us : *_usage) + { + if (us == usage) + return true; + } + return false; + } + std::vector<Usage>& GetUsage() + { + return *_usage; + } + std::string* GetNpyFilePtr() const + { + return _npy_filename; + } + + // modifier + void SetDtype(DType dtype) + { + _dtype = dtype; + } + void SetName(std::string name) + { + *_name = name; + } + +private: + DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ + std::vector<Format>* _format; /* list of possible tensor format */ + std::vector<Usage>* _usage; /* list of possible tensor usage */ + std::vector<int32_t>* _shape; /* shape of the tensor */ + std::string* _name; /* name of the tensor, used for solving dependency */ + std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ +}; + +class TosaSerializationOperator +{ +public: + // use default copy, void constructor + // constructor and destructor + TosaSerializationOperator(Op op_name, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector<std::string> input_tensor_names, + std::vector<std::string> output_tensor_names); + ~TosaSerializationOperator(); + + // accessor + Op GetOp() const + { + return _op; + } + Attribute GetAttributeType() const + { + return _attribute_type; + } + TosaAttributeBase* GetAttribute() const + { + return _attribute; + } + QuantInfo GetQInfoType() const + { + return _qinfo_type; + } + TosaQuantInfoBase* GetQInfo() const + { + return _qinfo; + } + std::vector<std::string>& GetInputTensorNames() const + { + return *_input_tensor_names; + } + std::vector<std::string>& GetOutputTensorNames() const + { + return *_output_tensor_names; + } + std::vector<TosaSerializationTensor*>& GetInputTensors() const + { + return *_input_tensors; + } + std::vector<TosaSerializationTensor*>& GetOutputTensors() const + { + return *_output_tensors; + } + +private: + Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ + Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ + TosaAttributeBase* _attribute; /* real attribute class goes here */ + QuantInfo _qinfo_type; /* QuantInfo enum */ + TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ + std::vector<std::string>* _input_tensor_names; /* array of input tensor names */ + std::vector<std::string>* _output_tensor_names; /* array of output tensor names */ + + std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */ + std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */ +}; + +class TosaSerializationBasicBlock +{ +public: + // constructor and destructor + TosaSerializationBasicBlock(std::string name, + std::vector<TosaSerializationOperator*> operators, + std::vector<TosaSerializationTensor*> tensors, + std::vector<std::string> inputs, + std::vector<std::string> outputs); + ~TosaSerializationBasicBlock(); + + // accessor + std::string GetName() const + { + return *_name; + } + std::vector<TosaSerializationOperator*>& GetOperators() + { + return *_operators; + } + std::vector<TosaSerializationTensor*>& GetTensors() + { + return *_tensors; + } + + TosaSerializationTensor* GetTensorByName(std::string name) + { + TosaSerializationTensor* result = nullptr; + for (auto tensor : GetTensors()) + { + if (tensor->GetName() == name) + { + result = tensor; + break; + } + } + return result; + } + + std::vector<std::string>& GetInputs() + { + return *_inputs; + } + std::vector<std::string>& GetOutputs() + { + return *_outputs; + } + +private: + std::string* _name; /* name of basic block */ + std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */ + std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */ + std::vector<std::string>* _inputs; /* array of string to specify block inputs */ + std::vector<std::string>* _outputs; /* array of string to specify block outputs */ +}; + +/* + * this is a helper class for writing/reading Tosa ISA + * supported format: .tosa (flatbuffer), .json + * and provide high-level std::vector-like interface + * to access internal data structure + */ +class TosaSerializationHandler +{ +public: + // constructor and destructor + TosaSerializationHandler(); + ~TosaSerializationHandler(); + + // file io + tosa_err_t LoadFileJson(const char* filename); + tosa_err_t LoadFileTosaFlatbuffer(const char* filename); + tosa_err_t SaveFileJson(const char* filename); + tosa_err_t SaveFileTosaFlatbuffer(const char* filename); + tosa_err_t LoadFileSchema(const char* filename); + + // version + TosaVersion* GetTosaVersion() const + { + return _version; + } + + // accessor + std::vector<TosaSerializationBasicBlock*>& GetBlocks() + { + return *_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + TosaSerializationBasicBlock* GetMainBlock() + { + TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); + assert(main_block); + return main_block; + } + + std::vector<std::string>& GetInputs() + { + return GetMainBlock()->GetInputs(); + } + std::vector<std::string>& GetOutputs() + { + return GetMainBlock()->GetOutputs(); + } + + bool GetSchemaLoaded() const + { + return _schemaLoaded; + } + +protected: + tosa_err_t Clear(); + tosa_err_t InitWithBuf(const uint8_t* buf); + tosa_err_t FreezeBuilder(); + tosa_err_t SetTosaVersion(); + tosa_err_t CheckTosaVersion(const TosaVersion& read_version); + +private: + TosaVersion* _version; /* tosa version */ + flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */ + flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */ + std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */ + bool _schemaLoaded; /* is the schema properly loaded? */ +}; + +class NumpyUtilities +{ +public: + enum NPError + { + NO_ERROR = 0, + FILE_NOT_FOUND, + FILE_IO_ERROR, + FILE_TYPE_MISMATCH, + HEADER_PARSE_ERROR, + BUFFER_SIZE_MISMATCH, + }; + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf); + +private: + static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); + static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str); +}; + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_HANDLER_H |