// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "ParserFlatbuffersSerializeFixture.hpp" #include #include #include TEST_SUITE("Deserializer_BatchMatMul") { struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture { explicit BatchMatMulFixture(const std::string& inputXShape, const std::string& inputYShape, const std::string& outputShape, const std::string& dataType) { m_JsonString = R"( { inputIds:[ 0, 1 ], outputIds:[ 3 ], layers:[ { layer_type:"InputLayer", layer:{ base:{ layerBindingId:0, base:{ index:0, layerName:"InputXLayer", layerType:"Input", inputSlots:[ { index:0, connection:{ sourceLayerIndex:0, outputSlotIndex:0 }, } ], outputSlots:[ { index:0, tensorInfo:{ dimensions:)" + inputXShape + R"(, dataType:)" + dataType + R"( }, } ], }, } }, }, { layer_type:"InputLayer", layer:{ base:{ layerBindingId:1, base:{ index:1, layerName:"InputYLayer", layerType:"Input", inputSlots:[ { index:0, connection:{ sourceLayerIndex:0, outputSlotIndex:0 }, } ], outputSlots:[ { index:0, tensorInfo:{ dimensions:)" + inputYShape + R"(, dataType:)" + dataType + R"( }, } ], }, } }, }, { layer_type:"BatchMatMulLayer", layer:{ base:{ index:2, layerName:"BatchMatMulLayer", layerType:"BatchMatMul", inputSlots:[ { index:0, connection:{ sourceLayerIndex:0, outputSlotIndex:0 }, }, { index:1, connection:{ sourceLayerIndex:1, outputSlotIndex:0 }, } ], outputSlots:[ { index:0, tensorInfo:{ dimensions:)" + outputShape + R"(, dataType:)" + dataType + R"( }, } ], }, descriptor:{ transposeX:false, transposeY:false, adjointX:false, adjointY:false, dataLayoutX:NHWC, dataLayoutY:NHWC } }, }, { layer_type:"OutputLayer", layer:{ base:{ layerBindingId:0, base:{ index:3, layerName:"OutputLayer", layerType:"Output", inputSlots:[ { index:0, connection:{ sourceLayerIndex:2, outputSlotIndex:0 }, } ], outputSlots:[ { index:0, tensorInfo:{ dimensions:)" + outputShape + R"(, dataType:)" + dataType + R"( }, } ], } } }, } ] } )"; Setup(); } }; struct SimpleBatchMatMulFixture : BatchMatMulFixture { SimpleBatchMatMulFixture() : BatchMatMulFixture("[ 1, 2, 2, 1 ]", "[ 1, 2, 2, 1 ]", "[ 1, 2, 2, 1 ]", "Float32") {} }; TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest") { RunTest<4, armnn::DataType::Float32>( 0, {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }}, {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}}, {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}}); } }