ArmNN
 22.11
BatchMatMulQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchMatMulQueueDescriptor:
QueueDescriptorWithParameters< BatchMatMulDescriptor > QueueDescriptor

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Additional Inherited Members

- Public Attributes inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
BatchMatMulDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 
- Protected Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 780 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 4099 of file WorkloadData.cpp.

References armnn::BFloat16, armnn::Float16, armnn::Float32, BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), BatchMatMulDescriptor::GetPermuteVec(), TensorInfo::GetShape(), WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::NCDHW, armnn::NCHW, armnn::NDHWC, armnn::NHWC, armnnUtils::Permuted(), armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.

4100 {
4101  const std::string descriptorName{"BatchMatMulDescriptor"};
4102 
4103  ValidateNumInputs(workloadInfo, descriptorName, 2);
4104  ValidateNumOutputs(workloadInfo, descriptorName, 1);
4105 
4106  // Inputs must be: both 2D+
4107  // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4108  // axes N and I must be the same size
4109 
4110  const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4111  const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4112  const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4113  // Output info has already been inferred
4114 
4115  std::vector<DataType> supportedTypes =
4116  {
4123  };
4124 
4125  ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4126  ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4127  ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4128 
4129  if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4130  (inputYInfoBeforeParams.GetNumDimensions() < 2))
4131  {
4132  throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4133  }
4134 
4135  TensorInfo inputXInfoAfterParams;
4136  TensorInfo inputYInfoAfterParams;
4137 
4140  {
4141  throw InvalidArgumentException(descriptorName +
4142  ": Invalid descriptor parameters - Transpose and Adjoint "
4143  "cannot both be true for a given input tensor.");
4144  }
4146  {
4147  inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4150  inputXInfoBeforeParams.GetShape()));
4151  }
4152  else if(m_Parameters.m_AdjointX)
4153  {
4155  inputXInfoBeforeParams.GetShape());
4156  if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4157  inputXInfoBeforeParams.GetShape()[axesToMul.second])
4158  {
4159  throw InvalidArgumentException(descriptorName +
4160  ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4161  }
4162  // Shape remains the same as it's square
4163  inputXInfoAfterParams = inputXInfoBeforeParams;
4164  }
4165  else
4166  {
4167  inputXInfoAfterParams = inputXInfoBeforeParams;
4168  }
4169 
4171  {
4172  inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4175  inputYInfoBeforeParams.GetShape()));
4176  }
4177  else if(m_Parameters.m_AdjointY)
4178  {
4180  inputYInfoBeforeParams.GetShape());
4181  if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4182  inputYInfoBeforeParams.GetShape()[axesToMul.second])
4183  {
4184  throw InvalidArgumentException(descriptorName +
4185  ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4186  }
4187  // Shape remains the same as it's square
4188  inputYInfoAfterParams = inputYInfoBeforeParams;
4189  }
4190  else
4191  {
4192  inputYInfoAfterParams = inputYInfoBeforeParams;
4193  }
4194 
4195  switch(m_Parameters.m_DataLayoutX)
4196  {
4197  case DataLayout::NCDHW:
4198  case DataLayout::NDHWC:
4199  if(inputXInfoAfterParams.GetNumDimensions() < 3)
4200  {
4201  throw InvalidArgumentException(descriptorName +
4202  ": Input tensor X does not have the correct "
4203  "number of dimensions for the Data Layout that it has been assigned.");
4204  }
4205  break;
4206  case DataLayout::NCHW:
4207  case DataLayout::NHWC:
4208  default:
4209  break;
4210  }
4211 
4212  switch(m_Parameters.m_DataLayoutY)
4213  {
4214  case DataLayout::NCDHW:
4215  case DataLayout::NDHWC:
4216  if(inputYInfoAfterParams.GetNumDimensions() < 3)
4217  {
4218  throw InvalidArgumentException(descriptorName +
4219  ": Input tensor Y does not have the correct "
4220  "number of dimensions for the Data Layout that it has been assigned.");
4221  }
4222  break;
4223  case DataLayout::NCHW:
4224  case DataLayout::NHWC:
4225  default:
4226  break;
4227  }
4228 
4230  inputXInfoAfterParams.GetShape());
4232  inputXInfoBeforeParams.GetShape());
4233 
4234  if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4235  != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4236  {
4237  throw InvalidArgumentException(descriptorName +
4238  ": The final axis of input tensor X must be the same size as "
4239  "the second last axis of input tensor Y.");
4240  }
4241 
4242  { // Separate scope so we don't pollute the rest of the scope with our temp variables
4243  // e.g. NHWC isnt compatible with NCHW as of now
4246 
4247  if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4248  {
4249  if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4250  {
4251  throw InvalidArgumentException(descriptorName +
4252  ": Invalid input tensor data layout combination.");
4253  }
4254  }
4255  if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4256  {
4257  if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4258  {
4259  throw InvalidArgumentException(descriptorName +
4260  ": Invalid input tensor data layout combination.");
4261  }
4262  }
4263  }
4264 
4265  // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4266  unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4267  inputYInfoAfterParams.GetNumDimensions());
4268  if(outputTensorDimSize-2 > 0)
4269  {
4270  TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4272  TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4274  TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4276 
4277  auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4278  {
4279  auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4280 
4281  for(unsigned int i = 0; i < sizeDiff; i++)
4282  {
4283  axisIndices.insert(axisIndices.begin(), 1);
4284  }
4285 
4286  for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4287  {
4288  ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4289  }
4290  };
4291 
4293  inputXInfoAfterParams.GetShape());
4295  inputYInfoAfterParams.GetShape());
4296 
4297  doAxisExtension(axesXNotMul, tiXNotMul);
4298  doAxisExtension(axesYNotMul, tiYNotMul);
4299 
4300  for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4301  {
4302  tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4303  tiYNotMul.GetShape()[i]);
4304  }
4305 
4306  ValidateBroadcastTensorShapesMatch(tiXNotMul,
4307  tiYNotMul,
4308  tiOutNotMul,
4309  descriptorName,
4310  "input_X",
4311  "input_Y");
4312  }
4313 }
DataLayout
Definition: Types.hpp:62
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...
std::vector< TensorInfo > m_InputTensorInfos
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
std::vector< TensorInfo > m_OutputTensorInfos
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)
Definition: Permute.cpp:98

The documentation for this struct was generated from the following files: