aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.h
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-09 14:18:32 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-06-09 14:19:17 -0700
commit2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (patch)
treebefdb31f63a91a245605f94e2c83cbf070210854 /reference_model/src/ops/tensor_ops.h
parentcd79f0e06bf53c2c0fee39ee916bb6d79f177b57 (diff)
downloadreference_model-2d60f0063eb91f6514b20a1817663ce0ddd3ff4a.tar.gz
adding batch dimension to MatMul.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
Diffstat (limited to 'reference_model/src/ops/tensor_ops.h')
-rw-r--r--reference_model/src/ops/tensor_ops.h12
1 files changed, 9 insertions, 3 deletions
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 26ce84b..9aaa140 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -183,15 +183,21 @@ public:
static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
using InEigenType = typename GetEigenType<Dtype>::type;
using AccEigenType = typename GetEigenType<AccDtype>::type;
- using TIn = Eigen::Tensor<InEigenType, 2>;
- using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ using TIn = Eigen::Tensor<InEigenType, 3>;
+ using TAcc = Eigen::Tensor<AccEigenType, 3>;
+ using TInRank2 = Eigen::Tensor<InEigenType, 2>;
+ using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* a;
TosaReference::TensorTemplate<TIn>* b;
- TosaReference::TensorTemplate<TAcc>* c;
+ TosaReference::TensorTemplate<TAcc>* output;
+ int64_t N;
+ int64_t H;
+ int64_t W;
+ int64_t C;
tosa::TosaMatMulQuantInfo* qinfo;
};