aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Layers.cpp')
-rw-r--r--src/armnn/Layers.cpp69
1 files changed, 56 insertions, 13 deletions
diff --git a/src/armnn/Layers.cpp b/src/armnn/Layers.cpp
index ddbc7d222c..48a02aba9c 100644
--- a/src/armnn/Layers.cpp
+++ b/src/armnn/Layers.cpp
@@ -11,6 +11,8 @@
#include "Permute.hpp"
+#include <queue>
+
namespace armnn
{
@@ -21,6 +23,7 @@ LayerType* Layer::CloneBase(Graph& graph, Params&& ... params) const
LayerType* const layer = graph.AddLayer<LayerType>(std::forward<Params>(params)...);
layer->SetComputeDevice(m_ComputeDevice);
+ layer->SetGuid(GetGuid());
return layer;
}
@@ -82,12 +85,11 @@ void AdditionLayer::ValidateTensorShapesFromInputs()
unsigned int dim1 = input1.GetShape()[i];
if (dim0 != dim1)
{
- BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be one length");
+ BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1.");
}
}
#endif
-
for (unsigned int i = 0; i < numDims; i++)
{
unsigned int dim0 = input0.GetShape()[i];
@@ -439,14 +441,31 @@ void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fact
m_OutputHandlers[0].CreateTensorHandles(factory);
if (factory.SupportsSubTensors())
{
- const unsigned int numInputSlots = GetNumInputSlots();
- for (unsigned int i = 0; i < numInputSlots; ++i)
+ std::queue<MergerLayer*> m_MergerLayers;
+
+ m_MergerLayers.push(this);
+ while (!m_MergerLayers.empty())
{
- OutputHandler& outputHandler = GetInputSlot(i).GetConnectedOutputSlot()->GetOutputHandler();
+ MergerLayer* currentLayer = m_MergerLayers.front();
+ ITensorHandle* parentTensor = currentLayer->GetOutputHandler(0).GetData();
- outputHandler.SetData(factory.CreateSubTensorHandle(*m_OutputHandlers[0].GetData(),
- outputHandler.GetTensorInfo().GetShape(),
- m_Param.GetViewOrigin(i)));
+ m_MergerLayers.pop();
+
+ const unsigned int numInputSlots = currentLayer->GetNumInputSlots();
+ for (unsigned int i = 0; i < numInputSlots; ++i)
+ {
+ OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
+ OutputHandler& outputHandler = slot->GetOutputHandler();
+ outputHandler.SetData(factory.CreateSubTensorHandle(*parentTensor,
+ outputHandler.GetTensorInfo().GetShape(),
+ currentLayer->m_Param.GetViewOrigin(i)));
+
+ Layer& inputLayer = slot->GetOwningLayer();
+ if (inputLayer.GetType() == LayerType::Merger)
+ {
+ m_MergerLayers.push(boost::polymorphic_downcast<MergerLayer*>(&inputLayer));
+ }
+ }
}
}
}
@@ -568,12 +587,36 @@ MultiplicationLayer* MultiplicationLayer::Clone(Graph& graph) const
void MultiplicationLayer::ValidateTensorShapesFromInputs()
{
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() ==
- GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
- "MultiplicationLayer: Inputs must match");
+ auto& input0 = GetInputSlot(0).GetConnection()->GetTensorInfo();
+ auto& input1 = GetInputSlot(1).GetConnection()->GetTensorInfo();
+
+ // Get the max of the inputs
+ BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
+ unsigned int numDims = input0.GetNumDimensions();
+ std::vector<unsigned int> dims(numDims);
+
+ // validate inputs are broadcast compatible
+#if !NDEBUG
+ for (unsigned int i = 0; i < numDims; i++)
+ {
+ unsigned int dim0 = input0.GetShape()[i];
+ unsigned int dim1 = input1.GetShape()[i];
+ if (dim0 != dim1)
+ {
+ BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1.");
+ }
+ }
+#endif
- TensorInfo infoOut(GetInputSlot(0).GetConnection()->GetTensorInfo());
- ConditionalThrow<LayerValidationException>(GetOutputSlot(0).ValidateTensorShape(infoOut.GetShape()),
+ for (unsigned int i = 0; i < numDims; i++)
+ {
+ unsigned int dim0 = input0.GetShape()[i];
+ unsigned int dim1 = input1.GetShape()[i];
+ dims[i] = std::max(dim0, dim1);
+ }
+
+ TensorShape outShape(numDims, dims.data());
+ ConditionalThrow<LayerValidationException>(GetOutputSlot(0).ValidateTensorShape(outShape),
"MultiplicationLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.");
}