aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/GatherLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/GatherLayer.cpp')
-rw-r--r--src/armnn/layers/GatherLayer.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index 2e5d011599..d7ed4b2542 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -32,6 +32,32 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const
void GatherLayer::ValidateTensorShapesFromInputs()
{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo();
+
+ const unsigned int paramsDim = params.GetNumDimensions();
+ const unsigned int indicesDim = indices.GetNumDimensions();
+ const unsigned int outputDim = paramsDim - 1 + indicesDim;
+
+ std::vector<unsigned int> dimSizes;
+
+ for (unsigned int i = 0; i < indicesDim; ++i)
+ {
+ dimSizes.push_back(indices.GetShape()[i]);
+ }
+ for (unsigned int i = 1; i < paramsDim; ++i)
+ {
+ dimSizes.push_back(params.GetShape()[i]);
+ }
+
+ const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data());
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShape);
}
} // namespace armnn