diff options
Diffstat (limited to 'src/armnn/layers/ReverseV2Layer.cpp')
-rw-r--r-- | src/armnn/layers/ReverseV2Layer.cpp | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnn/layers/ReverseV2Layer.cpp b/src/armnn/layers/ReverseV2Layer.cpp index 201e19819b..e1160b6e16 100644 --- a/src/armnn/layers/ReverseV2Layer.cpp +++ b/src/armnn/layers/ReverseV2Layer.cpp @@ -10,9 +10,10 @@ namespace armnn { -ReverseV2Layer::ReverseV2Layer(const armnn::ReverseV2Descriptor ¶m, const char *name) - : LayerWithParameters(1, 1, LayerType::ReverseV2, param, name) -{} +ReverseV2Layer::ReverseV2Layer(const char* name) + : Layer(2, 1, LayerType::ReverseV2, name) +{ +} std::unique_ptr<IWorkload> ReverseV2Layer::CreateWorkload(const armnn::IWorkloadFactory &factory) const { @@ -24,27 +25,48 @@ std::unique_ptr<IWorkload> ReverseV2Layer::CreateWorkload(const armnn::IWorkload ReverseV2Layer* ReverseV2Layer::Clone(armnn::Graph &graph) const { - auto layer = CloneBase<ReverseV2Layer>(graph, m_Param, GetName()); + auto layer = CloneBase<ReverseV2Layer>(graph, GetName()); return std::move(layer); } -/// Use the default Layer::InferOutputShape method +std::vector<TensorShape> ReverseV2Layer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + ARMNN_ASSERT(inputShapes.size() == 2); + + const auto inputDims = inputShapes[0].GetNumDimensions(); + + std::vector<unsigned int> dimSizes(inputDims); + for (unsigned i=0; i<inputDims; i++) + { + dimSizes[i] = inputShapes[0][i]; + } + + TensorShape outputShape({ inputDims, dimSizes.data() }); + + return std::vector<TensorShape>({ outputShape }); +} void ReverseV2Layer::ValidateTensorShapesFromInputs() { - VerifyLayerConnections(1, CHECK_LOCATION()); + VerifyLayerConnections(2, CHECK_LOCATION()); const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); auto inferredShapes = InferOutputShapes({ - GetInputSlot(0).GetTensorInfo().GetShape() }); + GetInputSlot(0).GetTensorInfo().GetShape(), + GetInputSlot(1).GetTensorInfo().GetShape()}); ARMNN_ASSERT(inferredShapes.size() == 1); ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "ReverseV2Layer"); } -}
\ No newline at end of file +void ReverseV2Layer::ExecuteStrategy(IStrategy& strategy) const +{ + strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName()); +} + +} |