diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/use_case/asr/AsrFeaturesTests.cc | 6 | ||||
-rw-r--r-- | tests/use_case/asr/Wav2LetterPostprocessingTest.cc | 18 | ||||
-rw-r--r-- | tests/use_case/asr/Wav2LetterPreprocessingTest.cc | 4 |
3 files changed, 14 insertions, 14 deletions
diff --git a/tests/use_case/asr/AsrFeaturesTests.cc b/tests/use_case/asr/AsrFeaturesTests.cc index 6c23598..fe93c83 100644 --- a/tests/use_case/asr/AsrFeaturesTests.cc +++ b/tests/use_case/asr/AsrFeaturesTests.cc @@ -23,19 +23,19 @@ #include <catch.hpp> #include <random> -class TestPreprocess : public arm::app::ASRPreProcess { +class TestPreprocess : public arm::app::AsrPreProcess { public: static bool ComputeDeltas(arm::app::Array2d<float>& mfcc, arm::app::Array2d<float>& delta1, arm::app::Array2d<float>& delta2) { - return ASRPreProcess::ComputeDeltas(mfcc, delta1, delta2); + return AsrPreProcess::ComputeDeltas(mfcc, delta1, delta2); } static void NormaliseVec(arm::app::Array2d<float>& vec) { - return ASRPreProcess::StandardizeVecF32(vec); + return AsrPreProcess::StandardizeVecF32(vec); } }; diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc index d0b6505..11c4919 100644 --- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc @@ -24,9 +24,9 @@ template <typename T> static TfLiteTensor GetTestTensor( - std::vector <int>& shape, - T initVal, - std::vector<T>& vectorBuf) + std::vector<int>& shape, + T initVal, + std::vector<T>& vectorBuf) { REQUIRE(0 != shape.size()); @@ -60,7 +60,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; REQUIRE(!post.DoPostProcess()); @@ -80,7 +80,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -110,7 +110,7 @@ TEST_CASE("Postprocessing - erasing required elements") { std::vector<int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -127,7 +127,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -165,7 +165,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -208,7 +208,7 @@ TEST_CASE("Postprocessing - erasing required elements") tensorShape, 100, tensorVec); /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ diff --git a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc index 0280af6..0a44093 100644 --- a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc @@ -111,8 +111,8 @@ TEST_CASE("Preprocessing calculation INT8") tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput"); /* Initialise pre-processing module. */ - arm::app::ASRPreProcess prep{&inputTensor, - numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; + arm::app::AsrPreProcess prep{&inputTensor, + numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride}; /* Invoke pre-processing. */ REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size())); |