aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/FullyConnected.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/FullyConnected.cpp')
-rw-r--r--src/armnnTfLiteParser/test/FullyConnected.cpp36
1 files changed, 31 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index fc000bf95b..108b878e20 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
const std::string& filterShape,
const std::string& filterData,
const std::string biasShape = "",
- const std::string biasData = "")
+ const std::string biasData = "",
+ const std::string dataType = "UINT8",
+ const std::string weightsDataType = "UINT8",
+ const std::string biasDataType = "INT32")
{
std::string inputTensors = "[ 0, 2 ]";
std::string biasTensor = "";
@@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
biasTensor = R"(
{
"shape": )" + biasShape + R"( ,
- "type": "INT32",
+ "type": )" + biasDataType + R"(,
"buffer": 3,
"name": "biasTensor",
"quantization": {
@@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
"tensors": [
{
"shape": )" + inputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 0,
"name": "inputTensor",
"quantization": {
@@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + outputShape + R"(,
- "type": "UINT8",
+ "type": )" + dataType + R"(,
"buffer": 1,
"name": "outputTensor",
"quantization": {
@@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture
},
{
"shape": )" + filterShape + R"(,
- "type": "UINT8",
+ "type": )" + weightsDataType + R"(,
"buffer": 2,
"name": "filterTensor",
"quantization": {
@@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo
{{"output", { 20 }}});
}
+struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture
+{
+ FullyConnectedWeightsBiasFloat()
+ : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape
+ "[ 1, 1 ]", // outputShape
+ "[ 1, 4 ]", // filterShape
+ "[ 2, 3, 4, 5 ]", // filterData
+ "[ 1 ]", // biasShape
+ "[ 10, 0, 0, 0 ]", // filterShape
+ "FLOAT32", // input and output dataType
+ "INT8", // weights dataType
+ "FLOAT32") // bias dataType
+ {}
+};
+
+TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat")
+{
+ RunTest<2, armnn::DataType::Float32>(
+ 0,
+ { 10, 20, 30, 40 },
+ { 400 });
+}
+
}