diff options
author | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
---|---|---|
committer | Richard Burton <richard.burton@arm.com> | 2022-04-22 16:14:57 +0100 |
commit | b40ecf8522052809d2351677a96195d69e4d0c16 (patch) | |
tree | 8647dfdae7bcae0ec6d9564ba7a971819fdda431 /tests/use_case/asr/Wav2LetterPostprocessingTest.cc | |
parent | c291144b7f08c21d08cdaf79cc64dc420ca70070 (diff) | |
download | ml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz |
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments
Removed the inference runner code as it wasn't used
Fixes to doc strings
Consistent naming e.g. Asr/Kws instead of ASR/KWS
Signed-off-by: Richard Burton <richard.burton@arm.com>
Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'tests/use_case/asr/Wav2LetterPostprocessingTest.cc')
-rw-r--r-- | tests/use_case/asr/Wav2LetterPostprocessingTest.cc | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc index d0b6505..11c4919 100644 --- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc @@ -24,9 +24,9 @@ template <typename T> static TfLiteTensor GetTestTensor( - std::vector <int>& shape, - T initVal, - std::vector<T>& vectorBuf) + std::vector<int>& shape, + T initVal, + std::vector<T>& vectorBuf) { REQUIRE(0 != shape.size()); @@ -60,7 +60,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; REQUIRE(!post.DoPostProcess()); @@ -80,7 +80,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -110,7 +110,7 @@ TEST_CASE("Postprocessing - erasing required elements") { std::vector<int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -127,7 +127,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -165,7 +165,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -208,7 +208,7 @@ TEST_CASE("Postprocessing - erasing required elements") tensorShape, 100, tensorVec); /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ |