aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-09-01 12:51:58 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-09-16 01:06:27 +0100
commit1533b85d198a1dd2b1ce995b6c9d69456e56eb3f (patch)
tree9c2926e6f646d82ff72f832fcb383e88a688f66b /reference_model
parent93a1628bc3dd48d9ba099de503b586a561b4751f (diff)
downloadreference_model-1533b85d198a1dd2b1ce995b6c9d69456e56eb3f.tar.gz
Implement Conv3D kernel.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ic16e918b1a2423ad563684e29ce70d9efdbf9c02
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/op_factory.cc6
-rw-r--r--reference_model/src/ops/tensor_ops.cc200
-rw-r--r--reference_model/src/ops/tensor_ops.h32
-rw-r--r--reference_model/src/subgraph_traverser.cc1
4 files changed, 239 insertions, 0 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 193b2af..3bc55a8 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -64,6 +64,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8);
DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
break;
+ case Op_CONV3D:
+ DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8);
+ break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index a150656..a0a1f04 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -482,6 +482,201 @@ int OpConv2d<InDtype, WeightDtype>::eval()
}
template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_CONV3D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(5);
+
+ INIT_ATTRIBUTE(Conv);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv3d<InDtype, WeightDtype>::~OpConv3d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpConv3d: bias tensor must be rank 1");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (attribute->padding().size() != 6)
+ {
+ printNodeValidationError("OpConv3d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 3)
+ {
+ printNodeValidationError("OpConv3d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 3)
+ {
+ printNodeValidationError("OpConv3d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv3d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_depth = this->input->getShape()[1];
+ int in_height = this->input->getShape()[2];
+ int in_width = this->input->getShape()[3];
+ int in_channels = this->input->getShape()[4];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_depth = this->weight->getShape()[1];
+ int f_height = this->weight->getShape()[2];
+ int f_width = this->weight->getShape()[3];
+ int f_in_channels = this->weight->getShape()[4];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_depth = this->output->getShape()[1];
+ int out_height = this->output->getShape()[2];
+ int out_width = this->output->getShape()[3];
+ int out_channels = this->output->getShape()[4];
+
+ ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
+ out_channels);
+ ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+
+ int padding_d0 = this->attribute->padding()[0];
+ int padding_d1 = this->attribute->padding()[1];
+ int padding_top = this->attribute->padding()[2];
+ int padding_bottom = this->attribute->padding()[3];
+ int padding_left = this->attribute->padding()[4];
+ int padding_right = this->attribute->padding()[5];
+ int stride_d = this->attribute->stride()[0];
+ int stride_h = this->attribute->stride()[1];
+ int stride_w = this->attribute->stride()[2];
+ int dilation_d = this->attribute->dilation()[0];
+ int dilation_h = this->attribute->dilation()[1];
+ int dilation_w = this->attribute->dilation()[2];
+
+ DEBUG_INFO(
+ OP,
+ "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
+ "stride=[%d,%d,%d], dilation=[%d,%d,%d], padding=[%d,%d,%d,%d,%d,%d]",
+ in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
+ out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
+ dilation_w, padding_d0, padding_d1, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<std::pair<int32_t, int32_t>, 5> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_d0, padding_d1);
+ padding[2] = std::make_pair(padding_top, padding_bottom);
+ padding[3] = std::make_pair(padding_left, padding_right);
+ padding[4] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor5<InEigenType> input_padded = input_val.pad(padding);
+
+ // 1. initialize with bias
+ Eigen::array<Eigen::Index, 5> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[4] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 5> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_depth;
+ bcast[2] = out_height;
+ bcast[3] = out_width;
+ bcast[4] = 1;
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ // 2. direct convolution
+ AccEigenType acc = 0;
+ int d_idx, h_idx, w_idx;
+
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int od = 0; od < out_depth; od++)
+ {
+ for (int oh = 0; oh < out_height; oh++)
+ {
+ for (int ow = 0; ow < out_width; ow++)
+ {
+ for (int oc = 0; oc < out_channels; oc++)
+ {
+ acc = 0;
+ for (int fd = 0; fd < f_depth; fd++)
+ {
+ d_idx = od * stride_d + fd * dilation_d;
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ h_idx = oh * stride_h + fh * dilation_h;
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ w_idx = ow * stride_w + fw * dilation_w;
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) *
+ (AccEigenType)weight_val(oc, fd, fh, fw, ic));
+ }
+ }
+ }
+ }
+ this->output->getTensor()(ob, od, oh, ow, oc) = acc;
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
TosaQuantInfoBase* qinfo_,
@@ -1221,6 +1416,11 @@ DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4);
DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
+
DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index eea351d..2174d62 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -109,6 +109,38 @@ protected:
};
template <DType InDtype, DType WeightDtype>
+class OpConv3d : public GraphNode
+{
+public:
+ OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConv3d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 5>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 5>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 5>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConvAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index ef7bae6..4dba669 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -116,6 +116,7 @@ int SubgraphTraverser::initializeGraph()
switch (op->GetOp())
{
case Op_CONV2D:
+ case Op_CONV3D:
case Op_DEPTHWISE_CONV2D:
case Op_TRANSPOSE_CONV2D:
case Op_FULLY_CONNECTED: