diff options
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r-- | src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp | 40 |
1 files changed, 14 insertions, 26 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp index b4653cd8db..a237d2fc14 100644 --- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp +++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp @@ -376,25 +376,18 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId, m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors); + // Set flag so that the correct comparison function is called if the output is boolean. + bool isBoolean = armnnType2 == armnn::DataType::Boolean ? true : false; + // Compare each output tensor to the expected values for (auto&& it : expectedOutputData) { armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first); auto outputExpected = it.second; - if (std::is_same<DataType2, uint8_t>::value) - { - auto result = CompareTensors(outputExpected, outputStorage[it.first], - bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), - true, isDynamic); - CHECK_MESSAGE(result.m_Result, result.m_Message.str()); - } - else - { - auto result = CompareTensors(outputExpected, outputStorage[it.first], - bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), - false, isDynamic); - CHECK_MESSAGE(result.m_Result, result.m_Message.str()); - } + auto result = CompareTensors(outputExpected, outputStorage[it.first], + bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), + isBoolean, isDynamic); + CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } if (isDynamic) @@ -504,22 +497,17 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId, m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors); + // Set flag so that the correct comparison function is called if the output is boolean. + bool isBoolean = outputType == armnn::DataType::Boolean ? true : false; + // Compare each output tensor to the expected values for (auto&& it : expectedOutputData) { armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first); auto outputExpected = it.second; - if (std::is_same<DataType2, uint8_t>::value) - { - auto result = CompareTensors(outputExpected, outputStorage[it.first], - bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), true); - CHECK_MESSAGE(result.m_Result, result.m_Message.str()); - } - else - { - auto result = CompareTensors(outputExpected, outputStorage[it.first], - bindingInfo.second.GetShape(), bindingInfo.second.GetShape()); - CHECK_MESSAGE(result.m_Result, result.m_Message.str()); - } + auto result = CompareTensors(outputExpected, outputStorage[it.first], + bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), + isBoolean); + CHECK_MESSAGE(result.m_Result, result.m_Message.str()); } }
\ No newline at end of file |