Skip to content

Commit 53ad7b8

Browse files
committed
fix docs
1 parent 790eee3 commit 53ad7b8

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

examples/pcovr/PCovR-WHODataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
r_pcovr = Ridge(alpha=1e-4, fit_intercept=False, random_state=0).fit(
110110
T_train_pcovr, y_train
111111
)
112-
yp_pcovr = r_pcovr.predict(T_test_pcovr)
112+
yp_pcovr = r_pcovr.predict(T_test_pcovr).reshape(-1,1)
113113

114114
plt.scatter(y_scaler.inverse_transform(y_test), y_scaler.inverse_transform(yp_pcovr))
115115
r_pcovr.score(T_test_pcovr, y_test)
@@ -128,7 +128,7 @@
128128
T_pca = pca.transform(X)
129129

130130
r_pca = Ridge(alpha=1e-4, fit_intercept=False, random_state=0).fit(T_train_pca, y_train)
131-
yp_pca = r_pca.predict(T_test_pca)
131+
yp_pca = r_pca.predict(T_test_pca).reshape(-1,1)
132132

133133
plt.scatter(y_scaler.inverse_transform(y_test), y_scaler.inverse_transform(yp_pca))
134134
r_pca.score(T_test_pca, y_test)
@@ -312,3 +312,5 @@ def add_subplot(ax, axy, T, yp, let=""):
312312
"Linear and Kernel PCovR for Predicting Life Expectancy", y=0.925, fontsize=10
313313
)
314314
plt.show()
315+
316+
# %%

examples/pcovr/PCovR_Regressors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,12 @@
2424
mixing = 0.5
2525

2626
X, y = load_diabetes(return_X_y=True)
27-
y = y.reshape(X.shape[0], -1)
2827

2928
X_scaler = StandardScaler()
3029
X_scaled = X_scaler.fit_transform(X)
3130

3231
y_scaler = StandardScaler()
33-
y_scaled = y_scaler.fit_transform(y)
32+
y_scaled = y_scaler.fit_transform(y.reshape(-1, 1)).ravel()
3433

3534

3635
# %%

examples/pcovr/PCovR_Scaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
ax1_Y.set_title("Regression\nWithout Scaling")
7979

8080
ax2_Y.scatter(
81-
Yp_scaled, y, c=np.abs(y.flatten() - Yp_scaled.flatten()), cmap="bone_r", ec="k"
81+
Yp_scaled, y, c=np.abs(y.ravel() - Yp_scaled.ravel()), cmap="bone_r", ec="k"
8282
)
8383
ax2_Y.plot(ax2_Y.get_xlim(), ax2_Y.get_xlim(), "r--")
8484
ax2_Y.set_xlabel("True Y, unscaled")

0 commit comments

Comments
 (0)