aboutsummaryrefslogtreecommitdiff
path: root/1.0/HalPolicy.cpp
diff options
context:
space:
mode:
Diffstat (limited to '1.0/HalPolicy.cpp')
-rw-r--r--1.0/HalPolicy.cpp26
1 files changed, 8 insertions, 18 deletions
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index dee4a7a5..158f0e36 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -5,7 +5,9 @@
#include "HalPolicy.hpp"
-#include "armnn/Optional.hpp"
+#include <armnn/Optional.hpp>
+
+#include "FullyConnected.hpp"
namespace armnn_driver
{
@@ -633,25 +635,13 @@ bool HalPolicy::ConvertFullyConnected(const Operation& operation, const Model& m
armnn::ConstTensor weights = weightsPin.GetConstTensor();
armnn::ConstTensor bias = biasPin.GetConstTensor();
-
armnn::TensorInfo reshapedInfo = inputInfo;
- if (inputInfo.GetNumDimensions() > 2U)
- {
- unsigned int dim0 = inputInfo.GetShape()[0];
- unsigned int dim1 = inputInfo.GetShape()[1];
-
- for (unsigned int i = 2U; i < inputInfo.GetNumDimensions(); ++i)
- {
- dim1 *= inputInfo.GetShape()[i];
- }
- unsigned int divisor = weights.GetInfo().GetShape()[1] / dim1;
- if(dim0 % divisor != 0)
- {
- return Fail("%s: Failed to deduce tensor shape", __func__);
- }
-
- reshapedInfo.SetShape(armnn::TensorShape({dim0 / divisor, dim1 * divisor}));
+ try
+ {
+ reshapedInfo.SetShape(FlattenFullyConnectedInput(inputInfo.GetShape(), weights.GetInfo().GetShape()));
+ } catch (const std::exception &e) {
+ return Fail("%s: %s", __func__, e.what());
}
// ensuring that the bias value is within 1% of the weights input (small float differences can exist)