Skip to content

Commit 64475e8

Browse files
committed
default value for weights vector
1 parent 8440c11 commit 64475e8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4350,14 +4350,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43504350
if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) ||
43514351
binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) ||
43524352
binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) ||
4353-
binder.f32FloatArrayAttr(weights, "weights", {}) ||
43544353
binder.customOpNameStringAttr(mode, "mode", "") ||
43554354
binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) ||
43564355
binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) ||
43574356
binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) ||
43584357
binder.tensorOperand(input) || binder.tensorResultType(resultType))
43594358
return failure();
43604359

4360+
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
4361+
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
4362+
return failure();
4363+
43614364
if (pool_int64s.size() == 0)
43624365
return rewriter.notifyMatchFailure(
43634366
binder.op, "pool_int64s empty, only integers supported");

0 commit comments

Comments
 (0)