diff options
-rw-r--r-- | include/armnn/Descriptors.hpp | 12 | ||||
-rw-r--r-- | src/armnn/Descriptors.cpp | 20 | ||||
-rw-r--r-- | src/armnn/SerializeLayerParameters.cpp | 4 | ||||
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 2 |
4 files changed, 37 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index f1ac17f4c6..f60e8f3bea 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -276,9 +276,21 @@ struct ViewsDescriptor : BaseDescriptor /// Swap the ViewsDescriptor value first and second. friend void swap(ViewsDescriptor& first, ViewsDescriptor& second); + + /// Set the axis value. + void SetAxis(int32_t axis); + + /// Get the axis value. + int32_t GetAxis() const; + + /// Returns true if an axis has been set. + bool HasAxis() const; + private: OriginsDescriptor m_Origins; uint32_t** m_ViewSizes; + bool m_IsAxisSet = false; + int32_t m_Axis = 0; }; diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp index a1419cfbf7..e6374aea8f 100644 --- a/src/armnn/Descriptors.cpp +++ b/src/armnn/Descriptors.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "armnn/Descriptors.hpp" @@ -363,6 +363,24 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second) swap(first.m_ViewSizes, second.m_ViewSizes); } +void ViewsDescriptor::SetAxis(int32_t axis) +{ + m_Axis = axis; + m_IsAxisSet = true; +} + +/// Get the axis value. +int32_t ViewsDescriptor::GetAxis() const +{ + return m_Axis; +} + +/// Returns true if an axis has been set. +bool ViewsDescriptor::HasAxis() const +{ + return m_IsAxisSet; +} + int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape, unsigned int axis) const { diff --git a/src/armnn/SerializeLayerParameters.cpp b/src/armnn/SerializeLayerParameters.cpp index 1445c70a70..d65a7d55fa 100644 --- a/src/armnn/SerializeLayerParameters.cpp +++ b/src/armnn/SerializeLayerParameters.cpp @@ -636,6 +636,10 @@ void StringifyLayerParameters<ViewsDescriptor>::Serialize(ParameterStringifyFunc } value << "]"; fn(key.str(), value.str()); + if (desc.HasAxis()) + { + fn("Axis", std::to_string(desc.GetAxis())); + } } StringifyLayerParameters<OriginsDescriptor>::Serialize(fn, desc.GetOrigins()); } diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index bd0bd0380d..ec4b48639d 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -774,6 +774,8 @@ table OriginsDescriptor { table ViewsDescriptor { origins:OriginsDescriptor; viewSizes:[UintVector]; + hasAxis:bool; + axis:int; } table SplitterLayer { |