aboutsummaryrefslogtreecommitdiff
path: root/1.0/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to '1.0/FullyConnected.hpp')
-rw-r--r--1.0/FullyConnected.hpp13
1 files changed, 11 insertions, 2 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
index 26d61e4c..56997ad2 100644
--- a/1.0/FullyConnected.hpp
+++ b/1.0/FullyConnected.hpp
@@ -12,8 +12,8 @@
namespace armnn_driver
{
-inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
- const armnn::TensorShape &weightsShape)
+inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& weightsShape)
{
if (inputShape.GetNumDimensions() > 2U)
{
@@ -35,4 +35,13 @@ inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &i
}
}
+inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& weightsShape,
+ const armnn::TensorShape& outputShape,
+ bool transposeWeightMatrix)
+{
+ unsigned int dimIdx = transposeWeightMatrix ? 0 : 1;
+ return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]);
+}
+
} \ No newline at end of file