aboutsummaryrefslogtreecommitdiff
path: root/shim
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2022-09-21 15:41:41 +0100
committerKevin May <kevin.may@arm.com>2022-09-22 10:13:17 +0100
commit9636a9b109fcbc811ec876ba9ca6512b7fbe2ba0 (patch)
tree70d1242430e11fc28301443ef80cea94be645a41 /shim
parent09026930f1cf207cddb243c8bc388e2c390ac940 (diff)
downloadarmnn-9636a9b109fcbc811ec876ba9ca6512b7fbe2ba0.tar.gz
IVGCVSW-6495 Add Support for BATCH_MATMUL to Arm Support Library
* Update feature level support to FL6 * Add ConvertBatchMatMul function Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: I93a77ba869bcddf432229a20e619304305d3982e
Diffstat (limited to 'shim')
-rw-r--r--shim/sl/canonical/ArmnnDriver.hpp2
-rw-r--r--shim/sl/canonical/Converter.cpp98
-rw-r--r--shim/sl/canonical/Converter.hpp2
3 files changed, 101 insertions, 1 deletions
diff --git a/shim/sl/canonical/ArmnnDriver.hpp b/shim/sl/canonical/ArmnnDriver.hpp
index 484a5318f7..bf5565a219 100644
--- a/shim/sl/canonical/ArmnnDriver.hpp
+++ b/shim/sl/canonical/ArmnnDriver.hpp
@@ -61,7 +61,7 @@ public:
Version getFeatureLevel() const override
{
VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
- return kVersionFeatureLevel5;
+ return kVersionFeatureLevel6;
}
DeviceType getType() const override
diff --git a/shim/sl/canonical/Converter.cpp b/shim/sl/canonical/Converter.cpp
index 5d52b4a779..8885fafe53 100644
--- a/shim/sl/canonical/Converter.cpp
+++ b/shim/sl/canonical/Converter.cpp
@@ -32,6 +32,8 @@ bool Converter::ConvertOperation(const Operation& operation, const Model& model,
return ConvertArgMinMax(operation, model, data, ArgMinMaxFunction::Min);
case OperationType::AVERAGE_POOL_2D:
return ConvertAveragePool2d(operation, model, data);
+ case OperationType::BATCH_MATMUL:
+ return ConvertBatchMatMul(operation, model, data);
case OperationType::BATCH_TO_SPACE_ND:
return ConvertBatchToSpaceNd(operation, model, data);
case OperationType::CAST:
@@ -328,6 +330,102 @@ bool Converter::ConvertAveragePool2d(const Operation& operation, const Model& mo
return ConvertPooling2d(operation, __func__, PoolingAlgorithm::Average, model, data);
}
+bool Converter::ConvertBatchMatMul(const Operation& operation, const Model& model, ConversionData& data)
+{
+ VLOG(DRIVER) << "Converter::ConvertBatchMatMul()";
+ LayerInputHandle input0 = ConvertToLayerInputHandle(operation, 0, model, data);
+ LayerInputHandle input1 = ConvertToLayerInputHandle(operation, 1, model, data);
+
+ if (!input0.IsValid() || !input1.IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const armnn::TensorInfo& inputInfo0 = input0.GetTensorInfo();
+ const armnn::TensorInfo& inputInfo1 = input1.GetTensorInfo();
+
+ unsigned int rankInput0 = inputInfo0.GetNumDimensions();
+ if (rankInput0 > 4 || rankInput0 < 2)
+ {
+ Fail("%s: Only inputs with rank at least 2 and up to 4 are supported", __func__);
+ }
+
+ unsigned int rankInput1 = inputInfo1.GetNumDimensions();
+ if (rankInput1 > 4 || rankInput1 < 2)
+ {
+ Fail("%s: Only inputs with rank at least 2 and up to 4 are supported", __func__);
+ }
+
+ // Determine data type of input tensor 0
+ OperandType input0Type;
+ if (!GetOperandType(operation, 0, model, input0Type))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ // Determine data type of input tensor 0
+ OperandType input1Type;
+ if (!GetOperandType(operation, 0, model, input1Type))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ if (input0Type != input1Type)
+ {
+ return Fail("%s: Operation has invalid inputs (Inputs must have same OperandCode)", __func__);
+ }
+
+ const Operand* output = GetOutputOperand(operation, 0, model);
+ if (!output)
+ {
+ return Fail("%s: Could not read output 0", __func__);
+ }
+
+ const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+
+ armnn::BatchMatMulDescriptor batchMatMulDesc;
+
+ // Inputs 2 and 3 are adjoint in Android NeuralNetworks, but they perform transpose.
+ // This is why we are linking them with transpose parameters in the descriptor
+ batchMatMulDesc.m_TransposeX = GetOptionalBool(operation, 2, model, data);
+ batchMatMulDesc.m_TransposeY = GetOptionalBool(operation, 3, model, data);
+
+ bool isSupported = false;
+ auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
+ {
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsBatchMatMulSupported,
+ data.m_Backends,
+ isSupported,
+ inputInfo0,
+ inputInfo1,
+ outputInfo,
+ batchMatMulDesc);
+ };
+
+ if(!IsDynamicTensor(outputInfo))
+ {
+ validateFunc(outputInfo, isSupported);
+ }
+ else
+ {
+ isSupported = AreDynamicTensorsSupported();
+ }
+
+
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* const layer = data.m_Network->AddBatchMatMulLayer(batchMatMulDesc);
+ assert(layer != nullptr);
+ input0.Connect(layer->GetInputSlot(0));
+ input1.Connect(layer->GetInputSlot(1));
+
+ return SetupAndTrackLayerOutputSlot(operation, 0, *layer, model, data, nullptr, validateFunc);
+}
+
bool Converter::ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data)
{
VLOG(DRIVER) << "Converter::ConvertBatchToSpaceNd()";
diff --git a/shim/sl/canonical/Converter.hpp b/shim/sl/canonical/Converter.hpp
index 8549289d69..7e4a89ee05 100644
--- a/shim/sl/canonical/Converter.hpp
+++ b/shim/sl/canonical/Converter.hpp
@@ -40,6 +40,8 @@ private:
static bool ConvertAveragePool2d(const Operation& operation, const Model& model, ConversionData& data);
+ static bool ConvertBatchMatMul(const Operation& operation, const Model& model, ConversionData& data);
+
static bool ConvertBatchToSpaceNd(const Operation& operation, const Model& model, ConversionData& data);
static bool ConvertCast(const Operation& operation, const Model& model, ConversionData& data);