4148 const std::string descriptorName{
"BatchMatMulDescriptor"};
4150 ValidateNumInputs(workloadInfo, descriptorName, 2);
4151 ValidateNumOutputs(workloadInfo, descriptorName, 1);
4161 std::vector<DataType> supportedTypes =
4171 ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
4172 ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
4173 ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
4175 if ((inputTensorXInfo.GetNumDimensions() < 2) ||
4176 (inputTensorYInfo.GetNumDimensions() < 2))
4187 if(inputTensorXInfo.GetNumDimensions() != 4)
4190 ": Input tensor X does not have the correct " 4191 "number of dimensions for the Data Layout that it has been assigned.");
4196 if(inputTensorXInfo.GetNumDimensions() != 5)
4199 ": Input tensor X does not have the correct " 4200 "number of dimensions for the Data Layout that it has been assigned.");
4214 if(inputTensorYInfo.GetNumDimensions() != 4)
4217 ": Input tensor Y does not have the correct " 4218 "number of dimensions for the Data Layout that it has been assigned.");
4223 if(inputTensorYInfo.GetNumDimensions() != 5)
4226 ": Input tensor Y does not have the correct " 4227 "number of dimensions for the Data Layout that it has been assigned.");
4236 inputTensorXInfo.GetShape(),
4237 inputTensorYInfo.GetShape());
4239 if(inputTensorXInfo.GetShape()[axesToMul.first.second]
4240 != inputTensorYInfo.GetShape()[axesToMul.second.first])
4243 ": The final axis of input tensor X must be the same size as " 4244 "the second last axis of input tensor Y.");
4248 inputTensorXInfo.GetShape(),
4249 inputTensorYInfo.GetShape());
4279 ": Invalid input tensor data layout combination.");
4287 ": Invalid input tensor data layout combination.");
4293 unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
4294 inputTensorYInfo.GetNumDimensions());
4295 if(outputTensorDimSize-2 > 0)
4304 auto doAxisExtension = [&](std::vector<unsigned int> axisIndices,
TensorInfo& ti)
4306 auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4308 for(
unsigned int i = 0; i < sizeDiff; i++)
4310 axisIndices.insert(axisIndices.begin(), 1);
4313 for(
unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4315 ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
4319 doAxisExtension(axesNotMul.first, tiXNotMul);
4320 doAxisExtension(axesNotMul.second, tiYNotMul);
4328 ValidateBroadcastTensorShapesMatch(tiXNotMul,
4342 ": Invalid descriptor parameters - Transpose and Adjoint " 4343 "vectors cannot both be true for a given input tensor.");
4349 ": Invalid descriptor parameter - Transpose X vector must be " 4350 "the same size as tensor input X's dimensionality.");
4355 ": Invalid descriptor parameter - Adjoint X vector must be " 4356 "the same size as tensor input X's dimensionality.");
4361 ": Invalid descriptor parameter - Transpose Y vector must be " 4362 "the same size as tensor input Y's dimensionality.");
4367 ": Invalid descriptor parameter - Adjoint Y vector must be " 4368 "the same size as tensor input Y's dimensionality.");
const TensorShape & GetShape() const
Optional< DataLayout > m_DataLayoutX
Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)...
std::vector< unsigned int > m_TransposeX
Transpose vector for each input tensor (leave as empty vector for no pre-transposing) Transpose and A...
std::vector< unsigned int > m_AdjointX
Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) Transpose and Adjoint...
BatchMatMulDescriptor m_Parameters
std::vector< TensorInfo > m_InputTensorInfos
static std::pair< std::vector< unsigned int >, std::vector< unsigned int > > GetAxesNotMul(const BatchMatMulDescriptor &desc, const TensorShape &inputXShape, const TensorShape &inputYShape)
Static helper to get the axes (for each input) that will not be multiplied together.
bool has_value() const noexcept
std::vector< unsigned int > m_TransposeY
static std::pair< std::pair< unsigned int, unsigned int >, std::pair< unsigned int, unsigned int > > GetAxesToMul(const BatchMatMulDescriptor &desc, const TensorShape &tensorXShape, const TensorShape &tensorYShape)
Static helper to get the two axes (for each input) for multiplication.
std::vector< TensorInfo > m_OutputTensorInfos
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
std::vector< unsigned int > m_AdjointY
unsigned int GetNumDimensions() const
Optional< DataLayout > m_DataLayoutY