diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-09 14:18:32 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-06-09 14:19:17 -0700 |
commit | 2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (patch) | |
tree | befdb31f63a91a245605f94e2c83cbf070210854 /reference_model/src/ops/tensor_ops.h | |
parent | cd79f0e06bf53c2c0fee39ee916bb6d79f177b57 (diff) | |
download | reference_model-2d60f0063eb91f6514b20a1817663ce0ddd3ff4a.tar.gz |
adding batch dimension to MatMul.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 26ce84b..9aaa140 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -183,15 +183,21 @@ public: static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; using InEigenType = typename GetEigenType<Dtype>::type; using AccEigenType = typename GetEigenType<AccDtype>::type; - using TIn = Eigen::Tensor<InEigenType, 2>; - using TAcc = Eigen::Tensor<AccEigenType, 2>; + using TIn = Eigen::Tensor<InEigenType, 3>; + using TAcc = Eigen::Tensor<AccEigenType, 3>; + using TInRank2 = Eigen::Tensor<InEigenType, 2>; + using TAccRank2 = Eigen::Tensor<AccEigenType, 2>; static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; protected: TosaReference::TensorTemplate<TIn>* a; TosaReference::TensorTemplate<TIn>* b; - TosaReference::TensorTemplate<TAcc>* c; + TosaReference::TensorTemplate<TAcc>* output; + int64_t N; + int64_t H; + int64_t W; + int64_t C; tosa::TosaMatMulQuantInfo* qinfo; }; |