From fca5916e4e6a44cf11b47328659d4d7ee95ec231 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Tue, 8 Aug 2023 12:00:28 +0100 Subject: MLCE-1093 Added Axis to ViewsDescriptor * Added Axis to ViewsDescriptor to store the value where ever possible. * Updated Serializer and Deserializer to handle axis. Signed-off-by: Mike Kelly Change-Id: I56e442872b47485a608b25fbc79063b362a25618 --- include/armnn/Descriptors.hpp | 12 ++++++++++++ src/armnn/Descriptors.cpp | 20 +++++++++++++++++++- src/armnn/SerializeLayerParameters.cpp | 4 ++++ src/armnnSerializer/ArmnnSchema.fbs | 2 ++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index f1ac17f4c6..f60e8f3bea 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -276,9 +276,21 @@ struct ViewsDescriptor : BaseDescriptor /// Swap the ViewsDescriptor value first and second. friend void swap(ViewsDescriptor& first, ViewsDescriptor& second); + + /// Set the axis value. + void SetAxis(int32_t axis); + + /// Get the axis value. + int32_t GetAxis() const; + + /// Returns true if an axis has been set. + bool HasAxis() const; + private: OriginsDescriptor m_Origins; uint32_t** m_ViewSizes; + bool m_IsAxisSet = false; + int32_t m_Axis = 0; }; diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp index a1419cfbf7..e6374aea8f 100644 --- a/src/armnn/Descriptors.cpp +++ b/src/armnn/Descriptors.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "armnn/Descriptors.hpp" @@ -363,6 +363,24 @@ void swap(ViewsDescriptor& first, ViewsDescriptor& second) swap(first.m_ViewSizes, second.m_ViewSizes); } +void ViewsDescriptor::SetAxis(int32_t axis) +{ + m_Axis = axis; + m_IsAxisSet = true; +} + +/// Get the axis value. +int32_t ViewsDescriptor::GetAxis() const +{ + return m_Axis; +} + +/// Returns true if an axis has been set. +bool ViewsDescriptor::HasAxis() const +{ + return m_IsAxisSet; +} + int StridedSliceDescriptor::GetStartForAxis(const TensorShape& inputShape, unsigned int axis) const { diff --git a/src/armnn/SerializeLayerParameters.cpp b/src/armnn/SerializeLayerParameters.cpp index 1445c70a70..d65a7d55fa 100644 --- a/src/armnn/SerializeLayerParameters.cpp +++ b/src/armnn/SerializeLayerParameters.cpp @@ -636,6 +636,10 @@ void StringifyLayerParameters::Serialize(ParameterStringifyFunc } value << "]"; fn(key.str(), value.str()); + if (desc.HasAxis()) + { + fn("Axis", std::to_string(desc.GetAxis())); + } } StringifyLayerParameters::Serialize(fn, desc.GetOrigins()); } diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index bd0bd0380d..ec4b48639d 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -774,6 +774,8 @@ table OriginsDescriptor { table ViewsDescriptor { origins:OriginsDescriptor; viewSizes:[UintVector]; + hasAxis:bool; + axis:int; } table SplitterLayer { -- cgit v1.2.1