summaryrefslogtreecommitdiff
path: root/source/use_case/img_class/src/ImgClassProcessing.cc
diff options
context:
space:
mode:
Diffstat (limited to 'source/use_case/img_class/src/ImgClassProcessing.cc')
-rw-r--r--source/use_case/img_class/src/ImgClassProcessing.cc33
1 files changed, 13 insertions, 20 deletions
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc
index 6ba88ad..adf9794 100644
--- a/source/use_case/img_class/src/ImgClassProcessing.cc
+++ b/source/use_case/img_class/src/ImgClassProcessing.cc
@@ -21,50 +21,43 @@
namespace arm {
namespace app {
- ImgClassPreProcess::ImgClassPreProcess(Model* model)
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ ImgClassPreProcess::ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8)
+ :m_inputTensor{inputTensor},
+ m_convertToInt8{convertToInt8}
+ {}
bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize)
{
if (data == nullptr) {
printf_err("Data pointer is null");
+ return false;
}
auto input = static_cast<const uint8_t*>(data);
- TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0);
- std::memcpy(inputTensor->data.data, input, inputSize);
+ std::memcpy(this->m_inputTensor->data.data, input, inputSize);
debug("Input tensor populated \n");
- if (this->m_model->IsDataSigned()) {
- image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes);
+ if (this->m_convertToInt8) {
+ image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes);
}
return true;
}
- ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model,
+ ImgClassPostProcess::ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
const std::vector<std::string>& labels,
std::vector<ClassificationResult>& results)
- :m_imgClassifier{classifier},
+ :m_outputTensor{outputTensor},
+ m_imgClassifier{classifier},
m_labels{labels},
m_results{results}
- {
- if (!model->IsInited()) {
- printf_err("Model is not initialised!.\n");
- }
- this->m_model = model;
- }
+ {}
bool ImgClassPostProcess::DoPostProcess()
{
return this->m_imgClassifier.GetClassificationResults(
- this->m_model->GetOutputTensor(0), this->m_results,
+ this->m_outputTensor, this->m_results,
this->m_labels, 5, false);
}