aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ResizeBilinearLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ResizeBilinearLayer.cpp')
-rw-r--r--src/armnn/layers/ResizeBilinearLayer.cpp24
1 files changed, 16 insertions, 8 deletions
diff --git a/src/armnn/layers/ResizeBilinearLayer.cpp b/src/armnn/layers/ResizeBilinearLayer.cpp
index 204d5afae8..6477fa375a 100644
--- a/src/armnn/layers/ResizeBilinearLayer.cpp
+++ b/src/armnn/layers/ResizeBilinearLayer.cpp
@@ -30,23 +30,31 @@ ResizeBilinearLayer* ResizeBilinearLayer::Clone(Graph& graph) const
return CloneBase<ResizeBilinearLayer>(graph, m_Param, GetName());
}
-void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> ResizeBilinearLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr,
- "MemCopyLayer: InputSlot must be connected to an OutputSlot");
- ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(),
- "MemCopyLayer: TensorInfo must be set on connected OutputSlot.");
+ BOOST_ASSERT(inputShapes.size() == 1);
+ const TensorShape& inputShape = inputShapes[0];
- const TensorShape& inputShape = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape();
unsigned int outWidth = m_Param.m_TargetWidth;
unsigned int outHeight = m_Param.m_TargetHeight;
unsigned int outChannels = inputShape[1];
unsigned int outBatch = inputShape[0];
- TensorShape outShape({outBatch, outChannels, outHeight, outWidth});
+
+ return std::vector<TensorShape>({ TensorShape({outBatch, outChannels, outHeight, outWidth}) });
+}
+
+void ResizeBilinearLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(1, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
ConditionalThrowIfNotEqual<LayerValidationException>(
"ResizeBilinearLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
GetOutputSlot(0).GetTensorInfo().GetShape(),
- outShape);
+ inferredShapes[0]);
}
} // namespace armnn