From 77d0f76b42c6353459d3271ccba9e41728d066cd Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 24 Nov 2020 10:26:32 -0800 Subject: Update reference model/serialization library to 0.21.0 with unit tests added/updated - update tosa.GATHER - update tosa.RESIZE - add tosa.SCATTER Signed-off-by: Kevin Cheng Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25 --- reference_model/src/ops/image.cc | 116 ++++++++++++----- reference_model/src/ops/image.h | 2 + reference_model/src/ops/op_factory.cc | 20 +-- reference_model/src/ops/op_factory.h | 53 -------- reference_model/src/ops/scatter_gather.cc | 210 +++++++++++++++++++++++------- reference_model/src/ops/scatter_gather.h | 45 +++++-- 6 files changed, 293 insertions(+), 153 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index d3352ce..829a6e0 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -51,6 +51,8 @@ int OpResize::checkTensorAttributes() stride = this->attribute->stride(); offset = this->attribute->offset(); shift = this->attribute->shift(); + stride_fp = this->attribute->stride_fp(); + offset_fp = this->attribute->offset_fp(); mode = this->attribute->mode(); int output_height = outputs[0]->getShape()[1]; @@ -58,7 +60,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -66,7 +68,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -79,18 +81,6 @@ int OpResize::checkTensorAttributes() return 1; } - if (shift < 1 || shift > 11) - { - printNodeValidationError("OpResize: attribute shift should be within [1, 11]"); - return 1; - } - - if (stride[0] <= 0 || stride[1] <= 0) - { - printNodeValidationError("OpResize: invalid attribute stride"); - return 1; - } - in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); @@ -112,6 +102,8 @@ int OpResize::eval() int out_width = out->getShape()[2]; int out_channels = out->getShape()[3]; + ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]"); + ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride"); ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); @@ -120,30 +112,30 @@ int OpResize::eval() for (int oy = 0; oy < out_height; oy++) for (int ox = 0; ox < out_width; ox++) { - int y = oy * stride[0] + offset[0]; - int x = ox * stride[1] + offset[1]; + int32_t y = oy * stride[0] + offset[0]; + int32_t x = ox * stride[1] + offset[1]; - int iy = y >> shift; - int dy = y - (iy << shift); - int ix = x >> shift; - int dx = x - (ix << shift); + int32_t iy = y >> shift; + int32_t dy = y - (iy << shift); + int32_t ix = x >> shift; + int32_t dx = x - (ix << shift); - int iy0 = MAX(iy, 0); - int iy1 = MIN(iy + 1, in_height - 1); - int ix0 = MAX(ix, 0); - int ix1 = MIN(ix + 1, in_width - 1); + int32_t iy0 = MAX(iy, 0); + int32_t iy1 = MIN(iy + 1, in_height - 1); + int32_t ix0 = MAX(ix, 0); + int32_t ix1 = MIN(ix + 1, in_width - 1); ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", iy0, iy1, ix0, ix1); - InEigenType v00 = in->getTensor()(b, iy0, ix0, c); - InEigenType v01 = in->getTensor()(b, iy0, ix1, c); - InEigenType v10 = in->getTensor()(b, iy1, ix0, c); - InEigenType v11 = in->getTensor()(b, iy1, ix1, c); - OutEigenType acc; if (mode == ResizeMode_BILINEAR) { + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx); acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx; acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx); @@ -162,8 +154,74 @@ int OpResize::eval() return GraphNode::eval(); } +template <> +int OpResize::eval() +{ + int in_batch = in->getShape()[0]; + int in_height = in->getShape()[1]; + int in_width = in->getShape()[2]; + int in_channels = in->getShape()[3]; + + int out_batch = out->getShape()[0]; + int out_height = out->getShape()[1]; + int out_width = out->getShape()[2]; + int out_channels = out->getShape()[3]; + + ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift"); + ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride"); + ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); + ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + + for (int b = 0; b < out_batch; b++) + for (int c = 0; c < out_channels; c++) + for (int oy = 0; oy < out_height; oy++) + for (int ox = 0; ox < out_width; ox++) + { + float y = oy * stride_fp[0] + offset_fp[0]; + float x = ox * stride_fp[1] + offset_fp[1]; + + int32_t iy = static_cast(std::floor(y)); + float dy = y - static_cast(iy); + int32_t ix = static_cast(std::floor(x)); + float dx = x - static_cast(ix); + + int32_t iy0 = MAX(iy, 0); + int32_t iy1 = MIN(iy + 1, in_height - 1); + int32_t ix0 = MAX(ix, 0); + int32_t ix1 = MIN(ix + 1, in_width - 1); + + ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", + iy0, iy1, ix0, ix1); + + OutEigenType acc; + if (mode == ResizeMode_BILINEAR) + { + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + + acc = (OutEigenType)v00 * (1.0 - dy) * (1.0 - dx); + acc = acc + (OutEigenType)v01 * (1.0 - dy) * dx; + acc = acc + (OutEigenType)v10 * dy * (1.0 - dx); + acc = acc + (OutEigenType)v11 * dy * dx; + } + else + { + iy = (dy >= 0.5) ? iy1 : iy0; + ix = (dx >= 0.5) ? ix1 : ix0; + acc = in->getTensor()(b, iy, ix, c); + } + + out->getTensor()(b, oy, ox, c) = acc; + } + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48); DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16); +DEF_INSTANTIATE_TWO_TYPE(OpResize, FLOAT, FLOAT); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 9d15d49..5dd14c8 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -43,6 +43,8 @@ protected: std::vector stride; std::vector offset; int32_t shift; + std::vector stride_fp; + std::vector offset_fp; ResizeMode mode; TosaReference::TensorTemplate* in; TosaReference::TensorTemplate* out; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index bad0c40..4a06248 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -334,14 +334,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, // scatter_gather case Op_GATHER: - { - // output.rank = input.rank - 1 + index.rank - int32_t index_rank = outputRank - inputRank + 1; - DEF_FACTORY_GATHER(OpGather, AINT8); - DEF_FACTORY_GATHER(OpGather, INT16); - DEF_FACTORY_GATHER(OpGather, INT32); - } - break; + DEF_FACTORY_ONE_TYPE(OpGather, AINT8); + DEF_FACTORY_ONE_TYPE(OpGather, INT16); + DEF_FACTORY_ONE_TYPE(OpGather, INT32); + DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); + break; + case Op_SCATTER: + DEF_FACTORY_ONE_TYPE(OpScatter, AINT8); + DEF_FACTORY_ONE_TYPE(OpScatter, INT16); + DEF_FACTORY_ONE_TYPE(OpScatter, INT32); + DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); + break; // image case Op_RESIZE: @@ -349,6 +352,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8); DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, FLOAT, FLOAT); break; // data_nodes diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index cde6841..0c116b6 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -218,59 +218,6 @@ } \ } -#define DEF_FACTORY_GATHER(OP, DTYPE) \ - if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ - { \ - switch (inputRank) \ - { \ - case 1: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \ - } \ - case 2: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \ - } \ - case 3: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \ - } \ - case 4: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \ - } \ - case 5: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \ - } \ - case 6: \ - switch (index_rank) \ - { \ - DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \ - } \ - } \ - } - namespace TosaReference { diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index c54204a..2d1026f 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -20,31 +20,61 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template -OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template +OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_GATHER, id_) { setRequiredOperands(2, 1); - setRequiredRank(1, 6); - - INIT_ATTRIBUTE(Axis); } -template -OpGather::~OpGather() -{ - if (attribute) - delete attribute; -} +template +OpGather::~OpGather() +{} -template -int OpGather::checkTensorAttributes() +template +int OpGather::checkTensorAttributes() { if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (outputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: output needs to be rank 3 tensor"); + return 1; + } + + K = inputs[0]->getShape()[1]; + N = outputs[0]->getShape()[0]; + W = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; + + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0]) + { + printNodeValidationError("OpGather: dimension N mismatch"); + return 1; + } + + if (W != inputs[1]->getShape()[1]) + { + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; + } + + if (C != inputs[0]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); return 1; } @@ -55,59 +85,133 @@ int OpGather::checkTensorAttributes() return 1; } - in = dynamic_cast*>(inputs[0]); - index = dynamic_cast*>(inputs[1]); - out = dynamic_cast*>(outputs[0]); + values = dynamic_cast*>(inputs[0]); + indices = dynamic_cast*>(inputs[1]); + output = dynamic_cast*>(outputs[0]); - ASSERT_MEM(in && index && out); + ASSERT_MEM(values && indices && output); return 0; } -template -int OpGather::eval() +template +int OpGather::eval() +{ + for (int32_t n = 0; n < N; n++) + { + for (int32_t w = 0; w < W; w++) + { + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int32_t c = 0; c < C; c++) + { + EigenType value = this->values->getTensor()(n, k, c); + this->output->getTensor()(n, w, c) = value; + } + } + } + + return GraphNode::eval(); +} + +template +OpScatter::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SCATTER, id_) { - int axis = attribute->axis(); + setRequiredOperands(3, 1); +} + +template +OpScatter::~OpScatter() +{} - // calculate size left and right to axis - int left_size = 1; - for (int i = 0; i < axis; ++i) +template +int OpScatter::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values_in needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) + { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (inputs[2]->getRank() != 3) { - left_size *= in->getShape()[i]; + printNodeValidationError("OpGather: input needs to be rank 3 tensor"); + return 1; } - int right_size = 1; - for (int i = axis + 1; i < in->getRank(); ++i) + if (outputs[0]->getRank() != 3) { - right_size *= in->getShape()[i]; + printNodeValidationError("OpGather: values_out needs to be rank 3 tensor"); + return 1; } - InEigenType* input_data = in->getTensor().data(); - int32_t* index_data = index->getTensor().data(); - OutEigenType* output_data = out->getTensor().data(); + W = inputs[2]->getShape()[1]; + N = outputs[0]->getShape()[0]; + K = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; - int32_t axis_size = in->getShape()[axis]; - int32_t index_count = index->getElementCount(); + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0]) + { + printNodeValidationError("OpScatter: dimension N mismatch"); + return 1; + } - // sanity check if index is valid - // need to check until this point since index is not known until runtime - for (size_t i = 0; i < index->getElementCount(); i++) + if (W != inputs[1]->getShape()[1]) { - if (index_data[i] >= axis_size) - { - FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size); - } + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; } - // Eigen stores tensor in column-major - // so we iterate through dimension right to axis and the index array - // do memory copy with size of left size each time - for (int right = 0; right < right_size; ++right) + if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + values_in = dynamic_cast*>(inputs[0]); + indices = dynamic_cast*>(inputs[1]); + input = dynamic_cast*>(inputs[2]); + values_out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(values_in && indices && input && values_out); + + return 0; +} + +template +int OpScatter::eval() +{ + // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. + this->values_out->getTensor() = this->values_in->getTensor(); + + for (int n = 0; n < N; n++) { - for (int i = 0; i < index_count; ++i) + for (int w = 0; w < W; w++) { - std::memcpy(output_data + (right * index_count + i) * left_size, - input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size); + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int c = 0; c < C; c++) + { + EigenType value = this->input->getTensor()(n, w, c); + this->values_out->getTensor()(n, k, c) = value; + } } } @@ -115,6 +219,12 @@ int OpGather::eval() } // template explicit instantiation -DEF_INSTANTIATE_GATHER(OpGather, AINT8); -DEF_INSTANTIATE_GATHER(OpGather, INT16); -DEF_INSTANTIATE_GATHER(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index d9b1263..17ea723 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -23,9 +23,7 @@ using namespace tosa; namespace TosaReference { -// input and index can have different rank -// and infer OutRank statically -template +template class OpGather : public GraphNode { public: @@ -35,18 +33,39 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr int OutRank = InRank - 1 + IndexRank; - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; - using TIn = Eigen::Tensor; - using TIndex = Eigen::Tensor; - using TOut = Eigen::Tensor; + using EigenType = typename GetEigenType::type; + using TValue = Eigen::Tensor; + using TIndex = Eigen::Tensor; + using TOutput = Eigen::Tensor; protected: - TosaAxisAttribute* attribute; - TosaReference::TensorTemplate* in; - TosaReference::TensorTemplate* index; - TosaReference::TensorTemplate* out; + int32_t N, W, K, C; + TosaReference::TensorTemplate* values; + TosaReference::TensorTemplate* indices; + TosaReference::TensorTemplate* output; +}; + +template +class OpScatter : public GraphNode +{ +public: + OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpScatter(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using EigenType = typename GetEigenType::type; + using TValue = Eigen::Tensor; + using TIndex = Eigen::Tensor; + using TOutput = Eigen::Tensor; + +protected: + int32_t N, W, K, C; + TosaReference::TensorTemplate* values_in; + TosaReference::TensorTemplate* indices; + TosaReference::TensorTemplate* input; + TosaReference::TensorTemplate* values_out; }; }; // namespace TosaReference -- cgit v1.2.1