aboutsummaryrefslogtreecommitdiff
path: root/1.0/FullyConnected.hpp
diff options
context:
space:
mode:
Diffstat (limited to '1.0/FullyConnected.hpp')
-rw-r--r--1.0/FullyConnected.hpp42
1 files changed, 42 insertions, 0 deletions
diff --git a/1.0/FullyConnected.hpp b/1.0/FullyConnected.hpp
new file mode 100644
index 00000000..0fb029de
--- /dev/null
+++ b/1.0/FullyConnected.hpp
@@ -0,0 +1,42 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+#include "../ConversionUtils.hpp"
+
+namespace armnn_driver
+{
+
+inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
+ const armnn::TensorShape &weightsShape)
+{
+ if (inputShape.GetNumDimensions() > 2U)
+ {
+ unsigned int dim0 = inputShape[0];
+ unsigned int dim1 = inputShape[1];
+
+ for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
+ {
+ dim1 *= inputShape[i];
+ }
+
+ unsigned int divisor = weightsShape[1] / dim1;
+ if(dim0 % divisor != 0)
+ {
+ throw std::runtime_error("Failed to deduce tensor shape");
+ }
+
+ return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
+ }
+ else
+ {
+ return inputShape;
+ }
+}
+
+} \ No newline at end of file