Skip to content

Plots

Code for creating plots for the thesis.

Visualizing experiment data downloaded from Weights & Biases API.

PlotGroups

Source code in src/plots/plot_groups.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class PlotGroups:
    def __init__(self, sweep_data: SweepData, output_dir: Path) -> None:
        self.sweep_data = sweep_data
        self.output_dir = output_dir
        self.plot_maker = PlotMaker()
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def make_output_subdir(self, subdir_name: str) -> Path:
        subdir = self.output_dir / subdir_name
        subdir.mkdir(parents=True, exist_ok=True)
        return subdir

    def learning_rate_v_metrics(self) -> None:
        """Create plots showing the effect of learning rate on various metrics."""
        plot_args = [
            ("lr-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
            (
                "lr-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "lr-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "lr-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "lr-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_lr_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("lr"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def margin_v_metrics(self) -> None:
        """Create plots showing the effect of margin on various metrics."""
        plot_args = [
            ("margin-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
            (
                "margin-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "margin-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "margin-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "margin-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_margin_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("margin"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
            )

    def augmentation_v_metrics(self) -> None:
        """Create plots showing the effect of data augmentation on various metrics."""
        plot_args = [
            (
                "augmentation-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "augmentation-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "augmentation-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "augmentation-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "augmentation-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_augmentation_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("augmentation"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def only_head_v_metrics(self) -> None:
        """Create plots showing the effect of only_head on various metrics and runtime."""
        plot_args = [
            (
                "only-head-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "only-head-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "only-head-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "only-head-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "only-head-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
            ("only-head-vs-runtime.png", "runtime", "Czas obliczeń [s]"),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_only_head_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("only-head"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def batch_size_v_metrics(self) -> None:
        """Create plots showing the effect of batch size on various metrics."""
        plot_args = [
            (
                "mini-batch-size-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "mini-batch-size-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "mini-batch-size-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "mini-batch-size-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "mini-batch-size-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_mini_batch_size_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("mini-batch-size"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def parameter_importance_and_correlation(
        self, experiment_data_dict: dict[str, Any]
    ) -> None:
        """Create plots showing parameter importance and correlation with respect to metrics for selected metrics."""
        plot_args = [
            ("parameter-importance-lfw.png", "lfw", "Dokładność LFW"),
            ("parameter-importance-rof-m.png", "rof-m", "Dokładność ROF-m"),
            ("parameter-importance-rof-s.png", "rof-s", "Dokładność ROF-s"),
            (
                "parameter-importance-val-accuracy.png",
                "val_accuracy",
                "Dokładność walidacyjna",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_parameter_analysis_plot(
                importance_dict=experiment_data_dict[metric_key]["importance"],
                correlation_dict=experiment_data_dict[metric_key]["correlation"],
                output_dir=self.make_output_subdir("parameter-analysis"),
                filename=filename,
                title=metric_name,
            )

    def training_curves(self) -> None:
        """Create training curves for all runs in the sweep."""
        plot_args = [
            (
                "train-loss-over-epochs.png",
                "training/train_loss",
                "Strata treningowy",
            ),
            (
                "val-loss-over-epochs.png",
                "training/val_loss",
                "Strata walidacyjna",
            ),
            (
                "train-accuracy-over-epochs.png",
                "training/train_accuracy",
                "Dokładność treningowa",
            ),
            (
                "val-accuracy-over-epochs.png",
                "training/val_accuracy",
                "Dokładność walidacyjna",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_loss_or_acc_plot(
                run_histories=self.sweep_data.run_histories,
                output_dir=self.make_output_subdir("training"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def augmentation_comparison(self) -> None:
        """Group of plots comparing different augmentations with respect to a metric for all metrics."""
        # Baseline values hardcoded
        plot_args = [
            (
                "aug-comparison-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
                "max",
                0.90,
                1.0,
                0.971,
            ),
            (
                "aug-comparison-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
                "max",
                0.70,
                0.90,
                0.859,
            ),
            (
                "aug-comparison-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
                "max",
                0.70,
                0.90,
                0.872,
            ),
            (
                "aug-comparison-eer.png",
                "summary_rococo-evaluation/eer",
                "EER",
                "min",
                0.0,
                0.5,
                0.102,
            ),
            (
                "aug-comparison-frr-at-far-zero.png",
                "summary_rococo-evaluation/frr_at_far_zero",
                "FRR @ FAR=0",
                "min",
                0.0,
                1.0,
                0.763,
            ),
        ]

        for (
            filename,
            metric_key,
            metric_name,
            aggregate_best,
            xmin,
            xmax,
            base,
        ) in plot_args:
            self.plot_maker.make_aug_comparison_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("augmentation-comparison"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                aggregate_best=aggregate_best,
                xmin=xmin,
                xmax=xmax,
                base=base,
                show=False,
            )

augmentation_comparison()

Group of plots comparing different augmentations with respect to a metric for all metrics.

Source code in src/plots/plot_groups.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def augmentation_comparison(self) -> None:
    """Group of plots comparing different augmentations with respect to a metric for all metrics."""
    # Baseline values hardcoded
    plot_args = [
        (
            "aug-comparison-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
            "max",
            0.90,
            1.0,
            0.971,
        ),
        (
            "aug-comparison-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
            "max",
            0.70,
            0.90,
            0.859,
        ),
        (
            "aug-comparison-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
            "max",
            0.70,
            0.90,
            0.872,
        ),
        (
            "aug-comparison-eer.png",
            "summary_rococo-evaluation/eer",
            "EER",
            "min",
            0.0,
            0.5,
            0.102,
        ),
        (
            "aug-comparison-frr-at-far-zero.png",
            "summary_rococo-evaluation/frr_at_far_zero",
            "FRR @ FAR=0",
            "min",
            0.0,
            1.0,
            0.763,
        ),
    ]

    for (
        filename,
        metric_key,
        metric_name,
        aggregate_best,
        xmin,
        xmax,
        base,
    ) in plot_args:
        self.plot_maker.make_aug_comparison_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("augmentation-comparison"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            aggregate_best=aggregate_best,
            xmin=xmin,
            xmax=xmax,
            base=base,
            show=False,
        )

augmentation_v_metrics()

Create plots showing the effect of data augmentation on various metrics.

Source code in src/plots/plot_groups.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def augmentation_v_metrics(self) -> None:
    """Create plots showing the effect of data augmentation on various metrics."""
    plot_args = [
        (
            "augmentation-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "augmentation-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "augmentation-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "augmentation-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "augmentation-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_augmentation_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("augmentation"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

batch_size_v_metrics()

Create plots showing the effect of batch size on various metrics.

Source code in src/plots/plot_groups.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def batch_size_v_metrics(self) -> None:
    """Create plots showing the effect of batch size on various metrics."""
    plot_args = [
        (
            "mini-batch-size-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "mini-batch-size-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "mini-batch-size-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "mini-batch-size-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "mini-batch-size-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_mini_batch_size_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("mini-batch-size"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

learning_rate_v_metrics()

Create plots showing the effect of learning rate on various metrics.

Source code in src/plots/plot_groups.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def learning_rate_v_metrics(self) -> None:
    """Create plots showing the effect of learning rate on various metrics."""
    plot_args = [
        ("lr-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
        (
            "lr-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "lr-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "lr-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "lr-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_lr_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("lr"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

margin_v_metrics()

Create plots showing the effect of margin on various metrics.

Source code in src/plots/plot_groups.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def margin_v_metrics(self) -> None:
    """Create plots showing the effect of margin on various metrics."""
    plot_args = [
        ("margin-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
        (
            "margin-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "margin-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "margin-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "margin-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_margin_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("margin"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
        )

only_head_v_metrics()

Create plots showing the effect of only_head on various metrics and runtime.

Source code in src/plots/plot_groups.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def only_head_v_metrics(self) -> None:
    """Create plots showing the effect of only_head on various metrics and runtime."""
    plot_args = [
        (
            "only-head-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "only-head-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "only-head-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "only-head-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "only-head-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
        ("only-head-vs-runtime.png", "runtime", "Czas obliczeń [s]"),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_only_head_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("only-head"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

parameter_importance_and_correlation(experiment_data_dict)

Create plots showing parameter importance and correlation with respect to metrics for selected metrics.

Source code in src/plots/plot_groups.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def parameter_importance_and_correlation(
    self, experiment_data_dict: dict[str, Any]
) -> None:
    """Create plots showing parameter importance and correlation with respect to metrics for selected metrics."""
    plot_args = [
        ("parameter-importance-lfw.png", "lfw", "Dokładność LFW"),
        ("parameter-importance-rof-m.png", "rof-m", "Dokładność ROF-m"),
        ("parameter-importance-rof-s.png", "rof-s", "Dokładność ROF-s"),
        (
            "parameter-importance-val-accuracy.png",
            "val_accuracy",
            "Dokładność walidacyjna",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_parameter_analysis_plot(
            importance_dict=experiment_data_dict[metric_key]["importance"],
            correlation_dict=experiment_data_dict[metric_key]["correlation"],
            output_dir=self.make_output_subdir("parameter-analysis"),
            filename=filename,
            title=metric_name,
        )

training_curves()

Create training curves for all runs in the sweep.

Source code in src/plots/plot_groups.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def training_curves(self) -> None:
    """Create training curves for all runs in the sweep."""
    plot_args = [
        (
            "train-loss-over-epochs.png",
            "training/train_loss",
            "Strata treningowy",
        ),
        (
            "val-loss-over-epochs.png",
            "training/val_loss",
            "Strata walidacyjna",
        ),
        (
            "train-accuracy-over-epochs.png",
            "training/train_accuracy",
            "Dokładność treningowa",
        ),
        (
            "val-accuracy-over-epochs.png",
            "training/val_accuracy",
            "Dokładność walidacyjna",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_loss_or_acc_plot(
            run_histories=self.sweep_data.run_histories,
            output_dir=self.make_output_subdir("training"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

WandbClient

Wrapper for Weights & Biases API client.

Handles fetching and caching of sweep data. Since the API request can be slow and I am worried about rate limits, fetched data is cached in a local JSON file.

Source code in src/plots/wandb_client.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class WandbClient:
    """Wrapper for Weights & Biases API client.

    Handles fetching and caching of sweep data.
    Since the API request can be slow and I am worried about rate limits,
    fetched data is cached in a local JSON file.
    """

    def __init__(self, project_name: str, cache_dir: Path):
        self.api = wandb.Api()
        self.project_name = project_name
        self.cache_dir = cache_dir
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def get_sweep_data(self, sweep_id: str) -> SweepData:
        """Fetch sweep data from cache or API."""

        if self._get_cache_filepath(sweep_id).exists():
            print(f"Loading sweep data from cache: {sweep_id}")
            return self._load_sweep_data_from_json(sweep_id)
        else:
            print(f"Fetching sweep data from API: {sweep_id}")
            sweep_data = self._fetch_sweep_data_from_api(sweep_id)
            self._save_sweep_data_to_json(sweep_data)
            return sweep_data

    def _fetch_sweep_data_from_api(self, sweep_id: str):
        sweep = self._get_sweep_object(sweep_id)
        sweep_runs_data = self._extract_sweep_run_data(sweep)
        sweep_runs_data = sweep_runs_data[sweep_runs_data["state"] == "finished"]
        assert isinstance(sweep_runs_data, DataFrame)
        run_histories = [self._extract_run_history(run) for run in sweep.runs]
        return SweepData(
            sweep_id=sweep_id,
            sweep_runs_data=sweep_runs_data,
            run_histories=run_histories,
        )

    def _get_cache_filepath(self, sweep_id: str) -> Path:
        return self.cache_dir / f"sweep_{sweep_id}_data.json"

    def _save_sweep_data_to_json(self, sweep_data: SweepData):
        with open(self._get_cache_filepath(sweep_data.sweep_id), "w") as f:
            json.dump(
                {
                    "sweep_runs_data": sweep_data.sweep_runs_data.to_dict(
                        orient="records"
                    ),
                    "run_histories": [
                        rh.to_dict(orient="records") for rh in sweep_data.run_histories
                    ],
                },
                f,
                indent=4,
                default=str,
            )

    def _load_sweep_data_from_json(self, sweep_id: str) -> SweepData:
        with open(self._get_cache_filepath(sweep_id), "r") as f:
            data = json.load(f)
            sweep_runs_data = DataFrame(data["sweep_runs_data"])
            run_histories = [DataFrame(rh) for rh in data["run_histories"]]
            return SweepData(
                sweep_id=sweep_id,
                sweep_runs_data=sweep_runs_data,
                run_histories=run_histories,
            )

    def _extract_sweep_run_data(self, sweep) -> DataFrame:
        """Transform Sweep object into a DataFrame.

        Each row corresponds to a single run in the sweep with metrics and config as columns.
        """
        data = []

        for run in sweep.runs:
            row = {
                "run_id": run.id,
                "run_name": run.name,
                "state": run.state,
                "created_at": run.created_at,
                "runtime": run.summary["_runtime"],
            }

            # Config parameters
            for key, value in run.config.items():
                row[f"config_{key}"] = value

            # Summary metrics
            for key, value in run.summary.items():
                if not key.startswith("_"):  # Skip internal wandb fields
                    row[f"summary_{key}"] = value

            # History metric values (final and max)
            history = run.history()
            if not history.empty:
                for col in history.columns:
                    if not col.startswith("_"):
                        row[f"final_{col}"] = (
                            history[col].iloc[-1] if len(history) > 0 else None
                        )

                accuracy_cols = [
                    col for col in history.columns if "accuracy" in col.lower()
                ]
                for col in accuracy_cols:
                    row[f"max_{col}"] = history[col].max()

            data.append(row)

        return DataFrame(data)

    def _extract_run_history(self, run) -> DataFrame:
        """Transform a Run object into a DataFrame of its history.

        History includes metrics logged during training.
        Each row corresponds to a single logging step (epoch).
        """
        history = run.history()
        metrics = [
            "training/train_loss",
            "training/val_loss",
            "training/train_accuracy",
            "training/val_accuracy",
        ]

        # Filter to specific metrics (plus _step and _timestamp)
        available_metrics = [m for m in metrics if m in history.columns]
        cols_to_keep = ["_step", "_timestamp"] + available_metrics
        history = history[cols_to_keep]

        history["run_id"] = run.id
        history["run_name"] = run.name

        return history

    def _get_sweep_object(self, sweep_id: str):
        return self.api.sweep(f"{self.project_name}/{sweep_id}")

get_sweep_data(sweep_id)

Fetch sweep data from cache or API.

Source code in src/plots/wandb_client.py
24
25
26
27
28
29
30
31
32
33
34
def get_sweep_data(self, sweep_id: str) -> SweepData:
    """Fetch sweep data from cache or API."""

    if self._get_cache_filepath(sweep_id).exists():
        print(f"Loading sweep data from cache: {sweep_id}")
        return self._load_sweep_data_from_json(sweep_id)
    else:
        print(f"Fetching sweep data from API: {sweep_id}")
        sweep_data = self._fetch_sweep_data_from_api(sweep_id)
        self._save_sweep_data_to_json(sweep_data)
        return sweep_data

individual_plots

Reusable functions for creating plots from experiment data.

PlotMaker

Class for creating plots with shared configuration and automatic setup/cleanup.

Source code in src/plots/individual_plots.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
class PlotMaker:
    """Class for creating plots with shared configuration and automatic setup/cleanup."""

    def __init__(
        self,
        figsize=(4.5, 3),
        dpi=300,
        alpha=0.7,
        marker_size=60,
        linewidth=0.5,
        fontsize=12,
        grid_alpha=0.3,
    ):
        self.figsize = figsize
        self.dpi = dpi
        self.alpha = alpha
        self.marker_size = marker_size
        self.linewidth = linewidth
        self.fontsize = fontsize
        self.grid_alpha = grid_alpha

    @plot_wrapper
    def make_lr_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
    ):
        """Create learning rate plot."""
        head_only_data = plot_data[plot_data["config_only_head"] == True]
        full_model_data = plot_data[plot_data["config_only_head"] == False]

        plt.scatter(
            head_only_data["config_learning_rate"],
            head_only_data[metric_key],
            alpha=self.alpha,
            s=self.marker_size,
            c="steelblue",
            edgecolors="black",
            linewidth=self.linewidth,
            label="only_head=True",
        )

        plt.scatter(
            full_model_data["config_learning_rate"],
            full_model_data[metric_key],
            alpha=self.alpha,
            s=self.marker_size,
            c="red",
            edgecolors="black",
            linewidth=self.linewidth,
            label="only_head=False",
        )

        plt.xscale("log")
        plt.xlabel("Współczynnik uczenia (skala log)", fontsize=self.fontsize)
        plt.ylabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha)
        plt.legend(fontsize=self.fontsize - 1, loc="lower left")
        plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, p: f"{x:.1e}"))

    @plot_wrapper
    def make_margin_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
    ):
        """Create margin plot."""
        plt.scatter(
            plot_data["config_margin"],
            plot_data[metric_key],
            alpha=self.alpha,
            s=self.marker_size,
            c="green",
            edgecolors="black",
            linewidth=self.linewidth,
        )

        plt.xlabel("Margines", fontsize=self.fontsize)
        plt.ylabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha)

    @plot_wrapper
    def make_augmentation_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
    ):
        """Create augmentation plot."""
        plot_data["config_augmentation"] = plot_data["config_augmentation"].fillna(
            "None"
        )

        augmentations = ["None", "AddRandomRectangleAverageColor"]

        data_by_augmentation = []
        for aug in augmentations:
            filtered_data = plot_data[plot_data["config_augmentation"] == aug]
            metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
            data_by_augmentation.append(metric_series.to_numpy())

        box_plot = plt.boxplot(
            data_by_augmentation,
            tick_labels=["", ""],
            patch_artist=True,
            showmeans=True,
            vert=False,
        )

        box_plot["boxes"][0].set_facecolor("lightblue")
        box_plot["boxes"][0].set_alpha(self.alpha)
        box_plot["boxes"][1].set_facecolor("lightcoral")
        box_plot["boxes"][1].set_alpha(self.alpha)

        # Get the left edge of the entire plot area with some margin
        x_min = plt.xlim()[0]
        x_range = plt.xlim()[1] - plt.xlim()[0]
        x_margin = x_min + (0.05 * x_range)  # 5% margin from left edge

        for i, label in enumerate(augmentations):
            # Get the top edge of the box for positioning above
            box_top = box_plot["boxes"][i].get_path().vertices[:, 1].max()

            plt.text(
                x_margin,
                box_top + 0.15,
                label,
                horizontalalignment="left",
                verticalalignment="bottom",
                fontsize=self.fontsize - 2,
            )

        plt.xlabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha, axis="x")

    @plot_wrapper
    def make_only_head_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
    ):
        """Create only head plot."""
        plot_data["config_only_head"] = plot_data["config_only_head"].replace(
            {True: "True", False: "False"}
        )

        data_by_only_head = []
        for val in ["True", "False"]:
            filtered_data = plot_data[plot_data["config_only_head"] == val]
            metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
            data_by_only_head.append(metric_series.to_numpy())

        labels = ["only_head=True", "only_head=False"]

        box_plot = plt.boxplot(
            data_by_only_head,
            tick_labels=["", ""],
            patch_artist=True,
            showmeans=True,
            vert=False,
        )

        box_plot["boxes"][0].set_facecolor("lightblue")
        box_plot["boxes"][0].set_alpha(self.alpha)
        box_plot["boxes"][1].set_facecolor("lightcoral")
        box_plot["boxes"][1].set_alpha(self.alpha)

        # Get the left edge of the entire plot area with some margin
        x_min = plt.xlim()[0]
        x_range = plt.xlim()[1] - plt.xlim()[0]
        x_margin = x_min + (0.05 * x_range)  # 5% margin from left edge

        for i, label in enumerate(labels):
            # Get the top edge of the box for positioning above
            box_top = box_plot["boxes"][i].get_path().vertices[:, 1].max()

            plt.text(
                x_margin,
                box_top + 0.15,
                label,
                horizontalalignment="left",
                verticalalignment="bottom",
                fontsize=self.fontsize - 2,
            )

        plt.xlabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha, axis="x")
        plt.tick_params(axis="both", which="major", labelsize=self.fontsize - 3)

    @plot_wrapper
    def make_mini_batch_size_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
    ):
        """Create mini batch size plot."""
        mini_batch_sizes = sorted(plot_data["config_batch_size"].unique())

        data_by_mini_batch_size = []
        for size in mini_batch_sizes:
            filtered_data = plot_data[plot_data["config_batch_size"] == size]
            metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
            data_by_mini_batch_size.append(metric_series.to_numpy())

        box_plot = plt.boxplot(
            data_by_mini_batch_size,
            tick_labels=[str(size) for size in mini_batch_sizes],
            patch_artist=True,
            showmeans=True,
            vert=False,
        )

        colors = plt.cm.get_cmap("viridis")(np.linspace(0, 1, len(mini_batch_sizes)))
        for i, box in enumerate(box_plot["boxes"]):
            box.set_facecolor(colors[i])
            box.set_alpha(self.alpha)

        plt.xlabel(metric_name, fontsize=self.fontsize)
        plt.ylabel("Rozmiar mini-pakietu", fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha, axis="x")

    def make_loss_or_acc_plot(
        self,
        run_histories: list[pd.DataFrame],
        output_dir: Path,
        filename: str,
        metric_key: str,
        metric_name: str,
        show: bool = False,
    ):
        """Create loss or accuracy plot (no decorator - different signature)."""
        plt.figure(figsize=self.figsize)

        for history in run_histories:
            plt.plot(
                history["_step"],
                history[metric_key],
                alpha=0.3,
                linewidth=1,
            )

        plt.xlabel("Epoka", fontsize=self.fontsize)
        plt.ylabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha)
        plt.tight_layout()

        plt.savefig(output_dir / filename, dpi=self.dpi, bbox_inches="tight")
        if show:
            plt.show()
        plt.close()

    @plot_wrapper
    def make_aug_comparison_plot(
        self,
        plot_data: pd.DataFrame,
        metric_key: str,
        metric_name: str,
        base: float,
        aggregate_best: str = "max",
        xmin: float = 0.0,
        xmax: float = 1.0,
    ):
        """Create augmentation comparison plot."""
        plot_data["config_augmentation"] = plot_data["config_augmentation"].fillna(
            "None"
        )

        aug_names = plot_data["config_augmentation"].unique()
        best_result_by_aug = (
            plot_data.groupby("config_augmentation")
            .agg({metric_key: aggregate_best})
            .reset_index()
            .sort_values(by=metric_key, ascending=True)
        )

        # Create horizontal bar plot
        y_pos = np.arange(len(best_result_by_aug))
        plt.barh(
            y_pos,
            best_result_by_aug[metric_key],
            color=plt.cm.get_cmap("viridis")(np.linspace(0, 1, len(aug_names))),
            alpha=self.alpha,
            edgecolor="black",
        )

        # Set y-axis labels to augmentation names
        plt.yticks(
            y_pos,
            best_result_by_aug["config_augmentation"].tolist(),
            fontsize=self.fontsize - 5,
        )

        plt.xlim(xmin, xmax)
        plt.xlabel(metric_name, fontsize=self.fontsize)
        plt.grid(True, alpha=self.grid_alpha, axis="x")

        # Add vertical line for base score
        plt.axvline(
            x=base,
            color="red",
            linestyle="--",
            linewidth=2,
            alpha=0.8,
            label="Baseline",
        )

    def make_parameter_analysis_plot(
        self,
        importance_dict: dict,
        correlation_dict: dict,
        output_dir: Path,
        filename: str,
        title: str,
        show: bool = False,
    ):
        """Create parameter analysis plot (no decorator - different signature)."""
        parameters = list(importance_dict.keys())
        importances = [importance_dict[param] for param in parameters]
        correlations = [correlation_dict[param] for param in parameters]

        # Get absolute values for correlation bar heights
        abs_correlations = [abs(corr) for corr in correlations]

        # Create colors based on original correlation sign (positive=green, negative=red)
        corr_colors = ["green" if corr > 0 else "red" for corr in correlations]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 6))

        # Left subplot: Parameter Importance
        bars1 = ax1.bar(
            parameters,
            importances,
            color="skyblue",
            edgecolor="black",
            alpha=self.alpha,
        )
        ax1.set_xlabel("Parametr", fontsize=self.fontsize)
        ax1.set_ylabel("Ważność", fontsize=self.fontsize)
        ax1.set_title("Ważność parametrów", fontsize=self.fontsize + 2)
        ax1.set_ylim(0, max(importances) * 1.1)
        ax1.grid(axis="y", alpha=self.grid_alpha)

        # Add value labels on top of bars for importance
        for bar in bars1:
            height = bar.get_height()
            ax1.text(
                bar.get_x() + bar.get_width() / 2,
                height,
                f"{height:.3f}",
                ha="center",
                va="bottom",
                fontsize=self.fontsize - 2,
            )

        # Right subplot: Parameter Correlation (absolute values with color coding)
        bars2 = ax2.bar(
            parameters,
            abs_correlations,
            color=corr_colors,
            edgecolor="black",
            alpha=self.alpha,
        )
        ax2.set_xlabel("Parametr", fontsize=self.fontsize)
        ax2.set_ylabel("Korelacja", fontsize=self.fontsize)
        ax2.set_title("Korelacja parametrów", fontsize=self.fontsize + 2)
        ax2.set_ylim(0, max(abs_correlations) * 1.1)
        ax2.grid(axis="y", alpha=self.grid_alpha)

        # Add value labels on top of bars for correlation (showing original values)
        for bar, original_corr in zip(bars2, correlations):
            height = bar.get_height()
            ax2.text(
                bar.get_x() + bar.get_width() / 2,
                height,
                f"{original_corr:.3f}",
                ha="center",
                va="bottom",
                fontsize=self.fontsize - 2,
            )

        # Set overall title
        fig.suptitle(title, fontsize=self.fontsize + 4)

        plt.tight_layout()
        plt.savefig(output_dir / filename, dpi=self.dpi, bbox_inches="tight")
        if show:
            plt.show()
        plt.close()

make_aug_comparison_plot(plot_data, metric_key, metric_name, base, aggregate_best='max', xmin=0.0, xmax=1.0)

Create augmentation comparison plot.

Source code in src/plots/individual_plots.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
@plot_wrapper
def make_aug_comparison_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
    base: float,
    aggregate_best: str = "max",
    xmin: float = 0.0,
    xmax: float = 1.0,
):
    """Create augmentation comparison plot."""
    plot_data["config_augmentation"] = plot_data["config_augmentation"].fillna(
        "None"
    )

    aug_names = plot_data["config_augmentation"].unique()
    best_result_by_aug = (
        plot_data.groupby("config_augmentation")
        .agg({metric_key: aggregate_best})
        .reset_index()
        .sort_values(by=metric_key, ascending=True)
    )

    # Create horizontal bar plot
    y_pos = np.arange(len(best_result_by_aug))
    plt.barh(
        y_pos,
        best_result_by_aug[metric_key],
        color=plt.cm.get_cmap("viridis")(np.linspace(0, 1, len(aug_names))),
        alpha=self.alpha,
        edgecolor="black",
    )

    # Set y-axis labels to augmentation names
    plt.yticks(
        y_pos,
        best_result_by_aug["config_augmentation"].tolist(),
        fontsize=self.fontsize - 5,
    )

    plt.xlim(xmin, xmax)
    plt.xlabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha, axis="x")

    # Add vertical line for base score
    plt.axvline(
        x=base,
        color="red",
        linestyle="--",
        linewidth=2,
        alpha=0.8,
        label="Baseline",
    )

make_augmentation_plot(plot_data, metric_key, metric_name)

Create augmentation plot.

Source code in src/plots/individual_plots.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
@plot_wrapper
def make_augmentation_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
):
    """Create augmentation plot."""
    plot_data["config_augmentation"] = plot_data["config_augmentation"].fillna(
        "None"
    )

    augmentations = ["None", "AddRandomRectangleAverageColor"]

    data_by_augmentation = []
    for aug in augmentations:
        filtered_data = plot_data[plot_data["config_augmentation"] == aug]
        metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
        data_by_augmentation.append(metric_series.to_numpy())

    box_plot = plt.boxplot(
        data_by_augmentation,
        tick_labels=["", ""],
        patch_artist=True,
        showmeans=True,
        vert=False,
    )

    box_plot["boxes"][0].set_facecolor("lightblue")
    box_plot["boxes"][0].set_alpha(self.alpha)
    box_plot["boxes"][1].set_facecolor("lightcoral")
    box_plot["boxes"][1].set_alpha(self.alpha)

    # Get the left edge of the entire plot area with some margin
    x_min = plt.xlim()[0]
    x_range = plt.xlim()[1] - plt.xlim()[0]
    x_margin = x_min + (0.05 * x_range)  # 5% margin from left edge

    for i, label in enumerate(augmentations):
        # Get the top edge of the box for positioning above
        box_top = box_plot["boxes"][i].get_path().vertices[:, 1].max()

        plt.text(
            x_margin,
            box_top + 0.15,
            label,
            horizontalalignment="left",
            verticalalignment="bottom",
            fontsize=self.fontsize - 2,
        )

    plt.xlabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha, axis="x")

make_loss_or_acc_plot(run_histories, output_dir, filename, metric_key, metric_name, show=False)

Create loss or accuracy plot (no decorator - different signature).

Source code in src/plots/individual_plots.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def make_loss_or_acc_plot(
    self,
    run_histories: list[pd.DataFrame],
    output_dir: Path,
    filename: str,
    metric_key: str,
    metric_name: str,
    show: bool = False,
):
    """Create loss or accuracy plot (no decorator - different signature)."""
    plt.figure(figsize=self.figsize)

    for history in run_histories:
        plt.plot(
            history["_step"],
            history[metric_key],
            alpha=0.3,
            linewidth=1,
        )

    plt.xlabel("Epoka", fontsize=self.fontsize)
    plt.ylabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha)
    plt.tight_layout()

    plt.savefig(output_dir / filename, dpi=self.dpi, bbox_inches="tight")
    if show:
        plt.show()
    plt.close()

make_lr_plot(plot_data, metric_key, metric_name)

Create learning rate plot.

Source code in src/plots/individual_plots.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
@plot_wrapper
def make_lr_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
):
    """Create learning rate plot."""
    head_only_data = plot_data[plot_data["config_only_head"] == True]
    full_model_data = plot_data[plot_data["config_only_head"] == False]

    plt.scatter(
        head_only_data["config_learning_rate"],
        head_only_data[metric_key],
        alpha=self.alpha,
        s=self.marker_size,
        c="steelblue",
        edgecolors="black",
        linewidth=self.linewidth,
        label="only_head=True",
    )

    plt.scatter(
        full_model_data["config_learning_rate"],
        full_model_data[metric_key],
        alpha=self.alpha,
        s=self.marker_size,
        c="red",
        edgecolors="black",
        linewidth=self.linewidth,
        label="only_head=False",
    )

    plt.xscale("log")
    plt.xlabel("Współczynnik uczenia (skala log)", fontsize=self.fontsize)
    plt.ylabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha)
    plt.legend(fontsize=self.fontsize - 1, loc="lower left")
    plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, p: f"{x:.1e}"))

make_margin_plot(plot_data, metric_key, metric_name)

Create margin plot.

Source code in src/plots/individual_plots.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
@plot_wrapper
def make_margin_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
):
    """Create margin plot."""
    plt.scatter(
        plot_data["config_margin"],
        plot_data[metric_key],
        alpha=self.alpha,
        s=self.marker_size,
        c="green",
        edgecolors="black",
        linewidth=self.linewidth,
    )

    plt.xlabel("Margines", fontsize=self.fontsize)
    plt.ylabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha)

make_mini_batch_size_plot(plot_data, metric_key, metric_name)

Create mini batch size plot.

Source code in src/plots/individual_plots.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
@plot_wrapper
def make_mini_batch_size_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
):
    """Create mini batch size plot."""
    mini_batch_sizes = sorted(plot_data["config_batch_size"].unique())

    data_by_mini_batch_size = []
    for size in mini_batch_sizes:
        filtered_data = plot_data[plot_data["config_batch_size"] == size]
        metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
        data_by_mini_batch_size.append(metric_series.to_numpy())

    box_plot = plt.boxplot(
        data_by_mini_batch_size,
        tick_labels=[str(size) for size in mini_batch_sizes],
        patch_artist=True,
        showmeans=True,
        vert=False,
    )

    colors = plt.cm.get_cmap("viridis")(np.linspace(0, 1, len(mini_batch_sizes)))
    for i, box in enumerate(box_plot["boxes"]):
        box.set_facecolor(colors[i])
        box.set_alpha(self.alpha)

    plt.xlabel(metric_name, fontsize=self.fontsize)
    plt.ylabel("Rozmiar mini-pakietu", fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha, axis="x")

make_only_head_plot(plot_data, metric_key, metric_name)

Create only head plot.

Source code in src/plots/individual_plots.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@plot_wrapper
def make_only_head_plot(
    self,
    plot_data: pd.DataFrame,
    metric_key: str,
    metric_name: str,
):
    """Create only head plot."""
    plot_data["config_only_head"] = plot_data["config_only_head"].replace(
        {True: "True", False: "False"}
    )

    data_by_only_head = []
    for val in ["True", "False"]:
        filtered_data = plot_data[plot_data["config_only_head"] == val]
        metric_series: pd.Series = filtered_data[metric_key]  # type: ignore[assignment]
        data_by_only_head.append(metric_series.to_numpy())

    labels = ["only_head=True", "only_head=False"]

    box_plot = plt.boxplot(
        data_by_only_head,
        tick_labels=["", ""],
        patch_artist=True,
        showmeans=True,
        vert=False,
    )

    box_plot["boxes"][0].set_facecolor("lightblue")
    box_plot["boxes"][0].set_alpha(self.alpha)
    box_plot["boxes"][1].set_facecolor("lightcoral")
    box_plot["boxes"][1].set_alpha(self.alpha)

    # Get the left edge of the entire plot area with some margin
    x_min = plt.xlim()[0]
    x_range = plt.xlim()[1] - plt.xlim()[0]
    x_margin = x_min + (0.05 * x_range)  # 5% margin from left edge

    for i, label in enumerate(labels):
        # Get the top edge of the box for positioning above
        box_top = box_plot["boxes"][i].get_path().vertices[:, 1].max()

        plt.text(
            x_margin,
            box_top + 0.15,
            label,
            horizontalalignment="left",
            verticalalignment="bottom",
            fontsize=self.fontsize - 2,
        )

    plt.xlabel(metric_name, fontsize=self.fontsize)
    plt.grid(True, alpha=self.grid_alpha, axis="x")
    plt.tick_params(axis="both", which="major", labelsize=self.fontsize - 3)

make_parameter_analysis_plot(importance_dict, correlation_dict, output_dir, filename, title, show=False)

Create parameter analysis plot (no decorator - different signature).

Source code in src/plots/individual_plots.py
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def make_parameter_analysis_plot(
    self,
    importance_dict: dict,
    correlation_dict: dict,
    output_dir: Path,
    filename: str,
    title: str,
    show: bool = False,
):
    """Create parameter analysis plot (no decorator - different signature)."""
    parameters = list(importance_dict.keys())
    importances = [importance_dict[param] for param in parameters]
    correlations = [correlation_dict[param] for param in parameters]

    # Get absolute values for correlation bar heights
    abs_correlations = [abs(corr) for corr in correlations]

    # Create colors based on original correlation sign (positive=green, negative=red)
    corr_colors = ["green" if corr > 0 else "red" for corr in correlations]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 6))

    # Left subplot: Parameter Importance
    bars1 = ax1.bar(
        parameters,
        importances,
        color="skyblue",
        edgecolor="black",
        alpha=self.alpha,
    )
    ax1.set_xlabel("Parametr", fontsize=self.fontsize)
    ax1.set_ylabel("Ważność", fontsize=self.fontsize)
    ax1.set_title("Ważność parametrów", fontsize=self.fontsize + 2)
    ax1.set_ylim(0, max(importances) * 1.1)
    ax1.grid(axis="y", alpha=self.grid_alpha)

    # Add value labels on top of bars for importance
    for bar in bars1:
        height = bar.get_height()
        ax1.text(
            bar.get_x() + bar.get_width() / 2,
            height,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=self.fontsize - 2,
        )

    # Right subplot: Parameter Correlation (absolute values with color coding)
    bars2 = ax2.bar(
        parameters,
        abs_correlations,
        color=corr_colors,
        edgecolor="black",
        alpha=self.alpha,
    )
    ax2.set_xlabel("Parametr", fontsize=self.fontsize)
    ax2.set_ylabel("Korelacja", fontsize=self.fontsize)
    ax2.set_title("Korelacja parametrów", fontsize=self.fontsize + 2)
    ax2.set_ylim(0, max(abs_correlations) * 1.1)
    ax2.grid(axis="y", alpha=self.grid_alpha)

    # Add value labels on top of bars for correlation (showing original values)
    for bar, original_corr in zip(bars2, correlations):
        height = bar.get_height()
        ax2.text(
            bar.get_x() + bar.get_width() / 2,
            height,
            f"{original_corr:.3f}",
            ha="center",
            va="bottom",
            fontsize=self.fontsize - 2,
        )

    # Set overall title
    fig.suptitle(title, fontsize=self.fontsize + 4)

    plt.tight_layout()
    plt.savefig(output_dir / filename, dpi=self.dpi, bbox_inches="tight")
    if show:
        plt.show()
    plt.close()

plot_wrapper(func)

Decorator that handles plot setup and cleanup using class configuration.

Source code in src/plots/individual_plots.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def plot_wrapper(func: Callable) -> Callable:
    """Decorator that handles plot setup and cleanup using class configuration."""

    @wraps(func)
    def wrapper(
        self,  # The PlotMaker instance
        run_data: pd.DataFrame,
        output_dir: Path,
        filename: str,
        *args,
        show: bool = False,
        **kwargs,
    ):
        # Setup
        plot_data = run_data.copy()
        plt.figure(figsize=self.figsize)

        try:
            # Call the actual plotting function with plot_data
            func(self, plot_data, *args, **kwargs)
        finally:
            # Cleanup
            plt.tight_layout()
            plt.savefig(output_dir / filename, dpi=self.dpi, bbox_inches="tight")
            if show:
                plt.show()
            plt.close()

    return wrapper

parameter_importance

Data on parameter importance and linear correlation with respect to metrics.

This data is calculated by wandb but cannot be downloaded via the API. I manually copied it from the web interface. It is used for recreating the plots in the thesis.

plot_groups

Functions for creating groups of plots for the thesis.

There are groups of similar plots used together in the thesis. (e.g. plots of learning rate vs metric X for different metrics).

PlotGroups

Source code in src/plots/plot_groups.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class PlotGroups:
    def __init__(self, sweep_data: SweepData, output_dir: Path) -> None:
        self.sweep_data = sweep_data
        self.output_dir = output_dir
        self.plot_maker = PlotMaker()
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def make_output_subdir(self, subdir_name: str) -> Path:
        subdir = self.output_dir / subdir_name
        subdir.mkdir(parents=True, exist_ok=True)
        return subdir

    def learning_rate_v_metrics(self) -> None:
        """Create plots showing the effect of learning rate on various metrics."""
        plot_args = [
            ("lr-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
            (
                "lr-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "lr-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "lr-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "lr-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_lr_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("lr"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def margin_v_metrics(self) -> None:
        """Create plots showing the effect of margin on various metrics."""
        plot_args = [
            ("margin-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
            (
                "margin-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "margin-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "margin-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "margin-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_margin_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("margin"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
            )

    def augmentation_v_metrics(self) -> None:
        """Create plots showing the effect of data augmentation on various metrics."""
        plot_args = [
            (
                "augmentation-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "augmentation-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "augmentation-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "augmentation-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "augmentation-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_augmentation_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("augmentation"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def only_head_v_metrics(self) -> None:
        """Create plots showing the effect of only_head on various metrics and runtime."""
        plot_args = [
            (
                "only-head-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "only-head-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "only-head-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "only-head-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "only-head-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
            ("only-head-vs-runtime.png", "runtime", "Czas obliczeń [s]"),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_only_head_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("only-head"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def batch_size_v_metrics(self) -> None:
        """Create plots showing the effect of batch size on various metrics."""
        plot_args = [
            (
                "mini-batch-size-vs-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
            ),
            (
                "mini-batch-size-vs-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
            ),
            (
                "mini-batch-size-vs-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
            ),
            (
                "mini-batch-size-vs-train-accuracy.png",
                "max_training/train_accuracy",
                "Dokładność na zbiorze treningowym",
            ),
            (
                "mini-batch-size-vs-val-accuracy.png",
                "max_training/val_accuracy",
                "Dokładność na zbiorze walidacyjnym",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_mini_batch_size_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("mini-batch-size"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def parameter_importance_and_correlation(
        self, experiment_data_dict: dict[str, Any]
    ) -> None:
        """Create plots showing parameter importance and correlation with respect to metrics for selected metrics."""
        plot_args = [
            ("parameter-importance-lfw.png", "lfw", "Dokładność LFW"),
            ("parameter-importance-rof-m.png", "rof-m", "Dokładność ROF-m"),
            ("parameter-importance-rof-s.png", "rof-s", "Dokładność ROF-s"),
            (
                "parameter-importance-val-accuracy.png",
                "val_accuracy",
                "Dokładność walidacyjna",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_parameter_analysis_plot(
                importance_dict=experiment_data_dict[metric_key]["importance"],
                correlation_dict=experiment_data_dict[metric_key]["correlation"],
                output_dir=self.make_output_subdir("parameter-analysis"),
                filename=filename,
                title=metric_name,
            )

    def training_curves(self) -> None:
        """Create training curves for all runs in the sweep."""
        plot_args = [
            (
                "train-loss-over-epochs.png",
                "training/train_loss",
                "Strata treningowy",
            ),
            (
                "val-loss-over-epochs.png",
                "training/val_loss",
                "Strata walidacyjna",
            ),
            (
                "train-accuracy-over-epochs.png",
                "training/train_accuracy",
                "Dokładność treningowa",
            ),
            (
                "val-accuracy-over-epochs.png",
                "training/val_accuracy",
                "Dokładność walidacyjna",
            ),
        ]

        for filename, metric_key, metric_name in plot_args:
            self.plot_maker.make_loss_or_acc_plot(
                run_histories=self.sweep_data.run_histories,
                output_dir=self.make_output_subdir("training"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                show=False,
            )

    def augmentation_comparison(self) -> None:
        """Group of plots comparing different augmentations with respect to a metric for all metrics."""
        # Baseline values hardcoded
        plot_args = [
            (
                "aug-comparison-lfw.png",
                "summary_benchmark/lfw_accuracy",
                "Dokładność LFW",
                "max",
                0.90,
                1.0,
                0.971,
            ),
            (
                "aug-comparison-rof-m.png",
                "summary_benchmark/rof_masked_accuracy",
                "Dokładność ROF-m",
                "max",
                0.70,
                0.90,
                0.859,
            ),
            (
                "aug-comparison-rof-s.png",
                "summary_benchmark/rof_sunglasses_accuracy",
                "Dokładność ROF-s",
                "max",
                0.70,
                0.90,
                0.872,
            ),
            (
                "aug-comparison-eer.png",
                "summary_rococo-evaluation/eer",
                "EER",
                "min",
                0.0,
                0.5,
                0.102,
            ),
            (
                "aug-comparison-frr-at-far-zero.png",
                "summary_rococo-evaluation/frr_at_far_zero",
                "FRR @ FAR=0",
                "min",
                0.0,
                1.0,
                0.763,
            ),
        ]

        for (
            filename,
            metric_key,
            metric_name,
            aggregate_best,
            xmin,
            xmax,
            base,
        ) in plot_args:
            self.plot_maker.make_aug_comparison_plot(
                run_data=self.sweep_data.sweep_runs_data,
                output_dir=self.make_output_subdir("augmentation-comparison"),
                filename=filename,
                metric_key=metric_key,
                metric_name=metric_name,
                aggregate_best=aggregate_best,
                xmin=xmin,
                xmax=xmax,
                base=base,
                show=False,
            )

augmentation_comparison()

Group of plots comparing different augmentations with respect to a metric for all metrics.

Source code in src/plots/plot_groups.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def augmentation_comparison(self) -> None:
    """Group of plots comparing different augmentations with respect to a metric for all metrics."""
    # Baseline values hardcoded
    plot_args = [
        (
            "aug-comparison-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
            "max",
            0.90,
            1.0,
            0.971,
        ),
        (
            "aug-comparison-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
            "max",
            0.70,
            0.90,
            0.859,
        ),
        (
            "aug-comparison-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
            "max",
            0.70,
            0.90,
            0.872,
        ),
        (
            "aug-comparison-eer.png",
            "summary_rococo-evaluation/eer",
            "EER",
            "min",
            0.0,
            0.5,
            0.102,
        ),
        (
            "aug-comparison-frr-at-far-zero.png",
            "summary_rococo-evaluation/frr_at_far_zero",
            "FRR @ FAR=0",
            "min",
            0.0,
            1.0,
            0.763,
        ),
    ]

    for (
        filename,
        metric_key,
        metric_name,
        aggregate_best,
        xmin,
        xmax,
        base,
    ) in plot_args:
        self.plot_maker.make_aug_comparison_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("augmentation-comparison"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            aggregate_best=aggregate_best,
            xmin=xmin,
            xmax=xmax,
            base=base,
            show=False,
        )

augmentation_v_metrics()

Create plots showing the effect of data augmentation on various metrics.

Source code in src/plots/plot_groups.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def augmentation_v_metrics(self) -> None:
    """Create plots showing the effect of data augmentation on various metrics."""
    plot_args = [
        (
            "augmentation-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "augmentation-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "augmentation-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "augmentation-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "augmentation-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_augmentation_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("augmentation"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

batch_size_v_metrics()

Create plots showing the effect of batch size on various metrics.

Source code in src/plots/plot_groups.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def batch_size_v_metrics(self) -> None:
    """Create plots showing the effect of batch size on various metrics."""
    plot_args = [
        (
            "mini-batch-size-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "mini-batch-size-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "mini-batch-size-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "mini-batch-size-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "mini-batch-size-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_mini_batch_size_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("mini-batch-size"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

learning_rate_v_metrics()

Create plots showing the effect of learning rate on various metrics.

Source code in src/plots/plot_groups.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def learning_rate_v_metrics(self) -> None:
    """Create plots showing the effect of learning rate on various metrics."""
    plot_args = [
        ("lr-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
        (
            "lr-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "lr-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "lr-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "lr-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_lr_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("lr"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

margin_v_metrics()

Create plots showing the effect of margin on various metrics.

Source code in src/plots/plot_groups.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def margin_v_metrics(self) -> None:
    """Create plots showing the effect of margin on various metrics."""
    plot_args = [
        ("margin-vs-lfw.png", "summary_benchmark/lfw_accuracy", "Dokładność LFW"),
        (
            "margin-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "margin-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "margin-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "margin-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_margin_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("margin"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
        )

only_head_v_metrics()

Create plots showing the effect of only_head on various metrics and runtime.

Source code in src/plots/plot_groups.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def only_head_v_metrics(self) -> None:
    """Create plots showing the effect of only_head on various metrics and runtime."""
    plot_args = [
        (
            "only-head-vs-lfw.png",
            "summary_benchmark/lfw_accuracy",
            "Dokładność LFW",
        ),
        (
            "only-head-vs-rof-m.png",
            "summary_benchmark/rof_masked_accuracy",
            "Dokładność ROF-m",
        ),
        (
            "only-head-vs-rof-s.png",
            "summary_benchmark/rof_sunglasses_accuracy",
            "Dokładność ROF-s",
        ),
        (
            "only-head-vs-train-accuracy.png",
            "max_training/train_accuracy",
            "Dokładność na zbiorze treningowym",
        ),
        (
            "only-head-vs-val-accuracy.png",
            "max_training/val_accuracy",
            "Dokładność na zbiorze walidacyjnym",
        ),
        ("only-head-vs-runtime.png", "runtime", "Czas obliczeń [s]"),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_only_head_plot(
            run_data=self.sweep_data.sweep_runs_data,
            output_dir=self.make_output_subdir("only-head"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

parameter_importance_and_correlation(experiment_data_dict)

Create plots showing parameter importance and correlation with respect to metrics for selected metrics.

Source code in src/plots/plot_groups.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def parameter_importance_and_correlation(
    self, experiment_data_dict: dict[str, Any]
) -> None:
    """Create plots showing parameter importance and correlation with respect to metrics for selected metrics."""
    plot_args = [
        ("parameter-importance-lfw.png", "lfw", "Dokładność LFW"),
        ("parameter-importance-rof-m.png", "rof-m", "Dokładność ROF-m"),
        ("parameter-importance-rof-s.png", "rof-s", "Dokładność ROF-s"),
        (
            "parameter-importance-val-accuracy.png",
            "val_accuracy",
            "Dokładność walidacyjna",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_parameter_analysis_plot(
            importance_dict=experiment_data_dict[metric_key]["importance"],
            correlation_dict=experiment_data_dict[metric_key]["correlation"],
            output_dir=self.make_output_subdir("parameter-analysis"),
            filename=filename,
            title=metric_name,
        )

training_curves()

Create training curves for all runs in the sweep.

Source code in src/plots/plot_groups.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def training_curves(self) -> None:
    """Create training curves for all runs in the sweep."""
    plot_args = [
        (
            "train-loss-over-epochs.png",
            "training/train_loss",
            "Strata treningowy",
        ),
        (
            "val-loss-over-epochs.png",
            "training/val_loss",
            "Strata walidacyjna",
        ),
        (
            "train-accuracy-over-epochs.png",
            "training/train_accuracy",
            "Dokładność treningowa",
        ),
        (
            "val-accuracy-over-epochs.png",
            "training/val_accuracy",
            "Dokładność walidacyjna",
        ),
    ]

    for filename, metric_key, metric_name in plot_args:
        self.plot_maker.make_loss_or_acc_plot(
            run_histories=self.sweep_data.run_histories,
            output_dir=self.make_output_subdir("training"),
            filename=filename,
            metric_key=metric_key,
            metric_name=metric_name,
            show=False,
        )

sweep_data

SweepData dataclass

Data from a single experiment sweep for plotting.

Attributes

sweep_id : str The ID of the sweep. sweep_runs_data : DataFrame DataFrame containing summary data for all runs in the sweep. run_histories : list[DataFrame] List of DataFrames, each containing the history of metrics for a single run. Used for plotting training curves for individual runs.

Source code in src/plots/sweep_data.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@dataclass
class SweepData:
    """Data from a single experiment sweep for plotting.

    Attributes
    ----------
    sweep_id : str
        The ID of the sweep.
    sweep_runs_data : DataFrame
        DataFrame containing summary data for all runs in the sweep.
    run_histories : list[DataFrame]
        List of DataFrames, each containing the history of metrics for a single run.
        Used for plotting training curves for individual runs.
    """

    sweep_id: str
    sweep_runs_data: DataFrame
    run_histories: list[DataFrame]

wandb_client

WandbClient

Wrapper for Weights & Biases API client.

Handles fetching and caching of sweep data. Since the API request can be slow and I am worried about rate limits, fetched data is cached in a local JSON file.

Source code in src/plots/wandb_client.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class WandbClient:
    """Wrapper for Weights & Biases API client.

    Handles fetching and caching of sweep data.
    Since the API request can be slow and I am worried about rate limits,
    fetched data is cached in a local JSON file.
    """

    def __init__(self, project_name: str, cache_dir: Path):
        self.api = wandb.Api()
        self.project_name = project_name
        self.cache_dir = cache_dir
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def get_sweep_data(self, sweep_id: str) -> SweepData:
        """Fetch sweep data from cache or API."""

        if self._get_cache_filepath(sweep_id).exists():
            print(f"Loading sweep data from cache: {sweep_id}")
            return self._load_sweep_data_from_json(sweep_id)
        else:
            print(f"Fetching sweep data from API: {sweep_id}")
            sweep_data = self._fetch_sweep_data_from_api(sweep_id)
            self._save_sweep_data_to_json(sweep_data)
            return sweep_data

    def _fetch_sweep_data_from_api(self, sweep_id: str):
        sweep = self._get_sweep_object(sweep_id)
        sweep_runs_data = self._extract_sweep_run_data(sweep)
        sweep_runs_data = sweep_runs_data[sweep_runs_data["state"] == "finished"]
        assert isinstance(sweep_runs_data, DataFrame)
        run_histories = [self._extract_run_history(run) for run in sweep.runs]
        return SweepData(
            sweep_id=sweep_id,
            sweep_runs_data=sweep_runs_data,
            run_histories=run_histories,
        )

    def _get_cache_filepath(self, sweep_id: str) -> Path:
        return self.cache_dir / f"sweep_{sweep_id}_data.json"

    def _save_sweep_data_to_json(self, sweep_data: SweepData):
        with open(self._get_cache_filepath(sweep_data.sweep_id), "w") as f:
            json.dump(
                {
                    "sweep_runs_data": sweep_data.sweep_runs_data.to_dict(
                        orient="records"
                    ),
                    "run_histories": [
                        rh.to_dict(orient="records") for rh in sweep_data.run_histories
                    ],
                },
                f,
                indent=4,
                default=str,
            )

    def _load_sweep_data_from_json(self, sweep_id: str) -> SweepData:
        with open(self._get_cache_filepath(sweep_id), "r") as f:
            data = json.load(f)
            sweep_runs_data = DataFrame(data["sweep_runs_data"])
            run_histories = [DataFrame(rh) for rh in data["run_histories"]]
            return SweepData(
                sweep_id=sweep_id,
                sweep_runs_data=sweep_runs_data,
                run_histories=run_histories,
            )

    def _extract_sweep_run_data(self, sweep) -> DataFrame:
        """Transform Sweep object into a DataFrame.

        Each row corresponds to a single run in the sweep with metrics and config as columns.
        """
        data = []

        for run in sweep.runs:
            row = {
                "run_id": run.id,
                "run_name": run.name,
                "state": run.state,
                "created_at": run.created_at,
                "runtime": run.summary["_runtime"],
            }

            # Config parameters
            for key, value in run.config.items():
                row[f"config_{key}"] = value

            # Summary metrics
            for key, value in run.summary.items():
                if not key.startswith("_"):  # Skip internal wandb fields
                    row[f"summary_{key}"] = value

            # History metric values (final and max)
            history = run.history()
            if not history.empty:
                for col in history.columns:
                    if not col.startswith("_"):
                        row[f"final_{col}"] = (
                            history[col].iloc[-1] if len(history) > 0 else None
                        )

                accuracy_cols = [
                    col for col in history.columns if "accuracy" in col.lower()
                ]
                for col in accuracy_cols:
                    row[f"max_{col}"] = history[col].max()

            data.append(row)

        return DataFrame(data)

    def _extract_run_history(self, run) -> DataFrame:
        """Transform a Run object into a DataFrame of its history.

        History includes metrics logged during training.
        Each row corresponds to a single logging step (epoch).
        """
        history = run.history()
        metrics = [
            "training/train_loss",
            "training/val_loss",
            "training/train_accuracy",
            "training/val_accuracy",
        ]

        # Filter to specific metrics (plus _step and _timestamp)
        available_metrics = [m for m in metrics if m in history.columns]
        cols_to_keep = ["_step", "_timestamp"] + available_metrics
        history = history[cols_to_keep]

        history["run_id"] = run.id
        history["run_name"] = run.name

        return history

    def _get_sweep_object(self, sweep_id: str):
        return self.api.sweep(f"{self.project_name}/{sweep_id}")

get_sweep_data(sweep_id)

Fetch sweep data from cache or API.

Source code in src/plots/wandb_client.py
24
25
26
27
28
29
30
31
32
33
34
def get_sweep_data(self, sweep_id: str) -> SweepData:
    """Fetch sweep data from cache or API."""

    if self._get_cache_filepath(sweep_id).exists():
        print(f"Loading sweep data from cache: {sweep_id}")
        return self._load_sweep_data_from_json(sweep_id)
    else:
        print(f"Fetching sweep data from API: {sweep_id}")
        sweep_data = self._fetch_sweep_data_from_api(sweep_id)
        self._save_sweep_data_to_json(sweep_data)
        return sweep_data