diff options
Diffstat (limited to 'tests/use_case/noise_reduction/RNNoiseModelTests.cc')
-rw-r--r-- | tests/use_case/noise_reduction/RNNoiseModelTests.cc | 29 |
1 files changed, 15 insertions, 14 deletions
diff --git a/tests/use_case/noise_reduction/RNNoiseModelTests.cc b/tests/use_case/noise_reduction/RNNoiseModelTests.cc index 9720ba5..7bd83b1 100644 --- a/tests/use_case/noise_reduction/RNNoiseModelTests.cc +++ b/tests/use_case/noise_reduction/RNNoiseModelTests.cc @@ -23,14 +23,15 @@ #include <random> namespace arm { - namespace app { - static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; - } /* namespace app */ +namespace app { + static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE; + namespace rnn { + extern uint8_t* GetModelPointer(); + extern size_t GetModelLen(); + } /* namespace rnn */ +} /* namespace app */ } /* namespace arm */ -extern uint8_t* GetModelPointer(); -extern size_t GetModelLen(); - bool RunInference(arm::app::Model& model, std::vector<int8_t> vec, const size_t sizeRequired, const size_t dataInputIndex) { @@ -73,8 +74,8 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and RNNoiseModel REQUIRE_FALSE(model.IsInited()); REQUIRE(model.Init(arm::app::tensorArena, sizeof(arm::app::tensorArena), - GetModelPointer(), - GetModelLen())); + arm::app::rnn::GetModelPointer(), + arm::app::rnn::GetModelLen())); REQUIRE(model.IsInited()); model.ResetGruState(); @@ -128,9 +129,9 @@ TEST_CASE("Test initial GRU out state is 0", "[RNNoise]") { TestRNNoiseModel model{}; model.Init(arm::app::tensorArena, - sizeof(arm::app::tensorArena), - GetModelPointer(), - GetModelLen()); + sizeof(arm::app::tensorArena), + arm::app::rnn::GetModelPointer(), + arm::app::rnn::GetModelLen()); auto map = model.GetStateMap(); @@ -152,9 +153,9 @@ TEST_CASE("Test GRU state copy", "[RNNoise]") { TestRNNoiseModel model{}; model.Init(arm::app::tensorArena, - sizeof(arm::app::tensorArena), - GetModelPointer(), - GetModelLen()); + sizeof(arm::app::tensorArena), + arm::app::rnn::GetModelPointer(), + arm::app::rnn::GetModelLen()); REQUIRE(RunInferenceRandom(model, 0)); auto map = model.GetStateMap(); |