diff options
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 10 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/InputOutputTensorNames.cpp | 67 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/Minimum.cpp | 77 |
3 files changed, 148 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 86688add9d..1ee4950558 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -365,9 +365,15 @@ armnn::TensorInfo ToTensorInfo(TfLiteParser::TensorRawPtr tensorPtr, const std:: } } + std::vector<unsigned int> safeShape = shapes; + if (safeShape.size() == 0) + { + safeShape.push_back(1); + } + // two statements (on purpose) for easier debugging: - armnn::TensorInfo result(static_cast<unsigned int>(shapes.size()), - shapes.data(), + armnn::TensorInfo result(static_cast<unsigned int>(safeShape.size()), + safeShape.data(), type, quantizationScale, quantizationOffset); diff --git a/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp index 6f5016265a..d42ae2e438 100644 --- a/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp +++ b/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp @@ -60,9 +60,19 @@ struct InvalidTensorsFixture : public ParserFlatbuffersFixture "version": 3, "operator_codes": [ ], "subgraphs": [{ - "tensors": [ {}, {}, {}, {} ], - "inputs" : [ 0, 1 ], - "outputs" : [ 2, 3 ], + "tensors": [ { + "shape": [ 1, 1, 1, 1, 1 ], + "type": "FLOAT32", + "name": "In", + "buffer": 0 + }, { + "shape": [ 1, 1, 1, 1, 1 ], + "type": "FLOAT32", + "name": "Out", + "buffer": 1 + }], + "inputs" : [ 0 ], + "outputs" : [ 1 ], }] })"; } @@ -70,7 +80,7 @@ struct InvalidTensorsFixture : public ParserFlatbuffersFixture BOOST_FIXTURE_TEST_CASE(InvalidTensorsThrowException, InvalidTensorsFixture) { - // this throws because it cannot do the input output tensor connections + // Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions BOOST_CHECK_THROW(Setup(), armnn::InvalidArgumentException); } @@ -135,4 +145,53 @@ BOOST_FIXTURE_TEST_CASE(ThrowIfSubgraphIdInvalidForInOutNames, ValidTensorsFixtu BOOST_CHECK_THROW(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException); } +struct Rank0TensorFixture : public ParserFlatbuffersFixture +{ + explicit Rank0TensorFixture() + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ { "builtin_code": "MINIMUM" } ], + "subgraphs": [{ + "tensors": [ { + "shape": [ ], + "type": "FLOAT32", + "name": "In0", + "buffer": 0, + }, { + "shape": [ ], + "type": "FLOAT32", + "name": "In1", + "buffer": 1, + }, { + "shape": [ ], + "type": "FLOAT32", + "name": "Out", + "buffer": 2, + }], + "inputs" : [ 0, 1 ], + "outputs" : [ 2 ], + "operators": [{ + "opcode_index": 0, + "inputs": [ 0, 1 ], + "outputs": [ 2 ], + "custom_options_format": "FLEXBUFFERS" + }] + }] + } + )"; + } +}; + +BOOST_FIXTURE_TEST_CASE(Rank0Tensor, Rank0TensorFixture) +{ + Setup(); + BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0).size(), 2u); + BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u); + BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[0], "In0"); + BOOST_CHECK_EQUAL(m_Parser->GetSubgraphInputTensorNames(0)[1], "In1"); + BOOST_CHECK_EQUAL(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out"); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfLiteParser/test/Minimum.cpp b/src/armnnTfLiteParser/test/Minimum.cpp index 70b249ff5f..8c6db680e7 100644 --- a/src/armnnTfLiteParser/test/Minimum.cpp +++ b/src/armnnTfLiteParser/test/Minimum.cpp @@ -173,4 +173,81 @@ BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast1D4D, MinimumBroadcastFixture1D4D) 5.0f, 6.0f, 7.0f }}}); } +struct MinimumBroadcastFixture2D0D : public ParserFlatbuffersFixture +{ + explicit MinimumBroadcastFixture2D0D() + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ { "builtin_code": "MINIMUM" } ], + "subgraphs": [ { + "tensors": [ + { + "shape": [ 1, 2 ], + "type": "FLOAT32", + "buffer": 0, + "name": "input0", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": [ ], + "type": "FLOAT32", + "buffer": 2, + "name": "input1", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": [ 1, 2 ] , + "type": "FLOAT32", + "buffer": 1, + "name": "output", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + } + ], + "inputs": [ 0 ], + "outputs": [ 2 ], + "operators": [ + { + "opcode_index": 0, + "inputs": [ 0, 1 ], + "outputs": [ 2 ], + "custom_options_format": "FLEXBUFFERS" + } + ], + } ], + "buffers" : [ + { }, + { }, + { "data": [ 0, 0, 0, 64 ] } + ] + } + )"; + Setup(); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseMinimumBroadcast2D0D, MinimumBroadcastFixture2D0D) +{ + RunTest<2, armnn::DataType::Float32>( + 0, + {{"input0", { 1.0f, 5.0f }}}, + {{"output", { 1.0f, 2.0f }}}); +} + BOOST_AUTO_TEST_SUITE_END() |