diff options
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index dd3c0a32a1..a0a08d31b0 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -474,6 +474,53 @@ int MainImpl(const ExecuteNetworkParams& params, const size_t numOutputs = inferenceModelParams.m_OutputBindings.size(); + // The user is allowed to specify the data type of each output tensor. It is used here to construct the + // result tensors for each iteration. It is possible for the user to specify a type that does not match + // the data type of the corresponding model output. It may not make sense, but it is historically allowed. + // The potential problem here is a buffer overrun when a larger data type is written into the space for a + // smaller one. Issue a warning to highlight the potential problem. + for (unsigned int outputIdx = 0; outputIdx < model.GetOutputBindingInfos().size(); ++outputIdx) + { + armnn::DataType type = model.GetOutputBindingInfo(outputIdx).second.GetDataType(); + switch (type) + { + // --output-type only supports float, int, qasymms8 or qasymmu8. + case armnn::DataType::Float32: + if (params.m_OutputTypes[outputIdx].compare("float") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type Float32. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + case armnn::DataType::QAsymmU8: + if (params.m_OutputTypes[outputIdx].compare("qasymmu8") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type QAsymmU8. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problemsor random failures."; + } + break; + case armnn::DataType::Signed32: + if (params.m_OutputTypes[outputIdx].compare("int") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type Signed32. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + case armnn::DataType::QAsymmS8: + if (params.m_OutputTypes[outputIdx].compare("qasymms8") != 0) + { + ARMNN_LOG(warning) << "Model output index: " << outputIdx << " has data type QAsymmS8. The " << + "corresponding --output-type is " << params.m_OutputTypes[outputIdx] << + ". This may cause unexpected problems or random failures."; + } + break; + default: + break; + } + } for (unsigned int j = 0; j < params.m_Iterations; ++j) { std::vector <armnnUtils::TContainer> outputDataContainers; |