From 2d60f0063eb91f6514b20a1817663ce0ddd3ff4a Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 9 Jun 2021 14:18:32 -0700 Subject: adding batch dimension to MatMul. Signed-off-by: Kevin Cheng Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07 --- reference_model/src/ops/tensor_ops.cc | 79 +++++++++++++++++++++++++++++++---- reference_model/src/ops/tensor_ops.h | 12 ++++-- verif/tosa_test_gen.py | 14 +++---- 3 files changed, 86 insertions(+), 19 deletions(-) diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b8c7ade..0007553 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -742,7 +742,7 @@ OpMatMul::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinf : GraphNode(Op_MATMUL, id_) { setRequiredOperands(2, 1); - setRequiredRank(2); + setRequiredRank(3); INIT_QINFO(MatMul); } @@ -765,16 +765,47 @@ int OpMatMul::checkTensorAttributes() return 1; } - a = dynamic_cast*>(inputs[0]); - b = dynamic_cast*>(inputs[1]); + a = dynamic_cast*>(inputs[0]); + b = dynamic_cast*>(inputs[1]); + output = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(a && b && output); + + // a: [N, H, C] + // b: [N, C, W] + // c: [N, H, W] - if (a->getShape()[1] != b->getShape()[0]) + // Check N + if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0]) { - printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]"); + printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match"); return 1; } + N = a->getShape()[0]; - c = dynamic_cast*>(outputs[0]); + // Check C + if (a->getShape()[2] != b->getShape()[1]) + { + printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]"); + return 1; + } + C = a->getShape()[2]; + + // Check H + if (a->getShape()[1] != output->getShape()[1]) + { + printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]"); + return 1; + } + H = a->getShape()[1]; + + // Check W + if (b->getShape()[2] != output->getShape()[2]) + { + printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]"); + return 1; + } + W = b->getShape()[2]; return 0; } @@ -793,12 +824,42 @@ int OpMatMul::eval() b_val = b_val - (InEigenType)this->qinfo->b_zp(); } - this->c->getTensor() = a_val.template cast().contract(b_val.template cast(), dims); + Eigen::array a_rank2_shape({ H, C }); + Eigen::array b_rank2_shape({ C, W }); + Eigen::array output_rank3_shape({ 1, H, W }); + + Eigen::array a_size_array({ 1, H, C }); + Eigen::array b_size_array({ 1, C, W }); + + Eigen::array a_begin_array({ 0, 0, 0 }); + Eigen::array b_begin_array({ 0, 0, 0 }); + + // Iterate N dimension. + for (int i = 0; i < N; i++) + { + a_begin_array[0] = i; + b_begin_array[0] = i; + + TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape); + TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape); + TAccRank2 output_rank2_val = + a_rank2_val.template cast().contract(b_rank2_val.template cast(), dims); + TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape); + if (i == 0) + { + this->output->getTensor() = output_rank3_val; + } + else + { + TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0); + this->output->getTensor() = temp; + } + } if (AccDtype == DType_INT48) { - this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin); - this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); } return GraphNode::eval(); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 26ce84b..9aaa140 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -183,15 +183,21 @@ public: static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using AccEigenType = typename GetEigenType::type; - using TIn = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TIn = Eigen::Tensor; + using TAcc = Eigen::Tensor; + using TInRank2 = Eigen::Tensor; + using TAccRank2 = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; - TosaReference::TensorTemplate* c; + TosaReference::TensorTemplate* output; + int64_t N; + int64_t H; + int64_t W; + int64_t C; tosa::TosaMatMulQuantInfo* qinfo; }; diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 5670d1b..6f9acf4 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -314,12 +314,12 @@ class TosaTensorGen: def tgMatmul(testGen, op, rank): pl, const = op["operands"] - assert rank == 2 + assert rank == 3 assert pl == 2 and const == 0 a_shape = testGen.makeShape(rank) b_oc = testGen.makeShape(1)[0] - b_shape = np.asarray([a_shape[1], b_oc]) + b_shape = np.asarray([a_shape[0], a_shape[2], b_oc]) return [a_shape, b_shape] @@ -1994,7 +1994,7 @@ class TosaTestGen: "matmul": { "op": Op.MATMUL, "operands": (2, 0), - "rank": (2, 2), + "rank": (3, 3), "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, @@ -2630,11 +2630,11 @@ class OutputShaper: @staticmethod def matmulOp(ser, a, b): - # a: M, K - # b: K, N - # out: M, N + # a: N, H, C + # b: N, C, W + # out: N, H, W - output_shape = [a.shape[0], b.shape[1]] + output_shape = [a.shape[0], a.shape[1], b.shape[2]] if a.dtype == DType.INT8: out_dtype = DType.INT32 -- cgit v1.2.1