ArmNN
 22.08
BatchMatMulQueueDescriptor Struct Reference

#include <WorkloadData.hpp>

Inheritance diagram for BatchMatMulQueueDescriptor:
QueueDescriptorWithParameters< BatchMatMulDescriptor > QueueDescriptor

Public Member Functions

void Validate (const WorkloadInfo &workloadInfo) const
 
- Public Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
virtual ~QueueDescriptorWithParameters ()=default
 
- Public Member Functions inherited from QueueDescriptor
virtual ~QueueDescriptor ()=default
 
void ValidateTensorNumDimensions (const TensorInfo &tensor, std::string const &descName, unsigned int numDimensions, std::string const &tensorName) const
 
void ValidateTensorNumDimNumElem (const TensorInfo &tensorInfo, unsigned int numDimension, unsigned int numElements, std::string const &tensorName) const
 
void ValidateInputsOutputs (const std::string &descName, unsigned int numExpectedIn, unsigned int numExpectedOut) const
 
template<typename T >
const T * GetAdditionalInformation () const
 

Additional Inherited Members

- Public Attributes inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
BatchMatMulDescriptor m_Parameters
 
- Public Attributes inherited from QueueDescriptor
std::vector< ITensorHandle * > m_Inputs
 
std::vector< ITensorHandle * > m_Outputs
 
void * m_AdditionalInfoObject
 
bool m_AllowExpandedDims = false
 
- Protected Member Functions inherited from QueueDescriptorWithParameters< BatchMatMulDescriptor >
 QueueDescriptorWithParameters ()=default
 
 QueueDescriptorWithParameters (QueueDescriptorWithParameters const &)=default
 
QueueDescriptorWithParametersoperator= (QueueDescriptorWithParameters const &)=default
 
- Protected Member Functions inherited from QueueDescriptor
 QueueDescriptor ()
 
 QueueDescriptor (QueueDescriptor const &)=default
 
QueueDescriptoroperator= (QueueDescriptor const &)=default
 

Detailed Description

Definition at line 788 of file WorkloadData.hpp.

Member Function Documentation

◆ Validate()

void Validate ( const WorkloadInfo workloadInfo) const

Definition at line 4146 of file WorkloadData.cpp.

References armnn::BFloat16, armnn::Float16, armnn::Float32, BatchMatMulDescriptor::GetAxesNotMul(), BatchMatMulDescriptor::GetAxesToMul(), TensorInfo::GetNumDimensions(), TensorInfo::GetShape(), WorkloadInfo::m_InputTensorInfos, WorkloadInfo::m_OutputTensorInfos, armnn::NCDHW, armnn::NCHW, armnn::NDHWC, armnn::NHWC, armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.

4147 {
4148  const std::string descriptorName{"BatchMatMulDescriptor"};
4149 
4150  ValidateNumInputs(workloadInfo, descriptorName, 2);
4151  ValidateNumOutputs(workloadInfo, descriptorName, 1);
4152 
4153  // Inputs must be: both 2D+
4154  // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4155  // axes N and I must be the same size
4156 
4157  const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
4158  const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
4159  const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
4160 
4161  std::vector<DataType> supportedTypes =
4162  {
4169  };
4170 
4171  ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
4172  ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
4173  ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4174 
4175  if ((inputTensorXInfo.GetNumDimensions() < 2) ||
4176  (inputTensorYInfo.GetNumDimensions() < 2))
4177  {
4178  throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4179  }
4180 
4182  {
4183  switch(m_Parameters.m_DataLayoutX.value())
4184  {
4185  case DataLayout::NCHW:
4186  case DataLayout::NHWC:
4187  if(inputTensorXInfo.GetNumDimensions() != 4)
4188  {
4189  throw InvalidArgumentException(descriptorName +
4190  ": Input tensor X does not have the correct "
4191  "number of dimensions for the Data Layout that it has been assigned.");
4192  }
4193  break;
4194  case DataLayout::NCDHW:
4195  case DataLayout::NDHWC:
4196  if(inputTensorXInfo.GetNumDimensions() != 5)
4197  {
4198  throw InvalidArgumentException(descriptorName +
4199  ": Input tensor X does not have the correct "
4200  "number of dimensions for the Data Layout that it has been assigned.");
4201  }
4202  break;
4203  default:
4204  break;
4205  }
4206  }
4207 
4209  {
4210  switch(m_Parameters.m_DataLayoutY.value())
4211  {
4212  case DataLayout::NCHW:
4213  case DataLayout::NHWC:
4214  if(inputTensorYInfo.GetNumDimensions() != 4)
4215  {
4216  throw InvalidArgumentException(descriptorName +
4217  ": Input tensor Y does not have the correct "
4218  "number of dimensions for the Data Layout that it has been assigned.");
4219  }
4220  break;
4221  case DataLayout::NCDHW:
4222  case DataLayout::NDHWC:
4223  if(inputTensorYInfo.GetNumDimensions() != 5)
4224  {
4225  throw InvalidArgumentException(descriptorName +
4226  ": Input tensor Y does not have the correct "
4227  "number of dimensions for the Data Layout that it has been assigned.");
4228  }
4229  break;
4230  default:
4231  break;
4232  }
4233  }
4234 
4236  inputTensorXInfo.GetShape(),
4237  inputTensorYInfo.GetShape());
4238 
4239  if(inputTensorXInfo.GetShape()[axesToMul.first.second]
4240  != inputTensorYInfo.GetShape()[axesToMul.second.first])
4241  {
4242  throw InvalidArgumentException(descriptorName +
4243  ": The final axis of input tensor X must be the same size as "
4244  "the second last axis of input tensor Y.");
4245  }
4246 
4248  inputTensorXInfo.GetShape(),
4249  inputTensorYInfo.GetShape());
4250 
4251  { // Separate scope so we don't pollute the rest of the scope with our temp variables
4252  // e.g. NHWC isnt compatible with NCHW as of now
4253  DataLayout xLayout;
4254  DataLayout yLayout;
4255 
4257  {
4258  xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
4259  }
4260  else
4261  {
4262  xLayout = m_Parameters.m_DataLayoutX.value();
4263  }
4264 
4266  {
4267  yLayout = DataLayout::NCHW;
4268  }
4269  else
4270  {
4271  yLayout = m_Parameters.m_DataLayoutY.value();
4272  }
4273 
4274  if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4275  {
4276  if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4277  {
4278  throw InvalidArgumentException(descriptorName +
4279  ": Invalid input tensor data layout combination.");
4280  }
4281  }
4282  if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4283  {
4284  if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4285  {
4286  throw InvalidArgumentException(descriptorName +
4287  ": Invalid input tensor data layout combination.");
4288  }
4289  }
4290  }
4291 
4292  // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4293  unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
4294  inputTensorYInfo.GetNumDimensions());
4295  if(outputTensorDimSize-2 > 0)
4296  {
4297  TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4299  TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4301  TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4303 
4304  auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4305  {
4306  auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4307 
4308  for(unsigned int i = 0; i < sizeDiff; i++)
4309  {
4310  axisIndices.insert(axisIndices.begin(), 1);
4311  }
4312 
4313  for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4314  {
4315  ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
4316  }
4317  };
4318 
4319  doAxisExtension(axesNotMul.first, tiXNotMul);
4320  doAxisExtension(axesNotMul.second, tiYNotMul);
4321 
4322  for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4323  {
4324  tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4325  tiYNotMul.GetShape()[i]);
4326  }
4327 
4328  ValidateBroadcastTensorShapesMatch(tiXNotMul,
4329  tiYNotMul,
4330  tiOutNotMul,
4331  descriptorName,
4332  "input_X",
4333  "input_Y");
4334  }
4335 
4336  // Also check descriptor parameter validity
4337  // This will eventually be moved to the start of the function as explained below
4338  if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
4339  (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
4340  {
4341  throw InvalidArgumentException(descriptorName +
4342  ": Invalid descriptor parameters - Transpose and Adjoint "
4343  "vectors cannot both be true for a given input tensor.");
4344  }
4345 
4346  if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
4347  {
4348  throw InvalidArgumentException(descriptorName +
4349  ": Invalid descriptor parameter - Transpose X vector must be "
4350  "the same size as tensor input X's dimensionality.");
4351  }
4352  if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
4353  {
4354  throw InvalidArgumentException(descriptorName +
4355  ": Invalid descriptor parameter - Adjoint X vector must be "
4356  "the same size as tensor input X's dimensionality.");
4357  }
4358  if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
4359  {
4360  throw InvalidArgumentException(descriptorName +
4361  ": Invalid descriptor parameter - Transpose Y vector must be "
4362  "the same size as tensor input Y's dimensionality.");
4363  }
4364  if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
4365  {
4366  throw InvalidArgumentException(descriptorName +
4367  ": Invalid descriptor parameter - Adjoint Y vector must be "
4368  "the same size as tensor input Y's dimensionality.");
4369  }
4370  // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
4371 }
DataLayout
Definition: Types.hpp:62
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
Optional< DataLayout > m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)...
std::vector< unsigned int > m_TransposeX
Transpose vector for each input tensor (leave as empty vector for no pre-transposing) Transpose and A...
std::vector< unsigned int > m_AdjointX
Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) Transpose and Adjoint...
std::vector< TensorInfo > m_InputTensorInfos
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
Static helper to get the axes (for each input) that will not be multiplied together.
bool has_value() const noexcept
Definition: Optional.hpp:53
std::vector< unsigned int > m_TransposeY
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Static helper to get the two axes (for each input) for multiplication.
std::vector< TensorInfo > m_OutputTensorInfos
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition: Optional.hpp:32
std::vector< unsigned int > m_AdjointY
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
Optional< DataLayout > m_DataLayoutY

The documentation for this struct was generated from the following files: