diff options
Diffstat (limited to 'tests/datasets_new/FullyConnectedLayerDataset.h')
-rw-r--r-- | tests/datasets_new/FullyConnectedLayerDataset.h | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/datasets_new/FullyConnectedLayerDataset.h b/tests/datasets_new/FullyConnectedLayerDataset.h index 562295f00f..8401e39ece 100644 --- a/tests/datasets_new/FullyConnectedLayerDataset.h +++ b/tests/datasets_new/FullyConnectedLayerDataset.h @@ -59,7 +59,7 @@ public: description << "In=" << *_src_it << ":"; description << "Weights=" << *_weights_it << ":"; description << "Biases=" << *_biases_it << ":"; - description << "Out=" << *_dst_it << ":"; + description << "Out=" << *_dst_it; return description.str(); } @@ -113,6 +113,38 @@ private: std::vector<TensorShape> _bias_shapes{}; std::vector<TensorShape> _dst_shapes{}; }; + +class SmallFullyConnectedLayerDataset final : public FullyConnectedLayerDataset +{ +public: + SmallFullyConnectedLayerDataset() + { + // Conv -> FC + add_config(TensorShape(9U, 5U, 7U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U)); + // Conv -> FC (batched) + add_config(TensorShape(9U, 5U, 7U, 3U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U, 3U)); + // FC -> FC + add_config(TensorShape(201U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U)); + // FC -> FC (batched) + add_config(TensorShape(201U, 3U), TensorShape(201U, 529U), TensorShape(529U), TensorShape(529U, 3U)); + + add_config(TensorShape(9U, 5U, 7U, 3U, 2U), TensorShape(315U, 271U), TensorShape(271U), TensorShape(271U, 3U, 2U)); + } +}; + +class LargeFullyConnectedLayerDataset final : public FullyConnectedLayerDataset +{ +public: + LargeFullyConnectedLayerDataset() + { + add_config(TensorShape(9U, 5U, 257U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U)); + add_config(TensorShape(9U, 5U, 257U, 2U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U, 2U)); + add_config(TensorShape(3127U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U)); + add_config(TensorShape(3127U, 2U), TensorShape(3127U, 989U), TensorShape(989U), TensorShape(989U, 2U)); + + add_config(TensorShape(9U, 5U, 257U, 2U, 3U), TensorShape(11565U, 2123U), TensorShape(2123U), TensorShape(2123U, 2U, 3U)); + } +}; } // namespace datasets } // namespace test } // namespace arm_compute |