diff options
Diffstat (limited to 'tests/use_case/kws_asr')
-rw-r--r-- | tests/use_case/kws_asr/MfccTests.cc | 8 | ||||
-rw-r--r-- | tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc | 142 | ||||
-rw-r--r-- | tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc | 126 |
3 files changed, 147 insertions, 129 deletions
diff --git a/tests/use_case/kws_asr/MfccTests.cc b/tests/use_case/kws_asr/MfccTests.cc index 3ebdcf4..883c215 100644 --- a/tests/use_case/kws_asr/MfccTests.cc +++ b/tests/use_case/kws_asr/MfccTests.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -93,13 +93,13 @@ const std::vector<float> testWavMfcc { -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, }; -arm::app::audio::MicroNetMFCC GetMFCCInstance() { - const int sampFreq = arm::app::audio::MicroNetMFCC::ms_defaultSamplingFreq; +arm::app::audio::MicroNetKwsMFCC GetMFCCInstance() { + const int sampFreq = arm::app::audio::MicroNetKwsMFCC::ms_defaultSamplingFreq; const int frameLenMs = 40; const int frameLenSamples = sampFreq * frameLenMs * 0.001; const int numMfccFeats = 10; - return arm::app::audio::MicroNetMFCC(numMfccFeats, frameLenSamples); + return arm::app::audio::MicroNetKwsMFCC(numMfccFeats, frameLenSamples); } template <class T> diff --git a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc index 6fd7df3..e343b66 100644 --- a/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPostprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,15 +16,17 @@ */ #include "Wav2LetterPostprocess.hpp" #include "Wav2LetterModel.hpp" +#include "ClassificationResult.hpp" #include <algorithm> #include <catch.hpp> #include <limits> template <typename T> -static TfLiteTensor GetTestTensor(std::vector <int>& shape, - T initVal, - std::vector<T>& vectorBuf) +static TfLiteTensor GetTestTensor( + std::vector<int>& shape, + T initVal, + std::vector<T>& vectorBuf) { REQUIRE(0 != shape.size()); @@ -38,91 +40,112 @@ static TfLiteTensor GetTestTensor(std::vector <int>& shape, vectorBuf = std::vector<T>(sizeInBytes, initVal); TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(shape.data()); return tflite::testing::CreateQuantizedTensor( - vectorBuf.data(), dims, - 1, 0, "test-tensor"); + vectorBuf.data(), dims, + 1, 0, "test-tensor"); } TEST_CASE("Checking return value") { SECTION("Mismatched post processing parameters and tensor size") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector<std::string> dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector<arm::app::ClassificationResult> dummyResult; std::vector <int> tensorShape = {1, 1, 1, 13}; std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( - tensorShape, 100, tensorVec); - REQUIRE(false == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; + + REQUIRE(!post.DoPostProcess()); } SECTION("Post processing succeeds") { - const uint32_t ctxLen = 5; - const uint32_t innerLen = 3; - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, 0}; - - std::vector <int> tensorShape = {1, 1, 13, 1}; - std::vector <int8_t> tensorVec; + const uint32_t outputCtxLen = 5; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector<std::string> dummyLabels = {"a", "b", "$"}; + const uint32_t blankTokenIdx = 2; + std::vector<arm::app::ClassificationResult> dummyResult; + std::vector<int> tensorShape = {1, 1, 13, 1}; + std::vector<int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector <int8_t> originalVec = tensorVec; + std::vector<int8_t> originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); } } + TEST_CASE("Postprocessing - erasing required elements") { - constexpr uint32_t ctxLen = 5; + constexpr uint32_t outputCtxLen = 5; constexpr uint32_t innerLen = 3; - constexpr uint32_t nRows = 2*ctxLen + innerLen; + constexpr uint32_t nRows = 2*outputCtxLen + innerLen; constexpr uint32_t nCols = 10; constexpr uint32_t blankTokenIdx = nCols - 1; - std::vector <int> tensorShape = {1, 1, nRows, nCols}; + std::vector<int> tensorShape = {1, 1, nRows, nCols}; + arm::app::AsrClassifier classifier; + arm::app::Wav2LetterModel model; + model.Init(); + std::vector<std::string> dummyLabels = {"a", "b", "$"}; + std::vector<arm::app::ClassificationResult> dummyResult; SECTION("First and last iteration") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector <int8_t> tensorVec; - TfLiteTensor tensor = GetTestTensor<int8_t>( - tensorShape, 100, tensorVec); + std::vector<int8_t> tensorVec; + TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector <int8_t> originalVec = tensorVec; + std::vector<int8_t>originalVec = tensorVec; /* This step should not erase anything. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec == tensorVec); } SECTION("Right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ - std::vector <int8_t> originalVec = tensorVec; + std::vector<int8_t> originalVec = tensorVec; /* This step should erase the right context only. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { - /* Check right context elements are zeroed. */ + /* Check right context elements are zeroed. Blank token idx should be set to 1 when erasing. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); } /* Check left context is preserved. */ @@ -131,45 +154,47 @@ TEST_CASE("Postprocessing - erasing required elements") } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { CHECK(tensorVec[i] == originalVec[i]); } } SECTION("Left and right context erase") { - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector <int8_t> tensorVec; - TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec); + TfLiteTensor tensor = GetTestTensor<int8_t>( + tensorShape, 100, tensorVec); + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector <int8_t> originalVec = tensorVec; /* This step should erase right context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + post.m_lastIteration = false; + REQUIRE(post.DoPostProcess()); /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, false)); + REQUIRE(post.DoPostProcess()); REQUIRE(originalVec != tensorVec); /* The first and last ctxLen * 10 elements should be gone. */ - for (size_t i = 0; i < ctxLen; ++i) { + for (size_t i = 0; i < outputCtxLen; ++i) { for (size_t j = 0; j < nCols; ++j) { /* Check left and right context elements are zeroed. */ if (j == blankTokenIdx) { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 1); - CHECK(tensorVec[i * nCols + j] == 1); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 1); + CHECK(tensorVec[i*nCols + j] == 1); } else { - CHECK(tensorVec[(ctxLen + innerLen) * nCols + i * nCols + j] == 0); - CHECK(tensorVec[i * nCols + j] == 0); + CHECK(tensorVec[(outputCtxLen + innerLen) * nCols + i*nCols + j] == 0); + CHECK(tensorVec[i*nCols + j] == 0); } } } /* Check inner elements are preserved. */ - for (size_t i = ctxLen * nCols; i < (ctxLen + innerLen) * nCols; ++i) { + for (size_t i = outputCtxLen * nCols; i < (outputCtxLen + innerLen) * nCols; ++i) { /* Check left context is preserved. */ CHECK(tensorVec[i] == originalVec[i]); } @@ -177,18 +202,21 @@ TEST_CASE("Postprocessing - erasing required elements") SECTION("Try left context erase") { - /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::audio::asr::Postprocess post{ctxLen, innerLen, blankTokenIdx}; - std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( - tensorShape, 100, tensorVec); + tensorShape, 100, tensorVec); + + /* Should not be able to erase the left context if it is the first iteration. */ + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, + blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ std::vector <int8_t> originalVec = tensorVec; /* Calling it the second time should erase the left context. */ - REQUIRE(true == post.Invoke(&tensor, arm::app::Wav2LetterModel::ms_outputRowsIdx, true)); + post.m_lastIteration = true; + REQUIRE(post.DoPostProcess()); + REQUIRE(originalVec == tensorVec); } -}
\ No newline at end of file +} diff --git a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc index 26ddb24..372152d 100644 --- a/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/kws_asr/Wav2LetterPreprocessingTest.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. All rights reserved. + * Copyright (c) 2021-2022 Arm Limited. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,64 +16,54 @@ */ #include "Wav2LetterPreprocess.hpp" -#include <algorithm> -#include <catch.hpp> #include <limits> +#include <catch.hpp> constexpr uint32_t numMfccFeatures = 13; constexpr uint32_t numMfccVectors = 10; /* Test vector output: generated using test-asr-preprocessing.py. */ -int8_t expectedResult[numMfccVectors][numMfccFeatures*3] = { - /* Feature vec 0. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, /* Delta 2. */ - - /* Feature vec 1. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 3. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 4 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, - -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15, - - /* Feature vec 5 : this should have valid delta 1 and delta 2. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, - -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, - -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6, - - /* Feature vec 6. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 7. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 8. */ - -32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, - - /* Feature vec 9. */ - -31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, - -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, - -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10 +int8_t expectedResult[numMfccVectors][numMfccFeatures * 3] = { + /* Feature vec 0. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, /* MFCCs. */ + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, /* Delta 1. */ + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, /* Delta 2. */ + /* Feature vec 1. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -12, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 3. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 4 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -9, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -38, -29, -9, 1, -2, -7, -8, -8, -12, -16, -14, -5, 5, + -68, -50, -13, 5, 0, -9, -9, -8, -13, -20, -19, -3, 15}, + /* Feature vec 5 : this should have valid delta 1 and delta 2. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -11, -12, -12, + -62, -45, -11, 5, 0, -8, -9, -8, -12, -19, -17, -3, 13, + -27, -22, -13, -9, -11, -12, -12, -11, -11, -13, -13, -10, -6}, + /* Feature vec 6. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 7. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 8. */ + {-32, 4, -9, -8, -10, -10, -11, -11, -11, -12, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10}, + /* Feature vec 9. */ + {-31, 4, -9, -8, -10, -10, -11, -11, -11, -11, -12, -11, -11, + -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, -11, + -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10} }; void PopulateTestWavVector(std::vector<int16_t>& vec) @@ -97,17 +87,17 @@ void PopulateTestWavVector(std::vector<int16_t>& vec) TEST_CASE("Preprocessing calculation INT8") { - /* Constants. */ - const uint32_t windowLen = 512; - const uint32_t windowStride = 160; - int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; - const float quantScale = 0.1410219967365265; - const int quantOffset = -11; + const uint32_t mfccWindowLen = 512; + const uint32_t mfccWindowStride = 160; + int dimArray[] = {3, 1, numMfccFeatures * 3, numMfccVectors}; + const float quantScale = 0.1410219967365265; + const int quantOffset = -11; /* Test wav memory. */ - std::vector <int16_t> testWav((windowStride * numMfccVectors) + - (windowLen - windowStride)); + std::vector<int16_t> testWav((mfccWindowStride * numMfccVectors) + + (mfccWindowLen - mfccWindowStride) + ); /* Populate with dummy input. */ PopulateTestWavVector(testWav); @@ -117,20 +107,20 @@ TEST_CASE("Preprocessing calculation INT8") /* Initialise dimensions and the test tensor. */ TfLiteIntArray* dims= tflite::testing::IntArrayFromInts(dimArray); - TfLiteTensor tensor = tflite::testing::CreateQuantizedTensor( - tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); + TfLiteTensor inputTensor = tflite::testing::CreateQuantizedTensor( + tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::audio::asr::Preprocess prep{ - numMfccFeatures, windowLen, windowStride, numMfccVectors}; + arm::app::AsrPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ - REQUIRE(prep.Invoke(testWav.data(), testWav.size(), &tensor)); + REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); /* Wrap the tensor with a std::vector for ease. */ - int8_t * tensorData = tflite::GetTensorData<int8_t>(&tensor); + auto* tensorData = tflite::GetTensorData<int8_t>(&inputTensor); std::vector <int8_t> vecResults = - std::vector<int8_t>(tensorData, tensorData + tensor.bytes); + std::vector<int8_t>(tensorData, tensorData + inputTensor.bytes); /* Check sizes. */ REQUIRE(vecResults.size() == sizeof(expectedResult)); |