aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/FillLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/FillLayer.cpp')
-rw-r--r--src/armnn/layers/FillLayer.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/armnn/layers/FillLayer.cpp b/src/armnn/layers/FillLayer.cpp
index 03f93f76da..eb9f6af800 100644
--- a/src/armnn/layers/FillLayer.cpp
+++ b/src/armnn/layers/FillLayer.cpp
@@ -33,19 +33,21 @@ void FillLayer::ValidateTensorShapesFromInputs()
{
VerifyLayerConnections(1, CHECK_LOCATION());
- auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
+ auto inferredShapes = InferOutputShapes( { GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() });
ARMNN_ASSERT(inferredShapes.size() == 1);
+ // Cannot validate the output shape from the input shape. but we can validate that the correct dims have been
+ // inferred
ConditionalThrowIfNotEqual<LayerValidationException>(
"FillLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
- GetOutputSlot(0).GetTensorInfo().GetShape(),
- inferredShapes[0]);
+ GetOutputSlot(0).GetTensorInfo().GetNumDimensions(),
+ inferredShapes[0][0]);
}
void FillLayer::Accept(ILayerVisitor& visitor) const
{
- visitor.VisitGatherLayer(this, GetName());
+ visitor.VisitFillLayer(this, GetParameters(), GetName());
}
} // namespace armnn