Skip to content

Commit ecbc6c7

Browse files
authored
Merge pull request #90 from graphcore-research/import-fixes
Import fixes (fixes #89)
2 parents 215ba63 + 3efc366 commit ecbc6c7

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

pyproject.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,8 @@ classifiers = [
2828
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2929
]
3030
dependencies = [
31-
"datasets",
3231
"docstring-parser",
3332
"einops",
34-
"numpy<2.0.0",
35-
"seaborn",
3633
"tabulate",
3734
"torch>=2.2",
3835
]
@@ -46,6 +43,12 @@ dynamic = ["version"]
4643
[project.optional-dependencies]
4744
dev = ["check-manifest"]
4845
test = ["pytest"]
46+
analysis = [
47+
"datasets",
48+
"matplotlib",
49+
"pandas",
50+
"seaborn",
51+
]
4952

5053
[tool.setuptools]
5154
packages = ["unit_scaling", "unit_scaling.core", "unit_scaling.transforms"]
@@ -55,3 +58,7 @@ version = {attr = "unit_scaling._version.__version__"}
5558

5659
[tool.setuptools_scm]
5760
version_file = "unit_scaling/_version.py"
61+
62+
[tool.isort]
63+
profile = "black"
64+
extend_skip = ["unit_scaling/_version.py"]

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
datasets==3.1.0
77
docstring-parser==0.16
88
einops==0.8.0
9-
numpy==1.26.4
9+
numpy==2.2.6
1010
seaborn==0.13.2
1111
tabulate==0.9.0
1212
torch==2.5.1+cpu

unit_scaling/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
TransformerLayer,
2727
)
2828
from ._version import __version__
29-
from .analysis import visualiser
3029
from .core.functional import transformer_residual_scaling_rule
3130
from .parameter import MupType, Parameter
3231

@@ -58,6 +57,5 @@
5857
# Functions
5958
"Parameter",
6059
"transformer_residual_scaling_rule",
61-
"visualiser",
6260
"__version__",
6361
]

unit_scaling/analysis.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,21 @@
88
from math import isnan
99
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
1010

11-
import matplotlib
12-
import matplotlib.colors
13-
import matplotlib.pyplot as plt
14-
import pandas as pd
15-
import seaborn as sns # type: ignore[import-untyped]
16-
from datasets import load_dataset # type: ignore[import-untyped]
17-
from torch import Tensor, nn
18-
from torch.fx.graph import Graph
19-
from torch.fx.node import Node
11+
try:
12+
import matplotlib
13+
import matplotlib.colors
14+
import matplotlib.pyplot as plt
15+
import pandas as pd
16+
import seaborn as sns # type: ignore[import-untyped]
17+
from datasets import load_dataset # type: ignore[import-untyped]
18+
from torch import Tensor, nn
19+
from torch.fx.graph import Graph
20+
from torch.fx.node import Node
21+
except ImportError as e:
22+
raise ImportError(
23+
"Optional dependencies for `unit_scaling.analysis` are missing."
24+
" Please install `unit-scaling[analysis]`"
25+
) from e
2026

2127
from ._internal_utils import generate__all__
2228
from .transforms import (

0 commit comments

Comments
 (0)