From 30b17b9b0e73de1dd93e090c68b38f32339d411c Mon Sep 17 00:00:00 2001 From: Davide Grohmann Date: Tue, 14 Jun 2022 15:17:18 +0200 Subject: Check the validity of the buffer before parsing the model If the buffer does not point to a well defined flatbuffer the parsing segfaults. Change-Id: Icb8dfef37dc28b2b7a22c6d3804851be8198aa9d --- .../inference_process/include/inference_parser.hpp | 21 ++++++++++++++++++--- .../inference_process/include/inference_process.hpp | 3 +++ .../inference_process/src/inference_process.cpp | 8 +++----- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/applications/inference_process/include/inference_parser.hpp b/applications/inference_process/include/inference_parser.hpp index 3d90818..fd7a97a 100644 --- a/applications/inference_process/include/inference_parser.hpp +++ b/applications/inference_process/include/inference_parser.hpp @@ -56,15 +56,30 @@ Array makeArray(T *const data, U &size, size_t capacity) { class InferenceParser { public: - template - bool parseModel(const void *buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) { + const tflite::Model *getModel(const void *buffer, size_t size) { + // Verify buffer + flatbuffers::Verifier base_verifier(reinterpret_cast(buffer), size); + if (!tflite::VerifyModelBuffer(base_verifier)) { + printf("Warning: the model is not valid\n"); + return nullptr; + } + // Create model handle const tflite::Model *model = tflite::GetModel(buffer); if (model->subgraphs() == nullptr) { printf("Warning: nullptr subgraph\n"); - return true; + return nullptr; } + return model; + } + + template + bool parseModel(const void *buffer, size_t size, char (&description)[S], T &&ifmDims, U &&ofmDims) { + const tflite::Model *model = getModel(buffer, size); + if (model == nullptr) { + return true; + } strncpy(description, model->description()->c_str(), sizeof(description)); // Get input dimensions for first subgraph diff --git a/applications/inference_process/include/inference_process.hpp b/applications/inference_process/include/inference_process.hpp index fc54ae0..f8d7fd8 100644 --- a/applications/inference_process/include/inference_process.hpp +++ b/applications/inference_process/include/inference_process.hpp @@ -18,6 +18,8 @@ #pragma once +#include "inference_parser.hpp" + #include #include #include @@ -85,5 +87,6 @@ private: uint8_t *tensorArena; const size_t tensorArenaSize; + InferenceParser parser; }; } // namespace InferenceProcess diff --git a/applications/inference_process/src/inference_process.cpp b/applications/inference_process/src/inference_process.cpp index 29254c7..e96d601 100644 --- a/applications/inference_process/src/inference_process.cpp +++ b/applications/inference_process/src/inference_process.cpp @@ -119,11 +119,9 @@ bool InferenceProcess::runJob(InferenceJob &job) { RegisterDebugLogCallback(tfluDebugLog); // Get model handle and verify that the version is correct - const tflite::Model *model = ::tflite::GetModel(job.networkModel.data); - if (model->version() != TFLITE_SCHEMA_VERSION) { - LOG_ERR("Model schema version unsupported: version=%" PRIu32 ", supported=%d.", - model->version(), - TFLITE_SCHEMA_VERSION); + const tflite::Model *model = parser.getModel(job.networkModel.data, job.networkModel.size); + if (model == nullptr) { + LOG_ERR("Invalid model"); return true; } -- cgit v1.2.1