diff options
author | Adam Jalkemo <adam.jalkemo@arm.com> | 2022-10-12 15:14:04 +0200 |
---|---|---|
committer | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-10-13 12:47:22 +0100 |
commit | 1e8187aaf9e62044cdfa960b98fc565f0bc1a1b2 (patch) | |
tree | 989aa23f485ae366b3d44375af578caf7da50baf | |
parent | 251fd955c18434b7aa3f486374c4f1a15bbd160e (diff) | |
download | armnn-1e8187aaf9e62044cdfa960b98fc565f0bc1a1b2.tar.gz |
IVGCVSW-7283 Use stricter file extension check in CreateParser
* I had issues when folder name contained "armnn" and
a .tflite model was used, as the wrong parser was selected.
* Now only the extension, and not the full string, is
considered when selecting parser.
Change-Id: If7964d2ce5535f7d25762d2a2d7e810bf1a1ed43
-rw-r--r-- | tests/ExecuteNetwork/ArmNNExecutor.cpp | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tests/ExecuteNetwork/ArmNNExecutor.cpp b/tests/ExecuteNetwork/ArmNNExecutor.cpp index b655ef8bc3..aa71c408d7 100644 --- a/tests/ExecuteNetwork/ArmNNExecutor.cpp +++ b/tests/ExecuteNetwork/ArmNNExecutor.cpp @@ -556,13 +556,13 @@ armnn::IOptimizedNetworkPtr ArmNNExecutor::OptimizeNetwork(armnn::INetwork* netw std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser() { - // If no model format is given check the file name - const std::string& modelFormat = m_Params.m_ModelPath; + const fs::path modelFilename = m_Params.m_ModelPath; + const std::string modelExtension = modelFilename.extension(); - m_Params.m_IsModelBinary = modelFormat.find("json") == std::string::npos ? true : false; + m_Params.m_IsModelBinary = modelExtension != ".json"; std::unique_ptr<IParser> parser = nullptr; // Forward to implementation based on the parser type - if (modelFormat.find("armnn") != std::string::npos) + if (modelExtension == ".armnn") { #if defined(ARMNN_SERIALIZER) parser = std::make_unique<ArmNNDeserializer>(); @@ -570,7 +570,7 @@ std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser() LogAndThrow("Not built with serialization support."); #endif } - else if(modelFormat.find("tflite") != std::string::npos) + else if (modelExtension == ".tflite") { #if defined(ARMNN_TF_LITE_PARSER) parser = std::make_unique<TfliteParser>(m_Params); @@ -578,7 +578,7 @@ std::unique_ptr<ArmNNExecutor::IParser> ArmNNExecutor::CreateParser() LogAndThrow("Not built with Tensorflow-Lite parser support."); #endif } - else if (modelFormat.find("onnx") != std::string::npos) + else if (modelExtension == ".onnx") { #if defined(ARMNN_ONNX_PARSER) parser = std::make_unique<OnnxParser>(); |