From bf99b5f05514d3a717df36c6039dcfc4a9f5b9ba Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 27 May 2021 09:55:43 +0100 Subject: IVGCVSW-6079 Fix circular dependency Signed-off-by: Narumol Prangnawarat Change-Id: I29793ece7b6bfc015c643be3ed16529ab50f0d7d --- src/armnnTfLiteParser/TfLiteParser.cpp | 5 +- src/armnnTfLiteParser/test/Prelu.cpp | 223 +++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 3 deletions(-) (limited to 'src/armnnTfLiteParser') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index d4a0a6e865..8941ee93f5 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1963,10 +1963,9 @@ void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex) if (IsConstTensor(inputs[1])) { - armnn::IInputSlot* slot = &(layer->GetInputSlot(0)); - RegisterConsumerOfTensor(subgraphIndex, 0, slot); - auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex)); + armnn::IInputSlot* slot = &(layer->GetInputSlot(0)); + RegisterConsumerOfTensor(subgraphIndex, inputTensorIndexes[0], slot); auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo); std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name); diff --git a/src/armnnTfLiteParser/test/Prelu.cpp b/src/armnnTfLiteParser/test/Prelu.cpp index 48a86dcefc..83c4088377 100644 --- a/src/armnnTfLiteParser/test/Prelu.cpp +++ b/src/armnnTfLiteParser/test/Prelu.cpp @@ -103,6 +103,221 @@ struct PreluFixture : public ParserFlatbuffersFixture } }; +struct PreluNetworkFixture : public ParserFlatbuffersFixture +{ + explicit PreluNetworkFixture() + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ + { + "builtin_code": "PRELU", + "version": 1 + }, + { + "builtin_code": "MUL", + "version": 1 + }, + { + "builtin_code": "ADD", + "version": 1 + } + ], + "subgraphs": [ + { + "tensors": [ + { + "shape": [ + 1, + 2, + 3 + ], + "type": "FLOAT32", + "buffer": 6, + "name": "output", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + }, + }, + { + "shape": [ + 1, + 2, + 3 + ], + "type": "FLOAT32", + "buffer": 5, + "name": "mul", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + } + }, + { + "shape": [ + 1, + 2, + 3 + ], + "type": "FLOAT32", + "buffer": 1, + "name": "input0", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + } + }, + { + "shape": [ + 2, + 3 + ], + "type": "FLOAT32", + "buffer": 2, + "name": "alpha", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + } + }, + { + "shape": [ + 1 + ], + "type": "FLOAT32", + "buffer": 3, + "name": "const0", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + } + }, + { + "shape": [ + 1, + 2, + 3 + ], + "type": "FLOAT32", + "buffer": 4, + "name": "prelumul", + "quantization": { + "details_type": "NONE", + "quantized_dimension": 0 + } + } + ], + "inputs": [ + 2 + ], + "outputs": [ + 0 + ], + "operators": [ + { + "opcode_index": 0, + "inputs": [ + 2, + 3 + ], + "outputs": [ + 5 + ], + "builtin_options_type": "NONE", + "custom_options_format": "FLEXBUFFERS" + }, + { + "opcode_index": 1, + "inputs": [ + 5, + 4 + ], + "outputs": [ + 1 + ], + "builtin_options_type": "MulOptions", + "builtin_options": { + "fused_activation_function": "NONE" + }, + "custom_options_format": "FLEXBUFFERS" + }, + { + "opcode_index": 2, + "inputs": [ + 5, + 1 + ], + "outputs": [ + 0 + ], + "builtin_options_type": "AddOptions", + "builtin_options": { + "fused_activation_function": "NONE" + }, + "custom_options_format": "FLEXBUFFERS" + } + ], + "name": "main" + } + ], + "buffers": [ + { + }, + { + }, + { + "data": [ + 0, + 0, + 128, + 62, + 0, + 0, + 128, + 62, + 0, + 0, + 128, + 62, + 0, + 0, + 128, + 62, + 0, + 0, + 128, + 62, + 0, + 0, + 128, + 62 + ] + }, + { + "data": [ + 0, + 0, + 160, + 64 + ] + }, + { + }, + { + }, + { + }, + { + } + ], + } + )"; + Setup(); + } +}; + struct SimplePreluFixture : PreluFixture { SimplePreluFixture() : PreluFixture("[ 2, 3 ]", @@ -174,4 +389,12 @@ BOOST_FIXTURE_TEST_CASE(PreluDynamicTensor, PreluDynamicTensorFixture) true); } +BOOST_FIXTURE_TEST_CASE(PreluNetwork, PreluNetworkFixture) +{ + RunTest<3, armnn::DataType::Float32>( + 0, + {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}}, + {{"output", { -21.f, 12.f, 0.f, 6.f, -7.5f, 84.f }}}); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1