diff options
Diffstat (limited to 'src/armnn/Layers.cpp')
-rw-r--r-- | src/armnn/Layers.cpp | 69 |
1 files changed, 56 insertions, 13 deletions
diff --git a/src/armnn/Layers.cpp b/src/armnn/Layers.cpp index ddbc7d222c..48a02aba9c 100644 --- a/src/armnn/Layers.cpp +++ b/src/armnn/Layers.cpp @@ -11,6 +11,8 @@ #include "Permute.hpp" +#include <queue> + namespace armnn { @@ -21,6 +23,7 @@ LayerType* Layer::CloneBase(Graph& graph, Params&& ... params) const LayerType* const layer = graph.AddLayer<LayerType>(std::forward<Params>(params)...); layer->SetComputeDevice(m_ComputeDevice); + layer->SetGuid(GetGuid()); return layer; } @@ -82,12 +85,11 @@ void AdditionLayer::ValidateTensorShapesFromInputs() unsigned int dim1 = input1.GetShape()[i]; if (dim0 != dim1) { - BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be one length"); + BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); } } #endif - for (unsigned int i = 0; i < numDims; i++) { unsigned int dim0 = input0.GetShape()[i]; @@ -439,14 +441,31 @@ void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fact m_OutputHandlers[0].CreateTensorHandles(factory); if (factory.SupportsSubTensors()) { - const unsigned int numInputSlots = GetNumInputSlots(); - for (unsigned int i = 0; i < numInputSlots; ++i) + std::queue<MergerLayer*> m_MergerLayers; + + m_MergerLayers.push(this); + while (!m_MergerLayers.empty()) { - OutputHandler& outputHandler = GetInputSlot(i).GetConnectedOutputSlot()->GetOutputHandler(); + MergerLayer* currentLayer = m_MergerLayers.front(); + ITensorHandle* parentTensor = currentLayer->GetOutputHandler(0).GetData(); - outputHandler.SetData(factory.CreateSubTensorHandle(*m_OutputHandlers[0].GetData(), - outputHandler.GetTensorInfo().GetShape(), - m_Param.GetViewOrigin(i))); + m_MergerLayers.pop(); + + const unsigned int numInputSlots = currentLayer->GetNumInputSlots(); + for (unsigned int i = 0; i < numInputSlots; ++i) + { + OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot(); + OutputHandler& outputHandler = slot->GetOutputHandler(); + outputHandler.SetData(factory.CreateSubTensorHandle(*parentTensor, + outputHandler.GetTensorInfo().GetShape(), + currentLayer->m_Param.GetViewOrigin(i))); + + Layer& inputLayer = slot->GetOwningLayer(); + if (inputLayer.GetType() == LayerType::Merger) + { + m_MergerLayers.push(boost::polymorphic_downcast<MergerLayer*>(&inputLayer)); + } + } } } } @@ -568,12 +587,36 @@ MultiplicationLayer* MultiplicationLayer::Clone(Graph& graph) const void MultiplicationLayer::ValidateTensorShapesFromInputs() { - ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() == - GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(), - "MultiplicationLayer: Inputs must match"); + auto& input0 = GetInputSlot(0).GetConnection()->GetTensorInfo(); + auto& input1 = GetInputSlot(1).GetConnection()->GetTensorInfo(); + + // Get the max of the inputs + BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); + unsigned int numDims = input0.GetNumDimensions(); + std::vector<unsigned int> dims(numDims); + + // validate inputs are broadcast compatible +#if !NDEBUG + for (unsigned int i = 0; i < numDims; i++) + { + unsigned int dim0 = input0.GetShape()[i]; + unsigned int dim1 = input1.GetShape()[i]; + if (dim0 != dim1) + { + BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); + } + } +#endif - TensorInfo infoOut(GetInputSlot(0).GetConnection()->GetTensorInfo()); - ConditionalThrow<LayerValidationException>(GetOutputSlot(0).ValidateTensorShape(infoOut.GetShape()), + for (unsigned int i = 0; i < numDims; i++) + { + unsigned int dim0 = input0.GetShape()[i]; + unsigned int dim1 = input1.GetShape()[i]; + dims[i] = std::max(dim0, dim1); + } + + TensorShape outShape(numDims, dims.data()); + ConditionalThrow<LayerValidationException>(GetOutputSlot(0).ValidateTensorShape(outShape), "MultiplicationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape."); } |