aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinnWilliamsArm <Finn.Williams@Arm.com>2020-01-08 14:57:47 +0000
committerFinnWilliamsArm <Finn.Williams@Arm.com>2020-01-09 12:14:38 +0000
commit7b8d2e65c129263e9cdbdc82e5f73dd4d263aafb (patch)
tree7aafd35f2d7436c5de3d6db8da94d2ae48c8d67b
parent7100649a0f9cdec8195cb26937280080b8c340ce (diff)
downloadandroid-nn-driver-7b8d2e65c129263e9cdbdc82e5f73dd4d263aafb.tar.gz
IVGCVSW-4315 Fix Fully Connected infer output shape bug
Change-Id: If4fd1abdedf7de2046435d418fb1ee95ceb73419 Signed-off-by: FinnWilliamsArm <Finn.Williams@Arm.com>
-rw-r--r--1.0/FullyConnected.hpp13
-rw-r--r--ConversionUtils.hpp8
2 files changed, 19 insertions, 2 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
index 26d61e4c..56997ad2 100644
--- a/1.0/FullyConnected.hpp
+++ b/1.0/FullyConnected.hpp
@@ -12,8 +12,8 @@
namespace armnn_driver
{
-inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
- const armnn::TensorShape &weightsShape)
+inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& weightsShape)
{
if (inputShape.GetNumDimensions() > 2U)
{
@@ -35,4 +35,13 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i
}
}
+inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& weightsShape,
+ const armnn::TensorShape& outputShape,
+ bool transposeWeightMatrix)
+{
+ unsigned int dimIdx = transposeWeightMatrix ? 0 : 1;
+ return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]);
+}
+
} \ No newline at end of file
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 9500ba68..b53432ca 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -2587,6 +2587,14 @@ bool ConvertFullyConnected(const HalOperation& operation, const HalModel& model,
desc.m_TransposeWeightMatrix = true;
desc.m_BiasEnabled = true;
+ if (!VerifyFullyConnectedShapes(reshapedInfo.GetShape(),
+ weights.GetInfo().GetShape(),
+ outputInfo.GetShape(),
+ desc.m_TransposeWeightMatrix))
+ {
+ return Fail("%s: Expected outputShape does not match actual outputShape", __func__);
+ }
+
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsFullyConnectedSupported,