summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/use_case/vww/InferenceVisualWakeWordModelTests.cc39
1 files changed, 11 insertions, 28 deletions
diff --git a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
index c109a62..3a42dde 100644
--- a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
+++ b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
@@ -15,21 +15,29 @@
* limitations under the License.
*/
-#include <catch.hpp>
-#include <random>
#include "hal.h"
-#include "InputFiles.hpp"
#include "ImageUtils.hpp"
#include "TestData_vww.hpp"
#include "VisualWakeWordModel.hpp"
#include "TensorFlowLiteMicro.hpp"
+#include <catch.hpp>
bool RunInference(arm::app::Model& model, const int8_t* imageData)
{
TfLiteTensor* inputTensor = model.GetInputTensor(0);
REQUIRE(inputTensor);
+ const size_t copySz = inputTensor->bytes < IFM_DATA_SIZE ?
+ inputTensor->bytes :
+ IFM_DATA_SIZE;
+
+ memcpy(inputTensor->data.data, imageData, copySz);
+
+ if(model.IsDataSigned()){
+ convertImgIoInt8(inputTensor->data.data, copySz);
+ }
+
return model.RunInference();
}
@@ -54,28 +62,3 @@ void TestInference(int imageIdx,arm::app::Model& model) {
CHECK(testVal == goldenVal);
}
}
-
-
-/**
- * @brief Given an image name, get its index
- * @param[in] imageName Name of the image expected
- * @return index of the image if valid and (-1) if not found
- */
-static int _GetImageIdx(std::string &imageName)
-{
- int imgIdx = -1;
- for (uint32_t i = 0 ; i < NUMBER_OF_FILES; ++i) {
- if (imageName == std::string(get_filename(i))) {
- info("Image %s exists at index %u\n", get_filename(i), i);
- imgIdx = static_cast<int>(i);
- break;
- }
- }
-
- if (-1 == imgIdx) {
- warn("Image %s not found!\n", imageName.c_str());
- }
-
- return imgIdx;
-}
-