From 0bc5a258903bbbf8623b6e1c4253d958551293ed Mon Sep 17 00:00:00 2001 From: Gian Marco Date: Mon, 4 Dec 2017 13:55:08 +0000 Subject: COMPMID-556 - Fix SGEMM example - alpha and beta were integer values whilst should be float. - Replaced CLImage with CLTensor - Replaced Format with DataType Change-Id: I19f81b52d2eab8976be689b601d8e8e2bedfc6aa Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111725 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Anthony Barbier --- examples/cl_sgemm.cpp | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) (limited to 'examples/cl_sgemm.cpp') diff --git a/examples/cl_sgemm.cpp b/examples/cl_sgemm.cpp index 8808f7ebf5..e1729a85b0 100644 --- a/examples/cl_sgemm.cpp +++ b/examples/cl_sgemm.cpp @@ -31,14 +31,16 @@ #include "arm_compute/runtime/CL/CLTuner.h" #include "utils/Utils.h" +#include + using namespace arm_compute; using namespace utils; void main_cl_sgemm(int argc, const char **argv) { NPYLoader npy0, npy1, npy2; - CLImage src0, src1, src2, dst; - int alpha = 1, beta = 0; + CLTensor src0, src1, src2, dst; + float alpha = 1.0f, beta = 0.0f; CLTuner tuner; CLScheduler::get().default_init(&tuner); @@ -53,21 +55,21 @@ void main_cl_sgemm(int argc, const char **argv) { // Print help std::cout << "Usage: 1) ./build/cl_sgemm input_matrix_1.npy input_matrix_2.npy [input_matrix_3.npy] [alpha = 1] [beta = 0]\n"; - std::cout << " 2) ./build/cl_sgemm M N K [alpha = 1] [beta = 0]\n\n"; - std::cout << "Too few or no input_matrices provided, creating random 5x7, 3x5 and 3x7 matrices\n\n"; + std::cout << " 2) ./build/cl_sgemm M N K [alpha = 1.0f] [beta = 0.0f]\n\n"; + std::cout << "Too few or no input_matrices provided. Using M=7, N=3, K=5, alpha=1.0f and beta=0.0f\n\n"; - src0.allocator()->init(TensorInfo(5, 7, Format::F32)); - src1.allocator()->init(TensorInfo(3, 5, Format::F32)); - src2.allocator()->init(TensorInfo(3, 7, Format::F32)); + src0.allocator()->init(TensorInfo(TensorShape(5U, 7U), 1, DataType::F32)); + src1.allocator()->init(TensorInfo(TensorShape(3U, 5U), 1, DataType::F32)); + src2.allocator()->init(TensorInfo(TensorShape(3U, 7U), 1, DataType::F32)); } else { - if(stream.good()) /* case file1.npy file2.npy [file3.npy] [alpha = 1] [beta = 0] */ + if(stream.good()) /* case file1.npy file2.npy [file3.npy] [alpha = 1.0f] [beta = 0.0f] */ { npy0.open(argv[1]); - npy0.init_tensor(src0, Format::F32); + npy0.init_tensor(src0, DataType::F32); npy1.open(argv[2]); - npy1.init_tensor(src1, Format::F32); + npy1.init_tensor(src1, DataType::F32); if(argc > 3) { @@ -77,52 +79,54 @@ void main_cl_sgemm(int argc, const char **argv) if(stream.good()) /* case with third file */ { npy2.open(argv[3]); - npy2.init_tensor(src2, Format::F32); + npy2.init_tensor(src2, DataType::F32); if(argc > 4) { - alpha = strtol(argv[4], nullptr, 10); + // Convert string to float + alpha = strtof(argv[4], nullptr); if(argc > 5) { - beta = strtol(argv[5], nullptr, 10); + // Convert string to float + beta = strtof(argv[5], nullptr); } } } else /* case without third file */ { - alpha = strtol(argv[3], nullptr, 10); + alpha = strtof(argv[3], nullptr); if(argc > 4) { - beta = strtol(argv[4], nullptr, 10); + beta = strtof(argv[4], nullptr); } } } } - else /* case M N K [alpha = 1] [beta = 0] */ + else /* case M N K [alpha = 1.0f] [beta = 0.0f] */ { size_t M = strtol(argv[1], nullptr, 10); size_t N = strtol(argv[2], nullptr, 10); size_t K = strtol(argv[3], nullptr, 10); - src0.allocator()->init(TensorInfo(K, M, Format::F32)); - src1.allocator()->init(TensorInfo(N, K, Format::F32)); - src2.allocator()->init(TensorInfo(N, M, Format::F32)); + src0.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::F32)); + src1.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::F32)); + src2.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32)); if(argc > 4) { - alpha = strtol(argv[4], nullptr, 10); + alpha = strtof(argv[4], nullptr); if(argc > 5) { - beta = strtol(argv[5], nullptr, 10); + beta = strtof(argv[5], nullptr); } } } } - init_sgemm_output(dst, src0, src1, Format::F32); + init_sgemm_output(dst, src0, src1, DataType::F32); // Configure function CLGEMM sgemm; -- cgit v1.2.1