ArmNN
 22.11
BatchMatMulDescriptor Struct Reference

A BatchMatMulDescriptor for the BatchMatMul operator. More...

#include <Descriptors.hpp>

Inheritance diagram for BatchMatMulDescriptor:
BaseDescriptor

Public Member Functions

 BatchMatMulDescriptor (bool transposeX=false, bool transposeY=false, bool adjointX=false, bool adjointY=false, DataLayout dataLayoutX=DataLayout::NCHW, DataLayout dataLayoutY=DataLayout::NCHW)
 
bool operator== (const BatchMatMulDescriptor &rhs) const
 
- Public Member Functions inherited from BaseDescriptor
virtual bool IsNull () const
 
virtual ~BaseDescriptor ()=default
 

Static Public Member Functions

static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul (const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
 
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul (const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
 
static std::pair< unsigned int, unsigned int > GetAxesToMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the two axes (for each input) for multiplication. More...
 
static std::vector< unsigned int > GetAxesNotMul (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes (for each input) that will not be multiplied together. More...
 
static PermutationVector GetPermuteVec (DataLayout dataLayout, const TensorShape &tensorShape)
 Static helper to get the axes which will be transposed. More...
 

Public Attributes

bool m_TransposeX
 Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_TransposeY
 
bool m_AdjointX
 Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time. More...
 
bool m_AdjointY
 
DataLayout m_DataLayoutX
 Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) More...
 
DataLayout m_DataLayoutY
 

Detailed Description

A BatchMatMulDescriptor for the BatchMatMul operator.

Definition at line 1517 of file Descriptors.hpp.

Constructor & Destructor Documentation

◆ BatchMatMulDescriptor()

BatchMatMulDescriptor ( bool  transposeX = false,
bool  transposeY = false,
bool  adjointX = false,
bool  adjointY = false,
DataLayout  dataLayoutX = DataLayout::NCHW,
DataLayout  dataLayoutY = DataLayout::NCHW 
)
inline

Definition at line 1519 of file Descriptors.hpp.

1525  : m_TransposeX(transposeX)
1526  , m_TransposeY(transposeY)
1527  , m_AdjointX(adjointX)
1528  , m_AdjointY(adjointY)
1529  , m_DataLayoutX(dataLayoutX)
1530  , m_DataLayoutY(dataLayoutY)
1531  {}
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...

Member Function Documentation

◆ GetAxesNotMul() [1/2]

std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul ( const BatchMatMulDescriptor desc,
const TensorShape inputXShape,
const TensorShape inputYShape 
)
static

Definition at line 467 of file Descriptors.cpp.

References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.

Referenced by BatchMatMulQueueDescriptor::Validate().

471 {
472  return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape),
473  GetAxesNotMul(desc.m_DataLayoutY, inputYShape) };
474 }
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)

◆ GetAxesNotMul() [2/2]

std::vector< unsigned int > GetAxesNotMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes (for each input) that will not be multiplied together.

Definition at line 497 of file Descriptors.cpp.

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

500 {
501  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
502  std::vector<unsigned int> axesNotMul;
503  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
504  {
505  if(i == axesToMul.first || i == axesToMul.second)
506  {
507  continue;
508  }
509  axesNotMul.push_back(i);
510  }
511  return axesNotMul;
512 }
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)

◆ GetAxesToMul() [1/2]

std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul ( const BatchMatMulDescriptor desc,
const TensorShape tensorXShape,
const TensorShape tensorYShape 
)
static

Definition at line 459 of file Descriptors.cpp.

References BatchMatMulDescriptor::m_DataLayoutX, and BatchMatMulDescriptor::m_DataLayoutY.

Referenced by BatchMatMul::BatchMatMul(), BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetPermuteVec(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

463 {
464  return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape),
465  GetAxesToMul(desc.m_DataLayoutY, tensorYShape) };
466 }
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)

◆ GetAxesToMul() [2/2]

std::pair< unsigned int, unsigned int > GetAxesToMul ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the two axes (for each input) for multiplication.

Definition at line 476 of file Descriptors.cpp.

References TensorShape::GetNumDimensions(), armnn::NCDHW, armnn::NCHW, armnn::NDHWC, and armnn::NHWC.

479 {
480  auto numDims = tensorShape.GetNumDimensions();
481  std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
482  switch(dataLayout)
483  {
484  case DataLayout::NDHWC:
485  case DataLayout::NHWC:
486  axes.first -= 1;
487  axes.second -= 1;
488  break;
489  case DataLayout::NCDHW:
490  case DataLayout::NCHW:
491  default:
492  break;
493  }
494  return axes;
495 }

◆ GetPermuteVec()

PermutationVector GetPermuteVec ( DataLayout  dataLayout,
const TensorShape tensorShape 
)
static

Static helper to get the axes which will be transposed.

Definition at line 514 of file Descriptors.cpp.

References BatchMatMulDescriptor::GetAxesToMul(), and TensorShape::GetNumDimensions().

Referenced by BatchMatMul::BatchMatMul(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulQueueDescriptor::Validate().

517 {
518  std::vector<unsigned int> vec;
519  auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
520  for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
521  {
522  if(i == axesToMul.first)
523  {
524  vec.push_back(i+1);
525  }
526  else if(i == axesToMul.second)
527  {
528  vec.push_back(i-1);
529  }
530  else
531  {
532  vec.push_back(i);
533  }
534  }
535  return PermutationVector(vec.data(),
536  static_cast<unsigned int>(vec.size()));
537 }
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)

◆ operator==()

bool operator== ( const BatchMatMulDescriptor rhs) const
inline

Definition at line 1533 of file Descriptors.hpp.

References BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, BatchMatMulDescriptor::m_TransposeX, and BatchMatMulDescriptor::m_TransposeY.

1534  {
1535  return m_TransposeX == rhs.m_TransposeX &&
1536  m_TransposeY == rhs.m_TransposeY &&
1537  m_AdjointX == rhs.m_AdjointX &&
1538  m_AdjointY == rhs.m_AdjointY &&
1539  m_DataLayoutX == rhs.m_DataLayoutX &&
1540  m_DataLayoutY == rhs.m_DataLayoutY;
1541  }
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...

Member Data Documentation

◆ m_AdjointX

bool m_AdjointX

Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.

Definition at line 1550 of file Descriptors.hpp.

Referenced by BatchMatMul::BatchMatMul(), armnnSerializer::GetFlatBufferArgMinMaxFunction(), BatchMatMulLayer::InferOutputShapes(), armnn::NeonBatchMatMulValidate(), NeonBatchMatMulWorkload::NeonBatchMatMulWorkload(), and BatchMatMulDescriptor::operator==().

◆ m_AdjointY

◆ m_DataLayoutX

◆ m_DataLayoutY

◆ m_TransposeX

bool m_TransposeX

Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the same tensor at the same time.

Definition at line 1545 of file Descriptors.hpp.

Referenced by BatchMatMul::BatchMatMul(), Converter::ConvertOperation(), armnnSerializer::GetFlatBufferArgMinMaxFunction(), BatchMatMulLayer::InferOutputShapes(), and BatchMatMulDescriptor::operator==().

◆ m_TransposeY


The documentation for this struct was generated from the following files: