From e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa Mon Sep 17 00:00:00 2001 From: Matthew Bentham Date: Wed, 24 Apr 2019 14:36:20 +0100 Subject: MLCE-117 Handle more cases for implicit flattening of Fully Connected input Adds new unit test cases, and changes the implementation of FlattenFullyConnectedInput to more closely match the documentation of Android NNAPI. Change-Id: I7ca96b1168b9c7bc78db66f53b0cc776153fd780 Signed-off-by: Matthew Bentham --- 1.0/FullyConnected.hpp | 14 +++++--------- test/1.0/FullyConnectedReshape.cpp | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp index 0fb029de..26d61e4c 100644 --- a/1.0/FullyConnected.hpp +++ b/1.0/FullyConnected.hpp @@ -17,21 +17,17 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i { if (inputShape.GetNumDimensions() > 2U) { - unsigned int dim0 = inputShape[0]; - unsigned int dim1 = inputShape[1]; + unsigned int totalInputElements = inputShape.GetNumElements(); + unsigned int inputSize = weightsShape[1]; - for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i) - { - dim1 *= inputShape[i]; - } + unsigned int batchSize = totalInputElements / inputSize; - unsigned int divisor = weightsShape[1] / dim1; - if(dim0 % divisor != 0) + if(totalInputElements % batchSize != 0) { throw std::runtime_error("Failed to deduce tensor shape"); } - return armnn::TensorShape({dim0 / divisor, dim1 * divisor}); + return armnn::TensorShape({batchSize, inputSize}); } else { diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp index 250f8837..72c90ca5 100644 --- a/test/1.0/FullyConnectedReshape.cpp +++ b/test/1.0/FullyConnectedReshape.cpp @@ -13,8 +13,30 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests) BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput) { using armnn::TensorShape; + + // Pass through 2d input + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({2,2048}), TensorShape({512, 2048})) == + TensorShape({2, 2048})); + + // Trivial flattening of batched channels BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) == TensorShape({97, 2048})); + + // Flatten single batch of rows + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), TensorShape({512, 2048})) == + TensorShape({97, 2048})); + + // Flatten single batch of columns + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), TensorShape({512, 2048})) == + TensorShape({97, 2048})); + + // Move batches into input dimension + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), TensorShape({512, 20})) == + TensorShape({25, 20})); + + // Flatten single batch of 3D data (e.g. convolution output) + BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), TensorShape({512, 2560})) == + TensorShape({1, 2560})); } BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1