aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/FullyConnectedLayer.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/layers/FullyConnectedLayer.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/layers/FullyConnectedLayer.cpp')
-rw-r--r--src/armnn/layers/FullyConnectedLayer.cpp40
1 files changed, 30 insertions, 10 deletions
diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp
index 1597e8c2c3..8b8f010bdb 100644
--- a/src/armnn/layers/FullyConnectedLayer.cpp
+++ b/src/armnn/layers/FullyConnectedLayer.cpp
@@ -22,11 +22,15 @@ FullyConnectedLayer::FullyConnectedLayer(const FullyConnectedDescriptor& param,
std::unique_ptr<IWorkload> FullyConnectedLayer::CreateWorkload(const Graph& graph,
const IWorkloadFactory& factory) const
{
+ // on this level constant data should not be released..
+ BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
+
FullyConnectedQueueDescriptor descriptor;
descriptor.m_Weight = m_Weight.get();
if (m_Param.m_BiasEnabled)
{
+ BOOST_ASSERT_MSG(m_Bias != nullptr, "FullyConnectedLayer: Bias data should not be null.");
descriptor.m_Bias = m_Bias.get();
}
return factory.CreateFullyConnected(descriptor, PrepInfoAndDesc(descriptor, graph));
@@ -45,25 +49,41 @@ FullyConnectedLayer* FullyConnectedLayer::Clone(Graph& graph) const
return std::move(layer);
}
+std::vector<TensorShape> FullyConnectedLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 2);
+ const TensorShape& inputShape = inputShapes[0];
+ const TensorShape weightShape = inputShapes[1];
+
+ // Output for FC is [1, w[1]].
+ unsigned int batches = inputShape[0];
+ unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1;
+
+ return std::vector<TensorShape>({ TensorShape({batches, weightShape[dimIdx]})});
+}
+
void FullyConnectedLayer::ValidateTensorShapesFromInputs()
{
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
- "FullyConnectedLayer: InputSlot must be connected to an OutputSlot");
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(),
- "FullyConnectedLayer: TensorInfo must be set on connected OutputSlot.");
+ VerifyLayerConnections(1, CHECK_LOCATION());
+ // check if we m_Weight data is not nullptr
+ BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
- TensorShape const& weightShape = m_Weight->GetTensorInfo().GetShape();
+ auto inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ m_Weight->GetTensorInfo().GetShape() });
- // output for FC is [1, w[1]]
- unsigned int batches = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()[0];
- unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1;
- TensorShape outShape({batches, weightShape[dimIdx]});
+ BOOST_ASSERT(inferredShapes.size() == 1);
ConditionalThrowIfNotEqual<LayerValidationException>(
"FullyConnectedLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
GetOutputSlot(0).GetTensorInfo().GetShape(),
- outShape);
+ inferredShapes[0]);
+}
+
+Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef()
+{
+ return {m_Weight, m_Bias};
}
} // namespace armnn