aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-09-22 10:12:58 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2022-09-22 10:57:32 +0000
commitbc37a6b83faf92c92570fb3137b5fd549f304b3f (patch)
tree203dd10b33108a27251b3fff555e9b045fe76482
parent49ed0df12338b1e99674edeee4200acf8c05750e (diff)
downloadarmnn-bc37a6b83faf92c92570fb3137b5fd549f304b3f.tar.gz
IVGCVSW-7240 Adjoint is Transpose in TFLite. Change in TFLite parser
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I1bc3d50b8fa6e216d8b6b7e3421d2ff37a21712c
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp7
-rw-r--r--src/armnnTfLiteParser/test/BatchMatMul.cpp14
2 files changed, 11 insertions, 10 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index a26f3e5f04..e036d0ca1c 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1586,10 +1586,11 @@ void TfLiteParserImpl::ParseBatchMatMul(size_t subgraphIndex, size_t operatorInd
const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
const auto* options = operatorPtr->builtin_options.AsBatchMatMulOptions();
- BatchMatMulDescriptor descriptor(false,
+ // Adjoint in tensorflow lite performs transpose operation
+ BatchMatMulDescriptor descriptor(options->adj_x,
+ options->adj_y,
false,
- options->adj_x,
- options->adj_y);
+ false);
// Arbitrary DataLayout
IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
diff --git a/src/armnnTfLiteParser/test/BatchMatMul.cpp b/src/armnnTfLiteParser/test/BatchMatMul.cpp
index f4cdd67fb9..467637f30e 100644
--- a/src/armnnTfLiteParser/test/BatchMatMul.cpp
+++ b/src/armnnTfLiteParser/test/BatchMatMul.cpp
@@ -12,8 +12,8 @@ struct BatchMatMulFixture : public ParserFlatbuffersFixture
explicit BatchMatMulFixture(const std::string &inputXShape,
const std::string &inputYShape,
const std::string &outputShape,
- const std::string &adjX,
- const std::string &adjY)
+ const std::string &tranX,
+ const std::string &tranY)
{
m_JsonString = R"(
{
@@ -68,8 +68,8 @@ struct BatchMatMulFixture : public ParserFlatbuffersFixture
"outputs": [ 2 ],
"builtin_options_type": "BatchMatMulOptions",
"builtin_options": {
- adj_x: )" + adjX + R"(,
- adj_y: )" + adjY + R"(,
+ adj_x: )" + tranX + R"(,
+ adj_y: )" + tranY + R"(,
"asymmetric_quantize_inputs": false
},
"custom_options_format": "FLEXBUFFERS"
@@ -105,9 +105,9 @@ TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams")
{"inputYTensor", {0.0f, 1.0f, 1.0f,
1.0f, 0.0f, 1.0f,
1.0f, 1.0f, 0.0f}}},
- {{"outputTensor", {6.0f, 4.0f, 0.0f,
- 26.0f, 16.0f, 0.0f,
- 110.0f, 68.0f, 0.0f}}}
+ {{"outputTensor", {8.0f, 7.0f, 5.0f,
+ 34.0f, 29.0f, 21.0f,
+ 144.0f, 123.0f, 89.0f}}}
);
}