aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/Conv2D.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/Conv2D.cpp')
-rw-r--r--src/armnnTfLiteParser/test/Conv2D.cpp72
1 files changed, 56 insertions, 16 deletions
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp
index c25e62bb00..45c4a43519 100644
--- a/src/armnnTfLiteParser/test/Conv2D.cpp
+++ b/src/armnnTfLiteParser/test/Conv2D.cpp
@@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D")
struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
{
- explicit Conv2DWithBiasesFixture(const std::string & inputShape,
- const std::string & outputShape,
- const std::string & filterShape,
- const std::string & filterData,
- const std::string & biasShape,
- const std::string & biasData,
- const std::string & strides,
- const std::string & activation="NONE",
- const std::string & filterScale="1.0",
- const std::string & filterZeroPoint="0",
- const std::string & outputScale="2.0",
- const std::string & outputZeroPoint="0")
+ explicit Conv2DWithBiasesFixture(const std::string& inputShape,
+ const std::string& outputShape,
+ const std::string& filterShape,
+ const std::string& filterData,
+ const std::string& biasShape,
+ const std::string& biasData,
+ const std::string& strides,
+ const std::string& activation="NONE",
+ const std::string& filterScale="1.0",
+ const std::string& filterZeroPoint="0",
+ const std::string& outputScale="2.0",
+ const std::string& outputZeroPoint="0",
+ const std::string& dataType = "UINT8",
+ const std::string& filterDataType = "UINT8",
+ const std::string& biasDataType = "INT32")
{
m_JsonString = R"(
{
@@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"( ,
- "type": "UINT8",
+ "type": )" + filterDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D")
});
}
+struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture
+{
+ Conv2FloatWithInt8WeightsAndBiasesFixture()
+ : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
+ "[ 1, 2, 2, 1 ]", // outputShape
+ "[ 1, 2, 2, 1 ]", // filterShape
+ "[ 2,1, 0,6 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // biasData
+ "1", // stride w and h
+ "NONE", // activation
+ "1.0", // filterScale
+ "0", // filterZeroPoint
+ "2.0", // outputScale
+ "0", // outputZeroPoint
+ "FLOAT32", // dataType
+ "INT8", // filterDataType
+ "INT8") // biasDataType
+ {}
+};
+
+TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture")
+{
+ RunTest<4, armnn::DataType::Float32>(
+ 0,
+ {
+ 1, 2,
+ 3, 4,
+ },
+ {
+ (1*2 + 2*1 + 3*0 + 4*6 + 10),
+ (2*2 + 0*1 + 4*0 + 0*6 + 10),
+ (3*2 + 4*1 + 0*0 + 0*6 + 10),
+ (4*2 + 0*1 + 0*0 + 0*6 + 10)
+ });
+}
+
}