aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ReverseV2Layer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ReverseV2Layer.cpp')
-rw-r--r--src/armnn/layers/ReverseV2Layer.cpp38
1 files changed, 30 insertions, 8 deletions
diff --git a/src/armnn/layers/ReverseV2Layer.cpp b/src/armnn/layers/ReverseV2Layer.cpp
index 201e19819b..e1160b6e16 100644
--- a/src/armnn/layers/ReverseV2Layer.cpp
+++ b/src/armnn/layers/ReverseV2Layer.cpp
@@ -10,9 +10,10 @@
namespace armnn
{
-ReverseV2Layer::ReverseV2Layer(const armnn::ReverseV2Descriptor &param, const char *name)
- : LayerWithParameters(1, 1, LayerType::ReverseV2, param, name)
-{}
+ReverseV2Layer::ReverseV2Layer(const char* name)
+ : Layer(2, 1, LayerType::ReverseV2, name)
+{
+}
std::unique_ptr<IWorkload> ReverseV2Layer::CreateWorkload(const armnn::IWorkloadFactory &factory) const
{
@@ -24,27 +25,48 @@ std::unique_ptr<IWorkload> ReverseV2Layer::CreateWorkload(const armnn::IWorkload
ReverseV2Layer* ReverseV2Layer::Clone(armnn::Graph &graph) const
{
- auto layer = CloneBase<ReverseV2Layer>(graph, m_Param, GetName());
+ auto layer = CloneBase<ReverseV2Layer>(graph, GetName());
return std::move(layer);
}
-/// Use the default Layer::InferOutputShape method
+std::vector<TensorShape> ReverseV2Layer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ ARMNN_ASSERT(inputShapes.size() == 2);
+
+ const auto inputDims = inputShapes[0].GetNumDimensions();
+
+ std::vector<unsigned int> dimSizes(inputDims);
+ for (unsigned i=0; i<inputDims; i++)
+ {
+ dimSizes[i] = inputShapes[0][i];
+ }
+
+ TensorShape outputShape({ inputDims, dimSizes.data() });
+
+ return std::vector<TensorShape>({ outputShape });
+}
void ReverseV2Layer::ValidateTensorShapesFromInputs()
{
- VerifyLayerConnections(1, CHECK_LOCATION());
+ VerifyLayerConnections(2, CHECK_LOCATION());
const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
auto inferredShapes = InferOutputShapes({
- GetInputSlot(0).GetTensorInfo().GetShape() });
+ GetInputSlot(0).GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetTensorInfo().GetShape()});
ARMNN_ASSERT(inferredShapes.size() == 1);
ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "ReverseV2Layer");
}
-} \ No newline at end of file
+void ReverseV2Layer::ExecuteStrategy(IStrategy& strategy) const
+{
+ strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName());
+}
+
+}