diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-09 14:18:32 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-09 14:19:17 -0700 |
commit | 2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (patch) | |
tree | befdb31f63a91a245605f94e2c83cbf070210854 /reference_model/src | |
parent | cd79f0e06bf53c2c0fee39ee916bb6d79f177b57 (diff) | |
download | reference_model-2d60f0063eb91f6514b20a1817663ce0ddd3ff4a.tar.gz |
adding batch dimension to MatMul.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 79 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 12 |
2 files changed, 79 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b8c7ade..0007553 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -742,7 +742,7 @@ OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinf : GraphNode(Op_MATMUL, id_) { setRequiredOperands(2, 1); - setRequiredRank(2); + setRequiredRank(3); INIT_QINFO(MatMul); } @@ -765,16 +765,47 @@ int OpMatMul<Dtype>::checkTensorAttributes() return 1; } - a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + ASSERT_MEM(a && b && output); + + // a: [N, H, C] + // b: [N, C, W] + // c: [N, H, W] - if (a->getShape()[1] != b->getShape()[0]) + // Check N + if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0]) { - printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]"); + printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match"); return 1; } + N = a->getShape()[0]; - c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + // Check C + if (a->getShape()[2] != b->getShape()[1]) + { + printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]"); + return 1; + } + C = a->getShape()[2]; + + // Check H + if (a->getShape()[1] != output->getShape()[1]) + { + printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]"); + return 1; + } + H = a->getShape()[1]; + + // Check W + if (b->getShape()[2] != output->getShape()[2]) + { + printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]"); + return 1; + } + W = b->getShape()[2]; return 0; } @@ -793,12 +824,42 @@ int OpMatMul<Dtype>::eval() b_val = b_val - (InEigenType)this->qinfo->b_zp(); } - this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims); + Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C }); + Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W }); + Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W }); + + Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C }); + Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W }); + + Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 }); + Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 }); + + // Iterate N dimension. + for (int i = 0; i < N; i++) + { + a_begin_array[0] = i; + b_begin_array[0] = i; + + TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape); + TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape); + TAccRank2 output_rank2_val = + a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims); + TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape); + if (i == 0) + { + this->output->getTensor() = output_rank3_val; + } + else + { + TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0); + this->output->getTensor() = temp; + } + } if (AccDtype == DType_INT48) { - this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin); - this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); } return GraphNode::eval(); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 26ce84b..9aaa140 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -183,15 +183,21 @@ public: static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; using InEigenType = typename GetEigenType<Dtype>::type; using AccEigenType = typename GetEigenType<AccDtype>::type; - using TIn = Eigen::Tensor<InEigenType, 2>; - using TAcc = Eigen::Tensor<AccEigenType, 2>; + using TIn = Eigen::Tensor<InEigenType, 3>; + using TAcc = Eigen::Tensor<AccEigenType, 3>; + using TInRank2 = Eigen::Tensor<InEigenType, 2>; + using TAccRank2 = Eigen::Tensor<AccEigenType, 2>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; protected: TosaReference::TensorTemplate<TIn>* a; TosaReference::TensorTemplate<TIn>* b; - TosaReference::TensorTemplate<TAcc>* c; + TosaReference::TensorTemplate<TAcc>* output; + int64_t N; + int64_t H; + int64_t W; + int64_t C; tosa::TosaMatMulQuantInfo* qinfo; }; |