aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/armnn/Descriptors.hpp12
-rw-r--r--src/armnn/Descriptors.cpp20
-rw-r--r--src/armnn/SerializeLayerParameters.cpp4
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs2
4 files changed, 37 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index f1ac17f4c6..f60e8f3bea 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -276,9 +276,21 @@ struct ViewsDescriptor : BaseDescriptor
/// Swap the ViewsDescriptor value first and second.
friend void swap(ViewsDescriptor& first, ViewsDescriptor& second);
+
+ /// Set the axis value.
+ void SetAxis(int32_t axis);
+
+ /// Get the axis value.
+ int32_t GetAxis() const;
+
+ /// Returns true if an axis has been set.
+ bool HasAxis() const;
+
private:
OriginsDescriptor m_Origins;
uint32_t** m_ViewSizes;
+ bool m_IsAxisSet = false;
+ int32_t m_Axis = 0;
};
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index a1419cfbf7..e6374aea8f 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "armnn/Descriptors.hpp"
@@ -363,6 +363,24 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second)
swap(first.m_ViewSizes, second.m_ViewSizes);
}
+void ViewsDescriptor::SetAxis(int32_t axis)
+{
+ m_Axis = axis;
+ m_IsAxisSet = true;
+}
+
+/// Get the axis value.
+int32_t ViewsDescriptor::GetAxis() const
+{
+ return m_Axis;
+}
+
+/// Returns true if an axis has been set.
+bool ViewsDescriptor::HasAxis() const
+{
+ return m_IsAxisSet;
+}
+
int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape,
unsigned int axis) const
{
diff --git a/src/armnn/SerializeLayerParameters.cpp b/src/armnn/SerializeLayerParameters.cpp
index 1445c70a70..d65a7d55fa 100644
--- a/src/armnn/SerializeLayerParameters.cpp
+++ b/src/armnn/SerializeLayerParameters.cpp
@@ -636,6 +636,10 @@ void StringifyLayerParameters<ViewsDescriptor>::Serialize(ParameterStringifyFunc
}
value << "]";
fn(key.str(), value.str());
+ if (desc.HasAxis())
+ {
+ fn("Axis", std::to_string(desc.GetAxis()));
+ }
}
StringifyLayerParameters<OriginsDescriptor>::Serialize(fn, desc.GetOrigins());
}
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index bd0bd0380d..ec4b48639d 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -774,6 +774,8 @@ table OriginsDescriptor {
table ViewsDescriptor {
origins:OriginsDescriptor;
viewSizes:[UintVector];
+ hasAxis:bool;
+ axis:int;
}
table SplitterLayer {