From c291144b7f08c21d08cdaf79cc64dc420ca70070 Mon Sep 17 00:00:00 2001 From: Richard Burton Date: Fri, 22 Apr 2022 09:08:21 +0100 Subject: MLECO-3077: Add ASR use case API * Minor adjustments to doc strings in KWS * Remove unused score threshold in KWS Signed-off-by: Richard Burton Change-Id: Ie1c5bf6f7bdbebb853b6a10cb7ba1c4a1d9a76c9 --- tests/common/PlatformMathTests.cpp | 33 +++++- tests/use_case/asr/AsrFeaturesTests.cc | 52 +-------- tests/use_case/asr/Wav2LetterPostprocessingTest.cc | 124 ++++++++++++--------- tests/use_case/asr/Wav2LetterPreprocessingTest.cc | 120 ++++++++++---------- 4 files changed, 166 insertions(+), 163 deletions(-) (limited to 'tests') diff --git a/tests/common/PlatformMathTests.cpp b/tests/common/PlatformMathTests.cpp index ab1153f..c07cbf1 100644 --- a/tests/common/PlatformMathTests.cpp +++ b/tests/common/PlatformMathTests.cpp @@ -150,13 +150,28 @@ TEST_CASE("Test SqrtF32") TEST_CASE("Test MeanF32") { - /*Test Constants: */ + /* Test Constants: */ std::vector input {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000}; /* Manually calculated mean of above vector */ float expectedResult = 0.100; CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size()))); + + /* Mean of 0 */ + std::vector input2{1, 2, -1, -2}; + float expectedResult2 = 0.0f; + CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size()))); + + /* All 0s */ + std::vector input3 = std::vector(9, 0); + float expectedResult3 = 0.0f; + CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size()))); + + /* All 1s */ + std::vector input4 = std::vector(9, 1); + float expectedResult4 = 1.0f; + CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size()))); } TEST_CASE("Test StdDevF32") @@ -184,6 +199,22 @@ TEST_CASE("Test StdDevF32") float expectedResult = 0.969589282958136; CHECK (expectedResult == Approx(output)); + + /* All 0s should have 0 std dev. */ + std::vector input2 = std::vector(4, 0); + float expectedResult2 = 0.0f; + CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f))); + + /* All 1s should have 0 std dev. */ + std::vector input3 = std::vector(4, 1); + float expectedResult3 = 0.0f; + CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f))); + + /* Manually calclualted std value */ + std::vector input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; + float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size()); + float expectedResult4 = 2.872281323; + CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2))); } TEST_CASE("Test FFT32") diff --git a/tests/use_case/asr/AsrFeaturesTests.cc b/tests/use_case/asr/AsrFeaturesTests.cc index 940c25f..6c23598 100644 --- a/tests/use_case/asr/AsrFeaturesTests.cc +++ b/tests/use_case/asr/AsrFeaturesTests.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -23,29 +23,19 @@ #include #include -class TestPreprocess : public arm::app::audio::asr::Preprocess { +class TestPreprocess : public arm::app::ASRPreProcess { public: static bool ComputeDeltas(arm::app::Array2d& mfcc, arm::app::Array2d& delta1, arm::app::Array2d& delta2) { - return Preprocess::ComputeDeltas(mfcc, delta1, delta2); - } - - static float GetMean(arm::app::Array2d& vec) - { - return Preprocess::GetMean(vec); - } - - static float GetStdDev(arm::app::Array2d& vec, const float mean) - { - return Preprocess::GetStdDev(vec, mean); + return ASRPreProcess::ComputeDeltas(mfcc, delta1, delta2); } static void NormaliseVec(arm::app::Array2d& vec) { - return Preprocess::NormaliseVec(vec); + return ASRPreProcess::StandardizeVecF32(vec); } }; @@ -126,40 +116,6 @@ TEST_CASE("Floating point asr features calculation", "[ASR]") } - SECTION("Mean") - { - std::vector> mean1vec{{1, 2}, - {-1, -2}}; - arm::app::Array2d mean1(2,2); /* {{1, 2},{-1, -2}} */ - populateArray2dWithVectorOfVector(mean1vec, mean1); - REQUIRE(0 == Approx(TestPreprocess::GetMean(mean1))); - - arm::app::Array2d mean2(2, 2); - std::fill(mean2.begin(), mean2.end(), 0.f); - REQUIRE(0 == Approx(TestPreprocess::GetMean(mean2))); - - arm::app::Array2d mean3(3,3); - std::fill(mean3.begin(), mean3.end(), 1.f); - REQUIRE(1 == Approx(TestPreprocess::GetMean(mean3))); - } - - SECTION("Std") - { - arm::app::Array2d std1(2, 2); - std::fill(std1.begin(), std1.end(), 0.f); /* {{0, 0}, {0, 0}} */ - REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std1, 0))); - - std::vector> std2vec{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 0}}; - arm::app::Array2d std2(2,5); - populateArray2dWithVectorOfVector(std2vec, std2); - const float mean = TestPreprocess::GetMean(std2); - REQUIRE(2.872281323 == Approx(TestPreprocess::GetStdDev(std2, mean))); - - arm::app::Array2d std3(2,2); - std::fill(std3.begin(), std3.end(), 1.f); /* std3{{1, 1}, {1, 1}}; */ - REQUIRE(0 == Approx(TestPreprocess::GetStdDev(std3, 1))); - } - SECTION("Norm") { auto checker = [&](arm::app::Array2d& d, std::vector& g) { TestPreprocess::NormaliseVec(d); diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc index 9ed2e1b..d0b6505 100644 --- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,7 @@ */ #include "Wav2LetterPostprocess.hpp" #include "Wav2LetterModel.hpp" +#include "ClassificationResult.hpp" #include #include @@ -47,85 +48,105 @@ TEST_CASE("Checking return value") { SECTION("Mismatched post processing parameters and tensor size") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; std::vector tensorShape = {1, 1, 1, 13}; std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); - REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + tensorShape, 100, tensorVec); + + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; + + REQUIRE(!post.DoPostProcess()); } SECTION("Post processing succeeds") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - - std::vector tensorShape = {1, 1, 13, 1}; - std::vector tensorVec; + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector dummyResult; + std::vector tensorShape = {1, 1, 13, 1}; + std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); } } TEST_CASE("Postprocessing - erasing required elements") { - constexpr uint32_t ctxLen = 5; + constexpr uint32_t outputCtxLen = 5; constexpr uint32_t innerLen = 3; - constexpr uint32_t nRows = 2*ctxLen + innerLen; + constexpr uint32_t nRows = 2*outputCtxLen + innerLen; constexpr uint32_t nCols = 10; constexpr uint32_t blankTokenIdx = nCols - 1; - std::vector tensorShape = {1, 1, nRows, nCols}; + std::vector tensorShape = {1, 1, nRows, nCols}; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector dummyLabels = {"a", "b", "$"}; + std::vector dummyResult; SECTION("First and last iteration") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - - std::vector tensorVec; - TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + std::vector tensorVec; + TfLiteTensor tensor = GetTestTensor(tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vectororiginalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } SECTION("Right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector originalVec = tensorVec; + std::vector originalVec = tensorVec; + //auto tensorData = tflite::GetTensorData(tensor); /* This step should erase the right context only. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { - /* Check right context elements are zeroed. */ + /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); } /* Check left context is preserved. */ @@ -134,46 +155,47 @@ TEST_CASE("Postprocessing - erasing required elements") } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { CHECK(tensorVec[i] == originalVec[i]); } } SECTION("Left and right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* This step should erase right context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The first and last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { /* Check left and right context elements are zeroed. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); CHECK(tensorVec[i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); CHECK(tensorVec[i*nCols + j] == 0); } } } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { /* Check left context is preserved. */ CHECK(tensorVec[i] == originalVec[i]); } @@ -181,18 +203,20 @@ TEST_CASE("Postprocessing - erasing required elements") SECTION("Try left context erase") { - /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector tensorVec; TfLiteTensor tensor = GetTestTensor( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + /* Should not be able to erase the left context if it is the first iteration. */ + arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector originalVec = tensorVec; /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } diff --git a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc index 457257f..0280af6 100644 --- a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,55 +24,46 @@ constexpr uint32_t numMfccVectors = 10; /* Test vector output: generated using test-asr-preprocessing.py. */ int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = { - /* Feature vec 0. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */ - - /* Feature vec 1. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 3. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 4 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, - -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15, - - /* Feature vec 5 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, - -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6, - - /* Feature vec 6. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 7. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 8. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 9. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10 + /* Feature vec 0. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */ + /* Feature vec 1. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 3. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 4 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, + -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15}, + /* Feature vec 5 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, + -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6}, + /* Feature vec 6. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 7. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 8. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 9. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10} }; void PopulateTestWavVector(std::vector& vec) @@ -97,15 +88,16 @@ void PopulateTestWavVector(std::vector& vec) TEST_CASE("Preprocessing calculation INT8") { /* Constants. */ - const uint32_t windowLen = 512; - const uint32_t windowStride = 160; - int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; - const float quantScale = 0.1410219967365265; - const int quantOffset = -11; + const uint32_t mfccWindowLen = 512; + const uint32_t mfccWindowStride = 160; + int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; + const float quantScale = 0.1410219967365265; + const int quantOffset = -11; /* Test wav memory. */ - std::vector testWav((windowStride * numMfccVectors) + - (windowLen - windowStride)); + std::vector testWav((mfccWindowStride * numMfccVectors) + + (mfccWindowLen - mfccWindowStride) + ); /* Populate with dummy input. */ PopulateTestWavVector(testWav); @@ -115,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8") /* Initialise dimensions and the test tensor. */ TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); - TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor( - tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); + TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor( + tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::audio::asr::Preprocess prep{ - numMfccFeatures, windowLen, windowStride, numMfccVectors}; + arm::app::ASRPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ - REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor)); + REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); /* Wrap the tensor with a std::vector for ease. */ - auto* tensorData = tflite::GetTensorData(&tensor); + auto* tensorData = tflite::GetTensorData(&inputTensor); std::vector vecResults = - std::vector(tensorData, tensorData + tensor.bytes); + std::vector(tensorData, tensorData + inputTensor.bytes); /* Check sizes. */ REQUIRE(vecResults.size() == sizeof(expectedResult)); -- cgit v1.2.1