Skip to content

Commit 7550668

Browse files
Move basic rewriting code to aesara.graph.rewriting
- `aesara.graph.opt` has been changed to `aesara.graph.rewriting.basic` - `aesara.graph.opt_utils` has been changed to `aesara.graph.rewriting.utils` - `aesara.graph.optdb` has been changed to `aesara.graph.rewriting.db` - `aesara.graph.unify` has been changed to `aesara.graph.rewriting.unify` - `aesara.graph.kanren` has been changed to `aesara.graph.rewriting.kanren` The tests associated with each module have been updated accordingly.
1 parent 746eecb commit 7550668

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+146
-134
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ repos:
1515
aesara/breakpoint\.py|
1616
aesara/graph/op\.py|
1717
aesara/compile/nanguardmode\.py|
18-
aesara/graph/opt\.py|
18+
aesara/graph/rewriting/basic\.py|
1919
aesara/tensor/var\.py|
2020
)$
2121
- id: check-merge-conflict

aesara/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from aesara.graph.fg import FunctionGraph
2525
from aesara.graph.null_type import NullType
2626
from aesara.graph.op import HasInnerGraph, Op
27-
from aesara.graph.opt import in2out, node_rewriter
27+
from aesara.graph.rewriting.basic import in2out, node_rewriter
2828
from aesara.graph.utils import MissingInputError
2929
from aesara.tensor.basic_opt import ShapeFeature
3030

aesara/compile/mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
from aesara.compile.function.types import Supervisor
1111
from aesara.configdefaults import config
1212
from aesara.graph.destroyhandler import DestroyHandler
13-
from aesara.graph.opt import (
13+
from aesara.graph.rewriting.basic import (
1414
CheckStackTraceRewriter,
1515
GraphRewriter,
1616
MergeOptimizer,
1717
NodeProcessingGraphRewriter,
1818
)
19-
from aesara.graph.optdb import (
19+
from aesara.graph.rewriting.db import (
2020
EquilibriumDB,
2121
LocalGroupDB,
2222
RewriteDatabase,

aesara/graph/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from aesara.graph.op import Op
1414
from aesara.graph.type import Type
1515
from aesara.graph.fg import FunctionGraph
16-
from aesara.graph.opt import node_rewriter, graph_rewriter
17-
from aesara.graph.opt_utils import rewrite_graph
18-
from aesara.graph.optdb import RewriteDatabaseQuery
16+
from aesara.graph.rewriting.basic import node_rewriter, graph_rewriter
17+
from aesara.graph.rewriting.utils import rewrite_graph
18+
from aesara.graph.rewriting.db import RewriteDatabaseQuery
1919

2020
# isort: on

aesara/graph/rewriting/__init__.py

Whitespace-only changes.

aesara/graph/opt.py renamed to aesara/graph/rewriting/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@
4040

4141

4242
if TYPE_CHECKING:
43-
from aesara.graph.unify import Var
43+
from aesara.graph.rewriting.unify import Var
4444

4545

46-
_logger = logging.getLogger("aesara.graph.opt")
46+
_logger = logging.getLogger("aesara.graph.rewriting.basic")
4747

4848
RemoveKeyType = Literal["remove"]
4949
TransformOutputType = Union[
@@ -1586,7 +1586,7 @@ def __init__(
15861586
often.
15871587
15881588
"""
1589-
from aesara.graph.unify import convert_strs_to_vars
1589+
from aesara.graph.rewriting.unify import convert_strs_to_vars
15901590

15911591
var_map: Dict[str, "Var"] = {}
15921592
self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map)

aesara/graph/optdb.py renamed to aesara/graph/rewriting/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
77

88
from aesara.configdefaults import config
9-
from aesara.graph import opt as aesara_rewriting
9+
from aesara.graph.rewriting import basic as aesara_rewriting
1010
from aesara.misc.ordered_set import OrderedSet
1111
from aesara.utils import DefaultOrderedDict
1212

aesara/graph/kanren.py renamed to aesara/graph/rewriting/kanren.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from unification.variable import Var
77

88
from aesara.graph.basic import Apply, Variable
9-
from aesara.graph.opt import NodeRewriter
10-
from aesara.graph.unify import eval_if_etuple
9+
from aesara.graph.rewriting.basic import NodeRewriter
10+
from aesara.graph.rewriting.unify import eval_if_etuple
1111

1212

1313
class KanrenRelationSub(NodeRewriter):
@@ -24,7 +24,7 @@ class KanrenRelationSub(NodeRewriter):
2424
from kanren import eq, conso, var
2525
2626
import aesara.tensor as at
27-
from aesara.graph.kanren import KanrenRelationSub
27+
from aesara.graph.rewriting.kanren import KanrenRelationSub
2828
2929
3030
def relation(in_lv, out_lv):

aesara/graph/unify.py renamed to aesara/graph/rewriting/unify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __str__(self):
6969
return f"~{self.token} [{self.constraint}]"
7070

7171
def __repr__(self):
72-
return f"ConstrainedVar({repr(self.constraint)}, {self.token})"
72+
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"
7373

7474

7575
def car_Variable(x):

aesara/graph/opt_utils.py renamed to aesara/graph/rewriting/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
vars_between,
1212
)
1313
from aesara.graph.fg import FunctionGraph
14-
from aesara.graph.optdb import RewriteDatabaseQuery
14+
from aesara.graph.rewriting.db import RewriteDatabaseQuery
1515

1616

1717
if TYPE_CHECKING:
18-
from aesara.graph.opt import GraphRewriter
18+
from aesara.graph.rewriting.basic import GraphRewriter
1919

2020

2121
def rewrite_graph(
@@ -89,7 +89,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
8989
See help on `aesara.graph.basic.is_same_graph` for additional documentation.
9090
9191
"""
92-
from aesara.graph.opt import MergeOptimizer
92+
from aesara.graph.rewriting.basic import MergeOptimizer
9393

9494
if givens is None:
9595
givens = {}

0 commit comments

Comments
 (0)