ArmNN
 23.02
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
BaseDescriptor

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
 
bool operator== (const BatchMatMulDescriptor &rhs) const
 
- Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
 
virtual ~BaseDescriptor ()=default
 

Static Public Member Functions

static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul (const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
 
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul (const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
 
static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication. More...
 
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together. More...
 
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed. More...
 

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_TransposeY
 
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_AdjointY
 
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More...
 
DataLayout m_DataLayoutY
 

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1531 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool  transposeX = false,
bool  transposeY = false,
bool  adjointX = false,
bool  adjointY = false,
DataLayout  dataLayoutX = DataLayout::NCHW,
DataLayout  dataLayoutY = DataLayout::NCHW 
)
inline

Definition at line 1533 of file Descriptors.hpp.

1539  : m_TransposeX(transposeX)
1540  , m_TransposeY(transposeY)
1541  , m_AdjointX(adjointX)
1542  , m_AdjointY(adjointY)
1543  , m_DataLayoutX(dataLayoutX)
1544  , m_DataLayoutY(dataLayoutY)
1545  {}

Member Function Documentation

◆ GetAxesNotMul() [1/2]

std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul ( const BatchMatMulDescriptor desc,
const TensorShape inputXShape,
const TensorShape inputYShape 
)
static

Definition at line 467 of file Descriptors.cpp.

471 {
472  return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape),
473  GetAxesNotMul(desc.m_DataLayoutY, inputYShape) };
474 }

References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.

Referenced by BatchMatMulQueueDescriptor::Validate().

◆ GetAxesNotMul() [2/2]

std::vector< unsigned int > GetAxesNotMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 497 of file Descriptors.cpp.

500 {
501  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
502  std::vector<unsigned int> axesNotMul;
503  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
504  {
505  if(i == axesToMul.first || i == axesToMul.second)
506  {
507  continue;
508  }
509  axesNotMul.push_back(i);
510  }
511  return axesNotMul;
512 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

◆ GetAxesToMul() [1/2]

std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul ( const BatchMatMulDescriptor desc,
const TensorShape tensorXShape,
const TensorShape tensorYShape 
)
static

Definition at line 459 of file Descriptors.cpp.

463 {
464  return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape),
465  GetAxesToMul(desc.m_DataLayoutY, tensorYShape) };
466 }

References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.

Referenced by BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ GetAxesToMul() [2/2]

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 476 of file Descriptors.cpp.

479 {
480  auto numDims = tensorShape.GetNumDimensions();
481  std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
482  switch(dataLayout)
483  {
484  case DataLayout::NDHWC:
485  case DataLayout::NHWC:
486  axes.first -= 1;
487  axes.second -= 1;
488  break;
489  case DataLayout::NCDHW:
490  case DataLayout::NCHW:
491  default:
492  break;
493  }
494  return axes;
495 }

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes which will be transposed.

Definition at line 514 of file Descriptors.cpp.

517 {
518  std::vector<unsigned int> vec;
519  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
520  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
521  {
522  if(i == axesToMul.first)
523  {
524  vec.push_back(i+1);
525  }
526  else if(i == axesToMul.second)
527  {
528  vec.push_back(i-1);
529  }
530  else
531  {
532  vec.push_back(i);
533  }
534  }
535  return PermutationVector(vec.data(),
536  static_cast<unsigned int>(vec.size()));
537 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ operator==()

bool operator== ( const BatchMatMulDescriptor rhs) const
inline

Definition at line 1547 of file Descriptors.hpp.

1548  {
1549  return m_TransposeX == rhs.m_TransposeX &&
1550  m_TransposeY == rhs.m_TransposeY &&
1551  m_AdjointX == rhs.m_AdjointX &&
1552  m_AdjointY == rhs.m_AdjointY &&
1553  m_DataLayoutX == rhs.m_DataLayoutX &&
1554  m_DataLayoutY == rhs.m_DataLayoutY;
1555  }

References BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, and BatchMatMulDescriptor::m_TransposeY.

Member Data Documentation

◆ m_AdjointX

◆ m_AdjointY

◆ m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.

Definition at line 1559 of file Descriptors.hpp.

Referenced by BatchMatMulLayer::InferOutputShapes(), BatchMatMulDescriptor::operator==(), StringifyLayerParameters< BatchMatMulDescriptor >::Serialize(), and BatchMatMulQueueDescriptor::Validate().

◆ m_TransposeY


The documentation for this struct was generated from the following files:
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1559
armnn::DataLayout::NCHW
@ NCHW
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1568
armnn::DataLayout::NCDHW
@ NCDHW
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1560
armnn::DataLayout::NHWC
@ NHWC
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1569
armnn::DataLayout::NDHWC
@ NDHWC
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1565
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1564
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Definition: Descriptors.cpp:459
armnn::BatchMatMulDescriptor::GetAxesNotMul
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
Definition: Descriptors.cpp:467