aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2021-07-11 21:49:00 -0300
committerMatthew Sloyan <matthew.sloyan@arm.com>2021-07-21 11:53:13 +0000
commit9021125ff83a035dae05e5c7c0e6b1455b71af1a (patch)
treedbe534ebd4641d9bb625408cdd69126f653e2dd9
parent2d0eb86a5756fb9402bd31d3f5adc5438305f676 (diff)
downloadarmnn-9021125ff83a035dae05e5c7c0e6b1455b71af1a.tar.gz
Fixed RunTest's TfliteParser with boolean output
Tests for TfLiteParser are not working when the expected outputs have boolean type Signed-off-by: Bruno Goncalves <bruno.slackware@gmail.com> Change-Id: I16890f82e8e581f53e6e8464668c5adf3374bf2f
-rw-r--r--src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp33
1 files changed, 26 insertions, 7 deletions
diff --git a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
index b0bfdfc016..c4c75594a3 100644
--- a/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
+++ b/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp
@@ -322,10 +322,20 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
{
armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
auto outputExpected = it.second;
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
- false, isDynamic);
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ if (std::is_same<DataType2, uint8_t>::value)
+ {
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
+ true, isDynamic);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
+ else
+ {
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
+ false, isDynamic);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
}
}
@@ -424,8 +434,17 @@ void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
{
armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
auto outputExpected = it.second;
- auto result = CompareTensors(outputExpected, outputStorage[it.first],
- bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), false);
- CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ if (std::is_same<DataType2, uint8_t>::value)
+ {
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape(), true);
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
+ else
+ {
+ auto result = CompareTensors(outputExpected, outputStorage[it.first],
+ bindingInfo.second.GetShape(), bindingInfo.second.GetShape());
+ CHECK_MESSAGE(result.m_Result, result.m_Message.str());
+ }
}
} \ No newline at end of file