From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- serialization/tosa_serialization_handler.h | 423 +++++++++++++++++++++++++++++ 1 file changed, 423 insertions(+) create mode 100644 serialization/tosa_serialization_handler.h (limited to 'serialization/tosa_serialization_handler.h') diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h new file mode 100644 index 0000000..124b8e0 --- /dev/null +++ b/serialization/tosa_serialization_handler.h @@ -0,0 +1,423 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _TOSA_SERIALIZATION_HANDLER_H +#define _TOSA_SERIALIZATION_HANDLER_H +#include "attribute.h" +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "quant_info.h" +#include "tosa_generated.h" +#include +#include +#include +#include + +namespace tosa +{ + +enum tosa_err_t +{ + TOSA_OK, + TOSA_USER_ERROR, + TOSA_FILE_ERROR, + TOSA_MEMORY_ERROR, + TOSA_SCHEMA_MISSING, + TOSA_INTERNAL_ERROR, + TOSA_VERSION_MISMATCH, + NUM_TOSA_ERROR +}; + +struct TosaVersion +{ + int32_t _major; + int32_t _minor; + int32_t _patch; + bool _experimental; + + TosaVersion() = delete; + TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental) + { + _major = major; + _minor = minor; + _patch = patch; + _experimental = experimental; + } + + std::string to_string() const + { + std::string str; + str += std::to_string(_major) + "."; + str += std::to_string(_minor) + "."; + str += std::to_string(_patch); + if (_experimental) + str += "(experimental)"; + return str; + }; + + bool operator==(const TosaVersion& rhs) + { + if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental) + { + return true; + } + return false; + } + + bool operator!=(const TosaVersion& rhs) + { + return !((*this) == rhs); + } +}; + +class TosaSerializationHandler; + +class TosaSerializationTensor +{ +public: + // constructor and destructor + TosaSerializationTensor(const flatbuffers::String* name, + const flatbuffers::Vector& usage, + const flatbuffers::Vector& shape, + DType dtype, + const flatbuffers::Vector& format, + const flatbuffers::String* npy_filename); + TosaSerializationTensor(std::string name, + const std::vector& usage, + const std::vector& shape, + DType dtype, + const std::vector& format, + const std::string* npy_filename); + TosaSerializationTensor(); + ~TosaSerializationTensor(); + + // copy constructor/assignment + TosaSerializationTensor(const TosaSerializationTensor& rhs); + TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs); + + // move constructor/assignment + TosaSerializationTensor(TosaSerializationTensor&& rhs); + TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs); + + // accessor + std::string GetName() const + { + return *_name; + } + const std::vector& GetShape() const + { + return *_shape; + } + DType GetDtype() + { + return _dtype; + } + bool HasFormat(Format format) + { + for (Format us : *_format) + { + if (us == format) + return true; + } + return false; + } + std::vector& GetFormat() + { + return *_format; + } + bool HasUsage(Usage usage) + { + for (Usage us : *_usage) + { + if (us == usage) + return true; + } + return false; + } + std::vector& GetUsage() + { + return *_usage; + } + std::string* GetNpyFilePtr() const + { + return _npy_filename; + } + + // modifier + void SetDtype(DType dtype) + { + _dtype = dtype; + } + void SetName(std::string name) + { + *_name = name; + } + +private: + DType _dtype; /* data type enumeration, see tosa_isa_generated.h */ + std::vector* _format; /* list of possible tensor format */ + std::vector* _usage; /* list of possible tensor usage */ + std::vector* _shape; /* shape of the tensor */ + std::string* _name; /* name of the tensor, used for solving dependency */ + std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */ +}; + +class TosaSerializationOperator +{ +public: + // use default copy, void constructor + // constructor and destructor + TosaSerializationOperator(Op op_name, + Attribute attribute_type, + const TosaAttributeBase* attribute, + QuantInfo qinfo_type, + const TosaQuantInfoBase* qinfo, + std::vector input_tensor_names, + std::vector output_tensor_names); + ~TosaSerializationOperator(); + + // accessor + Op GetOp() const + { + return _op; + } + Attribute GetAttributeType() const + { + return _attribute_type; + } + TosaAttributeBase* GetAttribute() const + { + return _attribute; + } + QuantInfo GetQInfoType() const + { + return _qinfo_type; + } + TosaQuantInfoBase* GetQInfo() const + { + return _qinfo; + } + std::vector& GetInputTensorNames() const + { + return *_input_tensor_names; + } + std::vector& GetOutputTensorNames() const + { + return *_output_tensor_names; + } + std::vector& GetInputTensors() const + { + return *_input_tensors; + } + std::vector& GetOutputTensors() const + { + return *_output_tensors; + } + +private: + Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */ + Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */ + TosaAttributeBase* _attribute; /* real attribute class goes here */ + QuantInfo _qinfo_type; /* QuantInfo enum */ + TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */ + std::vector* _input_tensor_names; /* array of input tensor names */ + std::vector* _output_tensor_names; /* array of output tensor names */ + + std::vector* _input_tensors; /* array of input TosaSerializationTensor */ + std::vector* _output_tensors; /* array of output TosaSerializationTensor */ +}; + +class TosaSerializationBasicBlock +{ +public: + // constructor and destructor + TosaSerializationBasicBlock(std::string name, + std::vector operators, + std::vector tensors, + std::vector inputs, + std::vector outputs); + ~TosaSerializationBasicBlock(); + + // accessor + std::string GetName() const + { + return *_name; + } + std::vector& GetOperators() + { + return *_operators; + } + std::vector& GetTensors() + { + return *_tensors; + } + + TosaSerializationTensor* GetTensorByName(std::string name) + { + TosaSerializationTensor* result = nullptr; + for (auto tensor : GetTensors()) + { + if (tensor->GetName() == name) + { + result = tensor; + break; + } + } + return result; + } + + std::vector& GetInputs() + { + return *_inputs; + } + std::vector& GetOutputs() + { + return *_outputs; + } + +private: + std::string* _name; /* name of basic block */ + std::vector* _operators; /* TosaSerializationOperator list */ + std::vector* _tensors; /* TosaSerializationTensor list */ + std::vector* _inputs; /* array of string to specify block inputs */ + std::vector* _outputs; /* array of string to specify block outputs */ +}; + +/* + * this is a helper class for writing/reading Tosa ISA + * supported format: .tosa (flatbuffer), .json + * and provide high-level std::vector-like interface + * to access internal data structure + */ +class TosaSerializationHandler +{ +public: + // constructor and destructor + TosaSerializationHandler(); + ~TosaSerializationHandler(); + + // file io + tosa_err_t LoadFileJson(const char* filename); + tosa_err_t LoadFileTosaFlatbuffer(const char* filename); + tosa_err_t SaveFileJson(const char* filename); + tosa_err_t SaveFileTosaFlatbuffer(const char* filename); + tosa_err_t LoadFileSchema(const char* filename); + + // version + TosaVersion* GetTosaVersion() const + { + return _version; + } + + // accessor + std::vector& GetBlocks() + { + return *_blocks; + } + + TosaSerializationBasicBlock* GetBlockByName(std::string name) + { + TosaSerializationBasicBlock* result = nullptr; + for (auto block : GetBlocks()) + { + if (block->GetName() == name) + { + result = block; + break; + } + } + return result; + } + TosaSerializationBasicBlock* GetMainBlock() + { + TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main")); + assert(main_block); + return main_block; + } + + std::vector& GetInputs() + { + return GetMainBlock()->GetInputs(); + } + std::vector& GetOutputs() + { + return GetMainBlock()->GetOutputs(); + } + + bool GetSchemaLoaded() const + { + return _schemaLoaded; + } + +protected: + tosa_err_t Clear(); + tosa_err_t InitWithBuf(const uint8_t* buf); + tosa_err_t FreezeBuilder(); + tosa_err_t SetTosaVersion(); + tosa_err_t CheckTosaVersion(const TosaVersion& read_version); + +private: + TosaVersion* _version; /* tosa version */ + flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */ + flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */ + std::vector* _blocks; /* array structure to store all TosaSerializationBasicBlock */ + bool _schemaLoaded; /* is the schema properly loaded? */ +}; + +class NumpyUtilities +{ +public: + enum NPError + { + NO_ERROR = 0, + FILE_NOT_FOUND, + FILE_IO_ERROR, + FILE_TYPE_MISMATCH, + HEADER_PARSE_ERROR, + BUFFER_SIZE_MISMATCH, + }; + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf); + + static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf); + + static NPError writeToNpyFile(const char* filename, const std::vector& shape, const float* buf); + + static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf); + +private: + static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str); + static NPError writeNpyHeader(FILE* infile, const std::vector& shape, const char* dtype_str); +}; + +} // namespace tosa + +#endif // _TOSA_SERIALIZATION_HANDLER_H -- cgit v1.2.1