From 1e8187aaf9e62044cdfa960b98fc565f0bc1a1b2 Mon Sep 17 00:00:00 2001 From: Adam Jalkemo Date: Wed, 12 Oct 2022 15:14:04 +0200 Subject: IVGCVSW-7283 Use stricter file extension check in CreateParser * I had issues when folder name contained "armnn" and a .tflite model was used, as the wrong parser was selected. * Now only the extension, and not the full string, is considered when selecting parser. Change-Id: If7964d2ce5535f7d25762d2a2d7e810bf1a1ed43 --- tests/ExecuteNetwork/ArmNNExecutor.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ExecuteNetwork/ArmNNExecutor.cpp b/tests/ExecuteNetwork/ArmNNExecutor.cpp index b655ef8bc3..aa71c408d7 100644 --- a/tests/ExecuteNetwork/ArmNNExecutor.cpp +++ b/tests/ExecuteNetwork/ArmNNExecutor.cpp @@ -556,13 +556,13 @@ armnn::IOptimizedNetworkPtr ArmNNExecutor::OptimizeNetwork(armnn::INetwork* netw std::unique_ptr ArmNNExecutor::CreateParser() { - // If no model format is given check the file name - const std::string& modelFormat = m_Params.m_ModelPath; + const fs::path modelFilename = m_Params.m_ModelPath; + const std::string modelExtension = modelFilename.extension(); - m_Params.m_IsModelBinary = modelFormat.find("json") == std::string::npos ? true : false; + m_Params.m_IsModelBinary = modelExtension != ".json"; std::unique_ptr parser = nullptr; // Forward to implementation based on the parser type - if (modelFormat.find("armnn") != std::string::npos) + if (modelExtension == ".armnn") { #if defined(ARMNN_SERIALIZER) parser = std::make_unique(); @@ -570,7 +570,7 @@ std::unique_ptr ArmNNExecutor::CreateParser() LogAndThrow("Not built with serialization support."); #endif } - else if(modelFormat.find("tflite") != std::string::npos) + else if (modelExtension == ".tflite") { #if defined(ARMNN_TF_LITE_PARSER) parser = std::make_unique(m_Params); @@ -578,7 +578,7 @@ std::unique_ptr ArmNNExecutor::CreateParser() LogAndThrow("Not built with Tensorflow-Lite parser support."); #endif } - else if (modelFormat.find("onnx") != std::string::npos) + else if (modelExtension == ".onnx") { #if defined(ARMNN_ONNX_PARSER) parser = std::make_unique(); -- cgit v1.2.1