From 725728e7d46c1e672bbdc72cf86e22db6fb210ee Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Thu, 5 May 2022 13:33:33 +0100 Subject: IVGCVSW-6938 Do not add Floor when FloorDiv is int32 in Tfliteparser Signed-off-by: Teresa Charlin Change-Id: I7ce633a66e2ecb72a9cdd1bff690c4195a9a449f --- src/armnnTfLiteParser/TfLiteParser.cpp | 11 ++++++- src/armnnTfLiteParser/test/FloorDiv.cpp | 56 +++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 7fe954d901..5f71ebcff6 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -4341,12 +4341,21 @@ armnn::IConnectableLayer* TfLiteParserImpl::AddFusedActivationLayer(armnn::IConn armnn::IConnectableLayer* TfLiteParserImpl::AddFusedFloorLayer(armnn::IConnectableLayer* prevLayer, unsigned int outputSlot) { + + auto& prevOutputSlot = prevLayer->GetOutputSlot(outputSlot); + DataType dataType = prevOutputSlot.GetTensorInfo().GetDataType(); + + if (dataType == DataType::Signed32) + { + return prevLayer; + } + std::string layerName = prevLayer->GetName(); IConnectableLayer* floorLayer = m_Network->AddFloorLayer(layerName.c_str()); - auto & prevOutputSlot = prevLayer->GetOutputSlot(outputSlot); prevOutputSlot.Connect(floorLayer->GetInputSlot(0)); floorLayer->GetOutputSlot(0).SetTensorInfo(prevOutputSlot.GetTensorInfo()); + return floorLayer; } diff --git a/src/armnnTfLiteParser/test/FloorDiv.cpp b/src/armnnTfLiteParser/test/FloorDiv.cpp index dfd7b14bf4..bcc2e2facd 100644 --- a/src/armnnTfLiteParser/test/FloorDiv.cpp +++ b/src/armnnTfLiteParser/test/FloorDiv.cpp @@ -6,7 +6,7 @@ #include "ParserFlatbuffersFixture.hpp" -TEST_SUITE("TensorflowLiteParser_Div") +TEST_SUITE("TensorflowLiteParser_FloorDiv") { struct FloorDivFixture : public ParserFlatbuffersFixture { @@ -15,7 +15,8 @@ struct FloorDivFixture : public ParserFlatbuffersFixture const std::string& outputShape, const std::string& inputShapeSignature1, const std::string& inputShapeSignature2, - const std::string& outputShapeSignature) + const std::string& outputShapeSignature, + const std::string& dataType = "FLOAT32") { m_JsonString = R"( { @@ -32,7 +33,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape1 + R"(, - "type": "FLOAT32", + "type": )" + dataType + R"(, "buffer": 1, "name": "inputTensor1", "quantization": { @@ -44,7 +45,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture }, { "shape": )" + inputShape2 + R"(, - "type": "FLOAT32", + "type": )" + dataType + R"(, "buffer": 2, "name": "inputTensor2", "quantization": { @@ -56,7 +57,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "FLOAT32", + "type": )" + dataType + R"(, "buffer": 3, "name": "outputTensor", "quantization": { @@ -155,6 +156,29 @@ TEST_CASE_FIXTURE(SimpleFloorDivFixture, "ParseFloorDiv") 1.0f, 1.0f, -1.0f } } }); } +struct SimpleFloorDivInt32Fixture : public FloorDivFixture +{ + SimpleFloorDivInt32Fixture() : FloorDivFixture("[ 1, 3, 4 ]", "[ 1, 3, 4 ]", "[ 1, 3, 4 ]", + "[ -1, 3, 4 ]", "[ -1, 3, 4 ]", "[ -1, 3, 4 ]", "INT32") {} +}; +TEST_CASE_FIXTURE(SimpleFloorDivInt32Fixture, "ParseFloorDivInt32") +{ + using armnn::DataType; + + RunTest<3, DataType::Signed32>(0, {{ "inputTensor1", { 1, 1, 2, + 3, 4, 5, + 6, -7, 8, + 9, 10, -11 } }, + { "inputTensor2", { 1, 1, 4, + 3, 40, 5, + 6, 2, 8, + 9, 10, 11} } }, + {{ "outputTensor", { 1, 1, 0, + 1, 0, 1, + 1, -4, 1, + 1, 1, -1 } } }); +} + struct DynamicFloorDivFixture : public FloorDivFixture { @@ -169,17 +193,17 @@ TEST_CASE_FIXTURE(DynamicFloorDivFixture, "ParseDynamicFloorDiv") float NaN = std::numeric_limits::quiet_NaN(); RunTest<3, DataType::Float32, DataType::Float32>(0, {{ "inputTensor1", { 0.0f, 1.0f, 2.0f, - 3.0f, 4.0f, 5.0f, - 6.0f, -7.0f, 8.0f, - 9.0f, 10.0f, -11.0f } }, - { "inputTensor2", { 0.0f, 0.0f, 4.0f, - 3.0f, 40.0f, 5.0f, - 6.0f, 2.0f, 8.0f, - 9.0f, 10.0f, 11.0f} } }, - {{ "outputTensor", { NaN, Inf, 0.0f, - 1.0f, 0.0f, 1.0f, - 1.0f, -4.0f, 1.0f, - 1.0f, 1.0f, -1.0f } } }, true); + 3.0f, 4.0f, 5.0f, + 6.0f, -7.0f, 8.0f, + 9.0f, 10.0f, -11.0f } }, + { "inputTensor2", { 0.0f, 0.0f, 4.0f, + 3.0f, 40.0f, 5.0f, + 6.0f, 2.0f, 8.0f, + 9.0f, 10.0f, 11.0f} } }, + {{ "outputTensor", { NaN, Inf, 0.0f, + 1.0f, 0.0f, 1.0f, + 1.0f, -4.0f, 1.0f, + 1.0f, 1.0f, -1.0f } } }, true); } } -- cgit v1.2.1