diff options
-rw-r--r-- | scripts/py/gen_test_data_cpp.py | 2 | ||||
-rw-r--r-- | scripts/py/templates/iofmdata.cc.template (renamed from scripts/py/templates/testdata.cc.template) | 0 | ||||
-rw-r--r-- | tests/use_case/vww/InferenceVisualWakeWordModelTests.cc | 39 |
3 files changed, 12 insertions, 29 deletions
diff --git a/scripts/py/gen_test_data_cpp.py b/scripts/py/gen_test_data_cpp.py index ea4bd6f..a58f415 100644 --- a/scripts/py/gen_test_data_cpp.py +++ b/scripts/py/gen_test_data_cpp.py @@ -85,7 +85,7 @@ def write_individual_cc_file(filename, cc_filename, header_filename, header_temp hex_line_generator = (', '.join(map(hex, sub_arr)) for sub_arr in np.array_split(fm_data, math.ceil(len(fm_data) / 20))) - env.get_template('testdata.cc.template').stream(common_template_header=hdr, + env.get_template('iofmdata.cc.template').stream(common_template_header=hdr, include_h=header_filename, var_name=array_name, fm_data=hex_line_generator, diff --git a/scripts/py/templates/testdata.cc.template b/scripts/py/templates/iofmdata.cc.template index e3c1dc6..e3c1dc6 100644 --- a/scripts/py/templates/testdata.cc.template +++ b/scripts/py/templates/iofmdata.cc.template diff --git a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc index c109a62..3a42dde 100644 --- a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc +++ b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc @@ -15,21 +15,29 @@ * limitations under the License. */ -#include <catch.hpp> -#include <random> #include "hal.h" -#include "InputFiles.hpp" #include "ImageUtils.hpp" #include "TestData_vww.hpp" #include "VisualWakeWordModel.hpp" #include "TensorFlowLiteMicro.hpp" +#include <catch.hpp> bool RunInference(arm::app::Model& model, const int8_t* imageData) { TfLiteTensor* inputTensor = model.GetInputTensor(0); REQUIRE(inputTensor); + const size_t copySz = inputTensor->bytes < IFM_DATA_SIZE ? + inputTensor->bytes : + IFM_DATA_SIZE; + + memcpy(inputTensor->data.data, imageData, copySz); + + if(model.IsDataSigned()){ + convertImgIoInt8(inputTensor->data.data, copySz); + } + return model.RunInference(); } @@ -54,28 +62,3 @@ void TestInference(int imageIdx,arm::app::Model& model) { CHECK(testVal == goldenVal); } } - - -/** - * @brief Given an image name, get its index - * @param[in] imageName Name of the image expected - * @return index of the image if valid and (-1) if not found - */ -static int _GetImageIdx(std::string &imageName) -{ - int imgIdx = -1; - for (uint32_t i = 0 ; i < NUMBER_OF_FILES; ++i) { - if (imageName == std::string(get_filename(i))) { - info("Image %s exists at index %u\n", get_filename(i), i); - imgIdx = static_cast<int>(i); - break; - } - } - - if (-1 == imgIdx) { - warn("Image %s not found!\n", imageName.c_str()); - } - - return imgIdx; -} - |