ArmNN
 23.08
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
[legend]
Collaboration diagram for BatchMatMulDescriptor:
[legend]

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
 
bool operator== (const BatchMatMulDescriptor &rhs) const
 
- Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
 
virtual ~BaseDescriptor ()=default
 

Static Public Member Functions

static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication. More...
 
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together. More...
 
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed. More...
 

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_TransposeY
 
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_AdjointY
 
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More...
 
DataLayout m_DataLayoutY
 

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1563 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool  transposeX = false,
bool  transposeY = false,
bool  adjointX = false,
bool  adjointY = false,
DataLayout  dataLayoutX = DataLayout::NCHW,
DataLayout  dataLayoutY = DataLayout::NCHW 
)
inline

Definition at line 1565 of file Descriptors.hpp.

1571  : m_TransposeX(transposeX)
1572  , m_TransposeY(transposeY)
1573  , m_AdjointX(adjointX)
1574  , m_AdjointY(adjointY)
1575  , m_DataLayoutX(dataLayoutX)
1576  , m_DataLayoutY(dataLayoutY)
1577  {}

Member Function Documentation

◆ GetAxesNotMul()

std::vector< unsigned int > GetAxesNotMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 497 of file Descriptors.cpp.

500 {
501  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
502  std::vector<unsigned int> axesNotMul;
503  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
504  {
505  if(i == axesToMul.first || i == axesToMul.second)
506  {
507  continue;
508  }
509  axesNotMul.push_back(i);
510  }
511  return axesNotMul;
512 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulQueueDescriptor::Validate().

◆ GetAxesToMul()

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 476 of file Descriptors.cpp.

479 {
480  auto numDims = tensorShape.GetNumDimensions();
481  std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
482  switch(dataLayout)
483  {
484  case DataLayout::NDHWC:
485  case DataLayout::NHWC:
486  axes.first -= 1;
487  axes.second -= 1;
488  break;
489  case DataLayout::NCDHW:
490  case DataLayout::NCHW:
491  default:
492  break;
493  }
494  return axes;
495 }

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

Referenced by BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes which will be transposed.

Definition at line 514 of file Descriptors.cpp.

517 {
518  std::vector<unsigned int> vec;
519  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
520  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
521  {
522  if(i == axesToMul.first)
523  {
524  vec.push_back(i+1);
525  }
526  else if(i == axesToMul.second)
527  {
528  vec.push_back(i-1);
529  }
530  else
531  {
532  vec.push_back(i);
533  }
534  }
535  return PermutationVector(vec.data(),
536  static_cast<unsigned int>(vec.size()));
537 }

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

◆ operator==()

bool operator== ( const BatchMatMulDescriptor rhs) const
inline

Definition at line 1579 of file Descriptors.hpp.

1580  {
1581  return m_TransposeX == rhs.m_TransposeX &&
1582  m_TransposeY == rhs.m_TransposeY &&
1583  m_AdjointX == rhs.m_AdjointX &&
1584  m_AdjointY == rhs.m_AdjointY &&
1585  m_DataLayoutX == rhs.m_DataLayoutX &&
1586  m_DataLayoutY == rhs.m_DataLayoutY;
1587  }

References BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, and BatchMatMulDescriptor::m_TransposeY.

Member Data Documentation

◆ m_AdjointX

◆ m_AdjointY

◆ m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

◆ m_TransposeY


The documentation for this struct was generated from the following files:
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1591
armnn::DataLayout::NCDHW
@ NCDHW
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1596
armnn::DataLayout::NHWC
@ NHWC
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< unsigned int, unsigned int > GetAxesToMul(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the two axes (for each input) for multiplication.
Definition: Descriptors.cpp:476
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1600
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1597
armnn::DataLayout::NDHWC
@ NDHWC
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1592
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1601
armnn::DataLayout::NCHW
@ NCHW