Skip to content

Commit a2bfe47

Browse files
authored
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Address nod-ai/SHARK-ModelDev#833
1 parent 617c1c7 commit a2bfe47

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,31 @@ struct OpBinder {
338338
return failure();
339339
}
340340

341+
ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
342+
StringRef nameSuffix,
343+
ArrayRef<float> defaults) {
344+
SmallString<64> name("torch.onnx.");
345+
name.append(nameSuffix);
346+
auto attr = op->getAttr(name);
347+
if (!attr) {
348+
values.append(defaults.begin(), defaults.end());
349+
return success();
350+
}
351+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
352+
for (auto element : arrayAttr) {
353+
auto floatAttr = dyn_cast<FloatAttr>(element);
354+
if (!floatAttr)
355+
return failure();
356+
FloatType t = cast<FloatType>(floatAttr.getType());
357+
if (t.getWidth() != 32)
358+
return failure();
359+
values.push_back(floatAttr.getValue().convertToFloat());
360+
}
361+
return success();
362+
}
363+
return failure();
364+
}
365+
341366
ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
342367
StringRef nameSuffix) {
343368
SmallString<64> name("torch.onnx.");

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43394339
llvm::SmallVector<int64_t> ngram_counts;
43404340
llvm::SmallVector<int64_t> ngram_indexes;
43414341
llvm::SmallVector<int64_t> pool_int64s;
4342+
llvm::SmallVector<float> weights;
43424343
std::string mode;
43434344
int64_t min_gram_length;
43444345
int64_t max_gram_length;
@@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
43564357
binder.tensorOperand(input) || binder.tensorResultType(resultType))
43574358
return failure();
43584359

4359-
if (mode != "TF")
4360-
return rewriter.notifyMatchFailure(binder.op,
4361-
"TF mode supported only");
4360+
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
4361+
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
4362+
return failure();
4363+
43624364
if (pool_int64s.size() == 0)
43634365
return rewriter.notifyMatchFailure(
43644366
binder.op, "pool_int64s empty, only integers supported");
@@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
45844586
binder.getLoc(), loopConditionTrue, ValueRange({count}));
45854587
}
45864588
count = skipLoop.getResult(0);
4587-
// insert count "tf" into output
45884589
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
45894590
binder.getLoc(), count);
4591+
if (mode == "IDF" || mode == "TFIDF") {
4592+
// both IDF and TFIDF modes use weights
4593+
float weight = weights[ngram_i];
4594+
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
4595+
binder.getLoc(), rewriter.getF64FloatAttr(weight));
4596+
4597+
// TFIDF
4598+
Value multiplier = countFloat;
4599+
if (mode == "IDF") {
4600+
// All the counts larger than 1 would be truncated to 1
4601+
// and the i-th element in weights would be used to scale
4602+
// (by multiplication) the count of the i-th n-gram in pool.
4603+
4604+
Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
4605+
binder.getLoc(), count);
4606+
// compare intCount > 0
4607+
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
4608+
binder.getLoc(), intCount, zero);
4609+
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
4610+
binder.getLoc(), gtZeroCount);
4611+
Value gtZeroCountFloat =
4612+
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
4613+
gtZeroCount);
4614+
multiplier = gtZeroCountFloat;
4615+
}
4616+
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
4617+
binder.getLoc(), multiplier, constWeight);
4618+
}
45904619
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
45914620
binder.getLoc(),
45924621
rewriter.getType<Torch::ListType>(

0 commit comments

Comments
 (0)