/* * Copyright (c) 2021 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MODEL_HPP #define MODEL_HPP #include "TensorFlowLiteMicro.hpp" #include "BufAttributes.hpp" #include namespace arm { namespace app { /** * @brief NN model class wrapping the underlying TensorFlow-Lite-Micro API. */ class Model { public: /** @brief Constructor. */ Model(); /** @brief Destructor. */ ~Model(); /** @brief Gets the pointer to the model's input tensor at given input index. */ TfLiteTensor* GetInputTensor(size_t index) const; /** @brief Gets the pointer to the model's output tensor at given output index. */ TfLiteTensor* GetOutputTensor(size_t index) const; /** @brief Gets the model's data type. */ TfLiteType GetType() const; /** @brief Gets the pointer to the model's input shape. */ TfLiteIntArray* GetInputShape(size_t index) const; /** @brief Gets the pointer to the model's output shape at given output index. */ TfLiteIntArray* GetOutputShape(size_t index) const; /** @brief Gets the number of input tensors the model has. */ size_t GetNumInputs() const; /** @brief Gets the number of output tensors the model has. */ size_t GetNumOutputs() const; /** @brief Logs the tensor information to stdout. */ void LogTensorInfo(TfLiteTensor* tensor); /** @brief Logs the interpreter information to stdout. */ void LogInterpreterInfo(); /** @brief Initialise the model class object. * @param[in] allocator Optional: a pre-initialised micro allocator pointer, * if available. If supplied, this allocator will be used * to create the interpreter instance. * @return true if initialisation succeeds, false otherwise. **/ bool Init(tflite::MicroAllocator* allocator = nullptr); /** * @brief Gets the allocator pointer for this instance. * @return Pointer to a tflite::MicroAllocator object, if * available; nullptr otherwise. **/ tflite::MicroAllocator* GetAllocator(); /** @brief Checks if this object has been initialised. */ bool IsInited() const; /** @brief Checks if the model uses signed data. */ bool IsDataSigned() const; /** @brief Runs the inference (invokes the interpreter). */ bool RunInference(); /** @brief Model information handler common to all models. * @return true or false based on execution success. **/ bool ShowModelInfoHandler(); /** @brief Gets a pointer to the tensor arena. */ uint8_t* GetTensorArena(); protected: /** @brief Gets the pointer to the NN model data array. * @return Pointer of uint8_t type. **/ virtual const uint8_t* ModelPointer() = 0; /** @brief Gets the model size. * @return size_t, size in bytes. **/ virtual size_t ModelSize() = 0; /** * @brief Gets the op resolver for the model instance. * @return const reference to a tflite::MicroOpResolver object. **/ virtual const tflite::MicroOpResolver& GetOpResolver() = 0; /** * @brief Add all the operators required for the given model. * Implementation of this should come from the use case. * @return true is ops are successfully added, false otherwise. **/ virtual bool EnlistOperations() = 0; /** @brief Gets the total size of tensor arena available for use. */ size_t GetActivationBufferSize(); private: tflite::MicroErrorReporter m_uErrorReporter; /* Error reporter object. */ tflite::ErrorReporter* m_pErrorReporter = nullptr; /* Pointer to the error reporter. */ const tflite::Model* m_pModel = nullptr; /* Tflite model pointer. */ tflite::MicroInterpreter* m_pInterpreter = nullptr; /* Tflite interpreter. */ tflite::MicroAllocator* m_pAllocator = nullptr; /* Tflite micro allocator. */ bool m_inited = false; /* Indicates whether this object has been initialised. */ std::vector m_input = {}; /* Model's input tensor pointers. */ std::vector m_output = {}; /* Model's output tensor pointers. */ TfLiteType m_type = kTfLiteNoType;/* Model's data type. */ }; } /* namespace app */ } /* namespace arm */ #endif /* MODEL_HPP */