Skip to content

Commit 0b44356

Browse files
committed
exercise update.
1 parent 44a7edd commit 0b44356

11 files changed

+9
-37
lines changed

README.md

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,9 @@ Open the `src/input_opt.py` file. The network `./data/weights`.pkl` contains net
77
\max_\mathbf{x} y_i = f(\mathbf{x}, \theta) .
88
```
99

10-
Use `jax.value_and_grad` to find the gradients of the network input $\mathbf{x}$.
11-
Start with a `jax.random.uniform` network input of shape `[1, 28, 28, 1]` and
12-
iteratively optimize it. Execute your script with `python src/input_opt.py`.
13-
14-
Solution:
15-
16-
![cnn ig MNIST 0](./figures/x_opt.png)
10+
Use `torch.func.grad` to find the gradients of the network input $\mathbf{x}$.
11+
Start with a network input of shape `[1, 1, 28, 28]`. Compare a random initialization
12+
to starting from an array filled with ones, iteratively optimize it. Execute your script with `python src/input_opt.py`.
1713

1814
# Task 2 Integrated Gradients (Optional):
1915

@@ -30,9 +26,6 @@ Finally, m denotes the number of summation steps from the black baseline image t
3026

3127
Follow the todos in `./src/mnist_integrated.py` and then run `scripts/integrated_gradients.slurm`.
3228

33-
Solution:
34-
35-
![cnn ig MNIST 0](./figures/IG_MNIST_0.png)
3629

3730

3831
# Task 3 Deepfake detection (Optional):
@@ -72,19 +65,11 @@ Compute log-scaled frequency domain representations of samples from both sources
7265

7366
Above `h`, `w` and `c` denote image height, width and columns. `Log` denotes the natural logarithm, and bars denote the absolute value. A small epsilon is added for numerical stability.
7467

75-
Use the numpy functions `jnp.log`, `jnp.abs`, `jnp.fft.fft2`. By default, `fft2` transforms the last two axes. The last axis contains the color channels in this case. We are looking to transform the rows and columns.
68+
Use the numpy functions `np.log`, `np.abs`, `np.fft.fft2`. By default, `fft2` transforms the last two axes. The last axis contains the color channels in this case. We are looking to transform the rows and columns.
7669

7770
Plot mean spectra for real and fake images as well as their difference over the entire validation or test sets. For that complete the TODOs in `src/deepfake_interpretation.py` and run the script `scripts/train.slurm`.
7871

79-
Solution:
80-
81-
![spectral analysis 2](./figures/log_mean_fft_2_diff.png)
82-
83-
![spectral analysis](./figures/log_mean_fft_2_1d.png)
8472

8573
## 3.3 Training and interpreting a linear classifier
86-
Train a linear classifier consisting of a single `nn.Dense`-layer on the log-scaled Fourier coefficients using Flax. Plot the result. What do you see?
87-
88-
Solution:
74+
Train a linear classifier consisting of a single `nn.Linear`-layer on the log-scaled Fourier coefficients using Torch. Plot the result. What do you see?
8975

90-
![linear weights](./figures/classifier_comparison.png)

figures/IG_MNIST_0.png

-9.51 KB
Binary file not shown.

figures/classifier_comparison.png

-410 KB
Binary file not shown.

figures/fake_ig_cnn_1.png

-82.4 KB
Binary file not shown.

figures/log_mean_fft_2_1d.png

-21.5 KB
Binary file not shown.

figures/log_mean_fft_2_diff.png

-139 KB
Binary file not shown.
-63.5 KB
Binary file not shown.

figures/real_ig_cnn_0.png

-75.5 KB
Binary file not shown.

figures/x_opt.png

-14.3 KB
Binary file not shown.

src/deepfake_interpretation.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def eval_step(net, loss, img, labels):
8787

8888
def transform(image_data):
8989
"""Transform image data."""
90-
return np.log(np.abs(np.fft.fft2(image_data, axes=(-3, -2))) + 1e-12)
90+
# TODO: Implement the function given in the readme
91+
return np.zeros_like(image_data)
9192

9293

9394
if __name__ == "__main__":
@@ -248,22 +249,8 @@ def transform(image_data):
248249
plt.colorbar()
249250
plt.savefig("mean_freq_difference.jpg")
250251

251-
weight_ffhq = np.mean(
252-
np.reshape(net.dense.weight[0, :].detach().cpu().numpy(), (3, 128, 128)), 0
253-
)
254-
weight_style = np.mean(
255-
np.reshape(net.dense.weight[1, :].detach().cpu().numpy(), (3, 128, 128)), 0
256-
)
257-
258-
plt.subplot(1, 2, 1)
259-
plt.title("Real classifier weights")
260-
plt.imshow(weight_ffhq, vmin=np.min(weight_ffhq), vmax=np.max(weight_ffhq))
261-
plt.subplot(1, 2, 2)
262-
plt.title("Fake classifier weights")
263-
plt.imshow(weight_style, vmin=np.min(weight_ffhq), vmax=np.max(weight_ffhq))
264-
plt.colorbar()
265-
266-
plt.savefig("classifier_weights.jpg")
252+
# TODO: Visualize the weight array `net.dense.weight`.
253+
# By reshaping and plotting the weight matrix.
267254

268255
if type(net) is CNN:
269256
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)