diff options
author | Colm Donelan <colm.donelan@arm.com> | 2023-01-25 21:19:49 +0000 |
---|---|---|
committer | Colm Donelan <colm.donelan@arm.com> | 2023-01-27 18:29:32 +0000 |
commit | 3811a97033be66f7a5d8fc3340b0899e0b60f737 (patch) | |
tree | d32a1352970668a009065b6f6b0575dd608efae8 /tests/ExecuteNetwork/TfliteExecutor.cpp | |
parent | 0c61a755cdcdf4ec2b3b79031c6cdf0432d3e87a (diff) | |
download | armnn-3811a97033be66f7a5d8fc3340b0899e0b60f737.tar.gz |
IVGCVSW-7441 Checking for constant input tensors before populating.
* When the tfLiteExecutor attempts to populate the input tensors it did
not check whether the tensor was constant. This was causing
segmentation faults.
Signed-off-by: Colm Donelan <colm.donelan@arm.com>
Change-Id: I80a4cc788de4ffe08afb2df9185d04fcb8b27c3a
Diffstat (limited to 'tests/ExecuteNetwork/TfliteExecutor.cpp')
-rw-r--r-- | tests/ExecuteNetwork/TfliteExecutor.cpp | 91 |
1 files changed, 50 insertions, 41 deletions
diff --git a/tests/ExecuteNetwork/TfliteExecutor.cpp b/tests/ExecuteNetwork/TfliteExecutor.cpp index 810495fe8c..3c8313b938 100644 --- a/tests/ExecuteNetwork/TfliteExecutor.cpp +++ b/tests/ExecuteNetwork/TfliteExecutor.cpp @@ -4,6 +4,7 @@ // #include "TfliteExecutor.hpp" +#include "tensorflow/lite/kernels/kernel_util.h" TfLiteExecutor::TfLiteExecutor(const ExecuteNetworkParams& params) : m_Params(params) { @@ -51,55 +52,63 @@ TfLiteExecutor::TfLiteExecutor(const ExecuteNetworkParams& params) : m_Params(pa : armnn::MakeOptional<std::string>(m_Params.m_InputTensorDataFilePaths[inputIndex]); int input = m_TfLiteInterpreter->inputs()[inputIndex]; - - TfLiteIntArray* inputDims = m_TfLiteInterpreter->tensor(input)->dims; - - unsigned int inputSize = 1; - for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim) - { - inputSize *= inputDims->data[dim]; - } - const auto& inputName = m_TfLiteInterpreter->tensor(input)->name; - const auto& dataType = m_TfLiteInterpreter->tensor(input)->type; - switch (dataType) + // Before we start, check if the tensor is constant. + if (!tflite::IsConstantTensor(m_TfLiteInterpreter->tensor(input))) { - case kTfLiteFloat32: - { - auto inputData = m_TfLiteInterpreter->typed_tensor<float>(input); - PopulateTensorWithData<float>(inputData, inputSize, dataFile, inputName); - break; - } - case kTfLiteInt32: - { - auto inputData = m_TfLiteInterpreter->typed_tensor<int32_t>(input); - PopulateTensorWithData<int32_t>(inputData, inputSize, dataFile, inputName); - break; - } - case kTfLiteUInt8: - { - auto inputData = m_TfLiteInterpreter->typed_tensor<uint8_t>(input); - PopulateTensorWithData<uint8_t>(inputData, inputSize, dataFile, inputName); - break; - } - case kTfLiteInt16: - { - auto inputData = m_TfLiteInterpreter->typed_tensor<int16_t>(input); - PopulateTensorWithData<int16_t>(inputData, inputSize, dataFile, inputName); - break; - } - case kTfLiteInt8: + TfLiteIntArray* inputDims = m_TfLiteInterpreter->tensor(input)->dims; + + unsigned int inputSize = 1; + for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim) { - auto inputData = m_TfLiteInterpreter->typed_tensor<int8_t>(input); - PopulateTensorWithData<int8_t>(inputData, inputSize, dataFile, inputName); - break; + inputSize *= inputDims->data[dim]; } - default: + + const auto& dataType = m_TfLiteInterpreter->tensor(input)->type; + + switch (dataType) { - LogAndThrow("Unsupported input tensor data type"); + case kTfLiteFloat32: + { + auto inputData = m_TfLiteInterpreter->typed_tensor<float>(input); + PopulateTensorWithData<float>(inputData, inputSize, dataFile, inputName); + break; + } + case kTfLiteInt32: + { + auto inputData = m_TfLiteInterpreter->typed_tensor<int32_t>(input); + PopulateTensorWithData<int32_t>(inputData, inputSize, dataFile, inputName); + break; + } + case kTfLiteUInt8: + { + auto inputData = m_TfLiteInterpreter->typed_tensor<uint8_t>(input); + PopulateTensorWithData<uint8_t>(inputData, inputSize, dataFile, inputName); + break; + } + case kTfLiteInt16: + { + auto inputData = m_TfLiteInterpreter->typed_tensor<int16_t>(input); + PopulateTensorWithData<int16_t>(inputData, inputSize, dataFile, inputName); + break; + } + case kTfLiteInt8: + { + auto inputData = m_TfLiteInterpreter->typed_tensor<int8_t>(input); + PopulateTensorWithData<int8_t>(inputData, inputSize, dataFile, inputName); + break; + } + default: + { + LogAndThrow("Unsupported input tensor data type"); + } } } + else + { + ARMNN_LOG(info) << "Input tensor \"" << inputName << "\" is constant and will not be populated with data."; + } } } |