diff options
Diffstat (limited to 'src/armnn/layers/AdditionLayer.cpp')
-rw-r--r-- | src/armnn/layers/AdditionLayer.cpp | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/src/armnn/layers/AdditionLayer.cpp b/src/armnn/layers/AdditionLayer.cpp index 85d12eabcb..ab73a918db 100644 --- a/src/armnn/layers/AdditionLayer.cpp +++ b/src/armnn/layers/AdditionLayer.cpp @@ -28,41 +28,51 @@ AdditionLayer* AdditionLayer::Clone(Graph& graph) const return CloneBase<AdditionLayer>(graph, GetName()); } -void AdditionLayer::ValidateTensorShapesFromInputs() +std::vector<TensorShape> AdditionLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - auto& input0 = GetInputSlot(0).GetConnection()->GetTensorInfo(); - auto& input1 = GetInputSlot(1).GetConnection()->GetTensorInfo(); + BOOST_ASSERT(inputShapes.size() == 2); + auto& input0 = inputShapes[0]; + auto& input1 = inputShapes[1]; - // Get the max of the inputs + // Get the max of the inputs. BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions()); unsigned int numDims = input0.GetNumDimensions(); std::vector<unsigned int> dims(numDims); - // validate inputs are broadcast compatible -#if !NDEBUG for (unsigned int i = 0; i < numDims; i++) { - unsigned int dim0 = input0.GetShape()[i]; - unsigned int dim1 = input1.GetShape()[i]; + unsigned int dim0 = input0[i]; + unsigned int dim1 = input1[i]; + + // Validates inputs are broadcast compatible. +#if !NDEBUG if (dim0 != dim1) { BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1."); } - } #endif - for (unsigned int i = 0; i < numDims; i++) - { - unsigned int dim0 = input0.GetShape()[i]; - unsigned int dim1 = input1.GetShape()[i]; dims[i] = std::max(dim0, dim1); } - TensorShape outShape(numDims, dims.data()); + return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) }); +} + +void AdditionLayer::ValidateTensorShapesFromInputs() +{ + VerifyLayerConnections(2, CHECK_LOCATION()); + + auto inferredShapes = InferOutputShapes({ + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() + }); + + BOOST_ASSERT(inferredShapes.size() == 1); + ConditionalThrowIfNotEqual<LayerValidationException>( "AdditionLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", GetOutputSlot(0).GetTensorInfo().GetShape(), - outShape); + inferredShapes[0]); } } // namespace armnn |