aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-15 17:35:36 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-16 09:39:56 +0000
commit6fe5247f8997a04edfdd7c974c96a0a086ef3ab5 (patch)
tree52d6cc314797f7bf138a0b2d81491543e05b6900 /include/armnn/Descriptors.hpp
parent20bea0071d507772e303eb6f1c476bf1feac9be5 (diff)
downloadarmnn-6fe5247f8997a04edfdd7c974c96a0a086ef3ab5.tar.gz
IVGCVSW-3991 Make Descriptor objects comparable and refactor LayerVisitor tests
* Implemented operator==() for Descriptor structs * Refactored TestNameAndDescriptorLayerVisitor to eliminate code duplication by using templates and taking advantage of the fact that descriptor objects can now all be compared the same way using == * Cleaned up TestNameOnlylayerVisitor by moving all test cases for layers that require a descriptor to TestNameAndDescriptorLayerVisitor Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Iee38b04d68d34a5f4ec7e5790de39ecb7ab0fb80
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp435
1 files changed, 326 insertions, 109 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index e2e59741a3..92e842b2c1 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -19,7 +19,16 @@ namespace armnn
/// An ActivationDescriptor for the ActivationLayer.
struct ActivationDescriptor
{
- ActivationDescriptor() : m_Function(ActivationFunction::Sigmoid), m_A(0), m_B(0) {}
+ ActivationDescriptor()
+ : m_Function(ActivationFunction::Sigmoid)
+ , m_A(0)
+ , m_B(0)
+ {}
+
+ bool operator ==(const ActivationDescriptor &rhs) const
+ {
+ return m_Function == rhs.m_Function && m_A == rhs.m_B && m_B == rhs.m_B;
+ }
/// @brief The activation function to use
/// (Sigmoid, TanH, Linear, ReLu, BoundedReLu, SoftReLu, LeakyReLu, Abs, Sqrt, Square).
@@ -34,10 +43,15 @@ struct ActivationDescriptor
struct ArgMinMaxDescriptor
{
ArgMinMaxDescriptor()
- : m_Function(ArgMinMaxFunction::Min)
- , m_Axis(-1)
+ : m_Function(ArgMinMaxFunction::Min)
+ , m_Axis(-1)
{}
+ bool operator ==(const ArgMinMaxDescriptor &rhs) const
+ {
+ return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis;
+ }
+
/// Specify if the function is to find Min or Max.
ArgMinMaxFunction m_Function;
/// Axis to reduce across the input tensor.
@@ -49,12 +63,17 @@ struct PermuteDescriptor
{
PermuteDescriptor()
: m_DimMappings{}
- {
- }
+ {}
+
PermuteDescriptor(const PermutationVector& dimMappings)
: m_DimMappings(dimMappings)
+ {}
+
+ bool operator ==(const PermuteDescriptor &rhs) const
{
+ return m_DimMappings.IsEqual(rhs.m_DimMappings);
}
+
/// @brief Indicates how to translate tensor elements from a given source into the target destination, when
/// source and target potentially have different memory layouts e.g. {0U, 3U, 1U, 2U}.
PermutationVector m_DimMappings;
@@ -64,10 +83,15 @@ struct PermuteDescriptor
struct SoftmaxDescriptor
{
SoftmaxDescriptor()
- : m_Beta(1.0f)
- , m_Axis(-1)
+ : m_Beta(1.0f)
+ , m_Axis(-1)
{}
+ bool operator ==(const SoftmaxDescriptor& rhs) const
+ {
+ return m_Beta == rhs.m_Beta && m_Axis == rhs.m_Axis;
+ }
+
/// Exponentiation value.
float m_Beta;
/// Scalar, defaulted to the last index (-1), specifying the dimension the activation will be performed on.
@@ -91,6 +115,8 @@ struct OriginsDescriptor
OriginsDescriptor& operator=(OriginsDescriptor rhs);
+ bool operator ==(const OriginsDescriptor& rhs) const;
+
/// @Brief Set the view origin coordinates. The arguments are: view, dimension, value.
/// If the view is greater than or equal to GetNumViews(), then the view argument is out of range.
/// If the coord is greater than or equal to GetNumDimensions(), then the coord argument is out of range.
@@ -131,6 +157,9 @@ struct ViewsDescriptor
~ViewsDescriptor();
ViewsDescriptor& operator=(ViewsDescriptor rhs);
+
+ bool operator ==(const ViewsDescriptor& rhs) const;
+
/// @Brief Set the view origin coordinates. The arguments are: view, dimension, value.
/// If the view is greater than or equal to GetNumViews(), then the view argument is out of range.
/// If the coord is greater than or equal to GetNumDimensions(), then the coord argument is out of range.
@@ -244,20 +273,36 @@ OriginsDescriptor CreateDescriptorForConcatenation(TensorShapeIt first,
struct Pooling2dDescriptor
{
Pooling2dDescriptor()
- : m_PoolType(PoolingAlgorithm::Max)
- , m_PadLeft(0)
- , m_PadRight(0)
- , m_PadTop(0)
- , m_PadBottom(0)
- , m_PoolWidth(0)
- , m_PoolHeight(0)
- , m_StrideX(0)
- , m_StrideY(0)
- , m_OutputShapeRounding(OutputShapeRounding::Floor)
- , m_PaddingMethod(PaddingMethod::Exclude)
- , m_DataLayout(DataLayout::NCHW)
+ : m_PoolType(PoolingAlgorithm::Max)
+ , m_PadLeft(0)
+ , m_PadRight(0)
+ , m_PadTop(0)
+ , m_PadBottom(0)
+ , m_PoolWidth(0)
+ , m_PoolHeight(0)
+ , m_StrideX(0)
+ , m_StrideY(0)
+ , m_OutputShapeRounding(OutputShapeRounding::Floor)
+ , m_PaddingMethod(PaddingMethod::Exclude)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const Pooling2dDescriptor& rhs) const
+ {
+ return m_PoolType == rhs.m_PoolType &&
+ m_PadLeft == rhs.m_PadLeft &&
+ m_PadRight == rhs.m_PadRight &&
+ m_PadTop == rhs.m_PadTop &&
+ m_PadBottom == rhs.m_PadBottom &&
+ m_PoolWidth == rhs.m_PoolWidth &&
+ m_PoolHeight == rhs.m_PoolHeight &&
+ m_StrideX == rhs.m_StrideX &&
+ m_StrideY == rhs.m_StrideY &&
+ m_OutputShapeRounding == rhs.m_OutputShapeRounding &&
+ m_PaddingMethod == rhs.m_PaddingMethod &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// The pooling algorithm to use (Max. Average, L2).
PoolingAlgorithm m_PoolType;
/// Padding left value in the width dimension.
@@ -288,10 +333,15 @@ struct Pooling2dDescriptor
struct FullyConnectedDescriptor
{
FullyConnectedDescriptor()
- : m_BiasEnabled(false)
- , m_TransposeWeightMatrix(false)
+ : m_BiasEnabled(false)
+ , m_TransposeWeightMatrix(false)
{}
+ bool operator ==(const FullyConnectedDescriptor& rhs) const
+ {
+ return m_BiasEnabled == rhs.m_BiasEnabled && m_TransposeWeightMatrix == rhs.m_TransposeWeightMatrix;
+ }
+
/// Enable/disable bias.
bool m_BiasEnabled;
/// Enable/disable transpose weight matrix.
@@ -302,18 +352,32 @@ struct FullyConnectedDescriptor
struct Convolution2dDescriptor
{
Convolution2dDescriptor()
- : m_PadLeft(0)
- , m_PadRight(0)
- , m_PadTop(0)
- , m_PadBottom(0)
- , m_StrideX(0)
- , m_StrideY(0)
- , m_DilationX(1)
- , m_DilationY(1)
- , m_BiasEnabled(false)
- , m_DataLayout(DataLayout::NCHW)
+ : m_PadLeft(0)
+ , m_PadRight(0)
+ , m_PadTop(0)
+ , m_PadBottom(0)
+ , m_StrideX(0)
+ , m_StrideY(0)
+ , m_DilationX(1)
+ , m_DilationY(1)
+ , m_BiasEnabled(false)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const Convolution2dDescriptor& rhs) const
+ {
+ return m_PadLeft == rhs.m_PadLeft &&
+ m_PadRight == rhs.m_PadRight &&
+ m_PadTop == rhs.m_PadTop &&
+ m_PadBottom == rhs.m_PadBottom &&
+ m_StrideX == rhs.m_StrideX &&
+ m_StrideY == rhs.m_StrideY &&
+ m_DilationX == rhs.m_DilationX &&
+ m_DilationY == rhs.m_DilationY &&
+ m_BiasEnabled == rhs.m_BiasEnabled &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Padding left value in the width dimension.
uint32_t m_PadLeft;
/// Padding right value in the width dimension.
@@ -340,18 +404,32 @@ struct Convolution2dDescriptor
struct DepthwiseConvolution2dDescriptor
{
DepthwiseConvolution2dDescriptor()
- : m_PadLeft(0)
- , m_PadRight(0)
- , m_PadTop(0)
- , m_PadBottom(0)
- , m_StrideX(0)
- , m_StrideY(0)
- , m_DilationX(1)
- , m_DilationY(1)
- , m_BiasEnabled(false)
- , m_DataLayout(DataLayout::NCHW)
+ : m_PadLeft(0)
+ , m_PadRight(0)
+ , m_PadTop(0)
+ , m_PadBottom(0)
+ , m_StrideX(0)
+ , m_StrideY(0)
+ , m_DilationX(1)
+ , m_DilationY(1)
+ , m_BiasEnabled(false)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const DepthwiseConvolution2dDescriptor& rhs) const
+ {
+ return m_PadLeft == rhs.m_PadLeft &&
+ m_PadRight == rhs.m_PadRight &&
+ m_PadTop == rhs.m_PadTop &&
+ m_PadBottom == rhs.m_PadBottom &&
+ m_StrideX == rhs.m_StrideX &&
+ m_StrideY == rhs.m_StrideY &&
+ m_DilationX == rhs.m_DilationX &&
+ m_DilationY == rhs.m_DilationY &&
+ m_BiasEnabled == rhs.m_BiasEnabled &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Padding left value in the width dimension.
uint32_t m_PadLeft;
/// Padding right value in the width dimension.
@@ -377,19 +455,34 @@ struct DepthwiseConvolution2dDescriptor
struct DetectionPostProcessDescriptor
{
DetectionPostProcessDescriptor()
- : m_MaxDetections(0)
- , m_MaxClassesPerDetection(1)
- , m_DetectionsPerClass(1)
- , m_NmsScoreThreshold(0)
- , m_NmsIouThreshold(0)
- , m_NumClasses(0)
- , m_UseRegularNms(false)
- , m_ScaleX(0)
- , m_ScaleY(0)
- , m_ScaleW(0)
- , m_ScaleH(0)
+ : m_MaxDetections(0)
+ , m_MaxClassesPerDetection(1)
+ , m_DetectionsPerClass(1)
+ , m_NmsScoreThreshold(0)
+ , m_NmsIouThreshold(0)
+ , m_NumClasses(0)
+ , m_UseRegularNms(false)
+ , m_ScaleX(0)
+ , m_ScaleY(0)
+ , m_ScaleW(0)
+ , m_ScaleH(0)
{}
+ bool operator ==(const DetectionPostProcessDescriptor& rhs) const
+ {
+ return m_MaxDetections == rhs.m_MaxDetections &&
+ m_MaxClassesPerDetection == rhs.m_MaxClassesPerDetection &&
+ m_DetectionsPerClass == rhs.m_DetectionsPerClass &&
+ m_NmsScoreThreshold == rhs.m_NmsScoreThreshold &&
+ m_NmsIouThreshold == rhs.m_NmsIouThreshold &&
+ m_NumClasses == rhs.m_NumClasses &&
+ m_UseRegularNms == rhs.m_UseRegularNms &&
+ m_ScaleX == rhs.m_ScaleX &&
+ m_ScaleY == rhs.m_ScaleY &&
+ m_ScaleW == rhs.m_ScaleW &&
+ m_ScaleH == rhs.m_ScaleH;
+ }
+
/// Maximum numbers of detections.
uint32_t m_MaxDetections;
/// Maximum numbers of classes per detection, used in Fast NMS.
@@ -418,15 +511,26 @@ struct DetectionPostProcessDescriptor
struct NormalizationDescriptor
{
NormalizationDescriptor()
- : m_NormChannelType(NormalizationAlgorithmChannel::Across)
- , m_NormMethodType(NormalizationAlgorithmMethod::LocalBrightness)
- , m_NormSize(0)
- , m_Alpha(0.f)
- , m_Beta(0.f)
- , m_K(0.f)
- , m_DataLayout(DataLayout::NCHW)
+ : m_NormChannelType(NormalizationAlgorithmChannel::Across)
+ , m_NormMethodType(NormalizationAlgorithmMethod::LocalBrightness)
+ , m_NormSize(0)
+ , m_Alpha(0.f)
+ , m_Beta(0.f)
+ , m_K(0.f)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const NormalizationDescriptor& rhs) const
+ {
+ return m_NormChannelType == rhs.m_NormChannelType &&
+ m_NormMethodType == rhs.m_NormMethodType &&
+ m_NormSize == rhs.m_NormSize &&
+ m_Alpha == rhs.m_Alpha &&
+ m_Beta == rhs.m_Beta &&
+ m_K == rhs.m_K &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Normalization channel algorithm to use (Across, Within).
NormalizationAlgorithmChannel m_NormChannelType;
/// Normalization method algorithm to use (LocalBrightness, LocalContrast).
@@ -447,10 +551,15 @@ struct NormalizationDescriptor
struct L2NormalizationDescriptor
{
L2NormalizationDescriptor()
- : m_Eps(1e-12f)
- , m_DataLayout(DataLayout::NCHW)
+ : m_Eps(1e-12f)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const L2NormalizationDescriptor& rhs) const
+ {
+ return m_Eps == rhs.m_Eps && m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Used to avoid dividing by zero.
float m_Eps;
/// The data layout to be used (NCHW, NHWC).
@@ -461,10 +570,15 @@ struct L2NormalizationDescriptor
struct BatchNormalizationDescriptor
{
BatchNormalizationDescriptor()
- : m_Eps(0.0001f)
- , m_DataLayout(DataLayout::NCHW)
+ : m_Eps(0.0001f)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const BatchNormalizationDescriptor& rhs) const
+ {
+ return m_Eps == rhs.m_Eps && m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Value to add to the variance. Used to avoid dividing by zero.
float m_Eps;
/// The data layout to be used (NCHW, NHWC).
@@ -481,6 +595,14 @@ struct InstanceNormalizationDescriptor
, m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const InstanceNormalizationDescriptor& rhs) const
+ {
+ return m_Gamma == rhs.m_Gamma &&
+ m_Beta == rhs.m_Beta &&
+ m_Eps == rhs.m_Eps &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Gamma, the scale scalar value applied for the normalized tensor. Defaults to 1.0.
float m_Gamma;
/// Beta, the offset scalar value applied for the normalized tensor. Defaults to 1.0.
@@ -507,6 +629,13 @@ struct BatchToSpaceNdDescriptor
, m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const BatchToSpaceNdDescriptor& rhs) const
+ {
+ return m_BlockShape == rhs.m_BlockShape &&
+ m_Crops == rhs.m_Crops &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Block shape values.
std::vector<unsigned int> m_BlockShape;
/// The values to crop from the input dimension.
@@ -518,11 +647,16 @@ struct BatchToSpaceNdDescriptor
/// A FakeQuantizationDescriptor for the FakeQuantizationLayer.
struct FakeQuantizationDescriptor
{
- FakeQuantizationDescriptor()
- : m_Min(-6.0f)
- , m_Max(6.0f)
+ FakeQuantizationDescriptor()
+ : m_Min(-6.0f)
+ , m_Max(6.0f)
{}
+ bool operator ==(const FakeQuantizationDescriptor& rhs) const
+ {
+ return m_Min == rhs.m_Min && m_Max == rhs.m_Max;
+ }
+
/// Minimum value.
float m_Min;
/// Maximum value.
@@ -533,9 +667,9 @@ struct FakeQuantizationDescriptor
struct ResizeBilinearDescriptor
{
ResizeBilinearDescriptor()
- : m_TargetWidth(0)
- , m_TargetHeight(0)
- , m_DataLayout(DataLayout::NCHW)
+ : m_TargetWidth(0)
+ , m_TargetHeight(0)
+ , m_DataLayout(DataLayout::NCHW)
{}
/// Target width value.
@@ -550,12 +684,20 @@ struct ResizeBilinearDescriptor
struct ResizeDescriptor
{
ResizeDescriptor()
- : m_TargetWidth(0)
- , m_TargetHeight(0)
- , m_Method(ResizeMethod::NearestNeighbor)
- , m_DataLayout(DataLayout::NCHW)
+ : m_TargetWidth(0)
+ , m_TargetHeight(0)
+ , m_Method(ResizeMethod::NearestNeighbor)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const ResizeDescriptor& rhs) const
+ {
+ return m_TargetWidth == rhs.m_TargetWidth &&
+ m_TargetHeight == rhs.m_TargetHeight &&
+ m_Method == rhs.m_Method &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Target width value.
uint32_t m_TargetWidth;
/// Target height value.
@@ -572,13 +714,18 @@ struct ResizeDescriptor
struct ReshapeDescriptor
{
ReshapeDescriptor()
- : m_TargetShape()
+ : m_TargetShape()
{}
ReshapeDescriptor(const TensorShape& shape)
- : m_TargetShape(shape)
+ : m_TargetShape(shape)
{}
+ bool operator ==(const ReshapeDescriptor& rhs) const
+ {
+ return m_TargetShape == rhs.m_TargetShape;
+ }
+
/// Target shape value.
TensorShape m_TargetShape;
};
@@ -587,18 +734,25 @@ struct ReshapeDescriptor
struct SpaceToBatchNdDescriptor
{
SpaceToBatchNdDescriptor()
- : m_BlockShape({1, 1})
- , m_PadList({{0, 0}, {0, 0}})
- , m_DataLayout(DataLayout::NCHW)
+ : m_BlockShape({1, 1})
+ , m_PadList({{0, 0}, {0, 0}})
+ , m_DataLayout(DataLayout::NCHW)
{}
SpaceToBatchNdDescriptor(const std::vector<unsigned int>& blockShape,
const std::vector<std::pair<unsigned int, unsigned int>>& padList)
- : m_BlockShape(blockShape)
- , m_PadList(padList)
- , m_DataLayout(DataLayout::NCHW)
+ : m_BlockShape(blockShape)
+ , m_PadList(padList)
+ , m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const SpaceToBatchNdDescriptor& rhs) const
+ {
+ return m_BlockShape == rhs.m_BlockShape &&
+ m_PadList == rhs.m_PadList &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Block shape value.
std::vector<unsigned int> m_BlockShape;
/// @brief Specifies the padding values for the input dimension:
@@ -620,6 +774,11 @@ struct SpaceToDepthDescriptor
, m_DataLayout(dataLayout)
{}
+ bool operator ==(const SpaceToDepthDescriptor& rhs) const
+ {
+ return m_BlockSize == rhs.m_BlockSize && m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Scalar specifying the input block size. It must be >= 1
unsigned int m_BlockSize;
@@ -634,15 +793,25 @@ using DepthToSpaceDescriptor = SpaceToDepthDescriptor;
struct LstmDescriptor
{
LstmDescriptor()
- : m_ActivationFunc(1) // 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid
- , m_ClippingThresCell(0.0)
- , m_ClippingThresProj(0.0)
- , m_CifgEnabled(true)
- , m_PeepholeEnabled(false)
- , m_ProjectionEnabled(false)
- , m_LayerNormEnabled(false)
+ : m_ActivationFunc(1) // 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid
+ , m_ClippingThresCell(0.0)
+ , m_ClippingThresProj(0.0)
+ , m_CifgEnabled(true)
+ , m_PeepholeEnabled(false)
+ , m_ProjectionEnabled(false)
+ , m_LayerNormEnabled(false)
{}
+ bool operator ==(const LstmDescriptor& rhs) const
+ {
+ return m_ActivationFunc == rhs.m_ActivationFunc &&
+ m_ClippingThresCell == rhs.m_ClippingThresCell &&
+ m_ClippingThresProj == rhs.m_ClippingThresProj &&
+ m_CifgEnabled == rhs.m_CifgEnabled &&
+ m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled;
+ }
+
/// @brief The activation function to use.
/// 0: None, 1: Relu, 3: Relu6, 4: Tanh, 6: Sigmoid.
uint32_t m_ActivationFunc;
@@ -664,15 +833,20 @@ struct LstmDescriptor
struct MeanDescriptor
{
MeanDescriptor()
- : m_Axis()
- , m_KeepDims(false)
+ : m_Axis()
+ , m_KeepDims(false)
{}
MeanDescriptor(const std::vector<unsigned int>& axis, bool keepDims)
- : m_Axis(axis)
- , m_KeepDims(keepDims)
+ : m_Axis(axis)
+ , m_KeepDims(keepDims)
{}
+ bool operator ==(const MeanDescriptor& rhs) const
+ {
+ return m_Axis == rhs.m_Axis && m_KeepDims == rhs.m_KeepDims;
+ }
+
/// Values for the dimensions to reduce.
std::vector<unsigned int> m_Axis;
/// Enable/disable keep dimensions. If true, then the reduced dimensions that are of length 1 are kept.
@@ -686,9 +860,15 @@ struct PadDescriptor
{}
PadDescriptor(const std::vector<std::pair<unsigned int, unsigned int>>& padList, const float& padValue = 0)
- : m_PadList(padList), m_PadValue(padValue)
+ : m_PadList(padList)
+ , m_PadValue(padValue)
{}
+ bool operator ==(const PadDescriptor& rhs) const
+ {
+ return m_PadList == rhs.m_PadList && m_PadValue == rhs.m_PadValue;
+ }
+
/// @brief Specifies the padding for input dimension.
/// First is the number of values to add before the tensor in the dimension.
/// Second is the number of values to add after the tensor in the dimension.
@@ -710,6 +890,11 @@ struct SliceDescriptor
SliceDescriptor() : SliceDescriptor({}, {})
{}
+ bool operator ==(const SliceDescriptor& rhs) const
+ {
+ return m_Begin == rhs.m_Begin && m_Size == rhs.m_Size;
+ }
+
/// Beginning indices of the slice in each dimension.
std::vector<unsigned int> m_Begin;
@@ -721,17 +906,24 @@ struct SliceDescriptor
struct StackDescriptor
{
StackDescriptor()
- : m_Axis(0)
- , m_NumInputs(0)
- , m_InputShape()
+ : m_Axis(0)
+ , m_NumInputs(0)
+ , m_InputShape()
{}
StackDescriptor(uint32_t axis, uint32_t numInputs, const TensorShape& inputShape)
- : m_Axis(axis)
- , m_NumInputs(numInputs)
- , m_InputShape(inputShape)
+ : m_Axis(axis)
+ , m_NumInputs(numInputs)
+ , m_InputShape(inputShape)
{}
+ bool operator ==(const StackDescriptor& rhs) const
+ {
+ return m_Axis == rhs.m_Axis &&
+ m_NumInputs == rhs.m_NumInputs &&
+ m_InputShape == rhs.m_InputShape;
+ }
+
/// 0-based axis along which to stack the input tensors.
uint32_t m_Axis;
/// Number of input tensors.
@@ -746,21 +938,34 @@ struct StridedSliceDescriptor
StridedSliceDescriptor(const std::vector<int>& begin,
const std::vector<int>& end,
const std::vector<int>& stride)
- : m_Begin(begin)
- , m_End(end)
- , m_Stride(stride)
- , m_BeginMask(0)
- , m_EndMask(0)
- , m_ShrinkAxisMask(0)
- , m_EllipsisMask(0)
- , m_NewAxisMask(0)
- , m_DataLayout(DataLayout::NCHW)
+ : m_Begin(begin)
+ , m_End(end)
+ , m_Stride(stride)
+ , m_BeginMask(0)
+ , m_EndMask(0)
+ , m_ShrinkAxisMask(0)
+ , m_EllipsisMask(0)
+ , m_NewAxisMask(0)
+ , m_DataLayout(DataLayout::NCHW)
{}
StridedSliceDescriptor()
- : StridedSliceDescriptor({}, {}, {})
+ : StridedSliceDescriptor({}, {}, {})
{}
+ bool operator ==(const StridedSliceDescriptor& rhs) const
+ {
+ return m_Begin == rhs.m_Begin &&
+ m_End == rhs.m_End &&
+ m_Stride == rhs.m_Stride &&
+ m_BeginMask == rhs.m_BeginMask &&
+ m_EndMask == rhs.m_EndMask &&
+ m_ShrinkAxisMask == rhs.m_ShrinkAxisMask &&
+ m_EllipsisMask == rhs.m_EllipsisMask &&
+ m_NewAxisMask == rhs.m_NewAxisMask &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
int GetStartForAxis(const TensorShape& inputShape, unsigned int axis) const;
int GetStopForAxis(const TensorShape& inputShape,
unsigned int axis,
@@ -818,6 +1023,18 @@ struct TransposeConvolution2dDescriptor
m_DataLayout(DataLayout::NCHW)
{}
+ bool operator ==(const TransposeConvolution2dDescriptor& rhs) const
+ {
+ return m_PadLeft == rhs.m_PadLeft &&
+ m_PadRight == rhs.m_PadRight &&
+ m_PadTop == rhs.m_PadTop &&
+ m_PadBottom == rhs.m_PadBottom &&
+ m_StrideX == rhs.m_StrideX &&
+ m_StrideY == rhs.m_StrideY &&
+ m_BiasEnabled == rhs.m_BiasEnabled &&
+ m_DataLayout == rhs.m_DataLayout;
+ }
+
/// Padding left value in the width dimension.
uint32_t m_PadLeft;
/// Padding right value in the width dimension.