/* * Copyright (c) 2021 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "RNNoiseModel.hpp" #include "log_macros.h" const tflite::MicroOpResolver& arm::app::RNNoiseModel::GetOpResolver() { return this->m_opResolver; } bool arm::app::RNNoiseModel::EnlistOperations() { this->m_opResolver.AddUnpack(); this->m_opResolver.AddFullyConnected(); this->m_opResolver.AddSplit(); this->m_opResolver.AddSplitV(); this->m_opResolver.AddAdd(); this->m_opResolver.AddLogistic(); this->m_opResolver.AddMul(); this->m_opResolver.AddSub(); this->m_opResolver.AddTanh(); this->m_opResolver.AddPack(); this->m_opResolver.AddReshape(); this->m_opResolver.AddQuantize(); this->m_opResolver.AddConcatenation(); this->m_opResolver.AddRelu(); #if defined(ARM_NPU) if (kTfLiteOk == this->m_opResolver.AddEthosU()) { info("Added %s support to op resolver\n", tflite::GetString_ETHOSU()); } else { printf_err("Failed to add Arm NPU support to op resolver."); return false; } #endif /* ARM_NPU */ return true; } extern uint8_t* GetModelPointer(); const uint8_t* arm::app::RNNoiseModel::ModelPointer() { return GetModelPointer(); } extern size_t GetModelLen(); size_t arm::app::RNNoiseModel::ModelSize() { return GetModelLen(); } bool arm::app::RNNoiseModel::RunInference() { return Model::RunInference(); } void arm::app::RNNoiseModel::ResetGruState() { for (auto& stateMapping: this->m_gruStateMap) { TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.second); auto* inputGruState = tflite::GetTensorData(inputGruStateTensor); /* Initial value of states is 0, but this is affected by quantization zero point. */ auto quantParams = arm::app::GetTensorQuantParams(inputGruStateTensor); memset(inputGruState, quantParams.offset, inputGruStateTensor->bytes); } } bool arm::app::RNNoiseModel::CopyGruStates() { std::vector>> tempOutGruStates; /* Saving output states before copying them to input states to avoid output states modification in the tensor. * tflu shares input and output tensors memory, thus writing to input tensor can change output tensor values. */ for (auto& stateMapping: this->m_gruStateMap) { TfLiteTensor* outputGruStateTensor = this->GetOutputTensor(stateMapping.first); std::vector tempOutGruState(outputGruStateTensor->bytes); auto* outGruState = tflite::GetTensorData(outputGruStateTensor); memcpy(tempOutGruState.data(), outGruState, outputGruStateTensor->bytes); /* Index of the input tensor and the data to copy. */ tempOutGruStates.emplace_back(stateMapping.second, std::move(tempOutGruState)); } /* Updating input GRU states with saved GRU output states. */ for (auto& stateMapping: tempOutGruStates) { auto outputGruStateTensorData = stateMapping.second; TfLiteTensor* inputGruStateTensor = this->GetInputTensor(stateMapping.first); if (outputGruStateTensorData.size() != inputGruStateTensor->bytes) { printf_err("Unexpected number of bytes for GRU state mapping. Input = %zuz, output = %zuz.\n", inputGruStateTensor->bytes, outputGruStateTensorData.size()); return false; } auto* inputGruState = tflite::GetTensorData(inputGruStateTensor); auto* outGruState = outputGruStateTensorData.data(); memcpy(inputGruState, outGruState, inputGruStateTensor->bytes); } return true; }