aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Jalkemo <adam.jalkemo@arm.com>2022-10-12 15:14:04 +0200
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-10-13 12:47:22 +0100
commit1e8187aaf9e62044cdfa960b98fc565f0bc1a1b2 (patch)
tree989aa23f485ae366b3d44375af578caf7da50baf
parent251fd955c18434b7aa3f486374c4f1a15bbd160e (diff)
downloadarmnn-1e8187aaf9e62044cdfa960b98fc565f0bc1a1b2.tar.gz
IVGCVSW-7283 Use stricter file extension check in CreateParser
* I had issues when folder name contained "armnn" and a .tflite model was used, as the wrong parser was selected. * Now only the extension, and not the full string, is considered when selecting parser. Change-Id: If7964d2ce5535f7d25762d2a2d7e810bf1a1ed43
-rw-r--r--tests/ExecuteNetwork/ArmNNExecutor.cpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.cpp b/tests/ExecuteNetwork/ArmNNExecutor.cpp
index b655ef8bc3..aa71c408d7 100644
--- a/tests/ExecuteNetwork/ArmNNExecutor.cpp
+++ b/tests/ExecuteNetwork/ArmNNExecutor.cpp
@@ -556,13 +556,13 @@ armnn::IOptimizedNetworkPtr ArmNNExecutor::OptimizeNetwork(armnn::INetwork* netw
std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser()
{
- // If no model format is given check the file name
- const std::string& modelFormat = m_Params.m_ModelPath;
+ const fs::path modelFilename = m_Params.m_ModelPath;
+ const std::string modelExtension = modelFilename.extension();
- m_Params.m_IsModelBinary = modelFormat.find("json") == std::string::npos ? true : false;
+ m_Params.m_IsModelBinary = modelExtension != ".json";
std::unique_ptr<IParser> parser = nullptr;
// Forward to implementation based on the parser type
- if (modelFormat.find("armnn") != std::string::npos)
+ if (modelExtension == ".armnn")
{
#if defined(ARMNN_SERIALIZER)
parser = std::make_unique<ArmNNDeserializer>();
@@ -570,7 +570,7 @@ std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser()
LogAndThrow("Not built with serialization support.");
#endif
}
- else if(modelFormat.find("tflite") != std::string::npos)
+ else if (modelExtension == ".tflite")
{
#if defined(ARMNN_TF_LITE_PARSER)
parser = std::make_unique<TfliteParser>(m_Params);
@@ -578,7 +578,7 @@ std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser()
LogAndThrow("Not built with Tensorflow-Lite parser support.");
#endif
}
- else if (modelFormat.find("onnx") != std::string::npos)
+ else if (modelExtension == ".onnx")
{
#if defined(ARMNN_ONNX_PARSER)
parser = std::make_unique<OnnxParser>();