diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2019-04-24 14:36:20 +0100 |
---|---|---|
committer | derek.lamberti <derek.lamberti@arm.com> | 2019-04-25 12:16:51 +0000 |
commit | e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa (patch) | |
tree | 46ade5055b255c87a08cff8f848f325431998143 /test | |
parent | f61c2705ad97273a9409a2aff427470ac131f596 (diff) | |
download | android-nn-driver-e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa.tar.gz |
MLCE-117 Handle more cases for implicit flattening of Fully Connected input
Adds new unit test cases, and changes the implementation of
FlattenFullyConnectedInput to more closely match the documentation
of Android NNAPI.
Change-Id: I7ca96b1168b9c7bc78db66f53b0cc776153fd780
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to 'test')
-rw-r--r-- | test/1.0/FullyConnectedReshape.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp index 250f8837..72c90ca5 100644 --- a/test/1.0/FullyConnectedReshape.cpp +++ b/test/1.0/FullyConnectedReshape.cpp @@ -13,8 +13,30 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests) BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput) { using armnn::TensorShape; + + // Pass through 2d input + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({2,2048}), TensorShape({512, 2048})) == + TensorShape({2, 2048})); + + // Trivial flattening of batched channels BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) == TensorShape({97, 2048})); + + // Flatten single batch of rows + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), TensorShape({512, 2048})) == + TensorShape({97, 2048})); + + // Flatten single batch of columns + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), TensorShape({512, 2048})) == + TensorShape({97, 2048})); + + // Move batches into input dimension + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), TensorShape({512, 20})) == + TensorShape({25, 20})); + + // Flatten single batch of 3D data (e.g. convolution output) + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), TensorShape({512, 2560})) == + TensorShape({1, 2560})); } BOOST_AUTO_TEST_SUITE_END() |