aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--1.0/FullyConnected.hpp14
-rw-r--r--test/1.0/FullyConnectedReshape.cpp22
2 files changed, 27 insertions, 9 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
index 0fb029de..26d61e4c 100644
--- a/1.0/FullyConnected.hpp
+++ b/1.0/FullyConnected.hpp
@@ -17,21 +17,17 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i
{
if (inputShape.GetNumDimensions() > 2U)
{
- unsigned int dim0 = inputShape[0];
- unsigned int dim1 = inputShape[1];
+ unsigned int totalInputElements = inputShape.GetNumElements();
+ unsigned int inputSize = weightsShape[1];
- for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
- {
- dim1 *= inputShape[i];
- }
+ unsigned int batchSize = totalInputElements / inputSize;
- unsigned int divisor = weightsShape[1] / dim1;
- if(dim0 % divisor != 0)
+ if(totalInputElements % batchSize != 0)
{
throw std::runtime_error("Failed to deduce tensor shape");
}
- return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
+ return armnn::TensorShape({batchSize, inputSize});
}
else
{
diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp
index 250f8837..72c90ca5 100644
--- a/test/1.0/FullyConnectedReshape.cpp
+++ b/test/1.0/FullyConnectedReshape.cpp
@@ -13,8 +13,30 @@ BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests)
BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput)
{
using armnn::TensorShape;
+
+ // Pass through 2d input
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({2,2048}), TensorShape({512, 2048})) ==
+ TensorShape({2, 2048}));
+
+ // Trivial flattening of batched channels
BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) ==
TensorShape({97, 2048}));
+
+ // Flatten single batch of rows
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Flatten single batch of columns
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+
+ // Move batches into input dimension
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), TensorShape({512, 20})) ==
+ TensorShape({25, 20}));
+
+ // Flatten single batch of 3D data (e.g. convolution output)
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), TensorShape({512, 2560})) ==
+ TensorShape({1, 2560}));
}
BOOST_AUTO_TEST_SUITE_END()