aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavide Grohmann <davide.grohmann@arm.com>2022-06-14 15:17:18 +0200
committerDavide Grohmann <davide.grohmann@arm.com>2022-08-22 14:16:42 +0200
commit30b17b9b0e73de1dd93e090c68b38f32339d411c (patch)
treec4293d6b35c8902a4cee2321ce2be820059a147b
parente48fa7a47239d9632dc4390af92bca7d0eac64a2 (diff)
downloadethos-u-core-software-30b17b9b0e73de1dd93e090c68b38f32339d411c.tar.gz
Check the validity of the buffer before parsing the model
If the buffer does not point to a well defined flatbuffer the parsing segfaults. Change-Id: Icb8dfef37dc28b2b7a22c6d3804851be8198aa9d
-rw-r--r--applications/inference_process/include/inference_parser.hpp21
-rw-r--r--applications/inference_process/include/inference_process.hpp3
-rw-r--r--applications/inference_process/src/inference_process.cpp8
3 files changed, 24 insertions, 8 deletions
diff --git a/applications/inference_process/include/inference_parser.hpp b/applications/inference_process/include/inference_parser.hpp
index 3d90818..fd7a97a 100644
--- a/applications/inference_process/include/inference_parser.hpp
+++ b/applications/inference_process/include/inference_parser.hpp
@@ -56,15 +56,30 @@ Array<T, U> makeArray(T *const data, U &size, size_t capacity) {
class InferenceParser {
public:
- template <typename T, typename U, size_t S>
- bool parseModel(const void *buffer, char (&description)[S], T &&ifmDims, U &&ofmDims) {
+ const tflite::Model *getModel(const void *buffer, size_t size) {
+ // Verify buffer
+ flatbuffers::Verifier base_verifier(reinterpret_cast<const uint8_t *>(buffer), size);
+ if (!tflite::VerifyModelBuffer(base_verifier)) {
+ printf("Warning: the model is not valid\n");
+ return nullptr;
+ }
+
// Create model handle
const tflite::Model *model = tflite::GetModel(buffer);
if (model->subgraphs() == nullptr) {
printf("Warning: nullptr subgraph\n");
- return true;
+ return nullptr;
}
+ return model;
+ }
+
+ template <typename T, typename U, size_t S>
+ bool parseModel(const void *buffer, size_t size, char (&description)[S], T &&ifmDims, U &&ofmDims) {
+ const tflite::Model *model = getModel(buffer, size);
+ if (model == nullptr) {
+ return true;
+ }
strncpy(description, model->description()->c_str(), sizeof(description));
// Get input dimensions for first subgraph
diff --git a/applications/inference_process/include/inference_process.hpp b/applications/inference_process/include/inference_process.hpp
index fc54ae0..f8d7fd8 100644
--- a/applications/inference_process/include/inference_process.hpp
+++ b/applications/inference_process/include/inference_process.hpp
@@ -18,6 +18,8 @@
#pragma once
+#include "inference_parser.hpp"
+
#include <array>
#include <queue>
#include <stdlib.h>
@@ -85,5 +87,6 @@ private:
uint8_t *tensorArena;
const size_t tensorArenaSize;
+ InferenceParser parser;
};
} // namespace InferenceProcess
diff --git a/applications/inference_process/src/inference_process.cpp b/applications/inference_process/src/inference_process.cpp
index 29254c7..e96d601 100644
--- a/applications/inference_process/src/inference_process.cpp
+++ b/applications/inference_process/src/inference_process.cpp
@@ -119,11 +119,9 @@ bool InferenceProcess::runJob(InferenceJob &job) {
RegisterDebugLogCallback(tfluDebugLog);
// Get model handle and verify that the version is correct
- const tflite::Model *model = ::tflite::GetModel(job.networkModel.data);
- if (model->version() != TFLITE_SCHEMA_VERSION) {
- LOG_ERR("Model schema version unsupported: version=%" PRIu32 ", supported=%d.",
- model->version(),
- TFLITE_SCHEMA_VERSION);
+ const tflite::Model *model = parser.getModel(job.networkModel.data, job.networkModel.size);
+ if (model == nullptr) {
+ LOG_ERR("Invalid model");
return true;
}