diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2022-09-22 10:12:58 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2022-09-22 10:57:32 +0000 |
commit | bc37a6b83faf92c92570fb3137b5fd549f304b3f (patch) | |
tree | 203dd10b33108a27251b3fff555e9b045fe76482 | |
parent | 49ed0df12338b1e99674edeee4200acf8c05750e (diff) | |
download | armnn-bc37a6b83faf92c92570fb3137b5fd549f304b3f.tar.gz |
IVGCVSW-7240 Adjoint is Transpose in TFLite. Change in TFLite parser
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I1bc3d50b8fa6e216d8b6b7e3421d2ff37a21712c
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 7 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/BatchMatMul.cpp | 14 |
2 files changed, 11 insertions, 10 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index a26f3e5f04..e036d0ca1c 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1586,10 +1586,11 @@ void TfLiteParserImpl::ParseBatchMatMul(size_t subgraphIndex, size_t operatorInd const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex]; const auto* options = operatorPtr->builtin_options.AsBatchMatMulOptions(); - BatchMatMulDescriptor descriptor(false, + // Adjoint in tensorflow lite performs transpose operation + BatchMatMulDescriptor descriptor(options->adj_x, + options->adj_y, false, - options->adj_x, - options->adj_y); + false); // Arbitrary DataLayout IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str()); diff --git a/src/armnnTfLiteParser/test/BatchMatMul.cpp b/src/armnnTfLiteParser/test/BatchMatMul.cpp index f4cdd67fb9..467637f30e 100644 --- a/src/armnnTfLiteParser/test/BatchMatMul.cpp +++ b/src/armnnTfLiteParser/test/BatchMatMul.cpp @@ -12,8 +12,8 @@ struct BatchMatMulFixture : public ParserFlatbuffersFixture explicit BatchMatMulFixture(const std::string &inputXShape, const std::string &inputYShape, const std::string &outputShape, - const std::string &adjX, - const std::string &adjY) + const std::string &tranX, + const std::string &tranY) { m_JsonString = R"( { @@ -68,8 +68,8 @@ struct BatchMatMulFixture : public ParserFlatbuffersFixture "outputs": [ 2 ], "builtin_options_type": "BatchMatMulOptions", "builtin_options": { - adj_x: )" + adjX + R"(, - adj_y: )" + adjY + R"(, + adj_x: )" + tranX + R"(, + adj_y: )" + tranY + R"(, "asymmetric_quantize_inputs": false }, "custom_options_format": "FLEXBUFFERS" @@ -105,9 +105,9 @@ TEST_CASE_FIXTURE(BatchMatMulParamsFixture, "ParseBatchMatMulParams") {"inputYTensor", {0.0f, 1.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f, 0.0f}}}, - {{"outputTensor", {6.0f, 4.0f, 0.0f, - 26.0f, 16.0f, 0.0f, - 110.0f, 68.0f, 0.0f}}} + {{"outputTensor", {8.0f, 7.0f, 5.0f, + 34.0f, 29.0f, 21.0f, + 144.0f, 123.0f, 89.0f}}} ); } |