From 989cb050228b47085189b1c5cb0d9b705e1060e7 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Wed, 28 Apr 2021 16:29:44 -0700 Subject: Support mixed-precision input tensors for TOSA unit test. Bring CONV2D/DEPTHWISE_CONV2D/TRANSPOSE_CONV2D/FULLY_CONNECTED up running. Other minor fixes: - reference model should bail out if shape is invalid, along with "goto done" cleanup. - cleanup typos/duplicate in tosa_test_gen.py/tosa_serializer.py. - wrong input_zp/output_zp being generated for RESCALE. Signed-off-by: Kevin Cheng Change-Id: Ic1f3fe0090482bdee8a61508be7c738714191e19 --- reference_model/src/tensor.h | 45 +++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) (limited to 'reference_model/src/tensor.h') diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d39cc7c..6c0622e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -33,9 +33,7 @@ class GraphNode; class Tensor { public: - Tensor(std::string tensorName_, - DType tensorDtype__, - std::vector shape_); + Tensor(std::string tensorName_, DType tensorDtype__, std::vector shape_); virtual ~Tensor(); @@ -240,9 +238,7 @@ template class TensorTemplate : public Tensor { public: - TensorTemplate(std::string tensorName_, - DType tensorDtype_, - std::vector shape_) + TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector shape_) : Tensor(tensorName_, tensorDtype_, shape_) { tensor = nullptr; @@ -606,11 +602,15 @@ int Tensor6::dumpTensor(FILE* out) const; class TensorFactory { public: - static Tensor* newTensor(std::string tensorName_, - DType tensorDtype_, - std::vector shape_, - const uint32_t rank) + static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector shape_, const uint32_t rank) { + // Bail out if any dimension is invalid. + for (auto& dim : shape_) + { + if (dim <= 0) + goto done; + } + switch (tensorDtype_) { case DType_FLOAT: @@ -630,9 +630,8 @@ public: return new Tensor5(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -654,9 +653,8 @@ public: return new Tensor5(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT48: switch (rank) { @@ -674,9 +672,8 @@ public: return new Tensor5(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_BOOL: switch (rank) { @@ -694,16 +691,22 @@ public: return new Tensor5(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; default: - goto done; + break; } done: - FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_], - rank); + std::string shape_str("["); + for (auto& dim : shape_) + { + shape_str += (std::to_string(dim) + ", "); + } + shape_str.append("]"); + + FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(), + EnumNamesDType()[tensorDtype_], rank, shape_str.c_str()); } static Tensor* newTensor(DType type, const std::vector shape); -- cgit v1.2.1