summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Burton <richard.burton@arm.com>2022-05-17 12:52:50 +0100
committerRichard Burton <richard.burton@arm.com>2022-05-20 11:08:32 +0100
commit6f6df0934f991b64fef494b86643b3f5074fca0e (patch)
tree3f04c50c7ee3bcaea4fa4e9c64c81b27d7cfc4fa
parent8c61c0a3cb8d6b534d1e423211e06b89f45bf223 (diff)
downloadml-embedded-evaluation-kit-6f6df0934f991b64fef494b86643b3f5074fca0e.tar.gz
Remove dependency on extern defined constants from OD use case
OD API now takes in these paramaters as part of the constructor Change-Id: I4cce25e364b2a99847b4540440db059997f6a81b
-rw-r--r--source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp47
-rw-r--r--source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc42
-rw-r--r--source/use_case/object_detection/src/UseCaseHandler.cc9
-rw-r--r--tests/use_case/object_detection/InferenceTestYoloFastest.cc6
4 files changed, 48 insertions, 56 deletions
diff --git a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
index 6a53688..b66edbf 100644
--- a/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
+++ b/source/application/api/use_case/object_detection/include/DetectorPostProcessing.hpp
@@ -28,6 +28,18 @@ namespace arm {
namespace app {
namespace object_detection {
+ struct PostProcessParams {
+ int inputImgRows{};
+ int inputImgCols{};
+ int originalImageSize{};
+ const float* anchor1;
+ const float* anchor2;
+ float threshold = 0.5f;
+ float nms = 0.45f;
+ int numClasses = 1;
+ int topN = 0;
+ };
+
struct Branch {
int resolution;
int numBox;
@@ -57,25 +69,15 @@ namespace object_detection {
public:
/**
* @brief Constructor.
- * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
- * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
- * @param[out] results Vector of detected results.
- * @param[in] inputImgRows Number of rows in the input image.
- * @param[in] inputImgCols Number of columns in the input image.
- * @param[in] threshold Post-processing threshold.
- * @param[in] nms Non-maximum Suppression threshold.
- * @param[in] numClasses Number of classes.
- * @param[in] topN Top N for each class.
+ * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
+ * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
+ * @param[out] results Vector of detected results.
+ * @param[in] postProcessParams Struct of various parameters used in post-processing.
**/
explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
TfLiteTensor* outputTensor1,
std::vector<object_detection::DetectionResult>& results,
- int inputImgRows,
- int inputImgCols,
- float threshold = 0.5f,
- float nms = 0.45f,
- int numClasses = 1,
- int topN = 0);
+ const object_detection::PostProcessParams& postProcessParams);
/**
* @brief Should perform YOLO post-processing of the result of inference then
@@ -85,16 +87,11 @@ namespace object_detection {
bool DoPostProcess() override;
private:
- TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
- TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
- std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
- int m_inputImgRows; /* Number of rows for model input. */
- int m_inputImgCols; /* Number of cols for model input. */
- float m_threshold; /* Post-processing threshold. */
- float m_nms; /* NMS threshold. */
- int m_numClasses; /* Number of classes. */
- int m_topN; /* TopN. */
- object_detection::Network m_net; /* YOLO network object. */
+ TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
+ TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
+ std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
+ const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */
+ object_detection::Network m_net; /* YOLO network object. */
/**
* @brief Insert the given Detection in the list.
diff --git a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
index 7610c4f..f555fbb 100644
--- a/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
+++ b/source/application/api/use_case/object_detection/src/DetectorPostProcessing.cc
@@ -26,31 +26,21 @@ namespace app {
TfLiteTensor* modelOutput0,
TfLiteTensor* modelOutput1,
std::vector<object_detection::DetectionResult>& results,
- int inputImgRows,
- int inputImgCols,
- const float threshold,
- const float nms,
- int numClasses,
- int topN)
+ const object_detection::PostProcessParams& postProcessParams)
: m_outputTensor0{modelOutput0},
m_outputTensor1{modelOutput1},
m_results{results},
- m_inputImgRows{inputImgRows},
- m_inputImgCols{inputImgCols},
- m_threshold(threshold),
- m_nms(nms),
- m_numClasses(numClasses),
- m_topN(topN)
+ m_postProcessParams{postProcessParams}
{
/* Init PostProcessing */
this->m_net = object_detection::Network{
- .inputWidth = inputImgCols,
- .inputHeight = inputImgRows,
- .numClasses = numClasses,
+ .inputWidth = postProcessParams.inputImgCols,
+ .inputHeight = postProcessParams.inputImgRows,
+ .numClasses = postProcessParams.numClasses,
.branches =
- {object_detection::Branch{.resolution = inputImgCols / 32,
+ {object_detection::Branch{.resolution = postProcessParams.inputImgCols / 32,
.numBox = 3,
- .anchor = arm::app::object_detection::anchor1,
+ .anchor = postProcessParams.anchor1,
.modelOutput = this->m_outputTensor0->data.int8,
.scale = (static_cast<TfLiteAffineQuantization*>(
this->m_outputTensor0->quantization.params))
@@ -59,9 +49,9 @@ namespace app {
this->m_outputTensor0->quantization.params))
->zero_point->data[0],
.size = this->m_outputTensor0->bytes},
- object_detection::Branch{.resolution = inputImgCols / 16,
+ object_detection::Branch{.resolution = postProcessParams.inputImgCols / 16,
.numBox = 3,
- .anchor = arm::app::object_detection::anchor2,
+ .anchor = postProcessParams.anchor2,
.modelOutput = this->m_outputTensor1->data.int8,
.scale = (static_cast<TfLiteAffineQuantization*>(
this->m_outputTensor1->quantization.params))
@@ -70,21 +60,21 @@ namespace app {
this->m_outputTensor1->quantization.params))
->zero_point->data[0],
.size = this->m_outputTensor1->bytes}},
- .topN = m_topN};
+ .topN = postProcessParams.topN};
/* End init */
}
bool DetectorPostProcess::DoPostProcess()
{
/* Start postprocessing */
- int originalImageWidth = arm::app::object_detection::originalImageSize;
- int originalImageHeight = arm::app::object_detection::originalImageSize;
+ int originalImageWidth = m_postProcessParams.originalImageSize;
+ int originalImageHeight = m_postProcessParams.originalImageSize;
std::forward_list<image::Detection> detections;
- GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections);
+ GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_postProcessParams.threshold, detections);
/* Do nms */
- CalculateNMS(detections, this->m_net.numClasses, m_nms);
+ CalculateNMS(detections, this->m_net.numClasses, this->m_postProcessParams.nms);
for (auto& it: detections) {
float xMin = it.bbox.x - it.bbox.w / 2.0f;
@@ -219,10 +209,10 @@ void DetectorPostProcess::GetNetworkBoxes(
num += 1;
} else if (num == net.topN) {
detections.sort(det_objectness_comparator);
- InsertTopNDetections(detections,det);
+ InsertTopNDetections(detections, det);
num += 1;
} else {
- InsertTopNDetections(detections,det);
+ InsertTopNDetections(detections, det);
}
}
}
diff --git a/source/use_case/object_detection/src/UseCaseHandler.cc b/source/use_case/object_detection/src/UseCaseHandler.cc
index e9bcd4a..a7acb46 100644
--- a/source/use_case/object_detection/src/UseCaseHandler.cc
+++ b/source/use_case/object_detection/src/UseCaseHandler.cc
@@ -27,9 +27,6 @@
namespace arm {
namespace app {
- namespace object_detection {
- extern const int channelsImageDisplayed;
- } /* namespace object_detection */
/**
* @brief Presents inference results along using the data presentation
@@ -102,8 +99,12 @@ namespace app {
DetectorPreProcess preProcess = DetectorPreProcess(inputTensor, true, model.IsDataSigned());
std::vector<object_detection::DetectionResult> results;
+ const object_detection::PostProcessParams postProcessParams {
+ inputImgRows, inputImgCols, object_detection::originalImageSize,
+ object_detection::anchor1, object_detection::anchor2
+ };
DetectorPostProcess postProcess = DetectorPostProcess(outputTensor0, outputTensor1,
- results, inputImgRows, inputImgCols);
+ results, postProcessParams);
do {
/* Ensure there are no results leftover from previous inference when running all. */
results.clear();
diff --git a/tests/use_case/object_detection/InferenceTestYoloFastest.cc b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
index eb92904..d328684 100644
--- a/tests/use_case/object_detection/InferenceTestYoloFastest.cc
+++ b/tests/use_case/object_detection/InferenceTestYoloFastest.cc
@@ -104,7 +104,11 @@ void TestInferenceDetectionResults(int imageIdx, arm::app::Model& model, T toler
REQUIRE(tflite::GetTensorData<T>(output_arr[i]));
}
- arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, nRows, nCols};
+ const arm::app::object_detection::PostProcessParams postProcessParams {
+ nRows, nCols, arm::app::object_detection::originalImageSize,
+ arm::app::object_detection::anchor1, arm::app::object_detection::anchor2
+ };
+ arm::app::DetectorPostProcess postp{output_arr[0], output_arr[1], results, postProcessParams};
postp.DoPostProcess();
std::vector<std::vector<arm::app::object_detection::DetectionResult>> expected_results;