From c5e419826ea5c94d6841a59f618d70230a5b56ef Mon Sep 17 00:00:00 2001 From: Colm Donelan Date: Thu, 28 Oct 2021 20:19:43 +0100 Subject: IVGCVSW-6473 Add warnings to ExecuteNetwork on invalid output tensor type. * In ExecuteNetwork MainImpl compare the data types of outputs on the loaded model with those specified by the user through --output-type. Issue a warning if there is a mismatch. Signed-off-by: Colm Donelan Change-Id: Ic5add9734dc239eddca0972a9e560e54abdb1093 --- tests/ExecuteNetwork/ExecuteNetwork.cpp | 47 +++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index dd3c0a32a1..a0a08d31b0 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -474,6 +474,53 @@ int MainImpl(const ExecuteNetworkParams& params, const size_t numOutputs = inferenceModelParams.m_OutputBindings.size(); + // The user is allowed to specify the data type of each output tensor. It is used here to construct the + // result tensors for each iteration. It is possible for the user to specify a type that does not match + // the data type of the corresponding model output. It may not make sense, but it is historically allowed. + // The potential problem here is a buffer overrun when a larger data type is written into the space for a + // smaller one. Issue a warning to highlight the potential problem. + for (unsigned int outputIdx = 0; outputIdx < model.GetOutputBindingInfos().size(); ++outputIdx) + { + armnn::DataType type = model.GetOutputBindingInfo(outputIdx).second.GetDataType(); + switch (type) + { + // --output-type only supports float, int, qasymms8 or qasymmu8. + case armnn::DataType::Float32: + if (params.m_OutputTypes[outputIdx].compare("float") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type Float32. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + case armnn::DataType::QAsymmU8: + if (params.m_OutputTypes[outputIdx].compare("qasymmu8") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type QAsymmU8. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problemsor random failures."; + } + break; + case armnn::DataType::Signed32: + if (params.m_OutputTypes[outputIdx].compare("int") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type Signed32. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + case armnn::DataType::QAsymmS8: + if (params.m_OutputTypes[outputIdx].compare("qasymms8") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type QAsymmS8. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + default: + break; + } + } for (unsigned int j = 0; j < params.m_Iterations; ++j) { std::vector outputDataContainers; -- cgit v1.2.1