aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Bentham <matthew.bentham@arm.com>2019-04-23 16:43:27 +0100
committerderek.lamberti <derek.lamberti@arm.com>2019-04-25 12:12:48 +0000
commitf61c2705ad97273a9409a2aff427470ac131f596 (patch)
tree8835dae435f174a54afe5bd41e2763d5eb1236bc
parent5404c01b4b0a2455aea0f5d1de0a45e2e859a466 (diff)
downloadandroid-nn-driver-f61c2705ad97273a9409a2aff427470ac131f596.tar.gz
MLCE-117 Add a unit test for implicit flatten of FC layer input
Change-Id: Ia4dd63927a54aa0cc24d5a378f30189c957f12e8 Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
-rw-r--r--1.0/FullyConnected.hpp42
-rw-r--r--1.0/HalPolicy.cpp26
-rw-r--r--test/1.0/FullyConnectedReshape.cpp20
-rw-r--r--test/Android.mk1
4 files changed, 71 insertions, 18 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
new file mode 100644
index 00000000..0fb029de
--- /dev/null
+++ b/1.0/FullyConnected.hpp
@@ -0,0 +1,42 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+#include "../ConversionUtils.hpp"
+
+namespace armnn_driver
+{
+
+inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
+ const armnn::TensorShape &weightsShape)
+{
+ if (inputShape.GetNumDimensions() > 2U)
+ {
+ unsigned int dim0 = inputShape[0];
+ unsigned int dim1 = inputShape[1];
+
+ for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
+ {
+ dim1 *= inputShape[i];
+ }
+
+ unsigned int divisor = weightsShape[1] / dim1;
+ if(dim0 % divisor != 0)
+ {
+ throw std::runtime_error("Failed to deduce tensor shape");
+ }
+
+ return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
+ }
+ else
+ {
+ return inputShape;
+ }
+}
+
+} \ No newline at end of file
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index dee4a7a5..158f0e36 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -5,7 +5,9 @@
#include "HalPolicy.hpp"
-#include "armnn/Optional.hpp"
+#include <armnn/Optional.hpp>
+
+#include "FullyConnected.hpp"
namespace armnn_driver
{
@@ -633,25 +635,13 @@ bool HalPolicy::ConvertFullyConnected(const Operation& operation, const Model& m
armnn::ConstTensor weights = weightsPin.GetConstTensor();
armnn::ConstTensor bias = biasPin.GetConstTensor();
-
armnn::TensorInfo reshapedInfo = inputInfo;
- if (inputInfo.GetNumDimensions() > 2U)
- {
- unsigned int dim0 = inputInfo.GetShape()[0];
- unsigned int dim1 = inputInfo.GetShape()[1];
-
- for (unsigned int i = 2U; i < inputInfo.GetNumDimensions(); ++i)
- {
- dim1 *= inputInfo.GetShape()[i];
- }
- unsigned int divisor = weights.GetInfo().GetShape()[1] / dim1;
- if(dim0 % divisor != 0)
- {
- return Fail("%s: Failed to deduce tensor shape", __func__);
- }
-
- reshapedInfo.SetShape(armnn::TensorShape({dim0 / divisor, dim1 * divisor}));
+ try
+ {
+ reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weights.GetInfo().GetShape()));
+ } catch (const std::exception &e) {
+ return Fail("%s: %s", __func__, e.what());
}
// ensuring that the bias value is within 1% of the weights input (small float differences can exist)
diff --git a/test/1.0/FullyConnectedReshape.cpp b/test/1.0/FullyConnectedReshape.cpp
new file mode 100644
index 00000000..250f8837
--- /dev/null
+++ b/test/1.0/FullyConnectedReshape.cpp
@@ -0,0 +1,20 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "../DriverTestHelpers.hpp"
+#include "../../1.0/FullyConnected.hpp"
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(FullyConnectedReshapeTests)
+
+BOOST_AUTO_TEST_CASE(TestFlattenFullyConnectedInput)
+{
+ using armnn::TensorShape;
+ BOOST_TEST(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), TensorShape({512, 2048})) ==
+ TensorShape({97, 2048}));
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/test/Android.mk b/test/Android.mk
index da3ac706..87f0b636 100644
--- a/test/Android.mk
+++ b/test/Android.mk
@@ -45,6 +45,7 @@ endif # PLATFORM_VERSION == 9
LOCAL_SRC_FILES := \
1.0/Convolution2D.cpp \
+ 1.0/FullyConnectedReshape.cpp \
Tests.cpp \
UtilsTests.cpp \
Concurrent.cpp \