summaryrefslogtreecommitdiff
path: root/tests/use_case/noise_reduction/RNNoiseModelTests.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tests/use_case/noise_reduction/RNNoiseModelTests.cc')
-rw-r--r--tests/use_case/noise_reduction/RNNoiseModelTests.cc29
1 files changed, 15 insertions, 14 deletions
diff --git a/tests/use_case/noise_reduction/RNNoiseModelTests.cc b/tests/use_case/noise_reduction/RNNoiseModelTests.cc
index 9720ba5..7bd83b1 100644
--- a/tests/use_case/noise_reduction/RNNoiseModelTests.cc
+++ b/tests/use_case/noise_reduction/RNNoiseModelTests.cc
@@ -23,14 +23,15 @@
#include <random>
namespace arm {
- namespace app {
- static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
- } /* namespace app */
+namespace app {
+ static uint8_t tensorArena[ACTIVATION_BUF_SZ] ACTIVATION_BUF_ATTRIBUTE;
+ namespace rnn {
+ extern uint8_t* GetModelPointer();
+ extern size_t GetModelLen();
+ } /* namespace rnn */
+} /* namespace app */
} /* namespace arm */
-extern uint8_t* GetModelPointer();
-extern size_t GetModelLen();
-
bool RunInference(arm::app::Model& model, std::vector<int8_t> vec,
const size_t sizeRequired, const size_t dataInputIndex)
{
@@ -73,8 +74,8 @@ TEST_CASE("Running random inference with TensorFlow Lite Micro and RNNoiseModel
REQUIRE_FALSE(model.IsInited());
REQUIRE(model.Init(arm::app::tensorArena,
sizeof(arm::app::tensorArena),
- GetModelPointer(),
- GetModelLen()));
+ arm::app::rnn::GetModelPointer(),
+ arm::app::rnn::GetModelLen()));
REQUIRE(model.IsInited());
model.ResetGruState();
@@ -128,9 +129,9 @@ TEST_CASE("Test initial GRU out state is 0", "[RNNoise]")
{
TestRNNoiseModel model{};
model.Init(arm::app::tensorArena,
- sizeof(arm::app::tensorArena),
- GetModelPointer(),
- GetModelLen());
+ sizeof(arm::app::tensorArena),
+ arm::app::rnn::GetModelPointer(),
+ arm::app::rnn::GetModelLen());
auto map = model.GetStateMap();
@@ -152,9 +153,9 @@ TEST_CASE("Test GRU state copy", "[RNNoise]")
{
TestRNNoiseModel model{};
model.Init(arm::app::tensorArena,
- sizeof(arm::app::tensorArena),
- GetModelPointer(),
- GetModelLen());
+ sizeof(arm::app::tensorArena),
+ arm::app::rnn::GetModelPointer(),
+ arm::app::rnn::GetModelLen());
REQUIRE(RunInferenceRandom(model, 0));
auto map = model.GetStateMap();