summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2021-10-15 10:17:33 +0100
committerIsabella Gottardi <isabella.gottardi@arm.com>2021-10-15 10:33:19 +0100
commit5cbcd9e2c85485db1fe8cf7b2445550e9dc36800 (patch)
treed42ed9864743202c36f90e9b7c5f5777e75d0a0f
parent118f73e0396fe66ee5cc3c0daec0882c7160a7cb (diff)
downloadml-embedded-evaluation-kit-5cbcd9e2c85485db1fe8cf7b2445550e9dc36800.tar.gz
MLECO-2423: [Fix] Case insensitive name clash
* Fixed vww usage warnings (imageData in RunInferene and _GetImageIdx function) Change-Id: I2c37e4e4cc8c8eca841690f2df8d525ed516ecc8
-rw-r--r--scripts/py/gen_test_data_cpp.py2
-rw-r--r--scripts/py/templates/iofmdata.cc.template (renamed from scripts/py/templates/testdata.cc.template)0
-rw-r--r--tests/use_case/vww/InferenceVisualWakeWordModelTests.cc39
3 files changed, 12 insertions, 29 deletions
diff --git a/scripts/py/gen_test_data_cpp.py b/scripts/py/gen_test_data_cpp.py
index ea4bd6f..a58f415 100644
--- a/scripts/py/gen_test_data_cpp.py
+++ b/scripts/py/gen_test_data_cpp.py
@@ -85,7 +85,7 @@ def write_individual_cc_file(filename, cc_filename, header_filename, header_temp
hex_line_generator = (', '.join(map(hex, sub_arr))
for sub_arr in np.array_split(fm_data, math.ceil(len(fm_data) / 20)))
- env.get_template('testdata.cc.template').stream(common_template_header=hdr,
+ env.get_template('iofmdata.cc.template').stream(common_template_header=hdr,
include_h=header_filename,
var_name=array_name,
fm_data=hex_line_generator,
diff --git a/scripts/py/templates/testdata.cc.template b/scripts/py/templates/iofmdata.cc.template
index e3c1dc6..e3c1dc6 100644
--- a/scripts/py/templates/testdata.cc.template
+++ b/scripts/py/templates/iofmdata.cc.template
diff --git a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
index c109a62..3a42dde 100644
--- a/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
+++ b/tests/use_case/vww/InferenceVisualWakeWordModelTests.cc
@@ -15,21 +15,29 @@
* limitations under the License.
*/
-#include <catch.hpp>
-#include <random>
#include "hal.h"
-#include "InputFiles.hpp"
#include "ImageUtils.hpp"
#include "TestData_vww.hpp"
#include "VisualWakeWordModel.hpp"
#include "TensorFlowLiteMicro.hpp"
+#include <catch.hpp>
bool RunInference(arm::app::Model& model, const int8_t* imageData)
{
TfLiteTensor* inputTensor = model.GetInputTensor(0);
REQUIRE(inputTensor);
+ const size_t copySz = inputTensor->bytes < IFM_DATA_SIZE ?
+ inputTensor->bytes :
+ IFM_DATA_SIZE;
+
+ memcpy(inputTensor->data.data, imageData, copySz);
+
+ if(model.IsDataSigned()){
+ convertImgIoInt8(inputTensor->data.data, copySz);
+ }
+
return model.RunInference();
}
@@ -54,28 +62,3 @@ void TestInference(int imageIdx,arm::app::Model& model) {
CHECK(testVal == goldenVal);
}
}
-
-
-/**
- * @brief Given an image name, get its index
- * @param[in] imageName Name of the image expected
- * @return index of the image if valid and (-1) if not found
- */
-static int _GetImageIdx(std::string &imageName)
-{
- int imgIdx = -1;
- for (uint32_t i = 0 ; i < NUMBER_OF_FILES; ++i) {
- if (imageName == std::string(get_filename(i))) {
- info("Image %s exists at index %u\n", get_filename(i), i);
- imgIdx = static_cast<int>(i);
- break;
- }
- }
-
- if (-1 == imgIdx) {
- warn("Image %s not found!\n", imageName.c_str());
- }
-
- return imgIdx;
-}
-