aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-17 11:22:38 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-04-17 11:55:30 +0100
commit4818d465973d72979a7f6e783e1b55b320781710 (patch)
tree7c6cce52a728b6fe3ea2733b808555ff27edbab2
parent0790dcea1056298d63f97dec904c8ade5d21f439 (diff)
downloadarmnn-4818d465973d72979a7f6e783e1b55b320781710.tar.gz
IVGCVSW-2849 Add TfLite Parser support for Rank-0 operands and unit tests
Change-Id: I6dab12aed395a30466d66421c6e5a12659fedac8 Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp10
-rw-r--r--src/armnnTfLiteParser/test/InputOutputTensorNames.cpp67
-rw-r--r--src/armnnTfLiteParser/test/Minimum.cpp77
3 files changed, 148 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 86688add9d..1ee4950558 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -365,9 +365,15 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std::
}
}
+ std::vector<unsigned int> safeShape = shapes;
+ if (safeShape.size() == 0)
+ {
+ safeShape.push_back(1);
+ }
+
// two statements (on purpose) for easier debugging:
- armnn::TensorInfo result(static_cast<unsigned int>(shapes.size()),
- shapes.data(),
+ armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()),
+ safeShape.data(),
type,
quantizationScale,
quantizationOffset);
diff --git a/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
index 6f5016265a..d42ae2e438 100644
--- a/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
+++ b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp
@@ -60,9 +60,19 @@ struct InvalidTensorsFixture : public ParserFlatbuffersFixture
"version": 3,
"operator_codes": [ ],
"subgraphs": [{
- "tensors": [ {}, {}, {}, {} ],
- "inputs" : [ 0, 1 ],
- "outputs" : [ 2, 3 ],
+ "tensors": [ {
+ "shape": [ 1, 1, 1, 1, 1 ],
+ "type": "FLOAT32",
+ "name": "In",
+ "buffer": 0
+ }, {
+ "shape": [ 1, 1, 1, 1, 1 ],
+ "type": "FLOAT32",
+ "name": "Out",
+ "buffer": 1
+ }],
+ "inputs" : [ 0 ],
+ "outputs" : [ 1 ],
}]
})";
}
@@ -70,7 +80,7 @@ struct InvalidTensorsFixture : public ParserFlatbuffersFixture
BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture)
{
- // this throws because it cannot do the input output tensor connections
+ // Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions
BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException);
}
@@ -135,4 +145,53 @@ BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixtu
BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
}
+struct Rank0TensorFixture : public ParserFlatbuffersFixture
+{
+ explicit Rank0TensorFixture()
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "MINIMUM" } ],
+ "subgraphs": [{
+ "tensors": [ {
+ "shape": [ ],
+ "type": "FLOAT32",
+ "name": "In0",
+ "buffer": 0,
+ }, {
+ "shape": [ ],
+ "type": "FLOAT32",
+ "name": "In1",
+ "buffer": 1,
+ }, {
+ "shape": [ ],
+ "type": "FLOAT32",
+ "name": "Out",
+ "buffer": 2,
+ }],
+ "inputs" : [ 0, 1 ],
+ "outputs" : [ 2 ],
+ "operators": [{
+ "opcode_index": 0,
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "custom_options_format": "FLEXBUFFERS"
+ }]
+ }]
+ }
+ )";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(Rank0Tensor, Rank0TensorFixture)
+{
+ Setup();
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 2u);
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In0");
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[1], "In1");
+ BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfLiteParser/test/Minimum.cpp b/src/armnnTfLiteParser/test/Minimum.cpp
index 70b249ff5f..8c6db680e7 100644
--- a/src/armnnTfLiteParser/test/Minimum.cpp
+++ b/src/armnnTfLiteParser/test/Minimum.cpp
@@ -173,4 +173,81 @@ BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast1D4D, MinimumBroadcastFixture1D4D)
5.0f, 6.0f, 7.0f }}});
}
+struct MinimumBroadcastFixture2D0D : public ParserFlatbuffersFixture
+{
+ explicit MinimumBroadcastFixture2D0D()
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "MINIMUM" } ],
+ "subgraphs": [ {
+ "tensors": [
+ {
+ "shape": [ 1, 2 ],
+ "type": "FLOAT32",
+ "buffer": 0,
+ "name": "input0",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": [ ],
+ "type": "FLOAT32",
+ "buffer": 2,
+ "name": "input1",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": [ 1, 2 ] ,
+ "type": "FLOAT32",
+ "buffer": 1,
+ "name": "output",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0 ],
+ "outputs": [ 2 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ],
+ } ],
+ "buffers" : [
+ { },
+ { },
+ { "data": [ 0, 0, 0, 64 ] }
+ ]
+ }
+ )";
+ Setup();
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast2D0D, MinimumBroadcastFixture2D0D)
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ {{"input0", { 1.0f, 5.0f }}},
+ {{"output", { 1.0f, 2.0f }}});
+}
+
BOOST_AUTO_TEST_SUITE_END()