aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h208
1 files changed, 208 insertions, 0 deletions
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index 90a4b5cf40..5793ebdd2d 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -611,6 +611,214 @@ protected:
TensorType _target{};
SimpleTensor<int32_t> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs)
+ {
+ GEMMLHSMatrixInfo lhs_info;
+ lhs_info.m0 = m0;
+ lhs_info.k0 = k0;
+
+ GEMMRHSMatrixInfo rhs_info;
+ rhs_info.n0 = n0;
+ rhs_info.k0 = k0;
+ rhs_info.h0 = h0;
+ rhs_info.interleave = interleave_rhs;
+ rhs_info.transpose = transpose_rhs;
+
+ // Set the tensor shapes for LHS and RHS matrices
+ const TensorShape lhs_shape(k, m, batch_size);
+ const TensorShape rhs_shape(n, k, batch_size);
+
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info);
+ _reference = compute_reference(lhs_shape, rhs_shape);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i)
+ {
+ // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+ std::uniform_int_distribution<> distribution(1, 254);
+ library->fill(tensor, distribution, i);
+ }
+
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
+ {
+ // Create tensors
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+ TensorType rhs_reshaped;
+ TensorType dst;
+
+ const unsigned int M = lhs_shape[1];
+ const unsigned int N = rhs_shape[0];
+ const unsigned int K = lhs_shape[0];
+
+ // The output tensor will be auto-initialized within the function
+
+ // Create and configure function
+ ReshapeRHSFunctionType reshape_rhs;
+ GEMMFunctionType gemm;
+ reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
+ gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
+
+ ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ rhs_reshaped.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(lhs), 0);
+ fill(AccessorType(rhs), 1);
+
+ // Compute GEMM
+ reshape_rhs.run();
+ gemm.run();
+
+ return dst;
+ }
+
+ SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape)
+ {
+ TensorShape dst_shape = lhs_shape;
+ dst_shape[0] = rhs_shape[0];
+ dst_shape[1] = lhs_shape[1];
+
+ // Create reference
+ SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
+ SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
+
+ TensorType _target{};
+ SimpleTensor<int32_t> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0,
+ bool interleave_rhs, bool transpose_rhs)
+ {
+ GEMMLHSMatrixInfo lhs_info;
+ lhs_info.m0 = m0;
+ lhs_info.k0 = k0;
+
+ GEMMRHSMatrixInfo rhs_info;
+ rhs_info.n0 = n0;
+ rhs_info.k0 = k0;
+ rhs_info.h0 = h0;
+ rhs_info.interleave = interleave_rhs;
+ rhs_info.transpose = transpose_rhs;
+
+ // In case of GEMM3D, m is the product between m_w and m_h
+ const unsigned int m = m_w * m_h;
+
+ // Set the tensor shapes for LHS and RHS matrices
+ const TensorShape lhs_shape(k, m, batch_size);
+ const TensorShape rhs_shape(n, k, batch_size);
+
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h);
+ _reference = compute_reference(lhs_shape, rhs_shape, m_h);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i)
+ {
+ // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+ std::uniform_int_distribution<> distribution(1, 254);
+ library->fill(tensor, distribution, i);
+ }
+
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h)
+ {
+ // Create tensors
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+ TensorType rhs_reshaped;
+ TensorType dst;
+
+ const unsigned int M = lhs_shape[1];
+ const unsigned int N = rhs_shape[0];
+ const unsigned int K = lhs_shape[0];
+
+ // The output tensor will be auto-initialized within the function
+
+ // Create and configure function
+ ReshapeRHSFunctionType reshape_rhs;
+ GEMMFunctionType gemm;
+ reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
+ gemm.configure(&lhs, &rhs_reshaped, &dst, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
+
+ ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ rhs_reshaped.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(lhs), 0);
+ fill(AccessorType(rhs), 1);
+
+ // Compute GEMM
+ reshape_rhs.run();
+ gemm.run();
+
+ return dst;
+ }
+
+ SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h)
+ {
+ TensorShape dst_shape = lhs_shape;
+ dst_shape.set(0, rhs_shape[0]);
+ dst_shape.set(1, lhs_shape[1] / m_h);
+ dst_shape.set(2, m_h);
+ dst_shape.set(3, lhs_shape[2]);
+
+ // Create reference
+ SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
+ SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
+
+ TensorType _target{};
+ SimpleTensor<int32_t> _reference{};
+};
} // namespace validation
} // namespace test
} // namespace arm_compute