aboutsummaryrefslogtreecommitdiff
path: root/tests/dataset/GEMMDataset.h
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2017-06-22 15:46:40 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:14:20 +0100
commit10c672c2e21bb77b7234d9d3611267400dce7ae0 (patch)
treee3a61be142225c38e36dc7db4719948a42afe32d /tests/dataset/GEMMDataset.h
parent84e3120f6803f66cd272729b1f3542cfd3bc75a5 (diff)
downloadComputeLibrary-10c672c2e21bb77b7234d9d3611267400dce7ae0.tar.gz
COMPMID-399 Add MatrixMultiply to benchmark
Change-Id: I86c3f808c0047c8d97211d21f61c4e79e2d2abb1 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78617 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/dataset/GEMMDataset.h')
-rw-r--r--tests/dataset/GEMMDataset.h32
1 files changed, 24 insertions, 8 deletions
diff --git a/tests/dataset/GEMMDataset.h b/tests/dataset/GEMMDataset.h
index f45bc3e838..ddd4a3424f 100644
--- a/tests/dataset/GEMMDataset.h
+++ b/tests/dataset/GEMMDataset.h
@@ -82,10 +82,10 @@ public:
SmallGEMMDataset()
: GenericDataset
{
- GEMMDataObject{ TensorShape(21u, 13u), TensorShape(33u, 21u), TensorShape(33u, 13u), TensorShape(33u, 13u), 1.0f, 0.0f },
- GEMMDataObject{ TensorShape(31u, 1u), TensorShape(23u, 31u), TensorShape(23u, 1u), TensorShape(23u, 1u), 1.0f, 0.0f },
- GEMMDataObject{ TensorShape(38u, 12u), TensorShape(21u, 38u), TensorShape(21u, 12u), TensorShape(21u, 12u), 0.2f, 1.2f },
- GEMMDataObject{ TensorShape(32u, 1u), TensorShape(17u, 32u), TensorShape(17u, 1u), TensorShape(17u, 1u), 0.4f, 0.7f },
+ GEMMDataObject{ TensorShape(21U, 13U), TensorShape(33U, 21U), TensorShape(33U, 13U), TensorShape(33U, 13U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U), TensorShape(23U, 1U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 12U), TensorShape(21U, 12U), 0.2f, 1.2f },
+ GEMMDataObject{ TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U), TensorShape(17U, 1U), 0.4f, 0.7f },
}
{
}
@@ -99,10 +99,10 @@ public:
LargeGEMMDataset()
: GenericDataset
{
- GEMMDataObject{ TensorShape(923u, 429u), TensorShape(871u, 923u), TensorShape(871u, 429u), TensorShape(871u, 429u), 1.0f, 0.0f },
- GEMMDataObject{ TensorShape(1021u, 1u), TensorShape(783u, 1021u), TensorShape(783u, 1u), TensorShape(783u, 1u), 1.0f, 0.0f },
- GEMMDataObject{ TensorShape(681u, 1023u), TensorShape(213u, 681u), TensorShape(213u, 1023u), TensorShape(213u, 1023u), 0.2f, 1.2f },
- GEMMDataObject{ TensorShape(941u, 1u), TensorShape(623u, 941u), TensorShape(623u, 1u), TensorShape(623u, 1u), 0.4f, 0.7f },
+ GEMMDataObject{ TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 429U), TensorShape(871U, 429U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(1021U, 1U), TensorShape(783U, 1021U), TensorShape(783U, 1U), TensorShape(783U, 1U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(681U, 1023U), TensorShape(213U, 681U), TensorShape(213U, 1023U), TensorShape(213U, 1023U), 0.2f, 1.2f },
+ GEMMDataObject{ TensorShape(941U, 1U), TensorShape(623U, 941U), TensorShape(623U, 1U), TensorShape(623U, 1U), 0.4f, 0.7f },
}
{
}
@@ -199,6 +199,22 @@ public:
~GoogLeNetGEMMDataset2() = default;
};
+
+class MatrixMultiplyDataset : public GenericDataset<GEMMDataObject, 3>
+{
+public:
+ MatrixMultiplyDataset()
+ : GenericDataset
+ {
+ GEMMDataObject{ TensorShape(1024U, 1U), TensorShape(1000U, 1024U), TensorShape(1000U, 1U), TensorShape(1000U, 1U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(256U, 784U), TensorShape(64U, 256U), TensorShape(64U, 784U), TensorShape(64U, 784U), 1.0f, 0.0f },
+ GEMMDataObject{ TensorShape(1152U, 2704U), TensorShape(256U, 1152U), TensorShape(256U, 2704U), TensorShape(256U, 2704U), 1.0f, 0.0f },
+ }
+ {
+ }
+
+ ~MatrixMultiplyDataset() = default;
+};
} // namespace test
} // namespace arm_compute
#endif //__ARM_COMPUTE_TEST_DATASET_GEMM_DATASET_H__