aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SplitterLayer.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/layers/SplitterLayer.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/layers/SplitterLayer.cpp')
-rw-r--r--src/armnn/layers/SplitterLayer.cpp32
1 files changed, 26 insertions, 6 deletions
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 630921e4d8..5e737a245e 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -22,7 +22,7 @@ std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const Graph& graph, con
{
SplitterQueueDescriptor descriptor;
- // copy the window origins to the descriptor
+ // Copies the window origins to the descriptor.
for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
{
descriptor.m_ViewOrigins.emplace_back(
@@ -34,14 +34,14 @@ std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const Graph& graph, con
void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
{
- //if sub tensors are supported than all the "splitter" need to do is to
+ //If sub tensors are supported than all the "splitter" need to do is to
//set the outputs to be appropriate sub tensors of the input.
if (factory.SupportsSubTensors())
{
const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
ITensorHandle* inputData = outputHandler.GetData();
- //create the outputs as subtensors of the input
+ //Creates the outputs as subtensors of the input.
for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
{
m_OutputHandlers[i].SetData(factory.CreateSubTensorHandle(*inputData,
@@ -63,18 +63,38 @@ SplitterLayer* SplitterLayer::Clone(Graph& graph) const
return CloneBase<SplitterLayer>(graph, m_Param, GetName());
}
-void SplitterLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> SplitterLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
+ BOOST_ASSERT(inputShapes.size() == m_Param.GetNumViews());
+ std::vector<TensorShape> outShapes;
//Output shapes must match View shapes.
for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
{
const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
+ outShapes.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
+ }
+ return outShapes;
+}
+
+void SplitterLayer::ValidateTensorShapesFromInputs()
+{
+ std::vector<TensorShape> views;
+ for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
+ {
+ const uint32_t* sizes = m_Param.GetViewSizes(viewIdx);
+ views.push_back(TensorShape(m_Param.GetNumDimensions(), sizes));
+ }
+
+ auto inferredShapes = InferOutputShapes(views);
- TensorShape outShape(m_Param.GetNumDimensions(), sizes);
+ BOOST_ASSERT(inferredShapes.size() == m_Param.GetNumViews());
+
+ for (unsigned int viewIdx = 0; viewIdx < m_Param.GetNumViews(); viewIdx++)
+ {
ConditionalThrowIfNotEqual<LayerValidationException>(
"SplitterLayer: View sizes must match output tensor shapes.",
GetOutputSlot(viewIdx).GetTensorInfo().GetShape(),
- outShape);
+ inferredShapes[viewIdx]);
}
}