aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2019-04-24 14:36:20 +0100
committerderek.lamberti <derek.lamberti@arm.com>2019-04-25 12:16:51 +0000
commite9aa4699ba4c4e16893867dda8dab0d0eed7e7fa (patch)
tree46ade5055b255c87a08cff8f848f325431998143
parentf61c2705ad97273a9409a2aff427470ac131f596 (diff)
downloadandroid-nn-driver-e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa.tar.gz
MLCE-117 Handle more cases for implicit flattening of Fully Connected input
Adds new unit test cases, and changes the implementation of FlattenFullyConnectedInput to more closely match the documentation of Android NNAPI. Change-Id: I7ca96b1168b9c7bc78db66f53b0cc776153fd780 Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
-rw-r--r--1.0/FullyConnected.hpp14
-rw-r--r--test/1.0/FullyConnectedReshape.cpp22
2 files changed, 27 insertions, 9 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
index 0fb029de..26d61e4c 100644
--- a/1.0/FullyConnected.hpp
+++ b/1.0/FullyConnected.hpp
@@ -17,21 +17,17 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i
{
if (inputShape.GetNumDimensions() > 2U)
{
- unsigned int dim0 = inputShape[0];
- unsigned int dim1 = inputShape[1];
+ unsigned int totalInputElements = inputShape.GetNumElements();
+ unsigned int inputSize = weightsShape[1];
- for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
- {
- dim1 *= inputShape[i];
- }
+ unsigned int batchSize = totalInputElements / inputSize;
- unsigned int divisor = weightsShape[1] / dim1;
- if(dim0 % divisor != 0)
+ if(totalInputElements % batchSize != 0)
{
throw std::runtime_error("Failed to deduce tensor shape");
}
- return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
+ return armnn::TensorShape({batchSize, inputSize});
}
else
{
diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp
index 250f8837..72c90ca5 100644
--- a/test/1.0/FullyConnectedReshape.cpp
+++ b/test/1.0/FullyConnectedReshape.cpp
@@ -13,8 +13,30 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests)
BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput)
{
using armnn::TensorShape;
+
+ // Pass through 2d input
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({2,2048}), TensorShape({512, 2048})) ==
+ TensorShape({2, 2048}));
+
+ // Trivial flattening of batched channels
BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) ==
TensorShape({97, 2048}));
+
+ // Flatten single batch of rows
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Flatten single batch of columns
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Move batches into input dimension
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), TensorShape({512, 20})) ==
+ TensorShape({25, 20}));
+
+ // Flatten single batch of 3D data (e.g. convolution output)
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), TensorShape({512, 2560})) ==
+ TensorShape({1, 2560}));
}
BOOST_AUTO_TEST_SUITE_END()