15 GemmFixture(
const std::string& alpha,
16 const std::string& beta,
17 const std::string& transA,
18 const std::string& transB,
19 const std::vector<int>& inputAShape,
20 const std::vector<int>& inputBShape,
21 const std::vector<int>& inputCShape,
22 const std::vector<int>& outputShape)
26 producer_name: "onnx-example" 104 struct GemmAllAttributesFixture : GemmFixture
106 GemmAllAttributesFixture() : GemmFixture(
"0.25",
"0.35",
"1",
"1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
112 struct GemmSimpleFixture : GemmFixture
114 GemmSimpleFixture() : GemmFixture(
"1",
"1",
"0",
"0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
120 struct GemmTransAFixture : GemmFixture
122 GemmTransAFixture() : GemmFixture(
"1",
"1",
"1",
"0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
128 struct GemmTransBFixture : GemmFixture
130 GemmTransBFixture() : GemmFixture(
"1",
"1",
"0",
"1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
136 struct GemmParseExceptionFixture : GemmFixture
138 GemmParseExceptionFixture() : GemmFixture(
"1",
"1",
"0",
"1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
143 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
144 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
145 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
146 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
147 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
148 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
149 {
"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
150 {{
"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
151 12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
152 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
157 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
158 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
159 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
160 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
161 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
162 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
163 {
"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
164 {{
"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
165 196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
166 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
171 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
172 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
173 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
174 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
175 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
176 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
177 {
"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
178 {{
"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
179 146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
180 112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
185 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
186 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
187 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
188 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
189 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
190 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
191 {
"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
192 {{
"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
193 60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
194 20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
205 GemmConstantFixture()
209 producer_name: "onnx-example" 315 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
316 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
317 {{
"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
318 12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
319 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
324 GemmConstantSimpleFixture()
328 producer_name: "onnx-example" 434 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
435 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
436 {{
"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
437 196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
438 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
443 GemmABFixture(
const std::string& alpha,
444 const std::string& beta,
445 const std::string& transA,
446 const std::string& transB,
447 const std::vector<int>& inputAShape,
448 const std::vector<int>& inputBShape,
449 const std::vector<int>& outputShape)
453 producer_name: "onnx-example" 520 struct GemmAlphaTransAFixture : GemmABFixture
522 GemmAlphaTransAFixture() : GemmABFixture(
"0.25",
"0.35",
"1",
"0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
525 struct GemmAlphaTransBFixture : GemmABFixture
527 GemmAlphaTransBFixture() : GemmABFixture(
"0.25",
"0.35",
"0",
"1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
532 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
533 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
534 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
535 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
536 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
537 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
538 {{
"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
539 36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
540 28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
545 RunTest<2, float>({{
"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
546 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
547 {
"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
548 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
549 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
550 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
551 {{
"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
552 15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
553 5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
std::string ConstructTensorShapeString(const std::vector< int > &shape)
TEST_CASE_FIXTURE(ClContextControlFixture, "CopyBetweenNeonAndGpu")