/* * Copyright (c) 2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "KwsClassifier.hpp" #include "TensorFlowLiteMicro.hpp" #include "PlatformMath.hpp" #include "log_macros.h" #include "../include/KwsClassifier.hpp" #include #include #include #include #include #include namespace arm { namespace app { bool KwsClassifier::GetClassificationResults(TfLiteTensor* outputTensor, std::vector& vecResults, const std::vector & labels, uint32_t topNCount, bool useSoftmax, std::vector>& resultHistory) { if (outputTensor == nullptr) { printf_err("Output vector is null pointer.\n"); return false; } uint32_t totalOutputSize = 1; for (int inputDim = 0; inputDim < outputTensor->dims->size; inputDim++) { totalOutputSize *= outputTensor->dims->data[inputDim]; } /* Sanity checks. */ if (totalOutputSize < topNCount) { printf_err("Output vector is smaller than %" PRIu32 "\n", topNCount); return false; } else if (totalOutputSize != labels.size()) { printf_err("Output size doesn't match the labels' size\n"); return false; } else if (topNCount == 0) { printf_err("Top N results cannot be zero\n"); return false; } bool resultState; vecResults.clear(); /* De-Quantize Output Tensor */ QuantParams quantParams = GetTensorQuantParams(outputTensor); /* Floating point tensor data to be populated * NOTE: The assumption here is that the output tensor size isn't too * big and therefore, there's neglibible impact on heap usage. */ std::vector resultData(totalOutputSize); resultData.resize(totalOutputSize); /* Populate the floating point buffer */ switch (outputTensor->type) { case kTfLiteUInt8: { uint8_t* tensor_buffer = tflite::GetTensorData(outputTensor); for (size_t i = 0; i < totalOutputSize; ++i) { resultData[i] = quantParams.scale * (static_cast(tensor_buffer[i]) - quantParams.offset); } break; } case kTfLiteInt8: { int8_t* tensor_buffer = tflite::GetTensorData(outputTensor); for (size_t i = 0; i < totalOutputSize; ++i) { resultData[i] = quantParams.scale * (static_cast(tensor_buffer[i]) - quantParams.offset); } break; } case kTfLiteFloat32: { float* tensor_buffer = tflite::GetTensorData(outputTensor); for (size_t i = 0; i < totalOutputSize; ++i) { resultData[i] = tensor_buffer[i]; } break; } default: printf_err("Tensor type %s not supported by classifier\n", TfLiteTypeGetName(outputTensor->type)); return false; } if (useSoftmax) { math::MathUtils::SoftmaxF32(resultData); } /* If keeping track of recent results, update and take an average. */ if (resultHistory.size() > 1) { std::rotate(resultHistory.begin(), resultHistory.begin() + 1, resultHistory.end()); resultHistory.back() = resultData; AveragResults(resultHistory, resultData); } /* Get the top N results. */ resultState = GetTopNResults(resultData, vecResults, topNCount, labels); if (!resultState) { printf_err("Failed to get top N results set\n"); return false; } return true; } void app::KwsClassifier::AveragResults(const std::vector>& resultHistory, std::vector& averageResult) { /* Compute averages of each class across the window length. */ float sum; for (size_t j = 0; j < averageResult.size(); j++) { sum = 0; for (size_t i = 0; i < resultHistory.size(); i++) { sum += resultHistory[i][j]; } averageResult[j] = (sum / resultHistory.size()); } } } /* namespace app */ } /* namespace arm */