aboutsummaryrefslogtreecommitdiff
path: root/serialization/tosa_serialization_handler.h
diff options
context:
space:
mode:
Diffstat (limited to 'serialization/tosa_serialization_handler.h')
-rw-r--r--serialization/tosa_serialization_handler.h423
1 files changed, 0 insertions, 423 deletions
diff --git a/serialization/tosa_serialization_handler.h b/serialization/tosa_serialization_handler.h
deleted file mode 100644
index 124b8e0..0000000
--- a/serialization/tosa_serialization_handler.h
+++ /dev/null
@@ -1,423 +0,0 @@
-
-// Copyright (c) 2020, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef _TOSA_SERIALIZATION_HANDLER_H
-#define _TOSA_SERIALIZATION_HANDLER_H
-#include "attribute.h"
-#include "flatbuffers/idl.h"
-#include "flatbuffers/util.h"
-#include "quant_info.h"
-#include "tosa_generated.h"
-#include <cstdint>
-#include <memory>
-#include <string>
-#include <vector>
-
-namespace tosa
-{
-
-enum tosa_err_t
-{
- TOSA_OK,
- TOSA_USER_ERROR,
- TOSA_FILE_ERROR,
- TOSA_MEMORY_ERROR,
- TOSA_SCHEMA_MISSING,
- TOSA_INTERNAL_ERROR,
- TOSA_VERSION_MISMATCH,
- NUM_TOSA_ERROR
-};
-
-struct TosaVersion
-{
- int32_t _major;
- int32_t _minor;
- int32_t _patch;
- bool _experimental;
-
- TosaVersion() = delete;
- TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
- {
- _major = major;
- _minor = minor;
- _patch = patch;
- _experimental = experimental;
- }
-
- std::string to_string() const
- {
- std::string str;
- str += std::to_string(_major) + ".";
- str += std::to_string(_minor) + ".";
- str += std::to_string(_patch);
- if (_experimental)
- str += "(experimental)";
- return str;
- };
-
- bool operator==(const TosaVersion& rhs)
- {
- if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
- {
- return true;
- }
- return false;
- }
-
- bool operator!=(const TosaVersion& rhs)
- {
- return !((*this) == rhs);
- }
-};
-
-class TosaSerializationHandler;
-
-class TosaSerializationTensor
-{
-public:
- // constructor and destructor
- TosaSerializationTensor(const flatbuffers::String* name,
- const flatbuffers::Vector<uint32_t>& usage,
- const flatbuffers::Vector<int32_t>& shape,
- DType dtype,
- const flatbuffers::Vector<uint32_t>& format,
- const flatbuffers::String* npy_filename);
- TosaSerializationTensor(std::string name,
- const std::vector<Usage>& usage,
- const std::vector<int32_t>& shape,
- DType dtype,
- const std::vector<Format>& format,
- const std::string* npy_filename);
- TosaSerializationTensor();
- ~TosaSerializationTensor();
-
- // copy constructor/assignment
- TosaSerializationTensor(const TosaSerializationTensor& rhs);
- TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs);
-
- // move constructor/assignment
- TosaSerializationTensor(TosaSerializationTensor&& rhs);
- TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs);
-
- // accessor
- std::string GetName() const
- {
- return *_name;
- }
- const std::vector<int32_t>& GetShape() const
- {
- return *_shape;
- }
- DType GetDtype()
- {
- return _dtype;
- }
- bool HasFormat(Format format)
- {
- for (Format us : *_format)
- {
- if (us == format)
- return true;
- }
- return false;
- }
- std::vector<Format>& GetFormat()
- {
- return *_format;
- }
- bool HasUsage(Usage usage)
- {
- for (Usage us : *_usage)
- {
- if (us == usage)
- return true;
- }
- return false;
- }
- std::vector<Usage>& GetUsage()
- {
- return *_usage;
- }
- std::string* GetNpyFilePtr() const
- {
- return _npy_filename;
- }
-
- // modifier
- void SetDtype(DType dtype)
- {
- _dtype = dtype;
- }
- void SetName(std::string name)
- {
- *_name = name;
- }
-
-private:
- DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
- std::vector<Format>* _format; /* list of possible tensor format */
- std::vector<Usage>* _usage; /* list of possible tensor usage */
- std::vector<int32_t>* _shape; /* shape of the tensor */
- std::string* _name; /* name of the tensor, used for solving dependency */
- std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
-};
-
-class TosaSerializationOperator
-{
-public:
- // use default copy, void constructor
- // constructor and destructor
- TosaSerializationOperator(Op op_name,
- Attribute attribute_type,
- const TosaAttributeBase* attribute,
- QuantInfo qinfo_type,
- const TosaQuantInfoBase* qinfo,
- std::vector<std::string> input_tensor_names,
- std::vector<std::string> output_tensor_names);
- ~TosaSerializationOperator();
-
- // accessor
- Op GetOp() const
- {
- return _op;
- }
- Attribute GetAttributeType() const
- {
- return _attribute_type;
- }
- TosaAttributeBase* GetAttribute() const
- {
- return _attribute;
- }
- QuantInfo GetQInfoType() const
- {
- return _qinfo_type;
- }
- TosaQuantInfoBase* GetQInfo() const
- {
- return _qinfo;
- }
- std::vector<std::string>& GetInputTensorNames() const
- {
- return *_input_tensor_names;
- }
- std::vector<std::string>& GetOutputTensorNames() const
- {
- return *_output_tensor_names;
- }
- std::vector<TosaSerializationTensor*>& GetInputTensors() const
- {
- return *_input_tensors;
- }
- std::vector<TosaSerializationTensor*>& GetOutputTensors() const
- {
- return *_output_tensors;
- }
-
-private:
- Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
- Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
- TosaAttributeBase* _attribute; /* real attribute class goes here */
- QuantInfo _qinfo_type; /* QuantInfo enum */
- TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
- std::vector<std::string>* _input_tensor_names; /* array of input tensor names */
- std::vector<std::string>* _output_tensor_names; /* array of output tensor names */
-
- std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */
- std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */
-};
-
-class TosaSerializationBasicBlock
-{
-public:
- // constructor and destructor
- TosaSerializationBasicBlock(std::string name,
- std::vector<TosaSerializationOperator*> operators,
- std::vector<TosaSerializationTensor*> tensors,
- std::vector<std::string> inputs,
- std::vector<std::string> outputs);
- ~TosaSerializationBasicBlock();
-
- // accessor
- std::string GetName() const
- {
- return *_name;
- }
- std::vector<TosaSerializationOperator*>& GetOperators()
- {
- return *_operators;
- }
- std::vector<TosaSerializationTensor*>& GetTensors()
- {
- return *_tensors;
- }
-
- TosaSerializationTensor* GetTensorByName(std::string name)
- {
- TosaSerializationTensor* result = nullptr;
- for (auto tensor : GetTensors())
- {
- if (tensor->GetName() == name)
- {
- result = tensor;
- break;
- }
- }
- return result;
- }
-
- std::vector<std::string>& GetInputs()
- {
- return *_inputs;
- }
- std::vector<std::string>& GetOutputs()
- {
- return *_outputs;
- }
-
-private:
- std::string* _name; /* name of basic block */
- std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */
- std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */
- std::vector<std::string>* _inputs; /* array of string to specify block inputs */
- std::vector<std::string>* _outputs; /* array of string to specify block outputs */
-};
-
-/*
- * this is a helper class for writing/reading Tosa ISA
- * supported format: .tosa (flatbuffer), .json
- * and provide high-level std::vector-like interface
- * to access internal data structure
- */
-class TosaSerializationHandler
-{
-public:
- // constructor and destructor
- TosaSerializationHandler();
- ~TosaSerializationHandler();
-
- // file io
- tosa_err_t LoadFileJson(const char* filename);
- tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
- tosa_err_t SaveFileJson(const char* filename);
- tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
- tosa_err_t LoadFileSchema(const char* filename);
-
- // version
- TosaVersion* GetTosaVersion() const
- {
- return _version;
- }
-
- // accessor
- std::vector<TosaSerializationBasicBlock*>& GetBlocks()
- {
- return *_blocks;
- }
-
- TosaSerializationBasicBlock* GetBlockByName(std::string name)
- {
- TosaSerializationBasicBlock* result = nullptr;
- for (auto block : GetBlocks())
- {
- if (block->GetName() == name)
- {
- result = block;
- break;
- }
- }
- return result;
- }
- TosaSerializationBasicBlock* GetMainBlock()
- {
- TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
- assert(main_block);
- return main_block;
- }
-
- std::vector<std::string>& GetInputs()
- {
- return GetMainBlock()->GetInputs();
- }
- std::vector<std::string>& GetOutputs()
- {
- return GetMainBlock()->GetOutputs();
- }
-
- bool GetSchemaLoaded() const
- {
- return _schemaLoaded;
- }
-
-protected:
- tosa_err_t Clear();
- tosa_err_t InitWithBuf(const uint8_t* buf);
- tosa_err_t FreezeBuilder();
- tosa_err_t SetTosaVersion();
- tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
-
-private:
- TosaVersion* _version; /* tosa version */
- flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */
- flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */
- std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */
- bool _schemaLoaded; /* is the schema properly loaded? */
-};
-
-class NumpyUtilities
-{
-public:
- enum NPError
- {
- NO_ERROR = 0,
- FILE_NOT_FOUND,
- FILE_IO_ERROR,
- FILE_TYPE_MISMATCH,
- HEADER_PARSE_ERROR,
- BUFFER_SIZE_MISMATCH,
- };
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf);
-
- static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf);
-
- static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf);
-
- static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf);
-
-private:
- static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
- static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str);
-};
-
-} // namespace tosa
-
-#endif // _TOSA_SERIALIZATION_HANDLER_H