summaryrefslogtreecommitdiff
path: root/source/use_case/vww
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/vww')
-rw-r--r--source/use_case/vww/include/VisualWakeWordModel.hpp8
-rw-r--r--source/use_case/vww/src/UseCaseHandler.cc26
2 files changed, 21 insertions, 13 deletions
diff --git a/source/use_case/vww/include/VisualWakeWordModel.hpp b/source/use_case/vww/include/VisualWakeWordModel.hpp
index ee3a7bf..1ed9202 100644
--- a/source/use_case/vww/include/VisualWakeWordModel.hpp
+++ b/source/use_case/vww/include/VisualWakeWordModel.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited. All rights reserved.
+ * Copyright (c) 2021 - 2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,6 +24,12 @@ namespace app {
class VisualWakeWordModel : public Model {
+ public:
+ /* Indices for the expected model - based on input tensor shape */
+ static constexpr uint32_t ms_inputRowsIdx = 1;
+ static constexpr uint32_t ms_inputColsIdx = 2;
+ static constexpr uint32_t ms_inputChannelsIdx = 3;
+
protected:
/** @brief Gets the reference to op resolver interface class. */
const tflite::MicroOpResolver& GetOpResolver() override;
diff --git a/source/use_case/vww/src/UseCaseHandler.cc b/source/use_case/vww/src/UseCaseHandler.cc
index dbfe92b..e4dc479 100644
--- a/source/use_case/vww/src/UseCaseHandler.cc
+++ b/source/use_case/vww/src/UseCaseHandler.cc
@@ -50,8 +50,6 @@ namespace app {
constexpr uint32_t dataPsnTxtInfStartX = 150;
constexpr uint32_t dataPsnTxtInfStartY = 70;
- time_t infTimeMs = 0;
-
auto& model = ctx.Get<Model&>("model");
/* If the request has a valid size, set the image index. */
@@ -78,9 +76,13 @@ namespace app {
return false;
}
TfLiteIntArray* inputShape = model.GetInputShape(0);
- const uint32_t nCols = inputShape->data[2];
- const uint32_t nRows = inputShape->data[1];
- const uint32_t nChannels = (inputShape->size == 4) ? inputShape->data[3] : 1;
+ const uint32_t nCols = inputShape->data[arm::app::VisualWakeWordModel::ms_inputColsIdx];
+ const uint32_t nRows = inputShape->data[arm::app::VisualWakeWordModel::ms_inputRowsIdx];
+ if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast<uint32_t>(inputShape->size)) {
+ printf_err("Invalid channel index.\n");
+ return false;
+ }
+ const uint32_t nChannels = inputShape->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx];
std::vector<ClassificationResult> results;
@@ -163,7 +165,11 @@ namespace app {
return false;
}
- const uint32_t nChannels = (inputTensor->dims->size == 4) ? inputTensor->dims->data[3] : 1;
+ if (arm::app::VisualWakeWordModel::ms_inputChannelsIdx >= static_cast<uint32_t>(inputTensor->dims->size)) {
+ printf_err("Invalid channel index.\n");
+ return false;
+ }
+ const uint32_t nChannels = inputTensor->dims->data[arm::app::VisualWakeWordModel::ms_inputChannelsIdx];
const uint8_t* srcPtr = get_img_array(imIdx);
auto* dstPtr = static_cast<uint8_t *>(inputTensor->data.data);
@@ -172,11 +178,7 @@ namespace app {
* Visual Wake Word model accepts only one channel =>
* Convert image to grayscale here
**/
- for (size_t i = 0; i < copySz; ++i, srcPtr += 3) {
- *dstPtr++ = 0.2989*(*srcPtr) +
- 0.587*(*(srcPtr+1)) +
- 0.114*(*(srcPtr+2));
- }
+ image::RgbToGrayscale(srcPtr, dstPtr, copySz);
} else {
memcpy(inputTensor->data.data, srcPtr, copySz);
}
@@ -186,4 +188,4 @@ namespace app {
}
} /* namespace app */
-} /* namespace arm */ \ No newline at end of file
+} /* namespace arm */