diff options
author | Matthew Bentham <matthew.bentham@arm.com> | 2019-04-24 14:36:20 +0100 |
---|---|---|
committer | derek.lamberti <derek.lamberti@arm.com> | 2019-04-25 12:16:51 +0000 |
commit | e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa (patch) | |
tree | 46ade5055b255c87a08cff8f848f325431998143 /1.0 | |
parent | f61c2705ad97273a9409a2aff427470ac131f596 (diff) | |
download | android-nn-driver-e9aa4699ba4c4e16893867dda8dab0d0eed7e7fa.tar.gz |
MLCE-117 Handle more cases for implicit flattening of Fully Connected input
Adds new unit test cases, and changes the implementation of
FlattenFullyConnectedInput to more closely match the documentation
of Android NNAPI.
Change-Id: I7ca96b1168b9c7bc78db66f53b0cc776153fd780
Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to '1.0')
-rw-r--r-- | 1.0/FullyConnected.hpp | 14 |
1 files changed, 5 insertions, 9 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp index 0fb029de..26d61e4c 100644 --- a/1.0/FullyConnected.hpp +++ b/1.0/FullyConnected.hpp @@ -17,21 +17,17 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i { if (inputShape.GetNumDimensions() > 2U) { - unsigned int dim0 = inputShape[0]; - unsigned int dim1 = inputShape[1]; + unsigned int totalInputElements = inputShape.GetNumElements(); + unsigned int inputSize = weightsShape[1]; - for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i) - { - dim1 *= inputShape[i]; - } + unsigned int batchSize = totalInputElements / inputSize; - unsigned int divisor = weightsShape[1] / dim1; - if(dim0 % divisor != 0) + if(totalInputElements % batchSize != 0) { throw std::runtime_error("Failed to deduce tensor shape"); } - return armnn::TensorShape({dim0 / divisor, dim1 * divisor}); + return armnn::TensorShape({batchSize, inputSize}); } else { |