aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp')
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp40
1 files changed, 14 insertions, 26 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index b4653cd8db..a237d2fc14 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -376,25 +376,18 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
+ // Set flag so that the correct comparison function is called if the output is boolean.
+ bool isBoolean = armnnType2 == armnn::DataType::Boolean ? true : false;
+
// Compare each output tensor to the expected values
for (auto&& it : expectedOutputData)
{
armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
auto outputExpected = it.second;
- if (std::is_same<DataType2, uint8_t>::value)
- {
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
- true, isDynamic);
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
- }
- else
- {
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
- false, isDynamic);
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
- }
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
+ isBoolean, isDynamic);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
if (isDynamic)
@@ -504,22 +497,17 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
+ // Set flag so that the correct comparison function is called if the output is boolean.
+ bool isBoolean = outputType == armnn::DataType::Boolean ? true : false;
+
// Compare each output tensor to the expected values
for (auto&& it : expectedOutputData)
{
armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
auto outputExpected = it.second;
- if (std::is_same<DataType2, uint8_t>::value)
- {
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), true);
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
- }
- else
- {
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape());
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
- }
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
+ isBoolean);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
}
} \ No newline at end of file