aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-14 16:11:10 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-20 16:02:39 +0000
commite16c8906a2aedf00e910754a01fca8bc4189cfc7 (patch)
treede9b88917bb00a76a9df68c9e92f05e38c5de817 /tests
parent0cbfda629dd8f684e625173341bab972f004222c (diff)
downloadComputeLibrary-e16c8906a2aedf00e910754a01fca8bc4189cfc7.tar.gz
COMPMID-2053: Fuse bias addition with CLGEMMMatrixMultiplyReshapedKernel
Change-Id: I5bfd38c94a6fd18a1cba2104f7e1b04e7bef6ec2 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1359 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/framework/Macros.h9
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp69
-rw-r--r--tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp27
-rw-r--r--tests/validation/fixtures/GEMMFixture.h255
4 files changed, 209 insertions, 151 deletions
diff --git a/tests/framework/Macros.h b/tests/framework/Macros.h
index 591b80e9d8..134f75e287 100644
--- a/tests/framework/Macros.h
+++ b/tests/framework/Macros.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,8 +49,8 @@
#define CONCAT(ARG0, ARG1) ARG0##ARG1
-#define VARIADIC_SIZE_IMPL(e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, size, ...) size
-#define VARIADIC_SIZE(...) VARIADIC_SIZE_IMPL(__VA_ARGS__, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
+#define VARIADIC_SIZE_IMPL(e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, size, ...) size
+#define VARIADIC_SIZE(...) VARIADIC_SIZE_IMPL(__VA_ARGS__, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#define JOIN_PARAM1(OP, param) OP(0, param)
#define JOIN_PARAM2(OP, param, ...) \
@@ -83,6 +83,9 @@
#define JOIN_PARAM11(OP, param, ...) \
OP(10, param) \
, JOIN_PARAM10(OP, __VA_ARGS__)
+#define JOIN_PARAM12(OP, param, ...) \
+ OP(11, param) \
+ , JOIN_PARAM11(OP, __VA_ARGS__)
#define JOIN_PARAM(OP, NUM, ...) \
CONCAT(JOIN_PARAM, NUM) \
(OP, __VA_ARGS__)
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index 564d3f4c2f..69e58303f3 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -76,6 +76,9 @@ constexpr float tolerance_num_f16 = 0.02f;
/** Alpha values to test - Precommit */
const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
+/** Beta values to test - Precommit */
+const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
+
/** M values to test */
const auto m_values = framework::dataset::make("M", 37);
@@ -130,8 +133,11 @@ const auto i_values_lhs = framework::dataset::make("interleave_lhs", { true, fal
/** Interleave values to test with RHS matrix */
const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, false });
+/** Broadcast bias from vector to matrix */
+const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} );
+
/** Configuration test */
-void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, DataType data_type)
+void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, DataType data_type)
{
const unsigned int M = m_value;
const unsigned int N = n_value;
@@ -151,7 +157,7 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned
rhs_info.interleave = i_value_rhs;
rhs_info.transpose = true;
- GEMMReshapeInfo gemm_info(M, N, K);
+ GEMMReshapeInfo gemm_info(M, N, K, false, false, 0, false, broadcast_bias);
const TensorShape lhs_shape(K, M, b_value);
const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, data_type),
@@ -166,18 +172,24 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned
TensorInfo(rhs_shape_reshaped, 1, data_type),
gemm_info);
+ const TensorShape bias_shape(N,
+ broadcast_bias? 1 : M,
+ broadcast_bias? 1 : b_value);
+
// Create tensors
CLTensor lhs_reshaped = create_tensor<CLTensor>(lhs_shape_reshaped, data_type);
CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
+ CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Create and configure function
CLGEMMMatrixMultiplyReshaped gemm;
- gemm.configure(&lhs_reshaped, &rhs_reshaped, &dst, 1.0f, lhs_info, rhs_info, gemm_info);
+ gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, gemm_info);
}
} // namespace
@@ -185,7 +197,7 @@ TEST_SUITE(CL)
TEST_SUITE(GEMMMatrixMultiplyReshaped)
TEST_SUITE(Float)
TEST_SUITE(FP32)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -197,13 +209,14 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi
h0_values_precommit),
i_values_lhs),
i_values_rhs),
-m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs)
+ broadcast_bias_values),
+m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias)
{
- validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, DataType::F32);
+ validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, DataType::F32);
}
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -216,14 +229,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values))
+ a_values),
+ beta_values),
+ broadcast_bias_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -236,14 +251,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<float>, fra
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values))
+ a_values),
+ beta_values),
+ broadcast_bias_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
@@ -257,14 +274,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values))
+ a_values),
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
@@ -278,7 +296,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<float>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
- a_values))
+ a_values),
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -287,7 +306,7 @@ TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -300,14 +319,16 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values))
+ a_values),
+ beta_values),
+ broadcast_bias_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -320,14 +341,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, fram
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values))
+ a_values),
+ beta_values),
+ broadcast_bias_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
}
FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
@@ -341,14 +364,15 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values))
+ a_values),
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
}
FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
@@ -362,7 +386,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>,
i_values_lhs),
i_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
- a_values))
+ a_values),
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
index 23ae004912..133170e2d3 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRHS.cpp
@@ -123,7 +123,7 @@ const auto i_values_rhs = framework::dataset::make("interleave_rhs", { true, fal
/** Transpose values to test with RHS matrix */
const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true, false });
-/**Broadcast bias from vector to matrix */
+/** Broadcast bias from vector to matrix */
const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", {false, true} );
/** Configuration test */
@@ -155,18 +155,15 @@ void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned
TensorInfo(rhs_shape_reshaped, 1, data_type),
gemm_info);
+ const TensorShape bias_shape(N,
+ broadcast_bias? 1 : M,
+ broadcast_bias? 1 : b_value);
+
// Create tensors
CLTensor lhs = create_tensor<CLTensor>(lhs_shape, data_type);
CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
- CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
-
- TensorShape bias_shape = dst_shape;
- if (broadcast_bias)
- {
- bias_shape[1] = 1;
- bias_shape[2] = 1;
- }
CLTensor bias = create_tensor<CLTensor>(bias_shape, data_type);
+ CLTensor dst = create_tensor<CLTensor>(dst_shape, data_type);
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -257,7 +254,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -278,7 +275,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F32)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
@@ -300,7 +297,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -321,7 +318,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedOnlyRHSFixture<half
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values),
+ beta_values),
broadcast_bias_values))
{
// Validate output
@@ -343,7 +340,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
@@ -364,7 +361,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshapedOnlyRHS3DFixture<
t_values_rhs),
framework::dataset::make("DataType", DataType::F16)),
a_values),
- b_values))
+ beta_values))
{
// Validate output
validate(CLAccessor(_target), _reference, rel_tolerance_f16, tolerance_num_f16);
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 34f9bd848c..fcb41bb0ba 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -157,7 +157,7 @@ class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture
public:
template <typename...>
void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0, bool interleave_lhs,
- bool interleave_rhs, DataType data_type, float alpha)
+ bool interleave_rhs, DataType data_type, float alpha, float beta, bool broadcast_bias)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
@@ -176,9 +176,12 @@ public:
// Set the tensor shapes for LHS and RHS matrices
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
+ const TensorShape bias_shape(n,
+ broadcast_bias ? 1 : m,
+ broadcast_bias ? 1 : batch_size);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias);
+ _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias);
}
protected:
@@ -193,11 +196,13 @@ protected:
library->fill_borders_with_garbage(tensor, distribution_inf, i);
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ DataType data_type, float alpha, float beta, bool broadcast_bias)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
TensorType lhs_reshaped;
TensorType rhs_reshaped;
TensorType dst;
@@ -214,20 +219,23 @@ protected:
GEMMFunctionType gemm;
reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
- gemm.configure(&lhs_reshaped, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
+ gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, 0, false, broadcast_bias));
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
lhs.allocator()->allocate();
rhs.allocator()->allocate();
lhs_reshaped.allocator()->allocate();
rhs_reshaped.allocator()->allocate();
+ bias.allocator()->allocate();
dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -235,6 +243,7 @@ protected:
// Fill tensors
fill(AccessorType(lhs), 0);
fill(AccessorType(rhs), 1);
+ fill(AccessorType(bias), 2);
// Compute GEMM
reshape_lhs.run();
@@ -244,7 +253,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha)
+ SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, bool broadcast_bias)
{
TensorShape dst_shape = lhs_shape;
dst_shape[0] = rhs_shape[0];
@@ -253,13 +262,27 @@ protected:
// Create reference
SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
- SimpleTensor<T> c{ dst_shape, data_type, 1 };
+ SimpleTensor<T> bias{ dst_shape, data_type, 1 };
+
+ const int n = rhs_shape[0];
+ const int m = lhs_shape[1];
+ const int batch_size = lhs_shape[2];
// Fill reference
fill(lhs, 0);
fill(rhs, 1);
+ fill(bias, 2);
- return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f);
+ if(broadcast_bias)
+ {
+ // In case of broadcast, we need simply copy the first into the following "M" ones
+ for(int i = 1; i < m * batch_size; i++)
+ {
+ memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
+ }
+ }
+
+ return reference::gemm<T>(lhs, rhs, bias, alpha, beta);
}
TensorType _target{};
@@ -273,7 +296,7 @@ public:
template <typename...>
void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
bool interleave_lhs,
- bool interleave_rhs, DataType data_type, float alpha)
+ bool interleave_rhs, DataType data_type, float alpha, float beta)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
@@ -295,9 +318,10 @@ public:
// Set the tensor shapes for LHS and RHS matrices
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
+ const TensorShape bias_shape(n, 1, 1);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha, m_h);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, m_h);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h);
+ _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h);
}
protected:
@@ -308,12 +332,13 @@ protected:
library->fill(tensor, distribution, i);
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha,
- unsigned int m_h)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ DataType data_type, float alpha, float beta, unsigned int m_h)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
TensorType lhs_reshaped;
TensorType rhs_reshaped;
TensorType dst;
@@ -330,27 +355,31 @@ protected:
GEMMFunctionType gemm;
reshape_lhs.configure(&lhs, &lhs_reshaped, lhs_info);
reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
- gemm.configure(&lhs_reshaped, &rhs_reshaped, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h));
+ gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h, false, true));
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
lhs.allocator()->allocate();
rhs.allocator()->allocate();
lhs_reshaped.allocator()->allocate();
rhs_reshaped.allocator()->allocate();
+ bias.allocator()->allocate();
dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
fill(AccessorType(lhs), 0);
fill(AccessorType(rhs), 1);
+ fill(AccessorType(bias), 2);
// Compute GEMM
reshape_lhs.run();
@@ -360,7 +389,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, unsigned int m_h)
+ SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h)
{
TensorShape dst_shape = lhs_shape;
dst_shape.set(0, rhs_shape[0]);
@@ -371,13 +400,24 @@ protected:
// Create reference
SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
- SimpleTensor<T> c{ dst_shape, data_type, 1 };
+ SimpleTensor<T> bias{ dst_shape, data_type, 1 };
+
+ const int n = rhs_shape[0];
+ const int m = lhs_shape[1];
+ const int batch_size = lhs_shape[2];
// Fill reference
fill(lhs, 0);
fill(rhs, 1);
+ fill(bias, 2);
- return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f);
+ // In case of broadcast, we need simply copy the first into the following "M" ones
+ for(int i = 1; i < m * batch_size; i++)
+ {
+ memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
+ }
+
+ return reference::gemm<T>(lhs, rhs, bias, alpha, beta);
}
TensorType _target{};
@@ -406,16 +446,9 @@ public:
// Set the tensor shapes for LHS and RHS matrices
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
-
- TensorShape bias_shape;
- if(broadcast_bias)
- {
- bias_shape = TensorShape(n, 1, 1);
- }
- else
- {
- bias_shape = TensorShape(n, m, batch_size);
- }
+ const TensorShape bias_shape(n,
+ broadcast_bias ? 1 : m,
+ broadcast_bias ? 1 : batch_size);
_target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias);
_reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, broadcast_bias);
@@ -457,6 +490,7 @@ protected:
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
lhs.allocator()->allocate();
@@ -468,6 +502,7 @@ protected:
ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
@@ -500,20 +535,16 @@ protected:
// Fill reference
fill(lhs, 0);
fill(rhs, 1);
+ fill(bias, 2);
if(broadcast_bias)
{
- SimpleTensor<T> tmp{ bias_shape, data_type, 1 };
- fill(tmp, 2);
- for(int i = 0; i < m * batch_size; i++)
+ // In case of broadcast, we need simply copy the first into the following "M" ones
+ for(int i = 1; i < m * batch_size; i++)
{
- memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T));
+ memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
}
}
- else
- {
- fill(bias, 2);
- }
return (reference::gemm<T>(lhs, rhs, bias, alpha, beta));
}
@@ -522,27 +553,35 @@ protected:
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename GEMMFunctionType>
-class GEMMMatrixMultiplyNativeValidationFixture : public framework::Fixture
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+class GEMMMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, DataType data_type, float alpha)
+ void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0,
+ bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
lhs_info.k0 = k0;
GEMMRHSMatrixInfo rhs_info;
- rhs_info.n0 = n0;
- rhs_info.k0 = k0;
+ rhs_info.n0 = n0;
+ rhs_info.k0 = k0;
+ rhs_info.h0 = h0;
+ rhs_info.interleave = interleave_rhs;
+ rhs_info.transpose = transpose_rhs;
+
+ // In case of GEMM3D, m is the product between m_w and m_h
+ const unsigned int m = m_w * m_h;
// Set the tensor shapes for LHS and RHS matrices
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
+ const TensorShape bias_shape(n, 1, 1);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha);
- _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha);
+ _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h);
+ _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h);
}
protected:
@@ -551,100 +590,116 @@ protected:
{
std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, i);
-
- // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
- std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
- library->fill_borders_with_garbage(tensor, distribution_inf, i);
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+ DataType data_type, float alpha, float beta,
+ unsigned int m_h)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
+ TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
+ TensorType rhs_reshaped;
TensorType dst;
const unsigned int M = lhs_shape[1];
const unsigned int N = rhs_shape[0];
const unsigned int K = lhs_shape[0];
+ // The output tensor will be auto-initialized within the function
+
// Create and configure function
- GEMMFunctionType gemm;
- gemm.configure(&lhs, &rhs, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
+ ReshapeRHSFunctionType reshape_rhs;
+ GEMMFunctionType gemm;
+ reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
+ gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h, false, true));
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
lhs.allocator()->allocate();
rhs.allocator()->allocate();
+ rhs_reshaped.allocator()->allocate();
+ bias.allocator()->allocate();
dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
fill(AccessorType(lhs), 0);
fill(AccessorType(rhs), 1);
+ fill(AccessorType(bias), 2);
// Compute GEMM
+ reshape_rhs.run();
gemm.run();
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha)
+ SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h)
{
TensorShape dst_shape = lhs_shape;
- dst_shape[0] = rhs_shape[0];
- dst_shape[1] = lhs_shape[1];
+ dst_shape.set(0, rhs_shape[0]);
+ dst_shape.set(1, lhs_shape[1] / m_h);
+ dst_shape.set(2, m_h);
+ dst_shape.set(3, lhs_shape[2]);
// Create reference
SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
- SimpleTensor<T> c{ dst_shape, data_type, 1 };
+ SimpleTensor<T> bias{ dst_shape, data_type, 1 };
+
+ const int n = rhs_shape[0];
+ const int m = lhs_shape[1];
+ const int batch_size = lhs_shape[2];
// Fill reference
fill(lhs, 0);
fill(rhs, 1);
+ fill(bias, 2);
- return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f);
+ // In case of broadcast, we need simply copy the first into the following "M" ones
+ for(int i = 1; i < m * batch_size; i++)
+ {
+ memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
+ }
+
+ return reference::gemm<T>(lhs, rhs, bias, alpha, beta);
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
-class GEMMMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture
+template <typename TensorType, typename AccessorType, typename T, typename GEMMFunctionType>
+class GEMMMatrixMultiplyNativeValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0,
- bool interleave_rhs, bool transpose_rhs, DataType data_type, float alpha, float beta)
+ void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, DataType data_type, float alpha)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
lhs_info.k0 = k0;
GEMMRHSMatrixInfo rhs_info;
- rhs_info.n0 = n0;
- rhs_info.k0 = k0;
- rhs_info.h0 = h0;
- rhs_info.interleave = interleave_rhs;
- rhs_info.transpose = transpose_rhs;
-
- // In case of GEMM3D, m is the product between m_w and m_h
- const unsigned int m = m_w * m_h;
+ rhs_info.n0 = n0;
+ rhs_info.k0 = k0;
// Set the tensor shapes for LHS and RHS matrices
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
- const TensorShape bias_shape(n, 1, 1);
- _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, m_h);
- _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, alpha, beta, m_h);
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type, alpha);
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha);
}
protected:
@@ -653,30 +708,26 @@ protected:
{
std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, i);
+
+ // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
+ std::uniform_real_distribution<> distribution_inf(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
+ library->fill_borders_with_garbage(tensor, distribution_inf, i);
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
- DataType data_type, float alpha, float beta,
- unsigned int m_h)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, DataType data_type, float alpha)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
- TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
- TensorType rhs_reshaped;
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
TensorType dst;
const unsigned int M = lhs_shape[1];
const unsigned int N = rhs_shape[0];
const unsigned int K = lhs_shape[0];
- // The output tensor will be auto-initialized within the function
-
// Create and configure function
- ReshapeRHSFunctionType reshape_rhs;
- GEMMFunctionType gemm;
- reshape_rhs.configure(&rhs, &rhs_reshaped, rhs_info);
- gemm.configure(&lhs, &rhs_reshaped, &bias, &dst, alpha, beta, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K, 1, 1, m_h, false, true));
+ GEMMFunctionType gemm;
+ gemm.configure(&lhs, &rhs, &dst, alpha, lhs_info, rhs_info, GEMMReshapeInfo(M, N, K));
ARM_COMPUTE_EXPECT(lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -684,56 +735,38 @@ protected:
// Allocate tensors
lhs.allocator()->allocate();
rhs.allocator()->allocate();
- rhs_reshaped.allocator()->allocate();
- bias.allocator()->allocate();
dst.allocator()->allocate();
ARM_COMPUTE_EXPECT(!lhs.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!rhs.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
fill(AccessorType(lhs), 0);
fill(AccessorType(rhs), 1);
- fill(AccessorType(bias), 2);
// Compute GEMM
- reshape_rhs.run();
gemm.run();
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, float alpha, float beta, unsigned int m_h)
+ SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha)
{
TensorShape dst_shape = lhs_shape;
- dst_shape.set(0, rhs_shape[0]);
- dst_shape.set(1, lhs_shape[1] / m_h);
- dst_shape.set(2, m_h);
- dst_shape.set(3, lhs_shape[2]);
+ dst_shape[0] = rhs_shape[0];
+ dst_shape[1] = lhs_shape[1];
// Create reference
SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
- SimpleTensor<T> bias{ dst_shape, data_type, 1 };
-
- const int n = rhs_shape[0];
- const int m = lhs_shape[1];
- const int batch_size = lhs_shape[2];
+ SimpleTensor<T> c{ dst_shape, data_type, 1 };
// Fill reference
fill(lhs, 0);
fill(rhs, 1);
- SimpleTensor<T> tmp{ bias_shape, data_type, 1 };
- fill(tmp, 2);
- for(int i = 0; i < m * batch_size; i++)
- {
- memcpy(bias.data() + i * n, tmp.data(), n * sizeof(T));
- }
-
- return reference::gemm<T>(lhs, rhs, bias, alpha, beta);
+ return reference::gemm<T>(lhs, rhs, c, alpha, 0.0f);
}
TensorType _target{};