aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test')
-rw-r--r--src/armnnTfLiteParser/test/Conv2D.cpp72
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp36
2 files changed, 87 insertions, 21 deletions
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp
index c25e62bb00..45c4a43519 100644
--- a/src/armnnTfLiteParser/test/Conv2D.cpp
+++ b/src/armnnTfLiteParser/test/Conv2D.cpp
@@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D")
struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
{
- explicit Conv2DWithBiasesFixture(const std::string & inputShape,
- const std::string & outputShape,
- const std::string & filterShape,
- const std::string & filterData,
- const std::string & biasShape,
- const std::string & biasData,
- const std::string & strides,
- const std::string & activation="NONE",
- const std::string & filterScale="1.0",
- const std::string & filterZeroPoint="0",
- const std::string & outputScale="2.0",
- const std::string & outputZeroPoint="0")
+ explicit Conv2DWithBiasesFixture(const std::string& inputShape,
+ const std::string& outputShape,
+ const std::string& filterShape,
+ const std::string& filterData,
+ const std::string& biasShape,
+ const std::string& biasData,
+ const std::string& strides,
+ const std::string& activation="NONE",
+ const std::string& filterScale="1.0",
+ const std::string& filterZeroPoint="0",
+ const std::string& outputScale="2.0",
+ const std::string& outputZeroPoint="0",
+ const std::string& dataType = "UINT8",
+ const std::string& filterDataType = "UINT8",
+ const std::string& biasDataType = "INT32")
{
m_JsonString = R"(
{
@@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"( ,
- "type": "UINT8",
+ "type": )" + filterDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D")
});
}
+struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture
+{
+ Conv2FloatWithInt8WeightsAndBiasesFixture()
+ : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape
+ "[ 1, 2, 2, 1 ]", // outputShape
+ "[ 1, 2, 2, 1 ]", // filterShape
+ "[ 2,1, 0,6 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // biasData
+ "1", // stride w and h
+ "NONE", // activation
+ "1.0", // filterScale
+ "0", // filterZeroPoint
+ "2.0", // outputScale
+ "0", // outputZeroPoint
+ "FLOAT32", // dataType
+ "INT8", // filterDataType
+ "INT8") // biasDataType
+ {}
+};
+
+TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture")
+{
+ RunTest<4, armnn::DataType::Float32>(
+ 0,
+ {
+ 1, 2,
+ 3, 4,
+ },
+ {
+ (1*2 + 2*1 + 3*0 + 4*6 + 10),
+ (2*2 + 0*1 + 4*0 + 0*6 + 10),
+ (3*2 + 4*1 + 0*0 + 0*6 + 10),
+ (4*2 + 0*1 + 0*0 + 0*6 + 10)
+ });
+}
+
}
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index fc000bf95b..108b878e20 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
const std::string& filterShape,
const std::string& filterData,
const std::string biasShape = "",
- const std::string biasData = "")
+ const std::string biasData = "",
+ const std::string dataType = "UINT8",
+ const std::string weightsDataType = "UINT8",
+ const std::string biasDataType = "INT32")
{
std::string inputTensors = "[ 0, 2 ]";
std::string biasTensor = "";
@@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
biasTensor = R"(
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"(,
- "type": "UINT8",
+ "type": )" + weightsDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo
{{"output", { 20 }}});
}
+struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
+{
+ FullyConnectedWeightsBiasFloat()
+ : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
+ "[ 1, 1 ]", // outputShape
+ "[ 1, 4 ]", // filterShape
+ "[ 2, 3, 4, 5 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // filterShape
+ "FLOAT32", // input and output dataType
+ "INT8", // weights dataType
+ "FLOAT32") // bias dataType
+ {}
+};
+
+TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ { 10, 20, 30, 40 },
+ { 400 });
+}
+
}