From 4b7a34dd92eb3f736e05ac6623fd147ecd8636b1 Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Wed, 6 Jul 2022 15:36:03 +0100 Subject: IVGCVSW-7109: Add Batch MatMul front end support - Reference * Descriptors added for BatchMatMul * Layer definition added * Input validation added (will likely change when opt. param support comes in) * Ref workload implementation for BatchMatMul added (will also change with opt. param support) * Ref layer tests made for BatchMatMul * CMake and other build files updated Signed-off-by: Samuel Yap Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617 --- src/backends/reference/test/RefLayerTests.cpp | 71 +++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) (limited to 'src/backends/reference/test') diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 419ae2b0e9..593dc7851e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1062,6 +1062,77 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1ElementInt32, Multiplicati ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest) +// Batch Mat Mul +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest); + // Batch Norm ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest) -- cgit v1.2.1