aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets_new/FullyConnectedLayerDataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets_new/FullyConnectedLayerDataset.h')
-rw-r--r--tests/datasets_new/FullyConnectedLayerDataset.h34
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/datasets_new/FullyConnectedLayerDataset.h b/tests/datasets_new/FullyConnectedLayerDataset.h
index 562295f00f..8401e39ece 100644
--- a/tests/datasets_new/FullyConnectedLayerDataset.h
+++ b/tests/datasets_new/FullyConnectedLayerDataset.h
@@ -59,7 +59,7 @@ public:
description << "In=" << *_src_it << ":";
description << "Weights=" << *_weights_it << ":";
description << "Biases=" << *_biases_it << ":";
- description << "Out=" << *_dst_it << ":";
+ description << "Out=" << *_dst_it;
return description.str();
}
@@ -113,6 +113,38 @@ private:
std::vector<TensorShape> _bias_shapes{};
std::vector<TensorShape> _dst_shapes{};
};
+
+class SmallFullyConnectedLayerDataset final : public FullyConnectedLayerDataset
+{
+public:
+ SmallFullyConnectedLayerDataset()
+ {
+ // Conv -> FC
+ add_config(TensorShape(9U, 5U, 7U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U));
+ // Conv -> FC (batched)
+ add_config(TensorShape(9U, 5U, 7U, 3U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U, 3U));
+ // FC -> FC
+ add_config(TensorShape(201U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U));
+ // FC -> FC (batched)
+ add_config(TensorShape(201U, 3U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U, 3U));
+
+ add_config(TensorShape(9U, 5U, 7U, 3U, 2U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U, 3U, 2U));
+ }
+};
+
+class LargeFullyConnectedLayerDataset final : public FullyConnectedLayerDataset
+{
+public:
+ LargeFullyConnectedLayerDataset()
+ {
+ add_config(TensorShape(9U, 5U, 257U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U));
+ add_config(TensorShape(9U, 5U, 257U, 2U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U, 2U));
+ add_config(TensorShape(3127U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U));
+ add_config(TensorShape(3127U, 2U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U, 2U));
+
+ add_config(TensorShape(9U, 5U, 257U, 2U, 3U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U, 2U, 3U));
+ }
+};
} // namespace datasets
} // namespace test
} // namespace arm_compute