@@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43394339 llvm::SmallVector<int64_t > ngram_counts;
43404340 llvm::SmallVector<int64_t > ngram_indexes;
43414341 llvm::SmallVector<int64_t > pool_int64s;
4342+ llvm::SmallVector<float > weights;
43424343 std::string mode;
43434344 int64_t min_gram_length;
43444345 int64_t max_gram_length;
@@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43564357 binder.tensorOperand (input) || binder.tensorResultType (resultType))
43574358 return failure ();
43584359
4359- if (mode != " TF" )
4360- return rewriter.notifyMatchFailure (binder.op ,
4361- " TF mode supported only" );
4360+ llvm::SmallVector<float > defaultWeights (ngram_indexes.size (), 1 .0f );
4361+ if (binder.f32FloatArrayAttr (weights, " weights" , defaultWeights))
4362+ return failure ();
4363+
43624364 if (pool_int64s.size () == 0 )
43634365 return rewriter.notifyMatchFailure (
43644366 binder.op , " pool_int64s empty, only integers supported" );
@@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
45844586 binder.getLoc (), loopConditionTrue, ValueRange ({count}));
45854587 }
45864588 count = skipLoop.getResult (0 );
4587- // insert count "tf" into output
45884589 Value countFloat = rewriter.create <Torch::AtenFloatScalarOp>(
45894590 binder.getLoc (), count);
4591+ if (mode == " IDF" || mode == " TFIDF" ) {
4592+ // both IDF and TFIDF modes use weights
4593+ float weight = weights[ngram_i];
4594+ Value constWeight = rewriter.create <Torch::ConstantFloatOp>(
4595+ binder.getLoc (), rewriter.getF64FloatAttr (weight));
4596+
4597+ // TFIDF
4598+ Value multiplier = countFloat;
4599+ if (mode == " IDF" ) {
4600+ // All the counts larger than 1 would be truncated to 1
4601+ // and the i-th element in weights would be used to scale
4602+ // (by multiplication) the count of the i-th n-gram in pool.
4603+
4604+ Value intCount = rewriter.create <Torch::AtenIntScalarOp>(
4605+ binder.getLoc (), count);
4606+ // compare intCount > 0
4607+ Value gtZeroCount = rewriter.create <Torch::AtenGtIntOp>(
4608+ binder.getLoc (), intCount, zero);
4609+ gtZeroCount = rewriter.create <Torch::AtenIntBoolOp>(
4610+ binder.getLoc (), gtZeroCount);
4611+ Value gtZeroCountFloat =
4612+ rewriter.create <Torch::AtenFloatScalarOp>(binder.getLoc (),
4613+ gtZeroCount);
4614+ multiplier = gtZeroCountFloat;
4615+ }
4616+ countFloat = rewriter.create <Torch::AtenMulFloatOp>(
4617+ binder.getLoc (), multiplier, constWeight);
4618+ }
45904619 Value dataList = rewriter.create <Torch::PrimListConstructOp>(
45914620 binder.getLoc (),
45924621 rewriter.getType <Torch::ListType>(
0 commit comments