From 1533b85d198a1dd2b1ce995b6c9d69456e56eb3f Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 1 Sep 2021 12:51:58 -0700 Subject: Implement Conv3D kernel. Signed-off-by: Kevin Cheng Change-Id: Ic16e918b1a2423ad563684e29ce70d9efdbf9c02 --- reference_model/src/ops/op_factory.cc | 6 + reference_model/src/ops/tensor_ops.cc | 200 ++++++++++++++++++++++++++++++ reference_model/src/ops/tensor_ops.h | 32 +++++ reference_model/src/subgraph_traverser.cc | 1 + verif/tosa_test_gen.py | 160 ++++++++++++++++++++++-- 5 files changed, 390 insertions(+), 9 deletions(-) diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 193b2af..3bc55a8 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -64,6 +64,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8); DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); break; + case Op_CONV3D: + DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4); + DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8); + DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8); + break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index a150656..a0a1f04 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -481,6 +481,201 @@ int OpConv2d::eval() return GraphNode::eval(); } +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(sgt_, Op_CONV3D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(5); + + INIT_ATTRIBUTE(Conv); + INIT_QINFO(Conv); +} + +template +OpConv3d::~OpConv3d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template +int OpConv3d::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4 + if (inputs[2]->getRank() != 1) + { + printNodeValidationError("OpConv3d: bias tensor must be rank 1"); + } + + input = dynamic_cast*>(inputs[0]); + weight = dynamic_cast*>(inputs[1]); + bias = dynamic_cast*>(inputs[2]); + output = dynamic_cast*>(outputs[0]); + + if (attribute->padding().size() != 6) + { + printNodeValidationError("OpConv3d: illegal size for attribute padding"); + return 1; + } + + if (attribute->stride().size() != 3) + { + printNodeValidationError("OpConv3d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 3) + { + printNodeValidationError("OpConv3d: illegal size for attribute dilation"); + return 1; + } + + return 0; +} + +template +int OpConv3d::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_depth = this->input->getShape()[1]; + int in_height = this->input->getShape()[2]; + int in_width = this->input->getShape()[3]; + int in_channels = this->input->getShape()[4]; + + int f_out_channels = this->weight->getShape()[0]; + int f_depth = this->weight->getShape()[1]; + int f_height = this->weight->getShape()[2]; + int f_width = this->weight->getShape()[3]; + int f_in_channels = this->weight->getShape()[4]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_depth = this->output->getShape()[1]; + int out_height = this->output->getShape()[2]; + int out_width = this->output->getShape()[3]; + int out_channels = this->output->getShape()[4]; + + ERROR_IF(in_batch != out_batch, "OpConv3d: tensor batch mismatch %d != %d", in_batch, out_batch); + ERROR_IF(f_in_channels != in_channels, "OpConv3d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels, + out_channels); + ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels); + + int padding_d0 = this->attribute->padding()[0]; + int padding_d1 = this->attribute->padding()[1]; + int padding_top = this->attribute->padding()[2]; + int padding_bottom = this->attribute->padding()[3]; + int padding_left = this->attribute->padding()[4]; + int padding_right = this->attribute->padding()[5]; + int stride_d = this->attribute->stride()[0]; + int stride_h = this->attribute->stride()[1]; + int stride_w = this->attribute->stride()[2]; + int dilation_d = this->attribute->dilation()[0]; + int dilation_h = this->attribute->dilation()[1]; + int dilation_w = this->attribute->dilation()[2]; + + DEBUG_INFO( + OP, + "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], " + "stride=[%d,%d,%d], dilation=[%d,%d,%d], padding=[%d,%d,%d,%d,%d,%d]", + in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels, + out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h, + dilation_w, padding_d0, padding_d1, padding_top, padding_bottom, padding_left, padding_right); + + Eigen::array, 5> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_d0, padding_d1); + padding[2] = std::make_pair(padding_top, padding_bottom); + padding[3] = std::make_pair(padding_left, padding_right); + padding[4] = std::make_pair(0, 0); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + ETensor5 input_padded = input_val.pad(padding); + + // 1. initialize with bias + Eigen::array reshape_dim; + reshape_dim.fill(1); + reshape_dim[4] = b_out_channels; + + Eigen::array bcast; + bcast[0] = out_batch; + bcast[1] = out_depth; + bcast[2] = out_height; + bcast[3] = out_width; + bcast[4] = 1; + this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); + + // 2. direct convolution + AccEigenType acc = 0; + int d_idx, h_idx, w_idx; + + for (int ob = 0; ob < out_batch; ob++) + { + for (int od = 0; od < out_depth; od++) + { + for (int oh = 0; oh < out_height; oh++) + { + for (int ow = 0; ow < out_width; ow++) + { + for (int oc = 0; oc < out_channels; oc++) + { + acc = 0; + for (int fd = 0; fd < f_depth; fd++) + { + d_idx = od * stride_d + fd * dilation_d; + for (int fh = 0; fh < f_height; fh++) + { + h_idx = oh * stride_h + fh * dilation_h; + for (int fw = 0; fw < f_width; fw++) + { + w_idx = ow * stride_w + fw * dilation_w; + for (int ic = 0; ic < in_channels; ic++) + { + acc += ((AccEigenType)input_padded(ob, d_idx, h_idx, w_idx, ic) * + (AccEigenType)weight_val(oc, fd, fh, fw, ic)); + } + } + } + } + this->output->getTensor()(ob, od, oh, ow, oc) = acc; + } + } + } + } + } + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + template OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -1221,6 +1416,11 @@ DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4); DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8); DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8); + DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index eea351d..2174d62 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -108,6 +108,38 @@ protected: tosa::TosaConvQuantInfo* qinfo; }; +template +class OpConv3d : public GraphNode +{ +public: + OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConv3d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType::value; + + using InEigenType = typename GetEigenType::type; + using WeightEigenType = typename GetEigenType::type; + using AccEigenType = typename GetEigenType::type; + using TIn = Eigen::Tensor; + using TWeight = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TAcc = Eigen::Tensor; + + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; + +protected: + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* weight; + TosaReference::TensorTemplate* bias; + TosaReference::TensorTemplate* output; + tosa::TosaConvAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + template class OpDepthwiseConv2d : public GraphNode { diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index ef7bae6..4dba669 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -116,6 +116,7 @@ int SubgraphTraverser::initializeGraph() switch (op->GetOp()) { case Op_CONV2D: + case Op_CONV3D: case Op_DEPTHWISE_CONV2D: case Op_TRANSPOSE_CONV2D: case Op_FULLY_CONNECTED: diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 44582ac..9555195 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -256,6 +256,35 @@ class TosaTensorGen: return [ifm_shape, filter_shape, bias_shape] + @staticmethod + def tgConv3D(testGen, op, rank): + pl, const = op["operands"] + + assert rank == 5 + + # IFM dimensions are NDHWC + ifm_shape = testGen.makeShape(rank) + + # Constrict the batch size? + if testGen.args.max_batch_size: + ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1 + + # Get the filter depth/height/width from the operator parameters + filter_dhw = op["filter"] + + # Generate a random OFM channel + ofm_channel = testGen.makeShape(1)[0] + + # The filter dimensions are ODHWI + filter_shape = np.asarray( + [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]] + ) + + # The bias is OC + bias_shape = np.asarray([ofm_channel]) + + return [ifm_shape, filter_shape, bias_shape] + @staticmethod def tgTransposeConv2D(testGen, op, rank): pl, const = op["operands"] @@ -462,6 +491,43 @@ class TosaArgGen: ) return arg_list + @staticmethod + def agConv3D(testGen, opName, shapeList, dtype): + arg_list = [] + + ifm_shape = shapeList[0] + filter_shape = shapeList[1] + + # Must be rank 5 + assert len(ifm_shape) == 5 + assert len(filter_shape) == 5 + + # Generate basic argument list now + # TODO: increase coverage + s = [1, 1, 1] + p = [0, 0, 0, 0, 0, 0] + d = [1, 1, 1] + arg_list.append( + ( + "st{}{}{}_pad{}{}{}{}{}{}_dilat{}{}{}".format( + s[0], + s[1], + s[2], + p[0], + p[1], + p[2], + p[3], + p[4], + p[5], + d[0], + d[1], + d[2], + ), + [s, p, d], + ) + ) + return arg_list + @staticmethod def agTransposeConv2D(testGen, opName, shapeList, dtype): arg_list = [] @@ -1357,6 +1423,20 @@ class TosaTestGen: ) return result_tens + def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo): + assert len(padding) == 6 + result_tens = OutputShaper.conv3dOp( + self.ser, ifm, filter, strides, padding, dilations + ) + + attr = ts.TosaSerializerAttribute() + attr.ConvAttribute(padding, strides, dilations) + + self.ser.addOperator( + op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo + ) + return result_tens + def build_transpose_conv2d( self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo ): @@ -1859,7 +1939,9 @@ class TosaTestGen: # Filter out the rank? if rankFilter is not None and r not in rankFilter: continue - if ( + if opName.startswith("conv3d"): + assert r == 5, "conv3d test must have input rank == 5" + elif ( rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range @@ -2188,9 +2270,9 @@ class TosaTestGen: def createDynamicOpLists(self): # Dynamically create op lists for convolutions with a list of kernel sizes - KERNELS = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]] + KERNELS_2D = [[1, 1], [2, 2], [3, 3], [5, 5], [3, 1], [1, 3]] - for k in KERNELS: + for k in KERNELS_2D: testName = "conv2d_{}x{}".format(k[0], k[1]) self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy() self.TOSA_OP_LIST[testName]["filter"] = k @@ -2210,6 +2292,13 @@ class TosaTestGen: self.TOSA_OP_LIST[testName]["filter"] = k self.TOSA_OP_LIST[testName]["template"] = False + KERNELS_3D = [[1, 1, 1], [2, 1, 1], [1, 2, 1], [1, 1, 2]] + for k in KERNELS_3D: + testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2]) + self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy() + self.TOSA_OP_LIST[testName]["filter"] = k + self.TOSA_OP_LIST[testName]["template"] = False + # Delete any templates after having created any dynamic ops # This is a two-pass operation because it's bad practice to delete # keys from dictionaries while iterating @@ -2286,7 +2375,7 @@ class TosaTestGen: TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] - TYPE_CONV2D = [ + TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], @@ -2319,11 +2408,20 @@ class TosaTestGen: "rank": (4, 4), "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D), "qgen": TosaQuantGen.qgConv, - "types": TYPE_CONV2D, + "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,), "template": True, }, - # Conv3d TBD + # Templated operator. Filled in by createDynamicOpLists + "conv3d_TEMPLATE": { + "op": Op.CONV3D, + "operands": (1, 2), + "rank": (5, 5), + "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv3D), + "qgen": TosaQuantGen.qgConv, + "types": TYPE_CONV, + "template": True, + }, # Templated operator. Filled in by createDynamicOpLists "depthwise_conv2d_TEMPLATE": { "op": Op.DEPTHWISE_CONV2D, @@ -2336,7 +2434,7 @@ class TosaTestGen: TosaArgGen.agConv2D, ), "qgen": TosaQuantGen.qgConv, - "types": TYPE_CONV2D, + "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,), "template": True, }, @@ -2346,7 +2444,7 @@ class TosaTestGen: "rank": (2, 2), "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None), "qgen": TosaQuantGen.qgConv, - "types": TYPE_CONV2D, + "types": TYPE_CONV, }, "matmul": { "op": Op.MATMUL, @@ -2375,7 +2473,7 @@ class TosaTestGen: TosaArgGen.agTransposeConv2D, ), "qgen": TosaQuantGen.qgConv, - "types": TYPE_CONV2D, + "types": TYPE_CONV, "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,), "template": True, }, @@ -2908,6 +3006,50 @@ class OutputShaper: return ser.addOutput(ofm_shape, out_dtype) + @staticmethod + def conv3dOp(ser, ifm, filter, strides, padding, dilations): + + # IFM: NDHWC + # Filter: ODHWI + # OFM: NDHWC + + d = ( + ifm.shape[1] + - filter.shape[1] + - (filter.shape[1] - 1) * (dilations[0] - 1) + + padding[0] + + padding[1] + ) // strides[0] + 1 + + h = ( + ifm.shape[2] + - filter.shape[2] + - (filter.shape[2] - 1) * (dilations[1] - 1) + + padding[2] + + padding[3] + ) // strides[1] + 1 + + w = ( + ifm.shape[3] + - filter.shape[3] + - (filter.shape[3] - 1) * (dilations[2] - 1) + + padding[4] + + padding[5] + ) // strides[2] + 1 + + ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]] + + if ifm.dtype == DType.INT8: + out_dtype = DType.INT32 + elif ifm.dtype == DType.INT16: + out_dtype = DType.INT48 + elif ifm.dtype == DType.FLOAT: + out_dtype = DType.FLOAT + else: + raise Exception("Unsupported input dtype: {}".format(ifm.dtype)) + + return ser.addOutput(ofm_shape, out_dtype) + @staticmethod def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations): # IFM: NHWC -- cgit v1.2.1