From 4e002791bc6781b549c6951cfe44f918289d7e82 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Wed, 4 May 2022 09:45:02 +0100 Subject: MLECO-3173: Add AD, KWS_ASR and Noise reduction use case API's Signed-off-by: Richard Burton Change-Id: I36f61ce74bf17f7b327cdae9704a22ca54144f37 --- tests/use_case/ad/PostProcessTests.cc | 53 -------- tests/use_case/kws_asr/MfccTests.cc | 8 +- .../kws_asr/Wav2LetterPostprocessingTest.cc | 142 ++++++++++++--------- .../kws_asr/Wav2LetterPreprocessingTest.cc | 126 +++++++++--------- .../noise_reduction/RNNoiseProcessingTests.cpp | 8 +- 5 files changed, 151 insertions(+), 186 deletions(-) delete mode 100644 tests/use_case/ad/PostProcessTests.cc (limited to 'tests') diff --git a/tests/use_case/ad/PostProcessTests.cc b/tests/use_case/ad/PostProcessTests.cc deleted file mode 100644 index 62fa9e7..0000000 --- a/tests/use_case/ad/PostProcessTests.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "AdPostProcessing.hpp" -#include - -TEST_CASE("Softmax_vector") { - - std::vector testVec = {1, 2, 3, 4, 1, 2, 3}; - arm::app::Softmax(testVec); - CHECK((testVec[0] - 0.024) == Approx(0.0).margin(0.001)); - CHECK((testVec[1] - 0.064) == Approx(0.0).margin(0.001)); - CHECK((testVec[2] - 0.175) == Approx(0.0).margin(0.001)); - CHECK((testVec[3] - 0.475) == Approx(0.0).margin(0.001)); - CHECK((testVec[4] - 0.024) == Approx(0.0).margin(0.001)); - CHECK((testVec[5] - 0.064) == Approx(0.0).margin(0.001)); - CHECK((testVec[6] - 0.175) == Approx(0.0).margin(0.001)); -} - -TEST_CASE("Output machine index") { - - auto index = arm::app::OutputIndexFromFileName("test_id_00.wav"); - CHECK(index == 0); - - auto index1 = arm::app::OutputIndexFromFileName("test_id_02.wav"); - CHECK(index1 == 1); - - auto index2 = arm::app::OutputIndexFromFileName("test_id_4.wav"); - CHECK(index2 == 2); - - auto index3 = arm::app::OutputIndexFromFileName("test_id_6.wav"); - CHECK(index3 == 3); - - auto index4 = arm::app::OutputIndexFromFileName("test_id_id_00.wav"); - CHECK(index4 == -1); - - auto index5 = arm::app::OutputIndexFromFileName("test_id_7.wav"); - CHECK(index5 == -1); -} \ No newline at end of file diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc index 3ebdcf4..883c215 100644 --- a/tests/use_case/kws_asr/MfccTests.cc +++ b/tests/use_case/kws_asr/MfccTests.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,13 +93,13 @@ const std::vector testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::MicroNetMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template diff --git a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc index 6fd7df3..e343b66 100644 --- a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,15 +16,17 @@ */ #include "Wav2LetterPostprocess.hpp" #include "Wav2LetterModel.hpp" +#include "ClassificationResult.hpp" #include #include #include template -static TfLiteTensor GetTestTensor(std::vector & shape, - T initVal, - std::vector& vectorBuf) +static TfLiteTensor GetTestTensor( + std::vector& shape, + T initVal, + std::vector& vectorBuf) { REQUIRE(0 != shape.size()); @@ -38,91 +40,112 @@ static TfLiteTensor GetTestTensor(std::vector & shape, vectorBuf = std::vector(sizeInBytes, initVal); TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data()); return tflite::testing::CreateQuantizedTensor( - vectorBuf.data(), dims, - 1, 0, "test-tensor"); + vectorBuf.data(), dims, + 1, 0, "test-tensor"); } TEST_CASE("Checking return value") { SECTION("Mismatched post processing parameters and tensor size") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; std::vector tensorShape = {1, 1, 1, 13}; std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); - REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; + + REQUIRE(!post.DoPostProcess()); } SECTION("Post processing succeeds") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - - std::vector tensorShape = {1, 1, 13, 1}; - std::vector tensorVec; + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; + std::vector tensorShape = {1, 1, 13, 1}; + std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); } } + TEST_CASE("Postprocessing - erasing required elements") { - constexpr uint32_t ctxLen = 5; + constexpr uint32_t outputCtxLen = 5; constexpr uint32_t innerLen = 3; - constexpr uint32_t nRows = 2*ctxLen + innerLen; + constexpr uint32_t nRows = 2*outputCtxLen + innerLen; constexpr uint32_t nCols = 10; constexpr uint32_t blankTokenIdx = nCols - 1; - std::vector tensorShape = {1, 1, nRows, nCols}; + std::vector tensorShape = {1, 1, nRows, nCols}; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + std::vector dummyResult; SECTION("First and last iteration") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + std::vector tensorVec; + TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vectororiginalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } SECTION("Right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should erase the right context only. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { - /* Check right context elements are zeroed. */ + /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); } /* Check left context is preserved. */ @@ -131,45 +154,47 @@ TEST_CASE("Postprocessing - erasing required elements") } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { CHECK(tensorVec[i] == originalVec[i]); } } SECTION("Left and right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + TfLiteTensor tensor = GetTestTensor( + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* This step should erase right context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The first and last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { /* Check left and right context elements are zeroed. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1); - CHECK(tensorVec[i * nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0); - CHECK(tensorVec[i * nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[i*nCols + j] == 0); } } } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { /* Check left context is preserved. */ CHECK(tensorVec[i] == originalVec[i]); } @@ -177,18 +202,21 @@ TEST_CASE("Postprocessing - erasing required elements") SECTION("Try left context erase") { - /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + /* Should not be able to erase the left context if it is the first iteration. */ + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); + REQUIRE(originalVec == tensorVec); } -} \ No newline at end of file +} diff --git a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc index 26ddb24..372152d 100644 --- a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,64 +16,54 @@ */ #include "Wav2LetterPreprocess.hpp" -#include -#include #include +#include constexpr uint32_t numMfccFeatures = 13; constexpr uint32_t numMfccVectors = 10; /* Test vector output: generated using test-asr-preprocessing.py. */ -int8_t expectedResult[numMfccVectors][numMfccFeatures*3] = { - /* Feature vec 0. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */ - - /* Feature vec 1. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 3. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 4 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, - -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15, - - /* Feature vec 5 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, - -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6, - - /* Feature vec 6. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 7. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 8. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 9. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10 +int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = { + /* Feature vec 0. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */ + /* Feature vec 1. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 3. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 4 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, + -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15}, + /* Feature vec 5 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, + -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6}, + /* Feature vec 6. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 7. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 8. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 9. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10} }; void PopulateTestWavVector(std::vector& vec) @@ -97,17 +87,17 @@ void PopulateTestWavVector(std::vector& vec) TEST_CASE("Preprocessing calculation INT8") { - /* Constants. */ - const uint32_t windowLen = 512; - const uint32_t windowStride = 160; - int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; - const float quantScale = 0.1410219967365265; - const int quantOffset = -11; + const uint32_t mfccWindowLen = 512; + const uint32_t mfccWindowStride = 160; + int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; + const float quantScale = 0.1410219967365265; + const int quantOffset = -11; /* Test wav memory. */ - std::vector testWav((windowStride * numMfccVectors) + - (windowLen - windowStride)); + std::vector testWav((mfccWindowStride * numMfccVectors) + + (mfccWindowLen - mfccWindowStride) + ); /* Populate with dummy input. */ PopulateTestWavVector(testWav); @@ -117,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8") /* Initialise dimensions and the test tensor. */ TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); - TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor( - tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); + TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor( + tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::audio::asr::Preprocess prep{ - numMfccFeatures, windowLen, windowStride, numMfccVectors}; + arm::app::AsrPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ - REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor)); + REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); /* Wrap the tensor with a std::vector for ease. */ - int8_t * tensorData = tflite::GetTensorData(&tensor); + auto* tensorData = tflite::GetTensorData(&inputTensor); std::vector vecResults = - std::vector(tensorData, tensorData + tensor.bytes); + std::vector(tensorData, tensorData + inputTensor.bytes); /* Check sizes. */ REQUIRE(vecResults.size() == sizeof(expectedResult)); diff --git a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp index e28a6da..ca5aab1 100644 --- a/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp +++ b/tests/use_case/noise_reduction/RNNoiseProcessingTests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RNNoiseProcess.hpp" +#include "RNNoiseFeatureProcessor.hpp" #include #include @@ -208,7 +208,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]") { SECTION("FP32") { - arm::app::rnn::RNNoiseProcess rnnoiseProcessor; + arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor; arm::app::rnn::FrameFeatures features; rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), features); @@ -223,7 +223,7 @@ TEST_CASE("RNNoise preprocessing calculation test", "[RNNoise]") TEST_CASE("RNNoise postprocessing test", "[RNNoise]") { - arm::app::rnn::RNNoiseProcess rnnoiseProcessor; + arm::app::rnn::RNNoiseFeatureProcessor rnnoiseProcessor; arm::app::rnn::FrameFeatures p; rnnoiseProcessor.PreprocessFrame(testWav0.data(), testWav0.size(), p); std::vector denoised(testWav0.size()); -- cgit v1.2.1