From 77d0f76b42c6353459d3271ccba9e41728d066cd Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 24 Nov 2020 10:26:32 -0800 Subject: Update reference model/serialization library to 0.21.0 with unit tests added/updated - update tosa.GATHER - update tosa.RESIZE - add tosa.SCATTER Signed-off-by: Kevin Cheng Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25 --- reference_model/src/ops/scatter_gather.cc | 210 +++++++++++++++++++++++------- 1 file changed, 160 insertions(+), 50 deletions(-) (limited to 'reference_model/src/ops/scatter_gather.cc') diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index c54204a..2d1026f 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -20,31 +20,61 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template -OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) +template +OpGather::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_GATHER, id_) { setRequiredOperands(2, 1); - setRequiredRank(1, 6); - - INIT_ATTRIBUTE(Axis); } -template -OpGather::~OpGather() -{ - if (attribute) - delete attribute; -} +template +OpGather::~OpGather() +{} -template -int OpGather::checkTensorAttributes() +template +int OpGather::checkTensorAttributes() { if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (outputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: output needs to be rank 3 tensor"); + return 1; + } + + K = inputs[0]->getShape()[1]; + N = outputs[0]->getShape()[0]; + W = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; + + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0]) + { + printNodeValidationError("OpGather: dimension N mismatch"); + return 1; + } + + if (W != inputs[1]->getShape()[1]) + { + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; + } + + if (C != inputs[0]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); return 1; } @@ -55,59 +85,133 @@ int OpGather::checkTensorAttributes() return 1; } - in = dynamic_cast*>(inputs[0]); - index = dynamic_cast*>(inputs[1]); - out = dynamic_cast*>(outputs[0]); + values = dynamic_cast*>(inputs[0]); + indices = dynamic_cast*>(inputs[1]); + output = dynamic_cast*>(outputs[0]); - ASSERT_MEM(in && index && out); + ASSERT_MEM(values && indices && output); return 0; } -template -int OpGather::eval() +template +int OpGather::eval() +{ + for (int32_t n = 0; n < N; n++) + { + for (int32_t w = 0; w < W; w++) + { + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int32_t c = 0; c < C; c++) + { + EigenType value = this->values->getTensor()(n, k, c); + this->output->getTensor()(n, w, c) = value; + } + } + } + + return GraphNode::eval(); +} + +template +OpScatter::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SCATTER, id_) { - int axis = attribute->axis(); + setRequiredOperands(3, 1); +} + +template +OpScatter::~OpScatter() +{} - // calculate size left and right to axis - int left_size = 1; - for (int i = 0; i < axis; ++i) +template +int OpScatter::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (inputs[0]->getRank() != 3) + { + printNodeValidationError("OpGather: values_in needs to be rank 3 tensor"); + return 1; + } + + if (inputs[1]->getRank() != 2) + { + printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + return 1; + } + + if (inputs[2]->getRank() != 3) { - left_size *= in->getShape()[i]; + printNodeValidationError("OpGather: input needs to be rank 3 tensor"); + return 1; } - int right_size = 1; - for (int i = axis + 1; i < in->getRank(); ++i) + if (outputs[0]->getRank() != 3) { - right_size *= in->getShape()[i]; + printNodeValidationError("OpGather: values_out needs to be rank 3 tensor"); + return 1; } - InEigenType* input_data = in->getTensor().data(); - int32_t* index_data = index->getTensor().data(); - OutEigenType* output_data = out->getTensor().data(); + W = inputs[2]->getShape()[1]; + N = outputs[0]->getShape()[0]; + K = outputs[0]->getShape()[1]; + C = outputs[0]->getShape()[2]; - int32_t axis_size = in->getShape()[axis]; - int32_t index_count = index->getElementCount(); + if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0]) + { + printNodeValidationError("OpScatter: dimension N mismatch"); + return 1; + } - // sanity check if index is valid - // need to check until this point since index is not known until runtime - for (size_t i = 0; i < index->getElementCount(); i++) + if (W != inputs[1]->getShape()[1]) { - if (index_data[i] >= axis_size) - { - FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size); - } + printNodeValidationError("OpGather: dimension W mismatch"); + return 1; } - // Eigen stores tensor in column-major - // so we iterate through dimension right to axis and the index array - // do memory copy with size of left size each time - for (int right = 0; right < right_size; ++right) + if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2]) + { + printNodeValidationError("OpGather: dimension C mismatch"); + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + values_in = dynamic_cast*>(inputs[0]); + indices = dynamic_cast*>(inputs[1]); + input = dynamic_cast*>(inputs[2]); + values_out = dynamic_cast*>(outputs[0]); + + ASSERT_MEM(values_in && indices && input && values_out); + + return 0; +} + +template +int OpScatter::eval() +{ + // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. + this->values_out->getTensor() = this->values_in->getTensor(); + + for (int n = 0; n < N; n++) { - for (int i = 0; i < index_count; ++i) + for (int w = 0; w < W; w++) { - std::memcpy(output_data + (right * index_count + i) * left_size, - input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size); + int32_t k = this->indices->getTensor()(n, w); + ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); + for (int c = 0; c < C; c++) + { + EigenType value = this->input->getTensor()(n, w, c); + this->values_out->getTensor()(n, k, c) = value; + } } } @@ -115,6 +219,12 @@ int OpGather::eval() } // template explicit instantiation -DEF_INSTANTIATE_GATHER(OpGather, AINT8); -DEF_INSTANTIATE_GATHER(OpGather, INT16); -DEF_INSTANTIATE_GATHER(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); -- cgit v1.2.1