aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-08-24 17:04:34 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-08-31 10:21:01 +0100
commite7cd8f9d9b2b366b49207c519f5ec5fd0b41c382 (patch)
tree546406ab4199637a4374d859d0a9ad328a63c97d
parentb9e6b5c3b96792f40201315c831db0aa257f286c (diff)
downloadarmnn-e7cd8f9d9b2b366b49207c519f5ec5fd0b41c382.tar.gz
IVGCVSW-6497: BatchMatMul TfLite Parser
* Added armnnTfLiteParser for BatchMatMul * Added unit testing for parser * Updated CMakeLists Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: If6842aaf7cf08f688093b714e2ecea6e8cd87161
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp6
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp39
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.hpp1
-rw-r--r--src/armnnTfLiteParser/test/BatchMatMul.cpp114
5 files changed, 158 insertions, 3 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4e4818d232..14236c7ae8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -660,6 +660,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfLiteParser/test/Addition.cpp
src/armnnTfLiteParser/test/ArgMinMax.cpp
src/armnnTfLiteParser/test/AvgPool2D.cpp
+ src/armnnTfLiteParser/test/BatchMatMul.cpp
src/armnnTfLiteParser/test/BatchToSpaceND.cpp
src/armnnTfLiteParser/test/Cast.cpp
src/armnnTfLiteParser/test/Comparison.cpp
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index acd089aef8..0f86b9dc48 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -37,14 +37,14 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
TensorShape inputXShape = inputShapes[0];
TensorShape inputYShape = inputShapes[1];
- // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
- if(m_Param.m_TransposeX)
+ // Adjoint is assumed to be square, but we will apply the permute anyway
+ if(m_Param.m_TransposeX || m_Param.m_AdjointX)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
inputXShape);
inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
}
- if(m_Param.m_TransposeY)
+ if(m_Param.m_TransposeY || m_Param.m_AdjointY)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
inputYShape);
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 880de100c1..030420345e 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -680,6 +680,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt
m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX] = &TfLiteParserImpl::ParseArgMax;
m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D] = &TfLiteParserImpl::ParseAveragePool2D;
m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND] = &TfLiteParserImpl::ParseBatchToSpaceND;
+ m_ParserFunctions[tflite::BuiltinOperator_BATCH_MATMUL] = &TfLiteParserImpl::ParseBatchMatMul;
m_ParserFunctions[tflite::BuiltinOperator_CAST] = &TfLiteParserImpl::ParseCast;
m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION] = &TfLiteParserImpl::ParseConcatenation;
m_ParserFunctions[tflite::BuiltinOperator_CONV_2D] = &TfLiteParserImpl::ParseConv2D;
@@ -1565,6 +1566,44 @@ void TfLiteParserImpl::ParseAveragePool2D(size_t subgraphIndex, size_t operatorI
ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Average);
}
+void TfLiteParserImpl::ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex)
+{
+ CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
+
+ auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ auto layerName = fmt::format("BatchMatMul:{}:{}", subgraphIndex, operatorIndex);
+
+ TensorInfo inputXTensorInfo = ToTensorInfo(inputs[0]);
+ TensorInfo inputYTensorInfo = ToTensorInfo(inputs[1]);
+
+ TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
+
+ const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
+ const auto* options = operatorPtr->builtin_options.AsBatchMatMulOptions();
+
+ BatchMatMulDescriptor descriptor(false,
+ false,
+ options->adj_x,
+ options->adj_y);
+ // Arbitrary DataLayout
+
+ IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
+
+ auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
+ RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
+}
+
void TfLiteParserImpl::ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex)
{
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp
index 49744a0484..f8ddc55649 100644
--- a/src/armnnTfLiteParser/TfLiteParser.hpp
+++ b/src/armnnTfLiteParser/TfLiteParser.hpp
@@ -114,6 +114,7 @@ private:
void ParseArgMin(size_t subgraphIndex, size_t operatorIndex);
void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
+ void ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex);
void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
void ParseCast(size_t subgraphIndex, size_t operatorIndex);
void ParseComparison(size_t subgraphIndex, size_t operatorIndex, armnn::ComparisonOperation comparisonOperation);
diff --git a/src/armnnTfLiteParser/test/BatchMatMul.cpp b/src/armnnTfLiteParser/test/BatchMatMul.cpp
new file mode 100644
index 0000000000..f4cdd67fb9
--- /dev/null
+++ b/src/armnnTfLiteParser/test/BatchMatMul.cpp
@@ -0,0 +1,114 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ParserFlatbuffersFixture.hpp"
+
+TEST_SUITE("TensorflowLiteParser_BatchMatMul")
+{
+struct BatchMatMulFixture : public ParserFlatbuffersFixture
+{
+ explicit BatchMatMulFixture(const std::string &inputXShape,
+ const std::string &inputYShape,
+ const std::string &outputShape,
+ const std::string &adjX,
+ const std::string &adjY)
+ {
+ m_JsonString = R"(
+ {
+ "version": 3,
+ "operator_codes": [ { "builtin_code": "BATCH_MATMUL" } ],
+ "subgraphs": [
+ {
+ "tensors": [
+ {
+ "shape": )" + inputXShape + R"(,
+ "type": "FLOAT32",
+ "buffer": 0,
+ "name": "inputXTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + inputYShape + R"(,
+ "type": "FLOAT32",
+ "buffer": 1,
+ "name": "inputYTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ },
+ {
+ "shape": )" + outputShape + R"(,
+ "type": "FLOAT32",
+ "buffer": 2,
+ "name": "outputTensor",
+ "quantization": {
+ "min": [ 0.0 ],
+ "max": [ 255.0 ],
+ "scale": [ 1.0 ],
+ "zero_point": [ 0 ],
+ }
+ }
+ ],
+ "inputs": [ 0, 1 ],
+ "outputs": [ 2 ],
+ "operators": [
+ {
+ "opcode_index": 0,
+ "inputs": [ 0 , 1 ],
+ "outputs": [ 2 ],
+ "builtin_options_type": "BatchMatMulOptions",
+ "builtin_options": {
+ adj_x: )" + adjX + R"(,
+ adj_y: )" + adjY + R"(,
+ "asymmetric_quantize_inputs": false
+ },
+ "custom_options_format": "FLEXBUFFERS"
+ }
+ ]
+ }
+ ],
+ "buffers": [{},{}]
+ }
+ )";
+ Setup();
+ }
+};
+
+struct BatchMatMulParamsFixture : BatchMatMulFixture
+{
+ BatchMatMulParamsFixture()
+ : BatchMatMulFixture("[ 1, 3, 3 ]",
+ "[ 1, 3, 3 ]",
+ "[ 1, 3, 3 ]",
+ "false",
+ "true")
+ {}
+};
+
+TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
+{
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ {{"inputXTensor", {2.0f, 3.0f, 5.0f,
+ 8.0f, 13.0f, 21.0f,
+ 34.0f, 55.0f, 89.0f}},
+ {"inputYTensor", {0.0f, 1.0f, 1.0f,
+ 1.0f, 0.0f, 1.0f,
+ 1.0f, 1.0f, 0.0f}}},
+ {{"outputTensor", {6.0f, 4.0f, 0.0f,
+ 26.0f, 16.0f, 0.0f,
+ 110.0f, 68.0f, 0.0f}}}
+ );
+}
+
+} \ No newline at end of file