summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
committerRichard Burton <richard.burton@arm.com>2022-04-22 16:14:57 +0100
commitb40ecf8522052809d2351677a96195d69e4d0c16 (patch)
tree8647dfdae7bcae0ec6d9564ba7a971819fdda431 /tests
parentc291144b7f08c21d08cdaf79cc64dc420ca70070 (diff)
downloadml-embedded-evaluation-kit-b40ecf8522052809d2351677a96195d69e4d0c16.tar.gz
MLECO-3174: Minor refactoring to implemented use case APIS
Looks large but it is mainly just many small adjustments Removed the inference runner code as it wasn't used Fixes to doc strings Consistent naming e.g. Asr/Kws instead of ASR/KWS Signed-off-by: Richard Burton <richard.burton@arm.com> Change-Id: I43b620b5c51d7910a29a63b509ac4d8a82c3a8fc
Diffstat (limited to 'tests')
-rw-r--r--tests/use_case/asr/AsrFeaturesTests.cc6
-rw-r--r--tests/use_case/asr/Wav2LetterPostprocessingTest.cc18
-rw-r--r--tests/use_case/asr/Wav2LetterPreprocessingTest.cc4
3 files changed, 14 insertions, 14 deletions
diff --git a/tests/use_case/asr/AsrFeaturesTests.cc b/tests/use_case/asr/AsrFeaturesTests.cc
index 6c23598..fe93c83 100644
--- a/tests/use_case/asr/AsrFeaturesTests.cc
+++ b/tests/use_case/asr/AsrFeaturesTests.cc
@@ -23,19 +23,19 @@
#include <catch.hpp>
#include <random>
-class TestPreprocess : public arm::app::ASRPreProcess {
+class TestPreprocess : public arm::app::AsrPreProcess {
public:
static bool ComputeDeltas(arm::app::Array2d<float>& mfcc,
arm::app::Array2d<float>& delta1,
arm::app::Array2d<float>& delta2)
{
- return ASRPreProcess::ComputeDeltas(mfcc, delta1, delta2);
+ return AsrPreProcess::ComputeDeltas(mfcc, delta1, delta2);
}
static void NormaliseVec(arm::app::Array2d<float>& vec)
{
- return ASRPreProcess::StandardizeVecF32(vec);
+ return AsrPreProcess::StandardizeVecF32(vec);
}
};
diff --git a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
index d0b6505..11c4919 100644
--- a/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
+++ b/tests/use_case/asr/Wav2LetterPostprocessingTest.cc
@@ -24,9 +24,9 @@
template <typename T>
static TfLiteTensor GetTestTensor(
- std::vector <int>& shape,
- T initVal,
- std::vector<T>& vectorBuf)
+ std::vector<int>& shape,
+ T initVal,
+ std::vector<T>& vectorBuf)
{
REQUIRE(0 != shape.size());
@@ -60,7 +60,7 @@ TEST_CASE("Checking return value")
TfLiteTensor tensor = GetTestTensor<int8_t>(
tensorShape, 100, tensorVec);
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
REQUIRE(!post.DoPostProcess());
@@ -80,7 +80,7 @@ TEST_CASE("Checking return value")
TfLiteTensor tensor = GetTestTensor<int8_t>(
tensorShape, 100, tensorVec);
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
@@ -110,7 +110,7 @@ TEST_CASE("Postprocessing - erasing required elements")
{
std::vector<int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(tensorShape, 100, tensorVec);
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
@@ -127,7 +127,7 @@ TEST_CASE("Postprocessing - erasing required elements")
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
tensorShape, 100, tensorVec);
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
@@ -165,7 +165,7 @@ TEST_CASE("Postprocessing - erasing required elements")
std::vector <int8_t> tensorVec;
TfLiteTensor tensor = GetTestTensor<int8_t>(
tensorShape, 100, tensorVec);
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
@@ -208,7 +208,7 @@ TEST_CASE("Postprocessing - erasing required elements")
tensorShape, 100, tensorVec);
/* Should not be able to erase the left context if it is the first iteration. */
- arm::app::ASRPostProcess post{classifier, &tensor, dummyLabels, dummyResult, outputCtxLen,
+ arm::app::AsrPostProcess post{&tensor, classifier, dummyLabels, dummyResult, outputCtxLen,
blankTokenIdx, arm::app::Wav2LetterModel::ms_outputRowsIdx};
/* Copy elements to compare later. */
diff --git a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
index 0280af6..0a44093 100644
--- a/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
+++ b/tests/use_case/asr/Wav2LetterPreprocessingTest.cc
@@ -111,8 +111,8 @@ TEST_CASE("Preprocessing calculation INT8")
tensorVec.data(), dims, quantScale, quantOffset, "preprocessedInput");
/* Initialise pre-processing module. */
- arm::app::ASRPreProcess prep{&inputTensor,
- numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride};
+ arm::app::AsrPreProcess prep{&inputTensor,
+ numMfccFeatures, numMfccVectors, mfccWindowLen, mfccWindowStride};
/* Invoke pre-processing. */
REQUIRE(prep.DoPreProcess(testWav.data(), testWav.size()));