diff options
Diffstat (limited to 'source/use_case/img_class/src')
-rw-r--r-- | source/use_case/img_class/src/ImgClassProcessing.cc | 33 | ||||
-rw-r--r-- | source/use_case/img_class/src/UseCaseHandler.cc | 22 |
2 files changed, 24 insertions, 31 deletions
diff --git a/source/use_case/img_class/src/ImgClassProcessing.cc b/source/use_case/img_class/src/ImgClassProcessing.cc index 6ba88ad..adf9794 100644 --- a/source/use_case/img_class/src/ImgClassProcessing.cc +++ b/source/use_case/img_class/src/ImgClassProcessing.cc @@ -21,50 +21,43 @@ namespace arm { namespace app { - ImgClassPreProcess::ImgClassPreProcess(Model* model) - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + ImgClassPreProcess::ImgClassPreProcess(TfLiteTensor* inputTensor, bool convertToInt8) + :m_inputTensor{inputTensor}, + m_convertToInt8{convertToInt8} + {} bool ImgClassPreProcess::DoPreProcess(const void* data, size_t inputSize) { if (data == nullptr) { printf_err("Data pointer is null"); + return false; } auto input = static_cast<const uint8_t*>(data); - TfLiteTensor* inputTensor = this->m_model->GetInputTensor(0); - std::memcpy(inputTensor->data.data, input, inputSize); + std::memcpy(this->m_inputTensor->data.data, input, inputSize); debug("Input tensor populated \n"); - if (this->m_model->IsDataSigned()) { - image::ConvertImgToInt8(inputTensor->data.data, inputTensor->bytes); + if (this->m_convertToInt8) { + image::ConvertImgToInt8(this->m_inputTensor->data.data, this->m_inputTensor->bytes); } return true; } - ImgClassPostProcess::ImgClassPostProcess(Classifier& classifier, Model* model, + ImgClassPostProcess::ImgClassPostProcess(TfLiteTensor* outputTensor, Classifier& classifier, const std::vector<std::string>& labels, std::vector<ClassificationResult>& results) - :m_imgClassifier{classifier}, + :m_outputTensor{outputTensor}, + m_imgClassifier{classifier}, m_labels{labels}, m_results{results} - { - if (!model->IsInited()) { - printf_err("Model is not initialised!.\n"); - } - this->m_model = model; - } + {} bool ImgClassPostProcess::DoPostProcess() { return this->m_imgClassifier.GetClassificationResults( - this->m_model->GetOutputTensor(0), this->m_results, + this->m_outputTensor, this->m_results, this->m_labels, 5, false); } diff --git a/source/use_case/img_class/src/UseCaseHandler.cc b/source/use_case/img_class/src/UseCaseHandler.cc index c68d816..5cc3959 100644 --- a/source/use_case/img_class/src/UseCaseHandler.cc +++ b/source/use_case/img_class/src/UseCaseHandler.cc @@ -59,6 +59,7 @@ namespace app { } TfLiteTensor* inputTensor = model.GetInputTensor(0); + TfLiteTensor* outputTensor = model.GetOutputTensor(0); if (!inputTensor->dims) { printf_err("Invalid input tensor dims\n"); return false; @@ -74,13 +75,12 @@ namespace app { const uint32_t nChannels = inputShape->data[arm::app::MobileNetModel::ms_inputChannelsIdx]; /* Set up pre and post-processing. */ - ImgClassPreProcess preprocess = ImgClassPreProcess(&model); + ImgClassPreProcess preProcess = ImgClassPreProcess(inputTensor, model.IsDataSigned()); std::vector<ClassificationResult> results; - ImgClassPostProcess postprocess = ImgClassPostProcess(ctx.Get<ImgClassClassifier&>("classifier"), &model, - ctx.Get<std::vector<std::string>&>("labels"), results); - - UseCaseRunner runner = UseCaseRunner(&preprocess, &postprocess, &model); + ImgClassPostProcess postProcess = ImgClassPostProcess(outputTensor, + ctx.Get<ImgClassClassifier&>("classifier"), ctx.Get<std::vector<std::string>&>("labels"), + results); do { hal_lcd_clear(COLOR_BLACK); @@ -113,17 +113,18 @@ namespace app { inputTensor->bytes : IMAGE_DATA_SIZE; /* Run the pre-processing, inference and post-processing. */ - if (!runner.PreProcess(imgSrc, imgSz)) { + if (!preProcess.DoPreProcess(imgSrc, imgSz)) { + printf_err("Pre-processing failed."); return false; } - profiler.StartProfiling("Inference"); - if (!runner.RunInference()) { + if (!RunInference(model, profiler)) { + printf_err("Inference failed."); return false; } - profiler.StopProfiling(); - if (!runner.PostProcess()) { + if (!postProcess.DoPostProcess()) { + printf_err("Post-processing failed."); return false; } @@ -136,7 +137,6 @@ namespace app { ctx.Set<std::vector<ClassificationResult>>("results", results); #if VERIFY_TEST_OUTPUT - TfLiteTensor* outputTensor = model.GetOutputTensor(0); arm::app::DumpTensor(outputTensor); #endif /* VERIFY_TEST_OUTPUT */ |