From c577f2c6a3b4ddb6ba87a882723c53a248afbeba Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 31 Aug 2018 09:22:23 +0100 Subject: Release 18.08 --- src/armnn/layers/FullyConnectedLayer.cpp | 40 ++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 10 deletions(-) (limited to 'src/armnn/layers/FullyConnectedLayer.cpp') diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp index 1597e8c2c3..8b8f010bdb 100644 --- a/src/armnn/layers/FullyConnectedLayer.cpp +++ b/src/armnn/layers/FullyConnectedLayer.cpp @@ -22,11 +22,15 @@ FullyConnectedLayer::FullyConnectedLayer(const FullyConnectedDescriptor& param, std::unique_ptr FullyConnectedLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const { + // on this level constant data should not be released.. + BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null."); + FullyConnectedQueueDescriptor descriptor; descriptor.m_Weight = m_Weight.get(); if (m_Param.m_BiasEnabled) { + BOOST_ASSERT_MSG(m_Bias != nullptr, "FullyConnectedLayer: Bias data should not be null."); descriptor.m_Bias = m_Bias.get(); } return factory.CreateFullyConnected(descriptor, PrepInfoAndDesc(descriptor, graph)); @@ -45,25 +49,41 @@ FullyConnectedLayer* FullyConnectedLayer::Clone(Graph& graph) const return std::move(layer); } +std::vector FullyConnectedLayer::InferOutputShapes(const std::vector& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 2); + const TensorShape& inputShape = inputShapes[0]; + const TensorShape weightShape = inputShapes[1]; + + // Output for FC is [1, w[1]]. + unsigned int batches = inputShape[0]; + unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1; + + return std::vector({ TensorShape({batches, weightShape[dimIdx]})}); +} + void FullyConnectedLayer::ValidateTensorShapesFromInputs() { - ConditionalThrow(GetInputSlot(0).GetConnection() != nullptr, - "FullyConnectedLayer: InputSlot must be connected to an OutputSlot"); - ConditionalThrow(GetInputSlot(0).GetConnection()->IsTensorInfoSet(), - "FullyConnectedLayer: TensorInfo must be set on connected OutputSlot."); + VerifyLayerConnections(1, CHECK_LOCATION()); + // check if we m_Weight data is not nullptr + BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null."); - TensorShape const& weightShape = m_Weight->GetTensorInfo().GetShape(); + auto inferredShapes = InferOutputShapes({ + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + m_Weight->GetTensorInfo().GetShape() }); - // output for FC is [1, w[1]] - unsigned int batches = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()[0]; - unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1; - TensorShape outShape({batches, weightShape[dimIdx]}); + BOOST_ASSERT(inferredShapes.size() == 1); ConditionalThrowIfNotEqual( "FullyConnectedLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", GetOutputSlot(0).GetTensorInfo().GetShape(), - outShape); + inferredShapes[0]); +} + +Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef() +{ + return {m_Weight, m_Bias}; } } // namespace armnn -- cgit v1.2.1