You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+5-20Lines changed: 5 additions & 20 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -7,13 +7,9 @@ Open the `src/input_opt.py` file. The network `./data/weights`.pkl` contains net
7
7
\max_\mathbf{x} y_i = f(\mathbf{x}, \theta) .
8
8
```
9
9
10
-
Use `jax.value_and_grad` to find the gradients of the network input $\mathbf{x}$.
11
-
Start with a `jax.random.uniform` network input of shape `[1, 28, 28, 1]` and
12
-
iteratively optimize it. Execute your script with `python src/input_opt.py`.
13
-
14
-
Solution:
15
-
16
-

10
+
Use `torch.func.grad` to find the gradients of the network input $\mathbf{x}$.
11
+
Start with a network input of shape `[1, 1, 28, 28]`. Compare a random initialization
12
+
to starting from an array filled with ones, iteratively optimize it. Execute your script with `python src/input_opt.py`.
17
13
18
14
# Task 2 Integrated Gradients (Optional):
19
15
@@ -30,9 +26,6 @@ Finally, m denotes the number of summation steps from the black baseline image t
30
26
31
27
Follow the todos in `./src/mnist_integrated.py` and then run `scripts/integrated_gradients.slurm`.
32
28
33
-
Solution:
34
-
35
-

36
29
37
30
38
31
# Task 3 Deepfake detection (Optional):
@@ -72,19 +65,11 @@ Compute log-scaled frequency domain representations of samples from both sources
72
65
73
66
Above `h`, `w` and `c` denote image height, width and columns. `Log` denotes the natural logarithm, and bars denote the absolute value. A small epsilon is added for numerical stability.
74
67
75
-
Use the numpy functions `jnp.log`, `jnp.abs`, `jnp.fft.fft2`. By default, `fft2` transforms the last two axes. The last axis contains the color channels in this case. We are looking to transform the rows and columns.
68
+
Use the numpy functions `np.log`, `np.abs`, `np.fft.fft2`. By default, `fft2` transforms the last two axes. The last axis contains the color channels in this case. We are looking to transform the rows and columns.
76
69
77
70
Plot mean spectra for real and fake images as well as their difference over the entire validation or test sets. For that complete the TODOs in `src/deepfake_interpretation.py` and run the script `scripts/train.slurm`.
## 3.3 Training and interpreting a linear classifier
86
-
Train a linear classifier consisting of a single `nn.Dense`-layer on the log-scaled Fourier coefficients using Flax. Plot the result. What do you see?
87
-
88
-
Solution:
74
+
Train a linear classifier consisting of a single `nn.Linear`-layer on the log-scaled Fourier coefficients using Torch. Plot the result. What do you see?
0 commit comments