diff options
Diffstat (limited to 'tests/use_case/asr/InferenceTestWav2Letter.cc')
-rw-r--r-- | tests/use_case/asr/InferenceTestWav2Letter.cc | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/tests/use_case/asr/InferenceTestWav2Letter.cc b/tests/use_case/asr/InferenceTestWav2Letter.cc index 53c92ab..643f805 100644 --- a/tests/use_case/asr/InferenceTestWav2Letter.cc +++ b/tests/use_case/asr/InferenceTestWav2Letter.cc @@ -17,10 +17,21 @@ #include "TensorFlowLiteMicro.hpp" #include "Wav2LetterModel.hpp" #include "TestData_asr.hpp" +#include "BufAttributes.hpp" #include <catch.hpp> #include <random> +namespace arm { +namespace app { + static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; + namespace asr { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); + } /* namespace asr */ +} /* namespace app */ +} /* namespace arm */ + using namespace test; bool RunInference(arm::app::Model& model, const int8_t vec[], const size_t copySz) @@ -58,7 +69,10 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and Wav2LetterMod arm::app::Wav2LetterModel model{}; REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); + REQUIRE(model.Init(arm::app::tensorArena, + sizeof(arm::app::tensorArena), + arm::app::asr::GetModelPointer(), + arm::app::asr::GetModelLen())); REQUIRE(model.IsInited()); REQUIRE(RunInferenceRandom(model)); @@ -96,11 +110,14 @@ TEST_CASE("Running inference with Tflu and Wav2LetterModel Int8", "[Wav2Letter]" arm::app::Wav2LetterModel model{}; REQUIRE_FALSE(model.IsInited()); - REQUIRE(model.Init()); + REQUIRE(model.Init(arm::app::tensorArena, + sizeof(arm::app::tensorArena), + arm::app::asr::GetModelPointer(), + arm::app::asr::GetModelLen())); REQUIRE(model.IsInited()); TestInference<int8_t>(input_goldenFV, output_goldenFV, model); } } -}
\ No newline at end of file +} |