diff options
author | Kevin May <kevin.may@arm.com> | 2022-09-21 15:41:41 +0100 |
---|---|---|
committer | Kevin May <kevin.may@arm.com> | 2022-09-22 10:13:17 +0100 |
commit | 9636a9b109fcbc811ec876ba9ca6512b7fbe2ba0 (patch) | |
tree | 70d1242430e11fc28301443ef80cea94be645a41 /shim/sl/canonical/Converter.cpp | |
parent | 09026930f1cf207cddb243c8bc388e2c390ac940 (diff) | |
download | armnn-9636a9b109fcbc811ec876ba9ca6512b7fbe2ba0.tar.gz |
IVGCVSW-6495 Add Support for BATCH_MATMUL to Arm Support Library
* Update feature level support to FL6
* Add ConvertBatchMatMul function
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: I93a77ba869bcddf432229a20e619304305d3982e
Diffstat (limited to 'shim/sl/canonical/Converter.cpp')
-rw-r--r-- | shim/sl/canonical/Converter.cpp | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/shim/sl/canonical/Converter.cpp b/shim/sl/canonical/Converter.cpp index 5d52b4a779..8885fafe53 100644 --- a/shim/sl/canonical/Converter.cpp +++ b/shim/sl/canonical/Converter.cpp @@ -32,6 +32,8 @@ bool Converter::ConvertOperation(const Operation& operation, const Model& model, return ConvertArgMinMax(operation, model, data, ArgMinMaxFunction::Min); case OperationType::AVERAGE_POOL_2D: return ConvertAveragePool2d(operation, model, data); + case OperationType::BATCH_MATMUL: + return ConvertBatchMatMul(operation, model, data); case OperationType::BATCH_TO_SPACE_ND: return ConvertBatchToSpaceNd(operation, model, data); case OperationType::CAST: @@ -328,6 +330,102 @@ bool Converter::ConvertAveragePool2d(const Operation& operation, const Model& mo return ConvertPooling2d(operation, __func__, PoolingAlgorithm::Average, model, data); } +bool Converter::ConvertBatchMatMul(const Operation& operation, const Model& model, ConversionData& data) +{ + VLOG(DRIVER) << "Converter::ConvertBatchMatMul()"; + LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data); + LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data); + + if (!input0.IsValid() || !input1.IsValid()) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + const armnn::TensorInfo& inputInfo0 = input0.GetTensorInfo(); + const armnn::TensorInfo& inputInfo1 = input1.GetTensorInfo(); + + unsigned int rankInput0 = inputInfo0.GetNumDimensions(); + if (rankInput0 > 4 || rankInput0 < 2) + { + Fail("%s: Only inputs with rank at least 2 and up to 4 are supported", __func__); + } + + unsigned int rankInput1 = inputInfo1.GetNumDimensions(); + if (rankInput1 > 4 || rankInput1 < 2) + { + Fail("%s: Only inputs with rank at least 2 and up to 4 are supported", __func__); + } + + // Determine data type of input tensor 0 + OperandType input0Type; + if (!GetOperandType(operation, 0, model, input0Type)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + // Determine data type of input tensor 0 + OperandType input1Type; + if (!GetOperandType(operation, 0, model, input1Type)) + { + return Fail("%s: Operation has invalid inputs", __func__); + } + + if (input0Type != input1Type) + { + return Fail("%s: Operation has invalid inputs (Inputs must have same OperandCode)", __func__); + } + + const Operand* output = GetOutputOperand(operation, 0, model); + if (!output) + { + return Fail("%s: Could not read output 0", __func__); + } + + const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + + armnn::BatchMatMulDescriptor batchMatMulDesc; + + // Inputs 2 and 3 are adjoint in Android NeuralNetworks, but they perform transpose. + // This is why we are linking them with transpose parameters in the descriptor + batchMatMulDesc.m_TransposeX = GetOptionalBool(operation, 2, model, data); + batchMatMulDesc.m_TransposeY = GetOptionalBool(operation, 3, model, data); + + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsBatchMatMulSupported, + data.m_Backends, + isSupported, + inputInfo0, + inputInfo1, + outputInfo, + batchMatMulDesc); + }; + + if(!IsDynamicTensor(outputInfo)) + { + validateFunc(outputInfo, isSupported); + } + else + { + isSupported = AreDynamicTensorsSupported(); + } + + + if (!isSupported) + { + return false; + } + + armnn::IConnectableLayer* const layer = data.m_Network->AddBatchMatMulLayer(batchMatMulDesc); + assert(layer != nullptr); + input0.Connect(layer->GetInputSlot(0)); + input1.Connect(layer->GetInputSlot(1)); + + return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc); +} + bool Converter::ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data) { VLOG(DRIVER) << "Converter::ConvertBatchToSpaceNd()"; |