aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-05-05 13:33:33 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-05-05 13:33:33 +0100
commit725728e7d46c1e672bbdc72cf86e22db6fb210ee (patch)
treeb93901e1308f6b01bb41577237f05864518058a5
parentbd22c7d8d71bb9d6fdebcd07a472d66c7616abad (diff)
downloadarmnn-725728e7d46c1e672bbdc72cf86e22db6fb210ee.tar.gz
IVGCVSW-6938 Do not add Floor when FloorDiv is int32 in Tfliteparser
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I7ce633a66e2ecb72a9cdd1bff690c4195a9a449f
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp11
-rw-r--r--src/armnnTfLiteParser/test/FloorDiv.cpp56
2 files changed, 50 insertions, 17 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 7fe954d901..5f71ebcff6 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -4341,12 +4341,21 @@ armnn::IConnectableLayer* TfLiteParserImpl::AddFusedActivationLayer(armnn::IConn
armnn::IConnectableLayer* TfLiteParserImpl::AddFusedFloorLayer(armnn::IConnectableLayer* prevLayer,
unsigned int outputSlot)
{
+
+ auto& prevOutputSlot = prevLayer->GetOutputSlot(outputSlot);
+ DataType dataType = prevOutputSlot.GetTensorInfo().GetDataType();
+
+ if (dataType == DataType::Signed32)
+ {
+ return prevLayer;
+ }
+
std::string layerName = prevLayer->GetName();
IConnectableLayer* floorLayer = m_Network->AddFloorLayer(layerName.c_str());
- auto & prevOutputSlot = prevLayer->GetOutputSlot(outputSlot);
prevOutputSlot.Connect(floorLayer->GetInputSlot(0));
floorLayer->GetOutputSlot(0).SetTensorInfo(prevOutputSlot.GetTensorInfo());
+
return floorLayer;
}
diff --git a/src/armnnTfLiteParser/test/FloorDiv.cpp b/src/armnnTfLiteParser/test/FloorDiv.cpp
index dfd7b14bf4..bcc2e2facd 100644
--- a/src/armnnTfLiteParser/test/FloorDiv.cpp
+++ b/src/armnnTfLiteParser/test/FloorDiv.cpp
@@ -6,7 +6,7 @@
#include "ParserFlatbuffersFixture.hpp"
-TEST_SUITE("TensorflowLiteParser_Div")
+TEST_SUITE("TensorflowLiteParser_FloorDiv")
{
struct FloorDivFixture : public ParserFlatbuffersFixture
{
@@ -15,7 +15,8 @@ struct FloorDivFixture : public ParserFlatbuffersFixture
const std::string& outputShape,
const std::string& inputShapeSignature1,
const std::string& inputShapeSignature2,
- const std::string& outputShapeSignature)
+ const std::string& outputShapeSignature,
+ const std::string& dataType = "FLOAT32")
{
m_JsonString = R"(
{
@@ -32,7 +33,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape1 + R"(,
- "type": "FLOAT32",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "inputTensor1",
"quantization": {
@@ -44,7 +45,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + inputShape2 + R"(,
- "type": "FLOAT32",
+ "type": )" + dataType + R"(,
"buffer": 2,
"name": "inputTensor2",
"quantization": {
@@ -56,7 +57,7 @@ struct FloorDivFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "FLOAT32",
+ "type": )" + dataType + R"(,
"buffer": 3,
"name": "outputTensor",
"quantization": {
@@ -155,6 +156,29 @@ TEST_CASE_FIXTURE(SimpleFloorDivFixture, "ParseFloorDiv")
1.0f, 1.0f, -1.0f } } });
}
+struct SimpleFloorDivInt32Fixture : public FloorDivFixture
+{
+ SimpleFloorDivInt32Fixture() : FloorDivFixture("[ 1, 3, 4 ]", "[ 1, 3, 4 ]", "[ 1, 3, 4 ]",
+ "[ -1, 3, 4 ]", "[ -1, 3, 4 ]", "[ -1, 3, 4 ]", "INT32") {}
+};
+TEST_CASE_FIXTURE(SimpleFloorDivInt32Fixture, "ParseFloorDivInt32")
+{
+ using armnn::DataType;
+
+ RunTest<3, DataType::Signed32>(0, {{ "inputTensor1", { 1, 1, 2,
+ 3, 4, 5,
+ 6, -7, 8,
+ 9, 10, -11 } },
+ { "inputTensor2", { 1, 1, 4,
+ 3, 40, 5,
+ 6, 2, 8,
+ 9, 10, 11} } },
+ {{ "outputTensor", { 1, 1, 0,
+ 1, 0, 1,
+ 1, -4, 1,
+ 1, 1, -1 } } });
+}
+
struct DynamicFloorDivFixture : public FloorDivFixture
{
@@ -169,17 +193,17 @@ TEST_CASE_FIXTURE(DynamicFloorDivFixture, "ParseDynamicFloorDiv")
float NaN = std::numeric_limits<float>::quiet_NaN();
RunTest<3, DataType::Float32, DataType::Float32>(0, {{ "inputTensor1", { 0.0f, 1.0f, 2.0f,
- 3.0f, 4.0f, 5.0f,
- 6.0f, -7.0f, 8.0f,
- 9.0f, 10.0f, -11.0f } },
- { "inputTensor2", { 0.0f, 0.0f, 4.0f,
- 3.0f, 40.0f, 5.0f,
- 6.0f, 2.0f, 8.0f,
- 9.0f, 10.0f, 11.0f} } },
- {{ "outputTensor", { NaN, Inf, 0.0f,
- 1.0f, 0.0f, 1.0f,
- 1.0f, -4.0f, 1.0f,
- 1.0f, 1.0f, -1.0f } } }, true);
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, -7.0f, 8.0f,
+ 9.0f, 10.0f, -11.0f } },
+ { "inputTensor2", { 0.0f, 0.0f, 4.0f,
+ 3.0f, 40.0f, 5.0f,
+ 6.0f, 2.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f} } },
+ {{ "outputTensor", { NaN, Inf, 0.0f,
+ 1.0f, 0.0f, 1.0f,
+ 1.0f, -4.0f, 1.0f,
+ 1.0f, 1.0f, -1.0f } } }, true);
}
}