diff options
Diffstat (limited to 'tests/use_case/asr/Wav2LetterPostprocessingTest.cc')
-rw-r--r-- | tests/use_case/asr/Wav2LetterPostprocessingTest.cc | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc index d0b6505..11c4919 100644 --- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc +++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc @@ -24,9 +24,9 @@ template <typename T> static TfLiteTensor GetTestTensor( - std::vector <int>& shape, - T initVal, - std::vector<T>& vectorBuf) + std::vector<int>& shape, + T initVal, + std::vector<T>& vectorBuf) { REQUIRE(0 != shape.size()); @@ -60,7 +60,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; REQUIRE(!post.DoPostProcess()); @@ -80,7 +80,7 @@ TEST_CASE("Checking return value") TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -110,7 +110,7 @@ TEST_CASE("Postprocessing - erasing required elements") { std::vector<int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -127,7 +127,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -165,7 +165,7 @@ TEST_CASE("Postprocessing - erasing required elements") std::vector <int8_t> tensorVec; TfLiteTensor tensor = GetTestTensor<int8_t>( tensorShape, 100, tensorVec); - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ @@ -208,7 +208,7 @@ TEST_CASE("Postprocessing - erasing required elements") tensorShape, 100, tensorVec); /* Should not be able to erase the left context if it is the first iteration. */ - arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen, + arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen, blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx}; /* Copy elements to compare later. */ |