diff options
Diffstat (limited to 'tests/use_case/kws_asr/InferenceTestWav2Letter.cc')
-rw-r--r-- | tests/use_case/kws_asr/InferenceTestWav2Letter.cc | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/tests/use_case/kws_asr/InferenceTestWav2Letter.cc b/tests/use_case/kws_asr/InferenceTestWav2Letter.cc index 1c5f20a..5d30211 100644 --- a/tests/use_case/kws_asr/InferenceTestWav2Letter.cc +++ b/tests/use_case/kws_asr/InferenceTestWav2Letter.cc @@ -17,10 +17,22 @@ #include "TensorFlowLiteMicro.hpp" #include "Wav2LetterModel.hpp" #include "TestData_asr.hpp" +#include "BufAttributes.hpp" #include <catch.hpp> #include <random> +namespace arm { + namespace app { + static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; + + namespace asr { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); + } + } /* namespace app */ +} /* namespace arm */ + namespace test { namespace asr { @@ -59,7 +71,10 @@ TEST_CASE("Running random inference with Tflu and Wav2LetterModel Int8", "[Wav2L arm::app::Wav2LetterModel model{}; REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); + REQUIRE(model.Init(arm::app::tensorArena, + sizeof(arm::app::tensorArena), + arm::app::asr::GetModelPointer(), + arm::app::asr::GetModelLen())); REQUIRE(model.IsInited()); REQUIRE(RunInferenceRandom(model)); @@ -98,7 +113,10 @@ TEST_CASE("Running inference with Tflu and Wav2LetterModel Int8", "[Wav2Letter]" arm::app::Wav2LetterModel model{}; REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); + REQUIRE(model.Init(arm::app::tensorArena, + sizeof(arm::app::tensorArena), + arm::app::asr::GetModelPointer(), + arm::app::asr::GetModelLen())); REQUIRE(model.IsInited()); TestInference<int8_t>(input_goldenFV, output_goldenFV, model); |