aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-09 14:18:32 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-06-09 14:19:17 -0700
commit2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (patch)
treebefdb31f63a91a245605f94e2c83cbf070210854
parentcd79f0e06bf53c2c0fee39ee916bb6d79f177b57 (diff)
downloadreference_model-2d60f0063eb91f6514b20a1817663ce0ddd3ff4a.tar.gz
adding batch dimension to MatMul.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I83f75dd5beb60fe7ca2d573ea0f81bac4cd62a07
-rw-r--r--reference_model/src/ops/tensor_ops.cc79
-rw-r--r--reference_model/src/ops/tensor_ops.h12
-rw-r--r--verif/tosa_test_gen.py14
3 files changed, 86 insertions, 19 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b8c7ade..0007553 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -742,7 +742,7 @@ OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinf
: GraphNode(Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(2);
+ setRequiredRank(3);
INIT_QINFO(MatMul);
}
@@ -765,16 +765,47 @@ int OpMatMul<Dtype>::checkTensorAttributes()
return 1;
}
- a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ ASSERT_MEM(a && b && output);
+
+ // a: [N, H, C]
+ // b: [N, C, W]
+ // c: [N, H, W]
- if (a->getShape()[1] != b->getShape()[0])
+ // Check N
+ if (a->getShape()[0] != b->getShape()[0] || a->getShape()[0] != output->getShape()[0])
{
- printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+ printNodeValidationError("OpMatMul operator a.shape[0], b.shape[0] and output.shape[0] should match");
return 1;
}
+ N = a->getShape()[0];
- c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ // Check C
+ if (a->getShape()[2] != b->getShape()[1])
+ {
+ printNodeValidationError("OpMatMul operator a.shape[2] should match b.shape[1]");
+ return 1;
+ }
+ C = a->getShape()[2];
+
+ // Check H
+ if (a->getShape()[1] != output->getShape()[1])
+ {
+ printNodeValidationError("OpMatMul operator a.shape[1] should match output.shape[1]");
+ return 1;
+ }
+ H = a->getShape()[1];
+
+ // Check W
+ if (b->getShape()[2] != output->getShape()[2])
+ {
+ printNodeValidationError("OpMatMul operator output.shape[2] should match output.shape[2]");
+ return 1;
+ }
+ W = b->getShape()[2];
return 0;
}
@@ -793,12 +824,42 @@ int OpMatMul<Dtype>::eval()
b_val = b_val - (InEigenType)this->qinfo->b_zp();
}
- this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+ Eigen::array<Eigen::Index, 2> a_rank2_shape({ H, C });
+ Eigen::array<Eigen::Index, 2> b_rank2_shape({ C, W });
+ Eigen::array<Eigen::Index, 3> output_rank3_shape({ 1, H, W });
+
+ Eigen::array<Eigen::Index, 3> a_size_array({ 1, H, C });
+ Eigen::array<Eigen::Index, 3> b_size_array({ 1, C, W });
+
+ Eigen::array<Eigen::Index, 3> a_begin_array({ 0, 0, 0 });
+ Eigen::array<Eigen::Index, 3> b_begin_array({ 0, 0, 0 });
+
+ // Iterate N dimension.
+ for (int i = 0; i < N; i++)
+ {
+ a_begin_array[0] = i;
+ b_begin_array[0] = i;
+
+ TInRank2 a_rank2_val = a_val.slice(a_begin_array, a_size_array).reshape(a_rank2_shape);
+ TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
+ TAccRank2 output_rank2_val =
+ a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
+ TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
+ if (i == 0)
+ {
+ this->output->getTensor() = output_rank3_val;
+ }
+ else
+ {
+ TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0);
+ this->output->getTensor() = temp;
+ }
+ }
if (AccDtype == DType_INT48)
{
- this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
}
return GraphNode::eval();
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 26ce84b..9aaa140 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -183,15 +183,21 @@ public:
static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
using InEigenType = typename GetEigenType<Dtype>::type;
using AccEigenType = typename GetEigenType<AccDtype>::type;
- using TIn = Eigen::Tensor<InEigenType, 2>;
- using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ using TIn = Eigen::Tensor<InEigenType, 3>;
+ using TAcc = Eigen::Tensor<AccEigenType, 3>;
+ using TInRank2 = Eigen::Tensor<InEigenType, 2>;
+ using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* a;
TosaReference::TensorTemplate<TIn>* b;
- TosaReference::TensorTemplate<TAcc>* c;
+ TosaReference::TensorTemplate<TAcc>* output;
+ int64_t N;
+ int64_t H;
+ int64_t W;
+ int64_t C;
tosa::TosaMatMulQuantInfo* qinfo;
};
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 5670d1b..6f9acf4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -314,12 +314,12 @@ class TosaTensorGen:
def tgMatmul(testGen, op, rank):
pl, const = op["operands"]
- assert rank == 2
+ assert rank == 3
assert pl == 2 and const == 0
a_shape = testGen.makeShape(rank)
b_oc = testGen.makeShape(1)[0]
- b_shape = np.asarray([a_shape[1], b_oc])
+ b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
return [a_shape, b_shape]
@@ -1994,7 +1994,7 @@ class TosaTestGen:
"matmul": {
"op": Op.MATMUL,
"operands": (2, 0),
- "rank": (2, 2),
+ "rank": (3, 3),
"build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
@@ -2630,11 +2630,11 @@ class OutputShaper:
@staticmethod
def matmulOp(ser, a, b):
- # a: M, K
- # b: K, N
- # out: M, N
+ # a: N, H, C
+ # b: N, C, W
+ # out: N, H, W
- output_shape = [a.shape[0], b.shape[1]]
+ output_shape = [a.shape[0], a.shape[1], b.shape[2]]
if a.dtype == DType.INT8:
out_dtype = DType.INT32