@@ -4350,14 +4350,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43504350 if (binder.s64IntegerArrayAttr (ngram_counts, " ngram_counts" , {}) ||
43514351 binder.s64IntegerArrayAttr (ngram_indexes, " ngram_indexes" , {}) ||
43524352 binder.s64IntegerArrayAttr (pool_int64s, " pool_int64s" , {}) ||
4353- binder.f32FloatArrayAttr (weights, " weights" , {}) ||
43544353 binder.customOpNameStringAttr (mode, " mode" , " " ) ||
43554354 binder.s64IntegerAttr (min_gram_length, " min_gram_length" , 0 ) ||
43564355 binder.s64IntegerAttr (max_gram_length, " max_gram_length" , 0 ) ||
43574356 binder.s64IntegerAttr (max_skip_count, " max_skip_count" , 0 ) ||
43584357 binder.tensorOperand (input) || binder.tensorResultType (resultType))
43594358 return failure ();
43604359
4360+ llvm::SmallVector<float > defaultWeights (ngram_indexes.size (), 1 .0f );
4361+ if (binder.f32FloatArrayAttr (weights, " weights" , defaultWeights))
4362+ return failure ();
4363+
43614364 if (pool_int64s.size () == 0 )
43624365 return rewriter.notifyMatchFailure (
43634366 binder.op , " pool_int64s empty, only integers supported" );
0 commit comments