aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2020-11-24 10:26:32 -0800
committerKevin Cheng <kevin.cheng@arm.com>2020-11-24 14:06:04 -0800
commit77d0f76b42c6353459d3271ccba9e41728d066cd (patch)
tree70f95376de80e249a15b15dba6faa27fb4d48f45
parentaee1facbde25caf27cc34e5ec08eb8bba6af8e18 (diff)
downloadreference_model-77d0f76b42c6353459d3271ccba9e41728d066cd.tar.gz
Update reference model/serialization library to 0.21.0 with unit tests added/updated
- update tosa.GATHER - update tosa.RESIZE - add tosa.SCATTER Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I1c3247b831a64e35a85c4044b24c6c29b8e18d25
-rw-r--r--reference_model/src/graph_node.h27
-rw-r--r--reference_model/src/ops/image.cc116
-rw-r--r--reference_model/src/ops/image.h2
-rw-r--r--reference_model/src/ops/op_factory.cc20
-rw-r--r--reference_model/src/ops/op_factory.h53
-rw-r--r--reference_model/src/ops/scatter_gather.cc210
-rw-r--r--reference_model/src/ops/scatter_gather.h45
-rw-r--r--serialization/attribute.def4
-rw-r--r--serialization/operator.def3
-rw-r--r--serialization/tosa.fbs5
-rw-r--r--serialization/tosa_generated.h61
-rw-r--r--verif/tosa/Op.py21
-rw-r--r--verif/tosa/ResizeAttribute.py54
-rw-r--r--verif/tosa/Version.py4
-rw-r--r--verif/tosa_serializer.py21
-rw-r--r--verif/tosa_test_gen.py175
16 files changed, 574 insertions, 247 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index 5b4a767..eee5464 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -123,33 +123,6 @@
DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
-#define DEF_INSTANTIATE_GATHER(OP, DTYPE) \
- /* gather op takes input and index rank as template argument */ \
- /* note output rank = input rank - 1 + index rank */ \
- /* and max rank allowed in tosa_reference is 6 */ \
- /* so only specific input and index pair is instantiated */ \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
- DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE)
-
#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
{ \
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index d3352ce..829a6e0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -51,6 +51,8 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
stride = this->attribute->stride();
offset = this->attribute->offset();
shift = this->attribute->shift();
+ stride_fp = this->attribute->stride_fp();
+ offset_fp = this->attribute->offset_fp();
mode = this->attribute->mode();
int output_height = outputs[0]->getShape()[1];
@@ -58,7 +60,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -66,7 +68,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -79,18 +81,6 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (shift < 1 || shift > 11)
- {
- printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
- return 1;
- }
-
- if (stride[0] <= 0 || stride[1] <= 0)
- {
- printNodeValidationError("OpResize: invalid attribute stride");
- return 1;
- }
-
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -112,6 +102,8 @@ int OpResize<InDtype, OutDtype>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
+ ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]");
+ ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride");
ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
@@ -120,30 +112,30 @@ int OpResize<InDtype, OutDtype>::eval()
for (int oy = 0; oy < out_height; oy++)
for (int ox = 0; ox < out_width; ox++)
{
- int y = oy * stride[0] + offset[0];
- int x = ox * stride[1] + offset[1];
+ int32_t y = oy * stride[0] + offset[0];
+ int32_t x = ox * stride[1] + offset[1];
- int iy = y >> shift;
- int dy = y - (iy << shift);
- int ix = x >> shift;
- int dx = x - (ix << shift);
+ int32_t iy = y >> shift;
+ int32_t dy = y - (iy << shift);
+ int32_t ix = x >> shift;
+ int32_t dx = x - (ix << shift);
- int iy0 = MAX(iy, 0);
- int iy1 = MIN(iy + 1, in_height - 1);
- int ix0 = MAX(ix, 0);
- int ix1 = MIN(ix + 1, in_width - 1);
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
iy0, iy1, ix0, ix1);
- InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
- InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
- InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
- InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
-
OutEigenType acc;
if (mode == ResizeMode_BILINEAR)
{
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
@@ -162,8 +154,74 @@ int OpResize<InDtype, OutDtype>::eval()
return GraphNode::eval();
}
+template <>
+int OpResize<DType_FLOAT, DType_FLOAT>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift");
+ ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride");
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ float y = oy * stride_fp[0] + offset_fp[0];
+ float x = ox * stride_fp[1] + offset_fp[1];
+
+ int32_t iy = static_cast<int32_t>(std::floor(y));
+ float dy = y - static_cast<float>(iy);
+ int32_t ix = static_cast<int32_t>(std::floor(x));
+ float dx = x - static_cast<float>(ix);
+
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ acc = (OutEigenType)v00 * (1.0 - dy) * (1.0 - dx);
+ acc = acc + (OutEigenType)v01 * (1.0 - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * (1.0 - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >= 0.5) ? iy1 : iy0;
+ ix = (dx >= 0.5) ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, FLOAT, FLOAT);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 9d15d49..5dd14c8 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -43,6 +43,8 @@ protected:
std::vector<int32_t> stride;
std::vector<int32_t> offset;
int32_t shift;
+ std::vector<float> stride_fp;
+ std::vector<float> offset_fp;
ResizeMode mode;
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index bad0c40..4a06248 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -334,14 +334,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// scatter_gather
case Op_GATHER:
- {
- // output.rank = input.rank - 1 + index.rank
- int32_t index_rank = outputRank - inputRank + 1;
- DEF_FACTORY_GATHER(OpGather, AINT8);
- DEF_FACTORY_GATHER(OpGather, INT16);
- DEF_FACTORY_GATHER(OpGather, INT32);
- }
- break;
+ DEF_FACTORY_ONE_TYPE(OpGather, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT16);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
+ break;
+ case Op_SCATTER:
+ DEF_FACTORY_ONE_TYPE(OpScatter, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
+ break;
// image
case Op_RESIZE:
@@ -349,6 +352,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, FLOAT, FLOAT);
break;
// data_nodes
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index cde6841..0c116b6 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -218,59 +218,6 @@
} \
}
-#define DEF_FACTORY_GATHER(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
- { \
- switch (inputRank) \
- { \
- case 1: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \
- } \
- case 2: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \
- } \
- case 3: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \
- } \
- case 4: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \
- } \
- case 5: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \
- } \
- case 6: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \
- } \
- } \
- }
-
namespace TosaReference
{
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index c54204a..2d1026f 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -20,31 +20,61 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <DType Dtype>
+OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_GATHER, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(1, 6);
-
- INIT_ATTRIBUTE(Axis);
}
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::~OpGather()
-{
- if (attribute)
- delete attribute;
-}
+template <DType Dtype>
+OpGather<Dtype>::~OpGather()
+{}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+template <DType Dtype>
+int OpGather<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
{
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (outputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: output needs to be rank 3 tensor");
+ return 1;
+ }
+
+ K = inputs[0]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ W = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
+
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
+ {
+ printNodeValidationError("OpGather: dimension N mismatch");
+ return 1;
+ }
+
+ if (W != inputs[1]->getShape()[1])
+ {
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
+ }
+
+ if (C != inputs[0]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
return 1;
}
@@ -55,59 +85,133 @@ int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
- ASSERT_MEM(in && index && out);
+ ASSERT_MEM(values && indices && output);
return 0;
}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::eval()
+template <DType Dtype>
+int OpGather<Dtype>::eval()
+{
+ for (int32_t n = 0; n < N; n++)
+ {
+ for (int32_t w = 0; w < W; w++)
+ {
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int32_t c = 0; c < C; c++)
+ {
+ EigenType value = this->values->getTensor()(n, k, c);
+ this->output->getTensor()(n, w, c) = value;
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SCATTER, id_)
{
- int axis = attribute->axis();
+ setRequiredOperands(3, 1);
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::~OpScatter()
+{}
- // calculate size left and right to axis
- int left_size = 1;
- for (int i = 0; i < axis; ++i)
+template <DType Dtype>
+int OpScatter<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
+ {
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (inputs[2]->getRank() != 3)
{
- left_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: input needs to be rank 3 tensor");
+ return 1;
}
- int right_size = 1;
- for (int i = axis + 1; i < in->getRank(); ++i)
+ if (outputs[0]->getRank() != 3)
{
- right_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
+ return 1;
}
- InEigenType* input_data = in->getTensor().data();
- int32_t* index_data = index->getTensor().data();
- OutEigenType* output_data = out->getTensor().data();
+ W = inputs[2]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ K = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
- int32_t axis_size = in->getShape()[axis];
- int32_t index_count = index->getElementCount();
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
+ {
+ printNodeValidationError("OpScatter: dimension N mismatch");
+ return 1;
+ }
- // sanity check if index is valid
- // need to check until this point since index is not known until runtime
- for (size_t i = 0; i < index->getElementCount(); i++)
+ if (W != inputs[1]->getShape()[1])
{
- if (index_data[i] >= axis_size)
- {
- FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
- }
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
}
- // Eigen stores tensor in column-major
- // so we iterate through dimension right to axis and the index array
- // do memory copy with size of left size each time
- for (int right = 0; right < right_size; ++right)
+ if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
+ values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
+
+ ASSERT_MEM(values_in && indices && input && values_out);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpScatter<Dtype>::eval()
+{
+ // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
+ this->values_out->getTensor() = this->values_in->getTensor();
+
+ for (int n = 0; n < N; n++)
{
- for (int i = 0; i < index_count; ++i)
+ for (int w = 0; w < W; w++)
{
- std::memcpy(output_data + (right * index_count + i) * left_size,
- input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int c = 0; c < C; c++)
+ {
+ EigenType value = this->input->getTensor()(n, w, c);
+ this->values_out->getTensor()(n, k, c) = value;
+ }
}
}
@@ -115,6 +219,12 @@ int OpGather<InRank, IndexRank, Dtype>::eval()
}
// template explicit instantiation
-DEF_INSTANTIATE_GATHER(OpGather, AINT8);
-DEF_INSTANTIATE_GATHER(OpGather, INT16);
-DEF_INSTANTIATE_GATHER(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index d9b1263..17ea723 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -23,9 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// input and index can have different rank
-// and infer OutRank statically
-template <int InRank, int IndexRank, DType Dtype>
+template <DType Dtype>
class OpGather : public GraphNode
{
public:
@@ -35,18 +33,39 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr int OutRank = InRank - 1 + IndexRank;
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, InRank>;
- using TIndex = Eigen::Tensor<int32_t, IndexRank>;
- using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
protected:
- TosaAxisAttribute* attribute;
- TosaReference::TensorTemplate<TIn>* in;
- TosaReference::TensorTemplate<TIndex>* index;
- TosaReference::TensorTemplate<TOut>* out;
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TOutput>* output;
+};
+
+template <DType Dtype>
+class OpScatter : public GraphNode
+{
+public:
+ OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpScatter();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
+
+protected:
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values_in;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TValue>* input;
+ TosaReference::TensorTemplate<TOutput>* values_out;
};
}; // namespace TosaReference
diff --git a/serialization/attribute.def b/serialization/attribute.def
index d937395..a146f47 100644
--- a/serialization/attribute.def
+++ b/serialization/attribute.def
@@ -59,11 +59,13 @@ DEF_ATTRIBUTE(Slice, 2,
DEF_ATTRIBUTE(Tile, 1,
int32_t, V, multiples)
-DEF_ATTRIBUTE(Resize, 5,
+DEF_ATTRIBUTE(Resize, 7,
int32_t, V, output_size,
int32_t, V, stride,
int32_t, V, offset,
int32_t, S, shift,
+ float, V, stride_fp,
+ float, V, offset_fp,
ResizeMode, S, mode)
DEF_ATTRIBUTE(Clamp, 4,
diff --git a/serialization/operator.def b/serialization/operator.def
index 267976c..9a93b70 100644
--- a/serialization/operator.def
+++ b/serialization/operator.def
@@ -100,7 +100,8 @@ DEF_OPERATOR(tile, TILE, Tile,
DEF_OPERATOR(transpose, TRANSPOSE, Transpose, None, None)
/* gather/scatter */
-DEF_OPERATOR(gather, GATHER, Gather, Axis, None)
+DEF_OPERATOR(gather, GATHER, Gather, None, None)
+DEF_OPERATOR(scatter, SCATTER, Scatter, None, None)
/* image */
DEF_OPERATOR(resize, RESIZE, Resize, Resize, None)
diff --git a/serialization/tosa.fbs b/serialization/tosa.fbs
index f57d9dc..9df8746 100644
--- a/serialization/tosa.fbs
+++ b/serialization/tosa.fbs
@@ -133,6 +133,7 @@ enum Op:uint32 {
// Gather/scatter operation
GATHER,
+ SCATTER,
// Image
RESIZE,
@@ -219,6 +220,8 @@ table ResizeAttribute {
stride: [int32];
offset: [int32];
shift: int32;
+ stride_fp: [float];
+ offset_fp: [float];
mode: ResizeMode;
}
@@ -285,7 +288,7 @@ table PadQuantInfo {
table Version {
_major: int32 = 0;
- _minor: int32 = 20;
+ _minor: int32 = 21;
_patch: int32 = 0;
_experimental: bool = false;
}
diff --git a/serialization/tosa_generated.h b/serialization/tosa_generated.h
index 5140f7b..fad0520 100644
--- a/serialization/tosa_generated.h
+++ b/serialization/tosa_generated.h
@@ -296,21 +296,22 @@ enum Op {
Op_TILE = 56,
Op_TRANSPOSE = 57,
Op_GATHER = 58,
- Op_RESIZE = 59,
- Op_CAST = 60,
- Op_RESCALE = 61,
- Op_CONST = 62,
- Op_PLACEHOLDER = 63,
- Op_IDENTITY = 64,
- Op_IDENTITYN = 65,
- Op_CUSTOM = 66,
- Op_COND_IF = 67,
- Op_WHILE_LOOP = 68,
+ Op_SCATTER = 59,
+ Op_RESIZE = 60,
+ Op_CAST = 61,
+ Op_RESCALE = 62,
+ Op_CONST = 63,
+ Op_PLACEHOLDER = 64,
+ Op_IDENTITY = 65,
+ Op_IDENTITYN = 66,
+ Op_CUSTOM = 67,
+ Op_COND_IF = 68,
+ Op_WHILE_LOOP = 69,
Op_MIN = Op_UNKNOWN,
Op_MAX = Op_WHILE_LOOP
};
-inline const Op (&EnumValuesOp())[69] {
+inline const Op (&EnumValuesOp())[70] {
static const Op values[] = {
Op_UNKNOWN,
Op_ARGMAX,
@@ -371,6 +372,7 @@ inline const Op (&EnumValuesOp())[69] {
Op_TILE,
Op_TRANSPOSE,
Op_GATHER,
+ Op_SCATTER,
Op_RESIZE,
Op_CAST,
Op_RESCALE,
@@ -446,6 +448,7 @@ inline const char * const *EnumNamesOp() {
"TILE",
"TRANSPOSE",
"GATHER",
+ "SCATTER",
"RESIZE",
"CAST",
"RESCALE",
@@ -1176,7 +1179,9 @@ struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_STRIDE = 6,
VT_OFFSET = 8,
VT_SHIFT = 10,
- VT_MODE = 12
+ VT_STRIDE_FP = 12,
+ VT_OFFSET_FP = 14,
+ VT_MODE = 16
};
const flatbuffers::Vector<int32_t> *output_size() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_OUTPUT_SIZE);
@@ -1190,6 +1195,12 @@ struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
int32_t shift() const {
return GetField<int32_t>(VT_SHIFT, 0);
}
+ const flatbuffers::Vector<float> *stride_fp() const {
+ return GetPointer<const flatbuffers::Vector<float> *>(VT_STRIDE_FP);
+ }
+ const flatbuffers::Vector<float> *offset_fp() const {
+ return GetPointer<const flatbuffers::Vector<float> *>(VT_OFFSET_FP);
+ }
ResizeMode mode() const {
return static_cast<ResizeMode>(GetField<uint32_t>(VT_MODE, 0));
}
@@ -1202,6 +1213,10 @@ struct ResizeAttribute FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_OFFSET) &&
verifier.VerifyVector(offset()) &&
VerifyField<int32_t>(verifier, VT_SHIFT) &&
+ VerifyOffset(verifier, VT_STRIDE_FP) &&
+ verifier.VerifyVector(stride_fp()) &&
+ VerifyOffset(verifier, VT_OFFSET_FP) &&
+ verifier.VerifyVector(offset_fp()) &&
VerifyField<uint32_t>(verifier, VT_MODE) &&
verifier.EndTable();
}
@@ -1222,6 +1237,12 @@ struct ResizeAttributeBuilder {
void add_shift(int32_t shift) {
fbb_.AddElement<int32_t>(ResizeAttribute::VT_SHIFT, shift, 0);
}
+ void add_stride_fp(flatbuffers::Offset<flatbuffers::Vector<float>> stride_fp) {
+ fbb_.AddOffset(ResizeAttribute::VT_STRIDE_FP, stride_fp);
+ }
+ void add_offset_fp(flatbuffers::Offset<flatbuffers::Vector<float>> offset_fp) {
+ fbb_.AddOffset(ResizeAttribute::VT_OFFSET_FP, offset_fp);
+ }
void add_mode(ResizeMode mode) {
fbb_.AddElement<uint32_t>(ResizeAttribute::VT_MODE, static_cast<uint32_t>(mode), 0);
}
@@ -1243,9 +1264,13 @@ inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttribute(
flatbuffers::Offset<flatbuffers::Vector<int32_t>> stride = 0,
flatbuffers::Offset<flatbuffers::Vector<int32_t>> offset = 0,
int32_t shift = 0,
+ flatbuffers::Offset<flatbuffers::Vector<float>> stride_fp = 0,
+ flatbuffers::Offset<flatbuffers::Vector<float>> offset_fp = 0,
ResizeMode mode = ResizeMode_UNKNOWN) {
ResizeAttributeBuilder builder_(_fbb);
builder_.add_mode(mode);
+ builder_.add_offset_fp(offset_fp);
+ builder_.add_stride_fp(stride_fp);
builder_.add_shift(shift);
builder_.add_offset(offset);
builder_.add_stride(stride);
@@ -1259,16 +1284,22 @@ inline flatbuffers::Offset<ResizeAttribute> CreateResizeAttributeDirect(
const std::vector<int32_t> *stride = nullptr,
const std::vector<int32_t> *offset = nullptr,
int32_t shift = 0,
+ const std::vector<float> *stride_fp = nullptr,
+ const std::vector<float> *offset_fp = nullptr,
ResizeMode mode = ResizeMode_UNKNOWN) {
auto output_size__ = output_size ? _fbb.CreateVector<int32_t>(*output_size) : 0;
auto stride__ = stride ? _fbb.CreateVector<int32_t>(*stride) : 0;
auto offset__ = offset ? _fbb.CreateVector<int32_t>(*offset) : 0;
+ auto stride_fp__ = stride_fp ? _fbb.CreateVector<float>(*stride_fp) : 0;
+ auto offset_fp__ = offset_fp ? _fbb.CreateVector<float>(*offset_fp) : 0;
return tosa::CreateResizeAttribute(
_fbb,
output_size__,
stride__,
offset__,
shift,
+ stride_fp__,
+ offset_fp__,
mode);
}
@@ -1875,7 +1906,7 @@ struct Version FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return GetField<int32_t>(VT__MAJOR, 0);
}
int32_t _minor() const {
- return GetField<int32_t>(VT__MINOR, 20);
+ return GetField<int32_t>(VT__MINOR, 21);
}
int32_t _patch() const {
return GetField<int32_t>(VT__PATCH, 0);
@@ -1900,7 +1931,7 @@ struct VersionBuilder {
fbb_.AddElement<int32_t>(Version::VT__MAJOR, _major, 0);
}
void add__minor(int32_t _minor) {
- fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 20);
+ fbb_.AddElement<int32_t>(Version::VT__MINOR, _minor, 21);
}
void add__patch(int32_t _patch) {
fbb_.AddElement<int32_t>(Version::VT__PATCH, _patch, 0);
@@ -1923,7 +1954,7 @@ struct VersionBuilder {
inline flatbuffers::Offset<Version> CreateVersion(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t _major = 0,
- int32_t _minor = 20,
+ int32_t _minor = 21,
int32_t _patch = 0,
bool _experimental = false) {
VersionBuilder builder_(_fbb);
diff --git a/verif/tosa/Op.py b/verif/tosa/Op.py
index 09f1364..ea9cdfe 100644
--- a/verif/tosa/Op.py
+++ b/verif/tosa/Op.py
@@ -77,14 +77,15 @@ class Op(object):
TILE = 56
TRANSPOSE = 57
GATHER = 58
- RESIZE = 59
- CAST = 60
- RESCALE = 61
- CONST = 62
- PLACEHOLDER = 63
- IDENTITY = 64
- IDENTITYN = 65
- CUSTOM = 66
- COND_IF = 67
- WHILE_LOOP = 68
+ SCATTER = 59
+ RESIZE = 60
+ CAST = 61
+ RESCALE = 62
+ CONST = 63
+ PLACEHOLDER = 64
+ IDENTITY = 65
+ IDENTITYN = 66
+ CUSTOM = 67
+ COND_IF = 68
+ WHILE_LOOP = 69
diff --git a/verif/tosa/ResizeAttribute.py b/verif/tosa/ResizeAttribute.py
index 1e6941f..35be73a 100644
--- a/verif/tosa/ResizeAttribute.py
+++ b/verif/tosa/ResizeAttribute.py
@@ -107,13 +107,57 @@ class ResizeAttribute(object):
return 0
# ResizeAttribute
- def Mode(self):
+ def StrideFp(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ResizeAttribute
+ def StrideFpAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
+ return 0
+
+ # ResizeAttribute
+ def StrideFpLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # ResizeAttribute
+ def OffsetFp(self, j):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ a = self._tab.Vector(o)
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
+ return 0
+
+ # ResizeAttribute
+ def OffsetFpAsNumpy(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
+ return 0
+
+ # ResizeAttribute
+ def OffsetFpLength(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
+ if o != 0:
+ return self._tab.VectorLen(o)
+ return 0
+
+ # ResizeAttribute
+ def Mode(self):
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
+ if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0
-def ResizeAttributeStart(builder): builder.StartObject(5)
+def ResizeAttributeStart(builder): builder.StartObject(7)
def ResizeAttributeAddOutputSize(builder, outputSize): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(outputSize), 0)
def ResizeAttributeStartOutputSizeVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def ResizeAttributeAddStride(builder, stride): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(stride), 0)
@@ -121,5 +165,9 @@ def ResizeAttributeStartStrideVector(builder, numElems): return builder.StartVec
def ResizeAttributeAddOffset(builder, offset): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(offset), 0)
def ResizeAttributeStartOffsetVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def ResizeAttributeAddShift(builder, shift): builder.PrependInt32Slot(3, shift, 0)
-def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(4, mode, 0)
+def ResizeAttributeAddStrideFp(builder, strideFp): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(strideFp), 0)
+def ResizeAttributeStartStrideFpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddOffsetFp(builder, offsetFp): builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(offsetFp), 0)
+def ResizeAttributeStartOffsetFpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
+def ResizeAttributeAddMode(builder, mode): builder.PrependUint32Slot(6, mode, 0)
def ResizeAttributeEnd(builder): return builder.EndObject()
diff --git a/verif/tosa/Version.py b/verif/tosa/Version.py
index ddfdb2d..e327507 100644
--- a/verif/tosa/Version.py
+++ b/verif/tosa/Version.py
@@ -45,7 +45,7 @@ class Version(object):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
- return 20
+ return 21
# Version
def _patch(self):
@@ -63,7 +63,7 @@ class Version(object):
def VersionStart(builder): builder.StartObject(4)
def VersionAdd_major(builder, Major): builder.PrependInt32Slot(0, Major, 0)
-def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 20)
+def VersionAdd_minor(builder, Minor): builder.PrependInt32Slot(1, Minor, 21)
def VersionAdd_patch(builder, Patch): builder.PrependInt32Slot(2, Patch, 0)
def VersionAdd_experimental(builder, Experimental): builder.PrependBoolSlot(3, Experimental, 0)
def VersionEnd(builder): return builder.EndObject()
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index 07e0e1a..3b7e339 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -63,12 +63,14 @@ class TosaSerializerUnion:
self.floats = []
self.strings = []
self.intvecs = []
+ self.fpvecs = []
def serialize(self, builder):
# We have to build strings and vectors first
strList = []
intVecList = []
+ fpVecList = []
for fcn, val in self.strings:
strList.append((fcn, builder.CreateString(val)))
@@ -76,6 +78,9 @@ class TosaSerializerUnion:
for fcn, val in self.intvecs:
intVecList.append((fcn, TosaSerializer.serializeInt32Vec(builder, val)))
+ for fcn, val in self.fpvecs:
+ fpVecList.append((fcn, TosaSerializer.serializeFpVec(builder, val)))
+
startFcn, endFcn = self.optFcns
# Then serialize the options object from the list of primitives and
@@ -96,6 +101,9 @@ class TosaSerializerUnion:
for fcn, val in intVecList:
fcn(builder, val)
+ for fcn, val in fpVecList:
+ fcn(builder, val)
+
return endFcn(builder)
class TosaSerializerAttribute(TosaSerializerUnion):
@@ -193,7 +201,7 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.TileAttributeAddMultiples,
multiples))
- def ResizeAttribute(self, output_size, stride, offset, shift, mode):
+ def ResizeAttribute(self, output_size, stride, offset, shift, stride_fp, offset_fp, mode):
from tosa import ResizeAttribute as a, Attribute
self.utype = Attribute.Attribute().ResizeAttribute
@@ -207,6 +215,10 @@ class TosaSerializerAttribute(TosaSerializerUnion):
offset))
self.ints.append((a.ResizeAttributeAddShift,
shift))
+ self.fpvecs.append((a.ResizeAttributeAddStrideFp,
+ stride_fp))
+ self.fpvecs.append((a.ResizeAttributeAddOffsetFp,
+ offset_fp))
self.ints.append((a.ResizeAttributeAddMode,
mode))
@@ -692,6 +704,13 @@ class TosaSerializer:
return builder.EndVector(len(vec))
@staticmethod
+ def serializeFpVec(builder, vec):
+ builder.StartVector(4, len(vec), 4)
+ for v in vec[::-1]:
+ builder.PrependFloat32(v)
+ return builder.EndVector(len(vec))
+
+ @staticmethod
def serializeObjVec(builder, vec, start_fcn):
serialized_vec = []
for v in vec[::-1]:
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 302e4f4..0e57a7b 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -158,6 +158,29 @@ class TosaTensorGen():
return shape_list
@staticmethod
+ def tgScatter(testGen, opName, rank):
+ pl, const = opName['operands']
+
+ assert(pl == 2)
+ assert(const == 0)
+ assert(rank == 3)
+
+ values_in_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
+
+ W = testGen.randInt(testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1])
+ input_shape = [values_in_shape[0], W, values_in_shape[2]]
+
+ shape_list = []
+ shape_list.append(values_in_shape.copy())
+ shape_list.append(input_shape.copy())
+
+ return shape_list
+
+ @staticmethod
def tgBroadcastFuzz(testGen, op, rank):
shape = testGen.makeShape(rank)
@@ -650,6 +673,8 @@ class TosaArgGen:
outputDTypeList = [ DType.INT8 ]
elif m == ResizeMode.BILINEAR and dtype == DType.INT16:
outputDTypeList = [ DType.INT48 ]
+ elif dtype == DType.FLOAT:
+ outputDTypeList = [ DType.FLOAT ]
else:
continue
@@ -659,19 +684,52 @@ class TosaArgGen:
# Randomly generate legal output dimensions and shift
# and then compute the stride and offset based on them
output_dims = [ testGen.randInt(), testGen.randInt() ]
-
- shift = testGen.randInt(1, 11)
-
- stride = [ (ifm_shape[1] << shift) // output_dims[0],
- (ifm_shape[2] << shift) // output_dims[1] ]
-
- offset = [ testGen.randInt(-stride[0], (ifm_shape[1] << shift) - (output_dims[0] - 1) * stride[0]),
- testGen.randInt(-stride[1], (ifm_shape[2] << shift) - (output_dims[1] - 1) * stride[1]) ]
-
- arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
- testGen.typeStr(outputDType), stride[0], stride[1],
- offset[0], offset[1]),
- [m, stride, offset, shift, output_dims, dtype, outputDType]))
+ in_center_h = (ifm_shape[1] - 1) / 2.0
+ in_center_w = (ifm_shape[2] - 1) / 2.0
+ out_center_h = (output_dims[0] - 1) / 2.0
+ out_center_w = (output_dims[1] - 1) / 2.0
+
+ fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
+ fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
+ fp_offset_y = in_center_h - fp_stride_y * out_center_h
+ fp_offset_x = in_center_w - fp_stride_x * out_center_w
+
+ if outputDType == DType.FLOAT:
+ shift = 0
+ stride = [0, 0]
+ offset = [0, 0]
+ stride_fp = [ fp_stride_y, fp_stride_x]
+ offset_fp = [ fp_offset_y, fp_offset_x]
+ arg_list.append(('mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}'.format(m, output_dims[0], output_dims[1],
+ testGen.typeStr(outputDType), stride_fp[0], stride_fp[1],
+ offset_fp[0], offset_fp[1]),
+ [m, stride, offset, shift, stride_fp, offset_fp, output_dims, dtype, outputDType]))
+ else:
+ shift = 11
+ unit = float(1 << shift)
+ stride_y = int(round(fp_stride_y * unit))
+ stride_x = int(round(fp_stride_x * unit))
+ offset_y = int(round(fp_offset_y * unit))
+ offset_x = int(round(fp_offset_x * unit))
+
+ while (stride_y >= 32768 or stride_x >= 32768 or offset_y >= 32768 or offset_x >= 32768 or offset_y < -32768 or offset_x < -32768):
+ shift = shift - 1
+ unit = float(1 << shift)
+ stride_y = int(round(fp_stride_y * unit))
+ stride_x = int(round(fp_stride_x * unit))
+ offset_y = int(round(fp_offset_y * unit))
+ offset_x = int(round(fp_offset_x * unit))
+
+ stride = [ stride_y, stride_x]
+ offset = [ offset_y, offset_x]
+
+ stride_fp = [0.0, 0.0]
+ offset_fp = [0.0, 0.0]
+
+ arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
+ testGen.typeStr(outputDType), stride[0], stride[1],
+ offset[0], offset[1]),
+ [m, stride, offset, shift, stride_fp, offset_fp, output_dims, dtype, outputDType]))
return arg_list
@@ -1139,29 +1197,44 @@ class TosaTestGen:
return result_tens
- def build_gather(self, op, values, axis):
+ def build_gather(self, op, values):
# Create a new indicies tensor
# here with data that doesn't exceed the dimensions of the values tensor
- max_val = values.shape[axis]
- indicies_arr = np.int32(self.rng.integers(low=0, high=max_val, size=[self.randInt(1, max_val + 1)]))
+ K = values.shape[1] # K
+ W = self.randInt(self.args.tensor_shape_range[0], self.args.tensor_shape_range[1]) # W
+ indicies_arr = np.int32(self.rng.integers(low=0, high=K, size=[values.shape[0], W])) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, Usage.INDEX, [], indicies_arr)
- result_tens = OutputShaper.gatherOp(self.ser, values, indicies, axis)
+ result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
- attr = ts.TosaSerializerAttribute()
- attr.AxisAttribute(axis)
+ self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name])
- self.ser.addOperator(op, [values.name, indicies.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_scatter(self, op, values_in, input):
+
+ # Create a new indicies tensor
+ # here with data that doesn't exceed the dimensions of the values_in tensor
+
+ K = values_in.shape[1] # K
+ W = input.shape[1] # W
+ indicies_arr = np.int32(self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])) # (N, W)
+ indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, Usage.INDEX, [], indicies_arr)
+
+ result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
+
+ self.ser.addOperator(op, [values_in.name, indicies.name, input.name], [result_tens.name])
return result_tens
- def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
- result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype)
+ def build_resize(self, op, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype):
+ result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype)
attr = ts.TosaSerializerAttribute()
- attr.ResizeAttribute(output_dims, stride, offset, shift, mode)
+
+ attr.ResizeAttribute(output_dims, stride, offset, shift, stride_fp, offset_fp, mode)
self.ser.addOperator(op, [input.name], [result_tens.name], attr)
return result_tens
@@ -1966,10 +2039,20 @@ class TosaTestGen:
# Scatter/Gather
'gather':
{ 'op': Op.GATHER,
+ # Only specify 'values' tensor here. 'indices' is generated in op building stage
'operands': (1, 0),
- 'build_fcn': (build_gather, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
- 'types': TYPE_INT },
+ 'rank': (3, 3),
+ 'build_fcn': (build_gather, TosaTensorGen.tgBasic, None),
+ 'types': TYPE_INT_FP },
+ 'scatter':
+ { 'op': Op.SCATTER,
+ # Only specify 'values_in' tensor here.
+ #'indices' and 'input' are generated in op building stage
+ 'operands': (2, 0),
+ 'rank': (3, 3),
+ 'build_fcn': (build_scatter, TosaTensorGen.tgScatter, None),
+ 'types': TYPE_INT_FP },
# Image operations
'resize':
@@ -1977,7 +2060,7 @@ class TosaTestGen:
'operands': (1, 0),
'rank': (4, 4),
'build_fcn': ( build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
- 'types': [ DType.INT8, DType.INT16 ] },
+ 'types': [ DType.INT8, DType.INT16, DType.FLOAT ] },
# Data nodes
@@ -2319,11 +2402,27 @@ class OutputShaper:
return ser.addOutput(output_shape, a.dtype, a.usage, a.dformat)
@staticmethod
- def gatherOp(ser, values, indicies, axis):
- # indicies minus the axis + values - the indexes used to look up values.
- output_shape = [*values.shape[0:axis], indicies.shape[0], *values.shape[axis+1:]]
+ def gatherOp(ser, values, indices):
+ assert len(values.shape) == 3
+ assert len(indices.shape) == 2
+ assert values.shape[0] == indices.shape[0]
+
+ output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
+
+ return ser.addOutput(output_shape, values.dtype, values.usage, values.dformat)
+
+ @staticmethod
+ def scatterOp(ser, values_in, indices, input):
+ assert len(values_in.shape) == 3
+ assert len(indices.shape) == 2
+ assert len(input.shape) == 3
+ assert values_in.shape[0] == indices.shape[0] # N
+ assert input.shape[1] == indices.shape[1] # W
+ assert values_in.shape[2] == input.shape[2] # C
+
+ output_shape = values_in.shape
- return ser.addOutput(output_shape, values.dtype, indicies.usage, indicies.dformat)
+ return ser.addOutput(output_shape, values_in.dtype, values_in.usage, values_in.dformat)
@staticmethod
def tableOp(ser, input, table):
@@ -2331,12 +2430,16 @@ class OutputShaper:
return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat)
@staticmethod
- def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
+ def resizeOp(ser, input, mode, stride, offset, shift, stride_fp, offset_fp, output_dims, input_dtype, output_dtype):
output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
- if stride[0] <= 0 or stride[1] <= 0:
- ser.setExpectedFailure(True, 'Negative or zero stride')
+ if input_dtype == DType.FLOAT:
+ if stride_fp[0] <= 0 or stride_fp[1] <= 0:
+ ser.setExpectedFailure(True, 'Negative or zero stride')
+ else:
+ if stride[0] <= 0 or stride[1] <= 0:
+ ser.setExpectedFailure(True, 'Negative or zero stride')
if mode == ResizeMode.BILINEAR:
if input_dtype == DType.INT8:
@@ -2345,6 +2448,9 @@ class OutputShaper:
elif input_dtype == DType.INT16:
if output_dtype != DType.INT48:
ser.setexpectedfailure(true, 'Invalid output data type')
+ elif input_dtype == DType.FLOAT:
+ if output_dtype != DType.FLOAT:
+ ser.setexpectedfailure(true, 'Invalid output data type')
else:
ser.setexpectedfailure(true, 'Invalid input data type')
@@ -2355,6 +2461,9 @@ class OutputShaper:
elif input_dtype == DType.INT16:
if output_dtype != DType.INT16:
ser.setexpectedfailure(true, 'Invalid output data type')
+ elif input_dtype == DType.FLOAT:
+ if output_dtype != DType.FLOAT:
+ ser.setexpectedfailure(true, 'Invalid output data type')
else:
ser.setexpectedfailure(true, 'Invalid input data type')