✨ RL compilation step wit GNN#563
✨ RL compilation step wit GNN#563antotu wants to merge 131 commits intomunich-quantum-toolkit:qce-experimentsfrom
Conversation
📝 WalkthroughWalkthroughThis PR integrates Graph Neural Network (GNN) capabilities into the MQT Predictor module. It adds torch, torch-geometric, optuna, and safetensors dependencies; implements GNN architecture with GraphConvolutionSage and MLP components; introduces GNN-specific training, evaluation, and prediction functions; and extends the prediction pipeline to support both classical (Random Forest) and graph-based (GNN) modeling pathways with Optuna-powered hyperparameter optimization. Changes
Sequence DiagramssequenceDiagram
participant Client as Predictor Setup
participant Optuna as Optuna Study
participant GNN as GNN Model
participant Loader as DataLoader
participant Eval as Evaluation
Client->>Optuna: create_study() for hyperparameter optimization
Optuna->>Optuna: sampler.ask() for trial parameters
Optuna->>GNN: instantiate with sampled hyperparams
GNN->>Loader: process training graphs
Loader->>GNN: batch forward pass
GNN->>Eval: compute loss
Eval-->>Optuna: return trial value
Optuna->>Optuna: optimize (repeat until convergence)
Optuna-->>Client: return best hyperparams & model
sequenceDiagram
participant Circuit as Quantum Circuit
participant DAG as DAG Constructor
participant GNN as GNN Encoder
participant MLP as MLP Head
participant Output as Device Scores
Circuit->>DAG: decompose & extract graph structure
DAG->>DAG: compute node features (gate types, qubits)
DAG-->>GNN: node_vector, edge_index
GNN->>GNN: apply SAGEConv layers with residuals
GNN->>GNN: optional SAGPooling
GNN->>GNN: global mean pooling
GNN-->>MLP: graph embedding
MLP->>MLP: apply hidden layers with activation
MLP-->>Output: per-device logits/scores
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@pyproject.toml`:
- Around line 47-50: The pyproject declares torch only for macOS x86 Python<3.13
but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.
In `@src/mqt/predictor/_version.py`:
- Around line 9-11: Remove the autogenerated file src/mqt/predictor/_version.py
from version control (stop tracking it and delete from the repo tree), add an
entry for src/mqt/predictor/_version.py to .gitignore so it isn't committed
again, and ensure the project relies on the hatch-vcs/hatch hook to regenerate
the file at build/time; reference the autogenerated indicator in
src/mqt/predictor/_version.py and the hatch-vcs hook to verify regeneration
works after removal.
In `@src/mqt/predictor/ml/predictor.py`:
- Around line 749-752: The GridSearchCV call can fail when a class has only one
sample because StratifiedKFold requires n_splits between 2 and the minimum class
count; before creating GridSearchCV in predictor.py, check min_class and set the
CV strategy accordingly: if min_class >= 2 compute num_cv = max(2,
min(original_num_cv, min_class)) and use that (allowing StratifiedKFold via
GridSearchCV), but if min_class < 2 then do not rely on StratifiedKFold — either
skip GridSearchCV and call mdl.fit(...) directly or instantiate GridSearchCV
with a non-stratified splitter (e.g., KFold) and appropriate n_splits (>=2 and
<=len(y_train)); update the code around variables num_cv, min_class,
GridSearchCV, and mdl to implement this guard and fallback.
In `@tests/device_selection/test_predictor_ml.py`:
- Around line 42-43: The function signature that currently accepts
path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool,
verbose: bool should make the boolean flags keyword-only to avoid positional
ambiguity; update the signature by inserting a positional-only separator (e.g.
add ", *," before gnn) so gnn and verbose must be passed by keyword (and add
sensible defaults like gnn: bool = False, verbose: bool = False if appropriate),
and apply the same change to the other occurrence referenced in the file.
In `@tests/hellinger_distance/test_estimated_hellinger_distance.py`:
- Line 165: The test function signature test_train_model_and_predict currently
allows the boolean flag verbose to be passed positionally; update the signature
to make boolean flags keyword-only (e.g., def
test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool)
-> None) to satisfy Ruff FBT001 and avoid positional ambiguity; apply the same
change to the other test function(s) with boolean flags reported in the review
(the ones around the later tests) so all boolean parameters are keyword-only.
pyproject.toml
Outdated
| "optuna>=4.0.0", | ||
| "torch-geometric>=2.7.0,<2.8.0", | ||
| "torch>=2.7.0,<2.8.0; sys_platform == 'darwin' and 'x86_64' in platform_machine and python_version < '3.13'", # Restrict torch v2.3.0 for macOS x86 since it is not supported anymore. | ||
| "safetensors>=0.3.0", |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify pinned versions exist on PyPI
echo "=== Checking torch 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
echo "✓ torch 2.7.0 found on PyPI"
else
echo "✗ torch 2.7.0 NOT found on PyPI"
fi
echo ""
echo "=== Checking torch-geometric 2.7.0 on PyPI ==="
if curl -s https://pypi.org/pypi/torch-geometric/json 2>/dev/null | jq -r '.releases | keys[]' | grep -q '^2\.7\.0$'; then
echo "✓ torch-geometric 2.7.0 found on PyPI"
else
echo "✗ torch-geometric 2.7.0 NOT found on PyPI"
fi
echo ""
echo "=== torch-geometric dependencies on torch ==="
curl -s https://pypi.org/pypi/torch-geometric/2.7.0/json 2>/dev/null | jq -r '.info.requires_dist[]' 2>/dev/null | grep -i 'torch' || echo "No torch dependency found in metadata"Repository: munich-quantum-toolkit/predictor
Length of output: 493
🏁 Script executed:
# Search for GNN imports and usage
echo "=== Searching for GNN-related imports ==="
find . -type f \( -name "*.py" \) | head -20
echo ""
echo "=== Searching for torch imports ==="
rg "import torch" --type py -A 2 -B 2 | head -100
echo ""
echo "=== Checking for conditional torch imports ==="
rg "try.*import torch|except.*ImportError" --type py -B 2 -A 2Repository: munich-quantum-toolkit/predictor
Length of output: 2922
🏁 Script executed:
cat pyproject.toml | head -100Repository: munich-quantum-toolkit/predictor
Length of output: 3350
🏁 Script executed:
# Check the full dependencies section
rg "^\[project\]" -A 100 pyproject.toml | head -150Repository: munich-quantum-toolkit/predictor
Length of output: 3277
🏁 Script executed:
# Look for all torch-related entries in pyproject.toml
rg "torch" pyproject.tomlRepository: munich-quantum-toolkit/predictor
Length of output: 827
Torch dependency is platform-conditional but imported unconditionally—will break on Linux, Windows, macOS ARM, and Python 3.13+.
torch is restricted to macOS x86 with Python < 3.13, but src/mqt/predictor/ml/gnn.py, predictor.py, and helper.py all import it unconditionally. Since torch-geometric doesn't declare torch as a dependency, installations on Linux, Windows, macOS ARM, or Python 3.13+ will fail with ImportError. Add an unconditional torch requirement or implement import guards with optional extras.
🤖 Prompt for AI Agents
In `@pyproject.toml` around lines 47 - 50, The pyproject declares torch only for
macOS x86 Python<3.13 but the codebase (files src/mqt/predictor/ml/gnn.py,
src/mqt/predictor/predictor.py, src/mqt/predictor/helper.py) imports torch
unconditionally; fix by either adding an unconditional torch dependency in
pyproject.toml (remove the platform marker so "torch>=2.7.0,<2.8.0" is installed
everywhere) or make torch optional by adding an extras_require (e.g. "torch":
[...]) and wrapping all torch imports in import guards/try-except with clear
fallback errors in the referenced modules (gnn.py, predictor.py, helper.py) to
raise a helpful message if torch is missing; update the package metadata and the
three modules consistently so imports match the dependency strategy.
| # file generated by setuptools-scm | ||
| # don't change, don't track in version control | ||
|
|
There was a problem hiding this comment.
Autogenerated _version.py should not be tracked.
The header indicates this file is generated by setuptools-scm; committing it risks stale versions and merge conflicts. Please remove it from VCS and add it to .gitignore, relying on the hatch-vcs hook to regenerate it. Based on learnings, this file should be excluded from version control.
🤖 Prompt for AI Agents
In `@src/mqt/predictor/_version.py` around lines 9 - 11, Remove the autogenerated
file src/mqt/predictor/_version.py from version control (stop tracking it and
delete from the repo tree), add an entry for src/mqt/predictor/_version.py to
.gitignore so it isn't committed again, and ensure the project relies on the
hatch-vcs/hatch hook to regenerate the file at build/time; reference the
autogenerated indicator in src/mqt/predictor/_version.py and the hatch-vcs hook
to verify regeneration works after removal.
src/mqt/predictor/ml/predictor.py
Outdated
| num_cv = min(len(training_data.y_train), 5) | ||
| min_class = min(Counter(training_data.y_train).values()) | ||
| num_cv = max(2, min(num_cv, min_class)) | ||
| mdl = GridSearchCV(mdl, tree_param, cv=num_cv, n_jobs=8).fit(training_data.X_train, training_data.y_train) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cd /tmp && git -C /root/repo ls-files src/mqt/predictor/ml/predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 141
🏁 Script executed:
# Get context around lines 749-752
head -n 760 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 30 | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 179
🏁 Script executed:
# Check the broader function context
head -n 800 /root/repo/src/mqt/predictor/ml/predictor.py | tail -n 100 | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 179
🏁 Script executed:
git ls-files | grep predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 170
🏁 Script executed:
# Also check the current working directory
pwd && ls -laRepository: munich-quantum-toolkit/predictor
Length of output: 1499
🏁 Script executed:
# List the repository structure
git ls-files | head -20Repository: munich-quantum-toolkit/predictor
Length of output: 570
🏁 Script executed:
# Read the file around lines 749-752
sed -n '740,760p' src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 1200
🏁 Script executed:
# Get more context - find the function this code belongs to
sed -n '700,760p' src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 3255
🏁 Script executed:
# Check the imports at the top of the file
head -n 50 src/mqt/predictor/ml/predictor.py | cat -nRepository: munich-quantum-toolkit/predictor
Length of output: 1886
🏁 Script executed:
# Check if there are any tests for this function
find tests -name "*.py" -type f | xargs grep -l "train_random_forest_model" | head -5Repository: munich-quantum-toolkit/predictor
Length of output: 137
🏁 Script executed:
# Search for any existing handling of small class sizes
grep -n "min_class\|StratifiedKFold\|cross_val" src/mqt/predictor/ml/predictor.pyRepository: munich-quantum-toolkit/predictor
Length of output: 196
🏁 Script executed:
# Check the test file to understand the expected behavior
cat -n tests/hellinger_distance/test_estimated_hellinger_distance.py | head -60Repository: munich-quantum-toolkit/predictor
Length of output: 2185
🏁 Script executed:
# Look for any other CV-related code or tests
find tests -name "*.py" -type f | xargs grep -l "GridSearchCV\|cv=" | head -5Repository: munich-quantum-toolkit/predictor
Length of output: 58
🏁 Script executed:
# Check if there's any documentation or handling of edge cases in the codebase
grep -rn "min_class\|class.*<.*2\|samples.*per.*class" tests/ src/ | head -10Repository: munich-quantum-toolkit/predictor
Length of output: 264
🏁 Script executed:
# Check scikit-learn version to understand GridSearchCV behavior
grep -i "scikit-learn\|sklearn" src/mqt/predictor/ml/predictor.py pyproject.tomlRepository: munich-quantum-toolkit/predictor
Length of output: 658
🏁 Script executed:
# Let's verify the GridSearchCV behavior with a simple test
python3 << 'EOF'
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
import inspect
# Check the default cv behavior
print("GridSearchCV default cv parameter:")
print(GridSearchCV.__init__.__doc__[:500])
print("\n---\n")
# Check if StratifiedKFold is used for classifiers
from sklearn.model_selection import _split
print("Let's check what happens with cv=2 and 1 sample per class:")
# Test with minimal data
import numpy as np
from collections import Counter
# Create a minimal problematic case: 2 samples, 1 sample per class
X = np.array([[1], [2]])
y = np.array([0, 1])
print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")
try:
# Try with cv=2
mdl = RandomForestClassifier(random_state=0)
gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
gs.fit(X, y)
print("cv=2 with 1 sample per class: SUCCESS")
except Exception as e:
print(f"cv=2 with 1 sample per class: ERROR - {type(e).__name__}: {e}")
print("\n---\n")
# Now test with min_class < 2
X = np.array([[1], [2], [3]])
y = np.array([0, 1, 0]) # min_class = 1
print(f"Data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {dict(Counter(y))}")
try:
mdl = RandomForestClassifier(random_state=0)
gs = GridSearchCV(mdl, {"n_estimators": [10]}, cv=2)
gs.fit(X, y)
print("cv=2 with min_class=1: SUCCESS")
except Exception as e:
print(f"cv=2 with min_class=1: ERROR - {type(e).__name__}: {e}")
EOFRepository: munich-quantum-toolkit/predictor
Length of output: 162
🌐 Web query:
GridSearchCV StratifiedKFold cv parameter classifier sklearn minimum samples per class error
💡 Result:
Cause: StratifiedKFold (used by GridSearchCV for classifiers when cv is an int) requires each class to appear in every fold, so n_splits cannot exceed the number of samples in the least-populated class. If a class has fewer samples than n_splits you get the warning/error. [1][2]
Fixes (pick one):
- Reduce n_splits to <= min_count = y.value_counts().min().
Example: n_splits = min(5, y.value_counts().min()). - Use non‑stratified folds: cv=KFold(n_splits=5).
- Use StratifiedShuffleSplit with an appropriate test_size (keeps class proportions but doesn’t require every fold to contain every class).
- Merge rare classes, oversample (SMOTE/upsampling) or remove extremely rare classes so min_count >= n_splits.
- Use GroupKFold if you have grouping that must be preserved.
Quick code examples:
- set n_splits dynamically:
from sklearn.model_selection import GridSearchCV, StratifiedKFold
n = min(5, y.value_counts().min())
cv = StratifiedKFold(n_splits=n, shuffle=True, random_state=0)
gs = GridSearchCV(estimator, param_grid, cv=cv) - use KFold:
from sklearn.model_selection import KFold
gs = GridSearchCV(estimator, param_grid, cv=KFold(5, shuffle=True))
References: scikit-learn defaults and the warning explanation; community Q&A with solutions. [1][2][3]
Sources:
[1] Stack Overflow — explanation that GridSearchCV uses StratifiedKFold for classifiers and requirement that each class appear in every split. (stackoverflow.com)
[2] Explanation of the error cause (n_splits > samples in a class). (m.php.cn)
[3] scikit-learn discussion/issue about behavior and constraints of StratifiedKFold. (github.com)
Citations:
- 1: https://stackoverflow.com/questions/74445334/userwarning-the-least-populated-class-in-y-has-only-1-members-which-is-less-th?utm_source=openai
- 2: https://m.php.cn/en/faq/1796910302.html?utm_source=openai
- 3: [Feature-Request] Add a flag to StratifiedKFold to force classes with only 1 sample in training scikit-learn/scikit-learn#10767
🏁 Script executed:
# Search for sklearn documentation or comments about cv handling in the codebase
grep -rn "StratifiedKFold\|cv=" src/mqt/predictor/ml/ | head -15Repository: munich-quantum-toolkit/predictor
Length of output: 228
🏁 Script executed:
# Check if there are any related issues or TODOs
grep -rn "TODO\|FIXME\|BUG\|HACK" src/mqt/predictor/ml/predictor.py | grep -i "cv\|class\|fold"Repository: munich-quantum-toolkit/predictor
Length of output: 58
GridSearchCV with StratifiedKFold will fail when min_class < 2.
When min_class equals 1, the code forces num_cv = max(2, min(num_cv, 1)) = 2, which violates scikit-learn's StratifiedKFold requirement that n_splits ≤ min(class_counts). GridSearchCV uses StratifiedKFold by default for classifiers, so this will raise an error at fit time rather than being caught during setup.
Add an explicit guard before GridSearchCV instantiation:
Suggested fix for small class counts
num_cv = min(len(training_data.y_train), 5)
min_class = min(Counter(training_data.y_train).values())
+if min_class < 2:
+ msg = "Not enough samples per class for cross-validation."
+ raise ValueError(msg)
-num_cv = max(2, min(num_cv, min_class))
+num_cv = min(num_cv, min_class)🤖 Prompt for AI Agents
In `@src/mqt/predictor/ml/predictor.py` around lines 749 - 752, The GridSearchCV
call can fail when a class has only one sample because StratifiedKFold requires
n_splits between 2 and the minimum class count; before creating GridSearchCV in
predictor.py, check min_class and set the CV strategy accordingly: if min_class
>= 2 compute num_cv = max(2, min(original_num_cv, min_class)) and use that
(allowing StratifiedKFold via GridSearchCV), but if min_class < 2 then do not
rely on StratifiedKFold — either skip GridSearchCV and call mdl.fit(...)
directly or instantiate GridSearchCV with a non-stratified splitter (e.g.,
KFold) and appropriate n_splits (>=2 and <=len(y_train)); update the code around
variables num_cv, min_class, GridSearchCV, and mdl to implement this guard and
fallback.
| path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool | ||
| ) -> None: |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).
♻️ Suggested signature tweak
-def test_setup_device_predictor_with_prediction(
- path_uncompiled_circuits: Path, path_compiled_circuits: Path, gnn: bool, verbose: bool
-) -> None:
+def test_setup_device_predictor_with_prediction(
+ path_uncompiled_circuits: Path, path_compiled_circuits: Path, *, gnn: bool, verbose: bool
+) -> None:
@@
-def test_get_prepared_training_data_false_input(gnn: bool) -> None:
+def test_get_prepared_training_data_false_input(*, gnn: bool) -> None:Also applies to: 162-162
🧰 Tools
🪛 Ruff (0.14.13)
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
🤖 Prompt for AI Agents
In `@tests/device_selection/test_predictor_ml.py` around lines 42 - 43, The
function signature that currently accepts path_uncompiled_circuits: Path,
path_compiled_circuits: Path, gnn: bool, verbose: bool should make the boolean
flags keyword-only to avoid positional ambiguity; update the signature by
inserting a positional-only separator (e.g. add ", *," before gnn) so gnn and
verbose must be passed by keyword (and add sensible defaults like gnn: bool =
False, verbose: bool = False if appropriate), and apply the same change to the
other occurrence referenced in the file.
| @pytest.mark.parametrize( | ||
| ("model_type", "verbose"), [("rf", False), ("gnn", False), ("gnn", True)], ids=["rf", "gnn", "gnn_verbose"] | ||
| ) | ||
| def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None: |
There was a problem hiding this comment.
🧹 Nitpick | 🔵 Trivial
Make boolean flags keyword-only to avoid positional ambiguity (Ruff FBT001).
♻️ Suggested signature tweak
-def test_train_model_and_predict(device: Target, model_type: str, verbose: bool) -> None:
+def test_train_model_and_predict(device: Target, model_type: str, *, verbose: bool) -> None:
@@
-def test_train_and_qcompile_with_hellinger_model(
- source_path: Path, target_path: Path, device: Target, model_type: str, verbose: bool
-) -> None:
+def test_train_and_qcompile_with_hellinger_model(
+ source_path: Path, target_path: Path, device: Target, model_type: str, *, verbose: bool
+) -> None:Also applies to: 232-233
🧰 Tools
🪛 Ruff (0.14.13)
165-165: Boolean-typed positional argument in function definition
(FBT001)
🤖 Prompt for AI Agents
In `@tests/hellinger_distance/test_estimated_hellinger_distance.py` at line 165,
The test function signature test_train_model_and_predict currently allows the
boolean flag verbose to be passed positionally; update the signature to make
boolean flags keyword-only (e.g., def test_train_model_and_predict(device:
Target, model_type: str, *, verbose: bool) -> None) to satisfy Ruff FBT001 and
avoid positional ambiguity; apply the same change to the other test function(s)
with boolean flags reported in the review (the ones around the later tests) so
all boolean parameters are keyword-only.
…and test all done
fe12d75 to
d239b2d
Compare
🧪 CI InsightsHere's what we observed from your CI run for de186b3. 🟢 All jobs passed!But CI Insights is watching 👀 |
There was a problem hiding this comment.
Pull request overview
This PR introduces a GNN-based alternative to the existing RandomForest-based ML device selection pipeline by adding circuit-to-graph (DAG) feature extraction, GNN training/prediction support, and associated test coverage. It also adjusts RL compilation conversion to emit explicit SWAPs during TKET→Qiskit conversion.
Changes:
- Add DAG-based graph feature extraction and a SAGEConv-based GNN model for device selection and Hellinger distance estimation.
- Extend the ML predictor API to support
gnn=Truefor training/prediction, persisting graph datasets viasafetensors. - Expand tests to cover both RF and GNN paths and add cleanup logic for generated artifacts.
Reviewed changes
Copilot reviewed 12 out of 13 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/hellinger_distance/test_estimated_hellinger_distance.py | Adds RF/GNN parametrized tests, graph-data setup, and cleanup fixture. |
| tests/device_selection/test_predictor_ml.py | Adds GNN coverage for setup/training/prediction and cleans graph artifacts. |
| tests/device_selection/test_helper_ml.py | Adds tests for new DAG creation helper. |
| tests/compilation/test_predictor_rl.py | Adds parametrized RL compilation test for multiple devices. |
| src/mqt/predictor/rl/predictorenv.py | Enables explicit SWAP insertion when converting TKET circuits to Qiskit. |
| src/mqt/predictor/ml/predictor.py | Implements GNN training (Optuna + k-fold), safetensors dataset I/O, and GNN prediction path. |
| src/mqt/predictor/ml/helper.py | Adds DAG graph construction, GNN training/eval utilities, and new model-path helper. |
| src/mqt/predictor/ml/gnn.py | New GNN model implementation (GraphSAGE + optional SAGPooling + MLP head). |
| src/mqt/predictor/ml/init.py | Reformats __all__ export list. |
| src/mqt/predictor/hellinger/utils.py | Extends Hellinger model path helper to support GNN checkpoints. |
| src/mqt/predictor/_version.py | Adds generated version file wired to build configuration. |
| pyproject.toml | Adds torch/torch-geometric/optuna/safetensors dependencies and warning filters; mypy overrides for new deps. |
Comments suppressed due to low confidence (1)
src/mqt/predictor/ml/predictor.py:389
_generate_training_sample()declaresfile: Path, butgenerate_training_data()passesfilename.name(astr). This mismatch makes the type hints misleading and can cause confusion for callers.
Either change the parameter type to str (or str | Path) and update docstrings accordingly, or pass the full Path object into _generate_training_sample().
path_uncompiled_circuits,
path_compiled_circuits,
logger.level,
)
for filename in path_uncompiled_circuits.glob("*.qasm")
)
for sample in results:
training_sample, circuit_name, scores = sample
if all(score == -1 for score in scores.values()):
continue
if self.gnn:
x, _y, edge_idx, n_nodes, target_label = training_sample
value_device = [scores.get(dev.description, -1.0) for dev in self.devices]
gnn_training_sample = Data(
x=x,
# unsqueeze to avoid concatenation issues later on
y=torch.tensor(value_device, dtype=torch.float32).unsqueeze(0),
edge_index=edge_idx,
num_nodes=n_nodes,
target_label=target_label,
)
training_data.append(gnn_training_sample if self.gnn else training_sample)
names_list.append(circuit_name)
scores_list.append(scores)
with resources.as_file(path_training_data) as path:
if self.gnn:
dataset_dir = path / f"graph_dataset_{self.figure_of_merit}"
dataset_dir.mkdir(parents=True, exist_ok=True)
for idx, data in enumerate(training_data):
# data is a torch_geometric.data.Data object
tensors = {
"x": data.x, # node features
"y": data.y, # target values per device
"edge_index": data.edge_index,
"num_nodes": torch.tensor([data.num_nodes], dtype=torch.int64),
}
save_file(tensors, str(dataset_dir / f"{idx}.safetensors"))
# target_label is a string; save it separately
label_path = dataset_dir / f"{idx}.label"
label_path.write_text(str(data.target_label), encoding="utf-8")
else:
data = np.asarray(training_data, dtype=object)
np.save(str(path / ("training_data_" + self.figure_of_merit + ".npy")), data)
data = np.asarray(names_list, dtype=str)
np.save(str(path / ("names_list_" + self.figure_of_merit + ".npy")), data)
data = np.asarray(scores_list, dtype=object)
np.save(str(path / ("scores_list_" + self.figure_of_merit + ".npy")), data)
def _generate_training_sample(
self,
file: Path,
path_uncompiled_circuit: Path,
path_compiled_circuits: Path,
logger_level: int = logging.INFO,
) -> tuple[TrainingSample, str, dict[str, float]]:
"""Handles to create a training sample from a given file.
Arguments:
file: The name of the file to be used for training.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| from qiskit.transpiler.passes import RemoveBarriers | ||
| from sklearn.metrics import accuracy_score, classification_report, mean_absolute_error, mean_squared_error, r2_score | ||
|
|
||
| from mqt.predictor.utils import calc_supermarq_features, get_openqasm_gates |
| from joblib import dump as joblib_dump | ||
| from torch import nn | ||
| from torch_geometric.loader import DataLoader | ||
| from typing_extensions import Unpack | ||
|
|
||
| from mqt.predictor.ml.gnn import GNN | ||
|
|
||
| if sys.version_info >= (3, 11) and TYPE_CHECKING: # pragma: no cover | ||
| pass | ||
|
|
||
| import gc | ||
| import json | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import optuna | ||
| import torch | ||
| from joblib import Parallel, delayed, load | ||
| from joblib import dump as joblib_dump | ||
| from mqt.bench.targets import get_device | ||
| from optuna.samplers import TPESampler | ||
|
|
||
| # cspell:disable-next-line | ||
| from qiskit import QuantumCircuit |
| # 3. Model Training | ||
| pred = ml_Predictor(figure_of_merit="hellinger_distance", devices=[device], gnn=gnn) | ||
| if gnn: | ||
| trained_model = pred.train_gnn_model(training_data, num_epochs=200, patience=30, verbose=verbose) |
| # file generated by setuptools-scm | ||
| # don't change, don't track in version control |
| @pytest.mark.parametrize("device_name", ["ibm_falcon_127", "quantinuum_h2_56"]) | ||
| def test_qcompile_with_newly_trained_models(device_name: str) -> None: | ||
| """Test the qcompile function with a newly trained model. | ||
|
|
||
| Important: Those trained models are used in later tests and must not be deleted. | ||
| To test ESP as well, training must be done with a device that provides all relevant information (i.e. T1, T2 and gate times). | ||
| """ | ||
| figure_of_merit = "expected_fidelity" |
| labels_list = np.full(n_circuits, distance_label) | ||
| training_data = TrainingData(X_train=feature_vector_list, y_train=labels_list) | ||
| labels_list = [distance_label] * n_circuits | ||
| else: | ||
| training_data_list = [] | ||
| for i in range(n_circuits): | ||
| x, edge_idx, n_nodes = training_sample[i] | ||
| gnn_training_sample = Data( | ||
| x=x, | ||
| y=torch.tensor(labels_list[i], dtype=torch.float32), | ||
| edge_index=edge_idx, | ||
| num_nodes=n_nodes, | ||
| ) | ||
| training_data_list.append(gnn_training_sample) | ||
| training_data = TrainingData(X_train=training_data_list, y_train=labels) | ||
|
|
||
| def test_train_and_qcompile_with_hellinger_model(source_path: Path, target_path: Path, device: Target) -> None: | ||
| """Test the entire predictor toolchain with the Hellinger distance model that was trained in the previous test.""" | ||
| # 3. Model Training |
| y_train: NDArray[np.float64] | torch.Tensor | ||
| X_test: NDArray[np.float64] | list[torch_geometric.data.Data] | None = None | ||
| y_test: NDArray[np.float64] | torch.Tensor | None = None |
| sampler_obj = TPESampler(n_startup_trials=10) | ||
| study = optuna.create_study(study_name="Best GNN Model", direction="minimize", sampler=sampler_obj) | ||
| k_folds = min(len(training_data.y_train), 5) | ||
|
|
||
| def _obj(trial: optuna.Trial) -> float: | ||
| return self.objective( | ||
| trial=trial, |
Description
Please include a summary of the change and, if applicable, which issue is fixed.
Please also include relevant motivation and context.
List any dependencies that are required for this change.
Fixes #(issue)
Checklist: