From 42defa6c2bc2574f39963554937bdf81e49be449 Mon Sep 17 00:00:00 2001 From: FinnWilliamsArm Date: Wed, 8 Jan 2020 14:57:47 +0000 Subject: IVGCVSW-4315 Fix Fully Connected infer output shape bug Change-Id: If4fd1abdedf7de2046435d418fb1ee95ceb73419 Signed-off-by: FinnWilliamsArm --- 1.0/FullyConnected.hpp | 13 +++++++++++-- ConversionUtils.hpp | 8 ++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp index 26d61e4c..56997ad2 100644 --- a/1.0/FullyConnected.hpp +++ b/1.0/FullyConnected.hpp @@ -12,8 +12,8 @@ namespace armnn_driver { -inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape, - const armnn::TensorShape &weightsShape) +inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape, + const armnn::TensorShape& weightsShape) { if (inputShape.GetNumDimensions() > 2U) { @@ -35,4 +35,13 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i } } +inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape, + const armnn::TensorShape& weightsShape, + const armnn::TensorShape& outputShape, + bool transposeWeightMatrix) +{ + unsigned int dimIdx = transposeWeightMatrix ? 0 : 1; + return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]); +} + } \ No newline at end of file diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index a342d399..3a144f78 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -2645,6 +2645,14 @@ bool ConvertFullyConnected(const Operation& operation, const Model& model, Conve desc.m_TransposeWeightMatrix = true; desc.m_BiasEnabled = true; + if (!VerifyFullyConnectedShapes(reshapedInfo.GetShape(), + weights.GetInfo().GetShape(), + outputInfo.GetShape(), + desc.m_TransposeWeightMatrix)) + { + return Fail("%s: Expected outputShape does not match actual outputShape", __func__); + } + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsFullyConnectedSupported, -- cgit v1.2.1