diff options
Diffstat (limited to 'python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i')
-rw-r--r-- | python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i index 9374945daf..e755ef5982 100644 --- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i +++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i @@ -91,6 +91,104 @@ struct ArgMinMaxDescriptor %feature("docstring", " + A descriptor for the BatchMatMul layer. See `INetwork.AddBatchMatMulLayer()`. + + Contains: + m_TransposeX (bool): Transpose the slices of input tensor X. Transpose and Adjoint can not both be set to true for the same tensor at the same time. + m_TransposeY (bool): Transpose the slices of input tensor Y. Transpose and Adjoint can not both be set to true for the same tensor at the same time. + m_AdjointX (bool): Adjoint the slices of input tensor X. Transpose and Adjoint can not both be set to true for the same tensor at the same time. + m_AdjointY (bool): Adjoint the slices of input tensor Y. Transpose and Adjoint can not both be set to true for the same tensor at the same time. + m_DataLayoutX (DataLayout): Data layout of input tensor X, such as NHWC/NDHWC (leave as default for arbitrary layout). + m_DatalayoutY (DataLayout): Data layout of input tensor X, such as NHWC/NDHWC (leave as default for arbitrary layout) + ") BatchMatMulDescriptor; +struct BatchMatMulDescriptor +{ + BatchMatMulDescriptor(bool transposeX = false, + bool transposeY = false, + bool adjointX = false, + bool adjointY = false, + DataLayout dataLayoutX = DataLayout::NCHW, + DataLayout dataLayoutY = DataLayout::NCHW) + : m_TransposeX(transposeX) + , m_TransposeY(transposeY) + , m_AdjointX(adjointX) + , m_AdjointY(adjointY) + , m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) + {} + + bool operator ==(const BatchMatMulDescriptor &rhs) const + { + return m_TransposeX == rhs.m_TransposeX && + m_TransposeY == rhs.m_TransposeY && + m_AdjointX == rhs.m_AdjointX && + m_AdjointY == rhs.m_AdjointY && + m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY; + } + + bool m_TransposeX; + bool m_TransposeY; + bool m_AdjointX; + bool m_AdjointY; + DataLayout m_DataLayoutX; + DataLayout m_DataLayoutY; + + static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul( + const BatchMatMulDescriptor& desc, + const armnn::TensorShape& tensorXShape, + const armnn::TensorShape& tensorYShape); + + static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul( + const BatchMatMulDescriptor& desc, + const armnn::TensorShape& inputXShape, + const armnn::TensorShape& inputYShape); + + %feature("docstring", + " + Static helper to get the two axes (for each input) for multiplication + Args: + dataLayout (DataLayout) + tensorShape (TensorShape) + + Returns: + std::pair<unsigned int, unsigned int> + ") GetAxesToMul; + static std::pair<unsigned int, unsigned int> GetAxesToMul( + DataLayout dataLayout, + const armnn::TensorShape& tensorShape); + + %feature("docstring", + " + Static helper to get the two axes (for each input) that will not be multiplied together + Args: + dataLayout (DataLayout) + tensorShape (TensorShape) + + Returns: + std::vector<unsigned int> + ") GetAxesToNotMul; + static std::vector<unsigned int> GetAxesNotMul( + DataLayout dataLayout, + const armnn::TensorShape& tensorShape); + + %feature("docstring", + " + Static helper to get the axes which will be transposed + Args: + dataLayout (DataLayout) + tensorShape (TensorShape) + + Returns: + PermutationVector + ") GetPermuteVec; + static PermutationVector GetPermuteVec( + DataLayout dataLayout, + const armnn::TensorShape& tensorShape); +}; + +%feature("docstring", + " A descriptor for the BatchNormalization layer. See `INetwork.AddBatchNormalizationLayer()`. Contains: @@ -679,6 +777,26 @@ struct PadDescriptor %feature("docstring", " + A descriptor for the ElementwiseBinary layer. See `INetwork.AddElementwiseBinaryLayer()`. + Contains: + m_Operation (int): Indicates which Binary operation to use. (`BinaryOperation_Add`, `BinaryOperation_Div`, + `BinaryOperation_Maximum`, `BinaryOperation_Minimum`, `BinaryOperation_Mul`, `BinaryOperation_Sub`, + `BinaryOperation_SqDiff`, `BinaryOperation_Power`) + Default: `BinaryOperation_Add`. + + ") ElementwiseBinaryDescriptor; +struct ElementwiseBinaryDescriptor +{ + ElementwiseBinaryDescriptor(); + ElementwiseBinaryDescriptor(BinaryOperation operation); + + BinaryOperation m_Operation; + + bool operator ==(const ElementwiseBinaryDescriptor &rhs) const; +}; + +%feature("docstring", + " A descriptor for the ElementwiseUnary layer. See `INetwork.AddElementwiseUnaryLayer()`. Contains: @@ -1211,6 +1329,24 @@ struct LogicalBinaryDescriptor %feature("docstring", " + A descriptor for the Tile layer. See `INetwork.AddTileLayer()`. + + Contains: + m_Multiples (std::vector<uint32_t>): The vector to multiply the input shape by + + ") TileDescriptor; +struct TileDescriptor +{ + TileDescriptor(); + TileDescriptor(const std::vector<uint32_t>& multiples); + + std::vector<uint32_t> m_Multiples; + + bool operator ==(const TileDescriptor &rhs) const; +}; + +%feature("docstring", + " A descriptor for the Transpose layer. See `INetwork.AddTransposeLayer()`. Contains: |