diff options
author | narpra01 <narumol.prangnawarat@arm.com> | 2019-01-16 17:22:19 +0000 |
---|---|---|
committer | Aron Virginas-Tar <aron.virginas-tar@arm.com> | 2019-01-22 17:08:42 +0000 |
commit | 33f8e3b6c71070fd867809ca6934069a950081dc (patch) | |
tree | 7738f96c7108cd38eb06f780f08d1e4bfb46c080 /src/armnn/layers | |
parent | 649dd9515ddf4bd00a0bff64d51dfd835a6c7b39 (diff) | |
download | armnn-33f8e3b6c71070fd867809ca6934069a950081dc.tar.gz |
IVGCVSW-2509 Add GatherLayer implementation
* implementation of ValidateTensorShapesFromInputs
* unit tests
Change-Id: I1ed88f8ba0ea20329a259c5f36caea4b1fbeb013
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/GatherLayer.cpp | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp index 2e5d011599..d7ed4b2542 100644 --- a/src/armnn/layers/GatherLayer.cpp +++ b/src/armnn/layers/GatherLayer.cpp @@ -32,6 +32,32 @@ GatherLayer* GatherLayer::Clone(Graph& graph) const void GatherLayer::ValidateTensorShapesFromInputs() { + VerifyLayerConnections(2, CHECK_LOCATION()); + + const TensorInfo& params = GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& indices = GetInputSlot(1).GetConnection()->GetTensorInfo(); + + const unsigned int paramsDim = params.GetNumDimensions(); + const unsigned int indicesDim = indices.GetNumDimensions(); + const unsigned int outputDim = paramsDim - 1 + indicesDim; + + std::vector<unsigned int> dimSizes; + + for (unsigned int i = 0; i < indicesDim; ++i) + { + dimSizes.push_back(indices.GetShape()[i]); + } + for (unsigned int i = 1; i < paramsDim; ++i) + { + dimSizes.push_back(params.GetShape()[i]); + } + + const TensorShape& inferredShape = TensorShape(outputDim, dimSizes.data()); + + ConditionalThrowIfNotEqual<LayerValidationException>( + "GatherLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", + GetOutputSlot(0).GetTensorInfo().GetShape(), + inferredShape); } } // namespace armnn |