4101 const std::string descriptorName{
"BatchMatMulDescriptor"};
4103 ValidateNumInputs(workloadInfo, descriptorName, 2);
4104 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4115 std::vector<DataType> supportedTypes =
4125 ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4126 ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4127 ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4129 if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4130 (inputYInfoBeforeParams.GetNumDimensions() < 2))
4142 ": Invalid descriptor parameters - Transpose and Adjoint " 4143 "cannot both be true for a given input tensor.");
4150 inputXInfoBeforeParams.GetShape()));
4155 inputXInfoBeforeParams.GetShape());
4156 if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4157 inputXInfoBeforeParams.GetShape()[axesToMul.second])
4160 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4163 inputXInfoAfterParams = inputXInfoBeforeParams;
4167 inputXInfoAfterParams = inputXInfoBeforeParams;
4175 inputYInfoBeforeParams.GetShape()));
4180 inputYInfoBeforeParams.GetShape());
4181 if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4182 inputYInfoBeforeParams.GetShape()[axesToMul.second])
4185 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4188 inputYInfoAfterParams = inputYInfoBeforeParams;
4192 inputYInfoAfterParams = inputYInfoBeforeParams;
4202 ": Input tensor X does not have the correct " 4203 "number of dimensions for the Data Layout that it has been assigned.");
4219 ": Input tensor Y does not have the correct " 4220 "number of dimensions for the Data Layout that it has been assigned.");
4232 inputXInfoBeforeParams.GetShape());
4234 if(inputXInfoAfterParams.
GetShape()[axesXToMul.second]
4235 != inputYInfoAfterParams.
GetShape()[axesYToMul.first])
4238 ": The final axis of input tensor X must be the same size as " 4239 "the second last axis of input tensor Y.");
4252 ": Invalid input tensor data layout combination.");
4260 ": Invalid input tensor data layout combination.");
4266 unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.
GetNumDimensions(),
4268 if(outputTensorDimSize-2 > 0)
4277 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices,
TensorInfo& ti)
4279 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4281 for(
unsigned int i = 0; i < sizeDiff; i++)
4283 axisIndices.insert(axisIndices.begin(), 1);
4286 for(
unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4288 ti.GetShape()[i] = inputXInfoAfterParams.
GetShape()[i];
4297 doAxisExtension(axesXNotMul, tiXNotMul);
4298 doAxisExtension(axesYNotMul, tiYNotMul);
4306 ValidateBroadcastTensorShapesMatch(tiXNotMul,
const TensorShape & GetShape() const
bool m_TransposeX
Transpose the slices of each input tensor Transpose and Adjoint can not both be set to true for the s...
bool m_AdjointX
Adjoint the slices of each input tensor Transpose and Adjoint can not both be set to true for the sam...
static PermutationVector GetPermuteVec(DataLayout dataLayout, const TensorShape &tensorShape)
Static helper to get the axes which will be transposed.
BatchMatMulDescriptor m_Parameters
DataLayout m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) ...
std::vector< TensorInfo > m_InputTensorInfos
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
std::vector< TensorInfo > m_OutputTensorInfos
unsigned int GetNumDimensions() const
armnn::TensorShape Permuted(const armnn::TensorShape &srcShape, const armnn::PermutationVector &mappings)