diff options
Diffstat (limited to 'src/armnn/layers/FillLayer.cpp')
-rw-r--r-- | src/armnn/layers/FillLayer.cpp | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/src/armnn/layers/FillLayer.cpp b/src/armnn/layers/FillLayer.cpp index 03f93f76da..eb9f6af800 100644 --- a/src/armnn/layers/FillLayer.cpp +++ b/src/armnn/layers/FillLayer.cpp @@ -33,19 +33,21 @@ void FillLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(1, CHECK_LOCATION()); - auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + auto inferredShapes = InferOutputShapes( { GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); ARMNN_ASSERT(inferredShapes.size() == 1); + // Cannot validate the output shape from the input shape. but we can validate that the correct dims have been + // inferred ConditionalThrowIfNotEqual<LayerValidationException>( "FillLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", - GetOutputSlot(0).GetTensorInfo().GetShape(), - inferredShapes[0]); + GetOutputSlot(0).GetTensorInfo().GetNumDimensions(), + inferredShapes[0][0]); } void FillLayer::Accept(ILayerVisitor& visitor) const { - visitor.VisitGatherLayer(this, GetName()); + visitor.VisitFillLayer(this, GetParameters(), GetName()); } } // namespace armnn |