diff options
3 files changed, 24 insertions, 8 deletions
diff --git a/applications/inference_process/include/inference_parser.hpp b/applications/inference_process/include/inference_parser.hpp index 3d90818..fd7a97a 100644 --- a/applications/inference_process/include/inference_parser.hpp +++ b/applications/inference_process/include/inference_parser.hpp @@ -56,15 +56,30 @@ Array<T, U> makeArray(T *const data, U &size, size_t capacity) { class InferenceParser { public: - template <typename T, typename U, size_t S> - bool parseModel(const void *buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) { + const tflite::Model *getModel(const void *buffer, size_t size) { + // Verify buffer + flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t *>(buffer), size); + if (!tflite::VerifyModelBuffer(base_verifier)) { + printf("Warning: the model is not valid\n"); + return nullptr; + } + // Create model handle const tflite::Model *model = tflite::GetModel(buffer); if (model->subgraphs() == nullptr) { printf("Warning: nullptr subgraph\n"); - return true; + return nullptr; } + return model; + } + + template <typename T, typename U, size_t S> + bool parseModel(const void *buffer, size_t size, char (&description)[S], T &&ifmDims, U &&ofmDims) { + const tflite::Model *model = getModel(buffer, size); + if (model == nullptr) { + return true; + } strncpy(description, model->description()->c_str(), sizeof(description)); // Get input dimensions for first subgraph diff --git a/applications/inference_process/include/inference_process.hpp b/applications/inference_process/include/inference_process.hpp index fc54ae0..f8d7fd8 100644 --- a/applications/inference_process/include/inference_process.hpp +++ b/applications/inference_process/include/inference_process.hpp @@ -18,6 +18,8 @@ #pragma once +#include "inference_parser.hpp" + #include <array> #include <queue> #include <stdlib.h> @@ -85,5 +87,6 @@ private: uint8_t *tensorArena; const size_t tensorArenaSize; + InferenceParser parser; }; } // namespace InferenceProcess diff --git a/applications/inference_process/src/inference_process.cpp b/applications/inference_process/src/inference_process.cpp index 29254c7..e96d601 100644 --- a/applications/inference_process/src/inference_process.cpp +++ b/applications/inference_process/src/inference_process.cpp @@ -119,11 +119,9 @@ bool InferenceProcess::runJob(InferenceJob &job) { RegisterDebugLogCallback(tfluDebugLog); // Get model handle and verify that the version is correct - const tflite::Model *model = ::tflite::GetModel(job.networkModel.data); - if (model->version() != TFLITE_SCHEMA_VERSION) { - LOG_ERR("Model schema version unsupported: version=%" PRIu32 ", supported=%d.", - model->version(), - TFLITE_SCHEMA_VERSION); + const tflite::Model *model = parser.getModel(job.networkModel.data, job.networkModel.size); + if (model == nullptr) { + LOG_ERR("Invalid model"); return true; } |