aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2019-04-24 14:36:20 +0100
committerderek.lamberti <derek.lamberti@arm.com>2019-04-25 12:16:51 +0000
commite9aa4699ba4c4e16893867dda8dab0d0eed7e7fa (patch)
tree46ade5055b255c87a08cff8f848f325431998143 /test
parentf61c2705ad97273a9409a2aff427470ac131f596 (diff)
downloadandroid-nn-driver-e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa.tar.gz
MLCE-117 Handle more cases for implicit flattening of Fully Connected input
Adds new unit test cases, and changes the implementation of FlattenFullyConnectedInput to more closely match the documentation of Android NNAPI. Change-Id: I7ca96b1168b9c7bc78db66f53b0cc776153fd780 Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to 'test')
-rw-r--r--test/1.0/FullyConnectedReshape.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp
index 250f8837..72c90ca5 100644
--- a/test/1.0/FullyConnectedReshape.cpp
+++ b/test/1.0/FullyConnectedReshape.cpp
@@ -13,8 +13,30 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests)
BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput)
{
using armnn::TensorShape;
+
+ // Pass through 2d input
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({2,2048}), TensorShape({512, 2048})) ==
+ TensorShape({2, 2048}));
+
+ // Trivial flattening of batched channels
BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) ==
TensorShape({97, 2048}));
+
+ // Flatten single batch of rows
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Flatten single batch of columns
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Move batches into input dimension
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), TensorShape({512, 20})) ==
+ TensorShape({25, 20}));
+
+ // Flatten single batch of 3D data (e.g. convolution output)
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), TensorShape({512, 2560})) ==
+ TensorShape({1, 2560}));
}
BOOST_AUTO_TEST_SUITE_END()