diff options
Diffstat (limited to 'tests/use_case/kws')
-rw-r--r-- | tests/use_case/kws/InferenceTestMicroNetKws.cc | 44 |
1 files changed, 21 insertions, 23 deletions
diff --git a/tests/use_case/kws/InferenceTestMicroNetKws.cc b/tests/use_case/kws/InferenceTestMicroNetKws.cc index 27c6f96..ace6684 100644 --- a/tests/use_case/kws/InferenceTestMicroNetKws.cc +++ b/tests/use_case/kws/InferenceTestMicroNetKws.cc @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates <open-source-office@arm.com> - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright 2021 Arm Limited and/or its affiliates + * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,10 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "BufAttributes.hpp" #include "MicroNetKwsModel.hpp" -#include "TestData_kws.hpp" #include "TensorFlowLiteMicro.hpp" -#include "BufAttributes.hpp" +#include "TestData_kws.hpp" #include <catch.hpp> #include <random> @@ -39,9 +39,8 @@ bool RunInference(arm::app::Model& model, const int8_t vec[]) TfLiteTensor* inputTensor = model.GetInputTensor(0); REQUIRE(inputTensor); - const size_t copySz = inputTensor->bytes < IFM_0_DATA_SIZE ? - inputTensor->bytes : - IFM_0_DATA_SIZE; + const size_t copySz = + inputTensor->bytes < IFM_0_DATA_SIZE ? inputTensor->bytes : IFM_0_DATA_SIZE; memcpy(inputTensor->data.data, vec, copySz); return model.RunInference(); @@ -54,11 +53,9 @@ bool RunInferenceRandom(arm::app::Model& model) std::random_device rndDevice; std::mt19937 mersenneGen{rndDevice()}; - std::uniform_int_distribution<short> dist {-128, 127}; + std::uniform_int_distribution<short> dist{-128, 127}; - auto gen = [&dist, &mersenneGen](){ - return dist(mersenneGen); - }; + auto gen = [&dist, &mersenneGen]() { return dist(mersenneGen); }; std::vector<int8_t> randomAudio(inputTensor->bytes); std::generate(std::begin(randomAudio), std::end(randomAudio), gen); @@ -67,7 +64,7 @@ bool RunInferenceRandom(arm::app::Model& model) return true; } -template<typename T> +template <typename T> void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app::Model& model) { REQUIRE(RunInference(model, input_goldenFV)); @@ -84,15 +81,16 @@ void TestInference(const T* input_goldenFV, const T* output_goldenFV, arm::app:: } } -TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", "[MicroNetKws]") +TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsModel Int8", + "[MicroNetKws]") { arm::app::MicroNetKwsModel model{}; REQUIRE_FALSE(model.IsInited()); REQUIRE(model.Init(arm::app::tensorArena, - sizeof(arm::app::tensorArena), - arm::app::kws::GetModelPointer(), - arm::app::kws::GetModelLen())); + sizeof(arm::app::tensorArena), + arm::app::kws::GetModelPointer(), + arm::app::kws::GetModelLen())); REQUIRE(model.IsInited()); REQUIRE(RunInferenceRandom(model)); @@ -101,9 +99,10 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and MicroNetKwsMo TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int8", "[MicroNetKws]") { REQUIRE(NUMBER_OF_IFM_FILES == NUMBER_OF_OFM_FILES); - for (uint32_t i = 0 ; i < NUMBER_OF_IFM_FILES; ++i) { - const int8_t* input_goldenFV = get_ifm_data_array(i);; - const int8_t* output_goldenFV = get_ofm_data_array(i); + for (uint32_t i = 0; i < NUMBER_OF_IFM_FILES; ++i) { + const int8_t* input_goldenFV = GetIfmDataArray(i); + ; + const int8_t* output_goldenFV = GetOfmDataArray(i); DYNAMIC_SECTION("Executing inference with re-init " << i) { @@ -111,13 +110,12 @@ TEST_CASE("Running inference with TensorFlow Lite Micro and MicroNetKwsModel int REQUIRE_FALSE(model.IsInited()); REQUIRE(model.Init(arm::app::tensorArena, - sizeof(arm::app::tensorArena), - arm::app::kws::GetModelPointer(), - arm::app::kws::GetModelLen())); + sizeof(arm::app::tensorArena), + arm::app::kws::GetModelPointer(), + arm::app::kws::GetModelLen())); REQUIRE(model.IsInited()); TestInference<int8_t>(input_goldenFV, output_goldenFV, model); - } } } |