diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-10-15 17:35:36 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-10-16 09:39:56 +0000 |
commit | 6fe5247f8997a04edfdd7c974c96a0a086ef3ab5 (patch) | |
tree | 52d6cc314797f7bf138a0b2d81491543e05b6900 /include | |
parent | 20bea0071d507772e303eb6f1c476bf1feac9be5 (diff) | |
download | armnn-6fe5247f8997a04edfdd7c974c96a0a086ef3ab5.tar.gz |
IVGCVSW-3991 Make Descriptor objects comparable and refactor LayerVisitor tests
* Implemented operator==() for Descriptor structs
* Refactored TestNameAndDescriptorLayerVisitor to eliminate code duplication
by using templates and taking advantage of the fact that descriptor objects
can now all be compared the same way using ==
* Cleaned up TestNameOnlylayerVisitor by moving all test cases for layers
that require a descriptor to TestNameAndDescriptorLayerVisitor
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Iee38b04d68d34a5f4ec7e5790de39ecb7ab0fb80
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 435 |
1 files changed, 326 insertions, 109 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index e2e59741a3..92e842b2c1 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -19,7 +19,16 @@ namespace armnn /// An ActivationDescriptor for the ActivationLayer. struct ActivationDescriptor { - ActivationDescriptor() : m_Function(ActivationFunction::Sigmoid), m_A(0), m_B(0) {} + ActivationDescriptor() + : m_Function(ActivationFunction::Sigmoid) + , m_A(0) + , m_B(0) + {} + + bool operator ==(const ActivationDescriptor &rhs) const + { + return m_Function == rhs.m_Function && m_A == rhs.m_B && m_B == rhs.m_B; + } /// @brief The activation function to use /// (Sigmoid, TanH, Linear, ReLu, BoundedReLu, SoftReLu, LeakyReLu, Abs, Sqrt, Square). @@ -34,10 +43,15 @@ struct ActivationDescriptor struct ArgMinMaxDescriptor { ArgMinMaxDescriptor() - : m_Function(ArgMinMaxFunction::Min) - , m_Axis(-1) + : m_Function(ArgMinMaxFunction::Min) + , m_Axis(-1) {} + bool operator ==(const ArgMinMaxDescriptor &rhs) const + { + return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis; + } + /// Specify if the function is to find Min or Max. ArgMinMaxFunction m_Function; /// Axis to reduce across the input tensor. @@ -49,12 +63,17 @@ struct PermuteDescriptor { PermuteDescriptor() : m_DimMappings{} - { - } + {} + PermuteDescriptor(const PermutationVector& dimMappings) : m_DimMappings(dimMappings) + {} + + bool operator ==(const PermuteDescriptor &rhs) const { + return m_DimMappings.IsEqual(rhs.m_DimMappings); } + /// @brief Indicates how to translate tensor elements from a given source into the target destination, when /// source and target potentially have different memory layouts e.g. {0U, 3U, 1U, 2U}. PermutationVector m_DimMappings; @@ -64,10 +83,15 @@ struct PermuteDescriptor struct SoftmaxDescriptor { SoftmaxDescriptor() - : m_Beta(1.0f) - , m_Axis(-1) + : m_Beta(1.0f) + , m_Axis(-1) {} + bool operator ==(const SoftmaxDescriptor& rhs) const + { + return m_Beta == rhs.m_Beta && m_Axis == rhs.m_Axis; + } + /// Exponentiation value. float m_Beta; /// Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed on. @@ -91,6 +115,8 @@ struct OriginsDescriptor OriginsDescriptor& operator=(OriginsDescriptor rhs); + bool operator ==(const OriginsDescriptor& rhs) const; + /// @Brief Set the view origin coordinates. The arguments are: view, dimension, value. /// If the view is greater than or equal to GetNumViews(), then the view argument is out of range. /// If the coord is greater than or equal to GetNumDimensions(), then the coord argument is out of range. @@ -131,6 +157,9 @@ struct ViewsDescriptor ~ViewsDescriptor(); ViewsDescriptor& operator=(ViewsDescriptor rhs); + + bool operator ==(const ViewsDescriptor& rhs) const; + /// @Brief Set the view origin coordinates. The arguments are: view, dimension, value. /// If the view is greater than or equal to GetNumViews(), then the view argument is out of range. /// If the coord is greater than or equal to GetNumDimensions(), then the coord argument is out of range. @@ -244,20 +273,36 @@ OriginsDescriptor CreateDescriptorForConcatenation(TensorShapeIt first, struct Pooling2dDescriptor { Pooling2dDescriptor() - : m_PoolType(PoolingAlgorithm::Max) - , m_PadLeft(0) - , m_PadRight(0) - , m_PadTop(0) - , m_PadBottom(0) - , m_PoolWidth(0) - , m_PoolHeight(0) - , m_StrideX(0) - , m_StrideY(0) - , m_OutputShapeRounding(OutputShapeRounding::Floor) - , m_PaddingMethod(PaddingMethod::Exclude) - , m_DataLayout(DataLayout::NCHW) + : m_PoolType(PoolingAlgorithm::Max) + , m_PadLeft(0) + , m_PadRight(0) + , m_PadTop(0) + , m_PadBottom(0) + , m_PoolWidth(0) + , m_PoolHeight(0) + , m_StrideX(0) + , m_StrideY(0) + , m_OutputShapeRounding(OutputShapeRounding::Floor) + , m_PaddingMethod(PaddingMethod::Exclude) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const Pooling2dDescriptor& rhs) const + { + return m_PoolType == rhs.m_PoolType && + m_PadLeft == rhs.m_PadLeft && + m_PadRight == rhs.m_PadRight && + m_PadTop == rhs.m_PadTop && + m_PadBottom == rhs.m_PadBottom && + m_PoolWidth == rhs.m_PoolWidth && + m_PoolHeight == rhs.m_PoolHeight && + m_StrideX == rhs.m_StrideX && + m_StrideY == rhs.m_StrideY && + m_OutputShapeRounding == rhs.m_OutputShapeRounding && + m_PaddingMethod == rhs.m_PaddingMethod && + m_DataLayout == rhs.m_DataLayout; + } + /// The pooling algorithm to use (Max. Average, L2). PoolingAlgorithm m_PoolType; /// Padding left value in the width dimension. @@ -288,10 +333,15 @@ struct Pooling2dDescriptor struct FullyConnectedDescriptor { FullyConnectedDescriptor() - : m_BiasEnabled(false) - , m_TransposeWeightMatrix(false) + : m_BiasEnabled(false) + , m_TransposeWeightMatrix(false) {} + bool operator ==(const FullyConnectedDescriptor& rhs) const + { + return m_BiasEnabled == rhs.m_BiasEnabled && m_TransposeWeightMatrix == rhs.m_TransposeWeightMatrix; + } + /// Enable/disable bias. bool m_BiasEnabled; /// Enable/disable transpose weight matrix. @@ -302,18 +352,32 @@ struct FullyConnectedDescriptor struct Convolution2dDescriptor { Convolution2dDescriptor() - : m_PadLeft(0) - , m_PadRight(0) - , m_PadTop(0) - , m_PadBottom(0) - , m_StrideX(0) - , m_StrideY(0) - , m_DilationX(1) - , m_DilationY(1) - , m_BiasEnabled(false) - , m_DataLayout(DataLayout::NCHW) + : m_PadLeft(0) + , m_PadRight(0) + , m_PadTop(0) + , m_PadBottom(0) + , m_StrideX(0) + , m_StrideY(0) + , m_DilationX(1) + , m_DilationY(1) + , m_BiasEnabled(false) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const Convolution2dDescriptor& rhs) const + { + return m_PadLeft == rhs.m_PadLeft && + m_PadRight == rhs.m_PadRight && + m_PadTop == rhs.m_PadTop && + m_PadBottom == rhs.m_PadBottom && + m_StrideX == rhs.m_StrideX && + m_StrideY == rhs.m_StrideY && + m_DilationX == rhs.m_DilationX && + m_DilationY == rhs.m_DilationY && + m_BiasEnabled == rhs.m_BiasEnabled && + m_DataLayout == rhs.m_DataLayout; + } + /// Padding left value in the width dimension. uint32_t m_PadLeft; /// Padding right value in the width dimension. @@ -340,18 +404,32 @@ struct Convolution2dDescriptor struct DepthwiseConvolution2dDescriptor { DepthwiseConvolution2dDescriptor() - : m_PadLeft(0) - , m_PadRight(0) - , m_PadTop(0) - , m_PadBottom(0) - , m_StrideX(0) - , m_StrideY(0) - , m_DilationX(1) - , m_DilationY(1) - , m_BiasEnabled(false) - , m_DataLayout(DataLayout::NCHW) + : m_PadLeft(0) + , m_PadRight(0) + , m_PadTop(0) + , m_PadBottom(0) + , m_StrideX(0) + , m_StrideY(0) + , m_DilationX(1) + , m_DilationY(1) + , m_BiasEnabled(false) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const DepthwiseConvolution2dDescriptor& rhs) const + { + return m_PadLeft == rhs.m_PadLeft && + m_PadRight == rhs.m_PadRight && + m_PadTop == rhs.m_PadTop && + m_PadBottom == rhs.m_PadBottom && + m_StrideX == rhs.m_StrideX && + m_StrideY == rhs.m_StrideY && + m_DilationX == rhs.m_DilationX && + m_DilationY == rhs.m_DilationY && + m_BiasEnabled == rhs.m_BiasEnabled && + m_DataLayout == rhs.m_DataLayout; + } + /// Padding left value in the width dimension. uint32_t m_PadLeft; /// Padding right value in the width dimension. @@ -377,19 +455,34 @@ struct DepthwiseConvolution2dDescriptor struct DetectionPostProcessDescriptor { DetectionPostProcessDescriptor() - : m_MaxDetections(0) - , m_MaxClassesPerDetection(1) - , m_DetectionsPerClass(1) - , m_NmsScoreThreshold(0) - , m_NmsIouThreshold(0) - , m_NumClasses(0) - , m_UseRegularNms(false) - , m_ScaleX(0) - , m_ScaleY(0) - , m_ScaleW(0) - , m_ScaleH(0) + : m_MaxDetections(0) + , m_MaxClassesPerDetection(1) + , m_DetectionsPerClass(1) + , m_NmsScoreThreshold(0) + , m_NmsIouThreshold(0) + , m_NumClasses(0) + , m_UseRegularNms(false) + , m_ScaleX(0) + , m_ScaleY(0) + , m_ScaleW(0) + , m_ScaleH(0) {} + bool operator ==(const DetectionPostProcessDescriptor& rhs) const + { + return m_MaxDetections == rhs.m_MaxDetections && + m_MaxClassesPerDetection == rhs.m_MaxClassesPerDetection && + m_DetectionsPerClass == rhs.m_DetectionsPerClass && + m_NmsScoreThreshold == rhs.m_NmsScoreThreshold && + m_NmsIouThreshold == rhs.m_NmsIouThreshold && + m_NumClasses == rhs.m_NumClasses && + m_UseRegularNms == rhs.m_UseRegularNms && + m_ScaleX == rhs.m_ScaleX && + m_ScaleY == rhs.m_ScaleY && + m_ScaleW == rhs.m_ScaleW && + m_ScaleH == rhs.m_ScaleH; + } + /// Maximum numbers of detections. uint32_t m_MaxDetections; /// Maximum numbers of classes per detection, used in Fast NMS. @@ -418,15 +511,26 @@ struct DetectionPostProcessDescriptor struct NormalizationDescriptor { NormalizationDescriptor() - : m_NormChannelType(NormalizationAlgorithmChannel::Across) - , m_NormMethodType(NormalizationAlgorithmMethod::LocalBrightness) - , m_NormSize(0) - , m_Alpha(0.f) - , m_Beta(0.f) - , m_K(0.f) - , m_DataLayout(DataLayout::NCHW) + : m_NormChannelType(NormalizationAlgorithmChannel::Across) + , m_NormMethodType(NormalizationAlgorithmMethod::LocalBrightness) + , m_NormSize(0) + , m_Alpha(0.f) + , m_Beta(0.f) + , m_K(0.f) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const NormalizationDescriptor& rhs) const + { + return m_NormChannelType == rhs.m_NormChannelType && + m_NormMethodType == rhs.m_NormMethodType && + m_NormSize == rhs.m_NormSize && + m_Alpha == rhs.m_Alpha && + m_Beta == rhs.m_Beta && + m_K == rhs.m_K && + m_DataLayout == rhs.m_DataLayout; + } + /// Normalization channel algorithm to use (Across, Within). NormalizationAlgorithmChannel m_NormChannelType; /// Normalization method algorithm to use (LocalBrightness, LocalContrast). @@ -447,10 +551,15 @@ struct NormalizationDescriptor struct L2NormalizationDescriptor { L2NormalizationDescriptor() - : m_Eps(1e-12f) - , m_DataLayout(DataLayout::NCHW) + : m_Eps(1e-12f) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const L2NormalizationDescriptor& rhs) const + { + return m_Eps == rhs.m_Eps && m_DataLayout == rhs.m_DataLayout; + } + /// Used to avoid dividing by zero. float m_Eps; /// The data layout to be used (NCHW, NHWC). @@ -461,10 +570,15 @@ struct L2NormalizationDescriptor struct BatchNormalizationDescriptor { BatchNormalizationDescriptor() - : m_Eps(0.0001f) - , m_DataLayout(DataLayout::NCHW) + : m_Eps(0.0001f) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const BatchNormalizationDescriptor& rhs) const + { + return m_Eps == rhs.m_Eps && m_DataLayout == rhs.m_DataLayout; + } + /// Value to add to the variance. Used to avoid dividing by zero. float m_Eps; /// The data layout to be used (NCHW, NHWC). @@ -481,6 +595,14 @@ struct InstanceNormalizationDescriptor , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const InstanceNormalizationDescriptor& rhs) const + { + return m_Gamma == rhs.m_Gamma && + m_Beta == rhs.m_Beta && + m_Eps == rhs.m_Eps && + m_DataLayout == rhs.m_DataLayout; + } + /// Gamma, the scale scalar value applied for the normalized tensor. Defaults to 1.0. float m_Gamma; /// Beta, the offset scalar value applied for the normalized tensor. Defaults to 1.0. @@ -507,6 +629,13 @@ struct BatchToSpaceNdDescriptor , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const BatchToSpaceNdDescriptor& rhs) const + { + return m_BlockShape == rhs.m_BlockShape && + m_Crops == rhs.m_Crops && + m_DataLayout == rhs.m_DataLayout; + } + /// Block shape values. std::vector<unsigned int> m_BlockShape; /// The values to crop from the input dimension. @@ -518,11 +647,16 @@ struct BatchToSpaceNdDescriptor /// A FakeQuantizationDescriptor for the FakeQuantizationLayer. struct FakeQuantizationDescriptor { - FakeQuantizationDescriptor() - : m_Min(-6.0f) - , m_Max(6.0f) + FakeQuantizationDescriptor() + : m_Min(-6.0f) + , m_Max(6.0f) {} + bool operator ==(const FakeQuantizationDescriptor& rhs) const + { + return m_Min == rhs.m_Min && m_Max == rhs.m_Max; + } + /// Minimum value. float m_Min; /// Maximum value. @@ -533,9 +667,9 @@ struct FakeQuantizationDescriptor struct ResizeBilinearDescriptor { ResizeBilinearDescriptor() - : m_TargetWidth(0) - , m_TargetHeight(0) - , m_DataLayout(DataLayout::NCHW) + : m_TargetWidth(0) + , m_TargetHeight(0) + , m_DataLayout(DataLayout::NCHW) {} /// Target width value. @@ -550,12 +684,20 @@ struct ResizeBilinearDescriptor struct ResizeDescriptor { ResizeDescriptor() - : m_TargetWidth(0) - , m_TargetHeight(0) - , m_Method(ResizeMethod::NearestNeighbor) - , m_DataLayout(DataLayout::NCHW) + : m_TargetWidth(0) + , m_TargetHeight(0) + , m_Method(ResizeMethod::NearestNeighbor) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const ResizeDescriptor& rhs) const + { + return m_TargetWidth == rhs.m_TargetWidth && + m_TargetHeight == rhs.m_TargetHeight && + m_Method == rhs.m_Method && + m_DataLayout == rhs.m_DataLayout; + } + /// Target width value. uint32_t m_TargetWidth; /// Target height value. @@ -572,13 +714,18 @@ struct ResizeDescriptor struct ReshapeDescriptor { ReshapeDescriptor() - : m_TargetShape() + : m_TargetShape() {} ReshapeDescriptor(const TensorShape& shape) - : m_TargetShape(shape) + : m_TargetShape(shape) {} + bool operator ==(const ReshapeDescriptor& rhs) const + { + return m_TargetShape == rhs.m_TargetShape; + } + /// Target shape value. TensorShape m_TargetShape; }; @@ -587,18 +734,25 @@ struct ReshapeDescriptor struct SpaceToBatchNdDescriptor { SpaceToBatchNdDescriptor() - : m_BlockShape({1, 1}) - , m_PadList({{0, 0}, {0, 0}}) - , m_DataLayout(DataLayout::NCHW) + : m_BlockShape({1, 1}) + , m_PadList({{0, 0}, {0, 0}}) + , m_DataLayout(DataLayout::NCHW) {} SpaceToBatchNdDescriptor(const std::vector<unsigned int>& blockShape, const std::vector<std::pair<unsigned int, unsigned int>>& padList) - : m_BlockShape(blockShape) - , m_PadList(padList) - , m_DataLayout(DataLayout::NCHW) + : m_BlockShape(blockShape) + , m_PadList(padList) + , m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const SpaceToBatchNdDescriptor& rhs) const + { + return m_BlockShape == rhs.m_BlockShape && + m_PadList == rhs.m_PadList && + m_DataLayout == rhs.m_DataLayout; + } + /// Block shape value. std::vector<unsigned int> m_BlockShape; /// @brief Specifies the padding values for the input dimension: @@ -620,6 +774,11 @@ struct SpaceToDepthDescriptor , m_DataLayout(dataLayout) {} + bool operator ==(const SpaceToDepthDescriptor& rhs) const + { + return m_BlockSize == rhs.m_BlockSize && m_DataLayout == rhs.m_DataLayout; + } + /// Scalar specifying the input block size. It must be >= 1 unsigned int m_BlockSize; @@ -634,15 +793,25 @@ using DepthToSpaceDescriptor = SpaceToDepthDescriptor; struct LstmDescriptor { LstmDescriptor() - : m_ActivationFunc(1) // 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid - , m_ClippingThresCell(0.0) - , m_ClippingThresProj(0.0) - , m_CifgEnabled(true) - , m_PeepholeEnabled(false) - , m_ProjectionEnabled(false) - , m_LayerNormEnabled(false) + : m_ActivationFunc(1) // 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid + , m_ClippingThresCell(0.0) + , m_ClippingThresProj(0.0) + , m_CifgEnabled(true) + , m_PeepholeEnabled(false) + , m_ProjectionEnabled(false) + , m_LayerNormEnabled(false) {} + bool operator ==(const LstmDescriptor& rhs) const + { + return m_ActivationFunc == rhs.m_ActivationFunc && + m_ClippingThresCell == rhs.m_ClippingThresCell && + m_ClippingThresProj == rhs.m_ClippingThresProj && + m_CifgEnabled == rhs.m_CifgEnabled && + m_PeepholeEnabled == rhs.m_PeepholeEnabled && + m_LayerNormEnabled == rhs.m_LayerNormEnabled; + } + /// @brief The activation function to use. /// 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid. uint32_t m_ActivationFunc; @@ -664,15 +833,20 @@ struct LstmDescriptor struct MeanDescriptor { MeanDescriptor() - : m_Axis() - , m_KeepDims(false) + : m_Axis() + , m_KeepDims(false) {} MeanDescriptor(const std::vector<unsigned int>& axis, bool keepDims) - : m_Axis(axis) - , m_KeepDims(keepDims) + : m_Axis(axis) + , m_KeepDims(keepDims) {} + bool operator ==(const MeanDescriptor& rhs) const + { + return m_Axis == rhs.m_Axis && m_KeepDims == rhs.m_KeepDims; + } + /// Values for the dimensions to reduce. std::vector<unsigned int> m_Axis; /// Enable/disable keep dimensions. If true, then the reduced dimensions that are of length 1 are kept. @@ -686,9 +860,15 @@ struct PadDescriptor {} PadDescriptor(const std::vector<std::pair<unsigned int, unsigned int>>& padList, const float& padValue = 0) - : m_PadList(padList), m_PadValue(padValue) + : m_PadList(padList) + , m_PadValue(padValue) {} + bool operator ==(const PadDescriptor& rhs) const + { + return m_PadList == rhs.m_PadList && m_PadValue == rhs.m_PadValue; + } + /// @brief Specifies the padding for input dimension. /// First is the number of values to add before the tensor in the dimension. /// Second is the number of values to add after the tensor in the dimension. @@ -710,6 +890,11 @@ struct SliceDescriptor SliceDescriptor() : SliceDescriptor({}, {}) {} + bool operator ==(const SliceDescriptor& rhs) const + { + return m_Begin == rhs.m_Begin && m_Size == rhs.m_Size; + } + /// Beginning indices of the slice in each dimension. std::vector<unsigned int> m_Begin; @@ -721,17 +906,24 @@ struct SliceDescriptor struct StackDescriptor { StackDescriptor() - : m_Axis(0) - , m_NumInputs(0) - , m_InputShape() + : m_Axis(0) + , m_NumInputs(0) + , m_InputShape() {} StackDescriptor(uint32_t axis, uint32_t numInputs, const TensorShape& inputShape) - : m_Axis(axis) - , m_NumInputs(numInputs) - , m_InputShape(inputShape) + : m_Axis(axis) + , m_NumInputs(numInputs) + , m_InputShape(inputShape) {} + bool operator ==(const StackDescriptor& rhs) const + { + return m_Axis == rhs.m_Axis && + m_NumInputs == rhs.m_NumInputs && + m_InputShape == rhs.m_InputShape; + } + /// 0-based axis along which to stack the input tensors. uint32_t m_Axis; /// Number of input tensors. @@ -746,21 +938,34 @@ struct StridedSliceDescriptor StridedSliceDescriptor(const std::vector<int>& begin, const std::vector<int>& end, const std::vector<int>& stride) - : m_Begin(begin) - , m_End(end) - , m_Stride(stride) - , m_BeginMask(0) - , m_EndMask(0) - , m_ShrinkAxisMask(0) - , m_EllipsisMask(0) - , m_NewAxisMask(0) - , m_DataLayout(DataLayout::NCHW) + : m_Begin(begin) + , m_End(end) + , m_Stride(stride) + , m_BeginMask(0) + , m_EndMask(0) + , m_ShrinkAxisMask(0) + , m_EllipsisMask(0) + , m_NewAxisMask(0) + , m_DataLayout(DataLayout::NCHW) {} StridedSliceDescriptor() - : StridedSliceDescriptor({}, {}, {}) + : StridedSliceDescriptor({}, {}, {}) {} + bool operator ==(const StridedSliceDescriptor& rhs) const + { + return m_Begin == rhs.m_Begin && + m_End == rhs.m_End && + m_Stride == rhs.m_Stride && + m_BeginMask == rhs.m_BeginMask && + m_EndMask == rhs.m_EndMask && + m_ShrinkAxisMask == rhs.m_ShrinkAxisMask && + m_EllipsisMask == rhs.m_EllipsisMask && + m_NewAxisMask == rhs.m_NewAxisMask && + m_DataLayout == rhs.m_DataLayout; + } + int GetStartForAxis(const TensorShape& inputShape, unsigned int axis) const; int GetStopForAxis(const TensorShape& inputShape, unsigned int axis, @@ -818,6 +1023,18 @@ struct TransposeConvolution2dDescriptor m_DataLayout(DataLayout::NCHW) {} + bool operator ==(const TransposeConvolution2dDescriptor& rhs) const + { + return m_PadLeft == rhs.m_PadLeft && + m_PadRight == rhs.m_PadRight && + m_PadTop == rhs.m_PadTop && + m_PadBottom == rhs.m_PadBottom && + m_StrideX == rhs.m_StrideX && + m_StrideY == rhs.m_StrideY && + m_BiasEnabled == rhs.m_BiasEnabled && + m_DataLayout == rhs.m_DataLayout; + } + /// Padding left value in the width dimension. uint32_t m_PadLeft; /// Padding right value in the width dimension. |