ArmNN
 23.02
BatchMatMulQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchMatMulQueueDescriptor:
QueueDescriptorWithParameters< BatchMatMulDescriptor > QueueDescriptor

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Additional Inherited Members

- Public Attributes inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
BatchMatMulDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 
- Protected Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 743 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 4053 of file WorkloadData.cpp.

4054 {
4055  const std::string descriptorName{"BatchMatMulDescriptor"};
4056 
4057  ValidateNumInputs(workloadInfo, descriptorName, 2);
4058  ValidateNumOutputs(workloadInfo, descriptorName, 1);
4059 
4060  // Inputs must be: both 2D+
4061  // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4062  // axes N and I must be the same size
4063 
4064  const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4065  const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4066  const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4067  // Output info has already been inferred
4068 
4069  std::vector<DataType> supportedTypes =
4070  {
4077  };
4078 
4079  ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4080  ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4081  ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4082 
4083  if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4084  (inputYInfoBeforeParams.GetNumDimensions() < 2))
4085  {
4086  throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4087  }
4088 
4089  TensorInfo inputXInfoAfterParams;
4090  TensorInfo inputYInfoAfterParams;
4091 
4094  {
4095  throw InvalidArgumentException(descriptorName +
4096  ": Invalid descriptor parameters - Transpose and Adjoint "
4097  "cannot both be true for a given input tensor.");
4098  }
4100  {
4101  inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4104  inputXInfoBeforeParams.GetShape()));
4105  }
4106  else if(m_Parameters.m_AdjointX)
4107  {
4109  inputXInfoBeforeParams.GetShape());
4110  if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4111  inputXInfoBeforeParams.GetShape()[axesToMul.second])
4112  {
4113  throw InvalidArgumentException(descriptorName +
4114  ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4115  }
4116  // Shape remains the same as it's square
4117  inputXInfoAfterParams = inputXInfoBeforeParams;
4118  }
4119  else
4120  {
4121  inputXInfoAfterParams = inputXInfoBeforeParams;
4122  }
4123 
4125  {
4126  inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4129  inputYInfoBeforeParams.GetShape()));
4130  }
4131  else if(m_Parameters.m_AdjointY)
4132  {
4134  inputYInfoBeforeParams.GetShape());
4135  if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4136  inputYInfoBeforeParams.GetShape()[axesToMul.second])
4137  {
4138  throw InvalidArgumentException(descriptorName +
4139  ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4140  }
4141  // Shape remains the same as it's square
4142  inputYInfoAfterParams = inputYInfoBeforeParams;
4143  }
4144  else
4145  {
4146  inputYInfoAfterParams = inputYInfoBeforeParams;
4147  }
4148 
4149  switch(m_Parameters.m_DataLayoutX)
4150  {
4151  case DataLayout::NCDHW:
4152  case DataLayout::NDHWC:
4153  if(inputXInfoAfterParams.GetNumDimensions() < 3)
4154  {
4155  throw InvalidArgumentException(descriptorName +
4156  ": Input tensor X does not have the correct "
4157  "number of dimensions for the Data Layout that it has been assigned.");
4158  }
4159  break;
4160  case DataLayout::NCHW:
4161  case DataLayout::NHWC:
4162  default:
4163  break;
4164  }
4165 
4166  switch(m_Parameters.m_DataLayoutY)
4167  {
4168  case DataLayout::NCDHW:
4169  case DataLayout::NDHWC:
4170  if(inputYInfoAfterParams.GetNumDimensions() < 3)
4171  {
4172  throw InvalidArgumentException(descriptorName +
4173  ": Input tensor Y does not have the correct "
4174  "number of dimensions for the Data Layout that it has been assigned.");
4175  }
4176  break;
4177  case DataLayout::NCHW:
4178  case DataLayout::NHWC:
4179  default:
4180  break;
4181  }
4182 
4184  inputXInfoAfterParams.GetShape());
4186  inputXInfoBeforeParams.GetShape());
4187 
4188  if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4189  != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4190  {
4191  throw InvalidArgumentException(descriptorName +
4192  ": The final axis of input tensor X must be the same size as "
4193  "the second last axis of input tensor Y.");
4194  }
4195 
4196  { // Separate scope so we don't pollute the rest of the scope with our temp variables
4197  // e.g. NHWC isnt compatible with NCHW as of now
4200 
4201  if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4202  {
4203  if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4204  {
4205  throw InvalidArgumentException(descriptorName +
4206  ": Invalid input tensor data layout combination.");
4207  }
4208  }
4209  if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4210  {
4211  if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4212  {
4213  throw InvalidArgumentException(descriptorName +
4214  ": Invalid input tensor data layout combination.");
4215  }
4216  }
4217  }
4218 
4219  // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4220  unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4221  inputYInfoAfterParams.GetNumDimensions());
4222  if(outputTensorDimSize-2 > 0)
4223  {
4224  TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4226  TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4228  TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4230 
4231  auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4232  {
4233  auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4234 
4235  for(unsigned int i = 0; i < sizeDiff; i++)
4236  {
4237  axisIndices.insert(axisIndices.begin(), 1);
4238  }
4239 
4240  for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4241  {
4242  ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4243  }
4244  };
4245 
4247  inputXInfoAfterParams.GetShape());
4249  inputYInfoAfterParams.GetShape());
4250 
4251  doAxisExtension(axesXNotMul, tiXNotMul);
4252  doAxisExtension(axesYNotMul, tiYNotMul);
4253 
4254  for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4255  {
4256  tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4257  tiYNotMul.GetShape()[i]);
4258  }
4259 
4260  ValidateBroadcastTensorShapesMatch(tiXNotMul,
4261  tiYNotMul,
4262  tiOutNotMul,
4263  descriptorName,
4264  "input_X",
4265  "input_Y");
4266  }
4267 }

References armnn::BFloat16, armnn::Float16, armnn::Float32, BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), BatchMatMulDescriptor::GetPermuteVec(), TensorInfo::GetShape(), BatchMatMulDescriptor::m_AdjointX, BatchMatMulDescriptor::m_AdjointY, BatchMatMulDescriptor::m_DataLayoutX, BatchMatMulDescriptor::m_DataLayoutY, WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, QueueDescriptorWithParameters< BatchMatMulDescriptor >::m_Parameters, BatchMatMulDescriptor::m_TransposeX, BatchMatMulDescriptor::m_TransposeY, armnn::NCDHW, armnn::NCHW, armnn::NDHWC, armnn::NHWC, armnnUtils::Permuted(), armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.


The documentation for this struct was generated from the following files:
armnn::BatchMatMulDescriptor::m_TransposeX
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
Definition: Descriptors.hpp:1559
armnn::DataType::QAsymmU8
@ QAsymmU8
armnn::DataLayout
DataLayout
Definition: Types.hpp:62
armnn::DataType::Float16
@ Float16
armnn::DataType::QAsymmS8
@ QAsymmS8
armnnUtils::Permuted
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98
armnn::TensorShape
Definition: Tensor.hpp:20
armnn::DataLayout::NCHW
@ NCHW
armnn::WorkloadInfo::m_OutputTensorInfos
std::vector< TensorInfo > m_OutputTensorInfos
Definition: WorkloadInfo.hpp:19
armnn::BatchMatMulDescriptor::m_DataLayoutX
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
Definition: Descriptors.hpp:1568
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
armnn::DataLayout::NCDHW
@ NCDHW
armnn::DataType::Float32
@ Float32
armnn::BatchMatMulDescriptor::m_TransposeY
bool m_TransposeY
Definition: Descriptors.hpp:1560
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
armnn::DataLayout::NHWC
@ NHWC
armnn::QueueDescriptorWithParameters< BatchMatMulDescriptor >::m_Parameters
BatchMatMulDescriptor m_Parameters
Definition: WorkloadData.hpp:66
armnn::DataType::BFloat16
@ BFloat16
armnn::BatchMatMulDescriptor::m_DataLayoutY
DataLayout m_DataLayoutY
Definition: Descriptors.hpp:1569
armnn::DataType::QSymmS16
@ QSymmS16
armnn::DataLayout::NDHWC
@ NDHWC
armnn::BatchMatMulDescriptor::GetPermuteVec
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
Definition: Descriptors.cpp:514
armnn::BatchMatMulDescriptor::m_AdjointY
bool m_AdjointY
Definition: Descriptors.hpp:1565
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn::BatchMatMulDescriptor::m_AdjointX
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
Definition: Descriptors.hpp:1564
armnn::BatchMatMulDescriptor::GetAxesToMul
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Definition: Descriptors.cpp:459
armnn::WorkloadInfo::m_InputTensorInfos
std::vector< TensorInfo > m_InputTensorInfos
Definition: WorkloadInfo.hpp:18
armnn::BatchMatMulDescriptor::GetAxesNotMul
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
Definition: Descriptors.cpp:467