Skip to content

Conversation

@xuzhenqi
Copy link
Contributor

Fix inputs num mismatch for node matching.

cc @justinchuby @gramalingam

@codecov
Copy link

codecov bot commented Sep 27, 2024

Codecov Report

Attention: Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Project coverage is 75.08%. Comparing base (37b11fc) to head (3aaa866).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/rewriter/pattern.py 0.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1885      +/-   ##
==========================================
+ Coverage   75.02%   75.08%   +0.06%     
==========================================
  Files         252      252              
  Lines       27415    27417       +2     
  Branches     5012     3190    -1822     
==========================================
+ Hits        20567    20587      +20     
- Misses       5875     5880       +5     
+ Partials      973      950      -23     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby
Copy link
Collaborator

Thanks for the PR. Could you provide a little bit of context on why this is needed?

@xuzhenqi
Copy link
Contributor Author

xuzhenqi commented Oct 8, 2024

Thanks for the PR. Could you provide a little bit of context on why this is needed?

@justinchuby

Suppose we need to remove an optional input of a node, for example, we want to replace DFT(x, axis=-2) with DFT(x), since the default axis is -2, so we can remove the axis input.

The solusion:

class DFTSimplify(pat.RewriteRuleAsClass):
    @classmethod
    def pattern(
        cls, ope: pat.OpsetPatternBuilder, inp: pat.Var, axis: pat.Var, onesided: pat.Var
    ) -> pat.NodeOutputPattern:
        ret = ope.DFT(inp, None, axis, onesided=onesided)
        assert isinstance(ret, pat.NodeOutputPattern)
        return ret

    @classmethod
    def rewrite(
        cls, ope: pat.RewriterContext, inp: ir.Value, axis: ir.Value, onesided: ir.Attr
    ) -> ir.Value:
        del axis
        return ope.DFT(inp, dft_length=None, axis=None, onesided=onesided)

    @classmethod
    def check(cls, _context: None, inp: ir.Value, axis: ir.Value, onesided: ir.Attr) -> bool:
        del onesided
        logging.info("check")
        if axis.const_value is None:
            return False
        value = axis.const_value.numpy()
        axis_value = value.item()
        if axis_value < 0:
            return axis_value == -2
        if inp.shape is None:
            return False
        return axis_value + 2 == len(inp.shape)

The code will crash when matching a normal DFT(x) node, since there is no axis input to feed rewrite function.

@xuzhenqi xuzhenqi force-pushed the fix_input_num_mismatch_for_node_matching branch from 8b31a6f to 0dbc14c Compare October 8, 2024 05:57
@justinchuby justinchuby enabled auto-merge (squash) October 8, 2024 15:57
@justinchuby justinchuby disabled auto-merge October 8, 2024 17:12
@justinchuby
Copy link
Collaborator

@xuzhenqi xuzhenqi force-pushed the fix_input_num_mismatch_for_node_matching branch from 0dbc14c to 5a62e2d Compare October 9, 2024 02:34
@xuzhenqi
Copy link
Contributor Author

xuzhenqi commented Oct 9, 2024

Could you fix the lint errors? https://github.com/microsoft/onnxscript/actions/runs/11229454770/job/31245771279?pr=1885 Thanks

Fixed.

Signed-off-by: xuzhenqi <[email protected]>
@xuzhenqi xuzhenqi force-pushed the fix_input_num_mismatch_for_node_matching branch from 5a62e2d to 4419f27 Compare October 9, 2024 06:11
@justinchuby justinchuby merged commit a7c797d into microsoft:main Oct 9, 2024
@xuzhenqi xuzhenqi deleted the fix_input_num_mismatch_for_node_matching branch October 10, 2024 05:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants