Skip to content

Dataset

DummyFacePairsDataset

Bases: FacePairsDataset

Dummy in-memory dataset for testing

Source code in src/dataset/dummy.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DummyFacePairsDataset(FacePairsDataset):
    """Dummy in-memory dataset for testing"""

    def __init__(self, transform):
        self.transform = transform

        # 60 pairs: 30 match (same color), 30 mismatch (different colors)
        self.pairs = []
        for i in range(30):
            color = (i * 8, 128, 200)
            img = Image.new("RGB", (160, 160), color)
            self.pairs.append((img, img.copy(), 1))  # match

        for i in range(30):
            img1 = Image.new("RGB", (160, 160), (i * 8, 100, 150))
            img2 = Image.new("RGB", (160, 160), (255 - i * 8, 200, 50))
            self.pairs.append((img1, img2, -1))  # mismatch

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1, img2, label = self.pairs[idx]
        return self.transform(img1), self.transform(img2), label

FacePairsDataset

Bases: Dataset, Sized, ABC

Interface for use with Face Pairs Benchmark

Source code in src/dataset/face_pairs.py
 7
 8
 9
10
11
12
13
14
class FacePairsDataset(Dataset, Sized, ABC):
    """Interface for use with Face Pairs Benchmark"""

    @abstractmethod
    def __len__(self): ...

    @abstractmethod
    def __getitem__(self, idx) -> tuple[Any, Any, int]: ...

LFWDataset

Bases: FacePairsDataset

Torch Dataset for Labeled Faces in the Wild (LFW)

https://www.kaggle.com/datasets/jessicali9530/lfw-dataset

Source code in src/dataset/lfw.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class LFWDataset(FacePairsDataset):
    """Torch Dataset for Labeled Faces in the Wild (LFW)

    https://www.kaggle.com/datasets/jessicali9530/lfw-dataset
    """

    def __init__(
        self, root_dir: str, pairs: list[tuple], transform_1=None, transform_2=None
    ):
        self.root_dir = root_dir
        self.pairs = pairs
        self.transform_1 = transform_1
        self.transform_2 = transform_2

    @staticmethod
    def make_file_path(name: str, img_num: int) -> str:
        """Relative to root dir"""
        return f"{name}/{name}_{img_num:04d}.jpg"

    @staticmethod
    def load_image(path: str):
        return Image.open(path).convert("RGB")

    @classmethod
    def from_match_and_mismatch_pairs(
        cls,
        root_dir: str,
        match_pairs_file: str,
        mismatch_pairs_file: str,
        transform_1=None,
        transform_2=None,
    ):
        match_pairs = []
        mismatch_pairs = []

        with open(match_pairs_file, "r") as f:
            f.readline()  # skip header
            for line in f:
                name, img_num_1, img_num_2 = line.strip().split(",")
                img_path_1 = cls.make_file_path(name, int(img_num_1))
                img_path_2 = cls.make_file_path(name, int(img_num_2))
                match_pairs.append((img_path_1, img_path_2, 1))  # 1 for same person

        with open(mismatch_pairs_file, "r") as f:
            f.readline()  # skip header
            for line in f:
                name_1, img_num_1, name_2, img_num_2 = line.strip().split(",")
                img_path_1 = cls.make_file_path(name_1, int(img_num_1))
                img_path_2 = cls.make_file_path(name_2, int(img_num_2))
                mismatch_pairs.append(
                    (img_path_1, img_path_2, -1)
                )  # -1 for different people

        pairs = match_pairs + mismatch_pairs
        return cls(root_dir, pairs, transform_1, transform_2)

    @classmethod
    def test_set_from_pairs_file(
        cls, root_dir: str, pairs_file: str, transform_1=None, transform_2=None
    ):
        n_total_examples = 6000  # Total pairs in the file
        pairs = []

        with open(pairs_file, "r") as f:
            f.readline()  # skip header
            for _ in range(n_total_examples):
                line = f.readline().strip().split(",")

                if line[-1] == "":  # Match pair
                    name, img_num_1, img_num_2, _ = line
                    img_path_1 = cls.make_file_path(name, int(img_num_1))
                    img_path_2 = cls.make_file_path(name, int(img_num_2))
                    pairs.append((img_path_1, img_path_2, 1))  # 1 for same person
                else:  # Mismatch pair
                    name_1, img_num_1, name_2, img_num_2 = line
                    img_path_1 = cls.make_file_path(name_1, int(img_num_1))
                    img_path_2 = cls.make_file_path(name_2, int(img_num_2))
                    pairs.append(
                        (img_path_1, img_path_2, -1)
                    )  # -1 for different people

        return cls(root_dir, pairs, transform_1, transform_2)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_name, img2_name, label = self.pairs[idx]
        img1_path = os.path.join(self.root_dir, img1_name)
        img2_path = os.path.join(self.root_dir, img2_name)

        img1 = self.load_image(img1_path)
        img2 = self.load_image(img2_path)

        if self.transform_1:
            img1 = self.transform_1(img1)
        if self.transform_2:
            img2 = self.transform_2(img2)

        return img1, img2, label

make_file_path(name, img_num) staticmethod

Relative to root dir

Source code in src/dataset/lfw.py
22
23
24
25
@staticmethod
def make_file_path(name: str, img_num: int) -> str:
    """Relative to root dir"""
    return f"{name}/{name}_{img_num:04d}.jpg"

ROFDataset

Bases: FacePairsDataset

Torch dataset class for the Real World Occluded Faces dataset

https://github.com/ekremerakin/RealWorldOccludedFaces

Source code in src/dataset/rof.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class ROFDataset(FacePairsDataset):
    """Torch dataset class for the Real World Occluded Faces dataset

    https://github.com/ekremerakin/RealWorldOccludedFaces
    """

    def __init__(
        self,
        root_dir: str,
        pairs: list[RofItem],
        occlusion: RofVariant,
        transform: Callable | None,
    ):
        if occlusion == RofVariant.NEUTRAL:
            raise ValueError(
                "ROFDataset requires an occlusion variant (MASKED or SUNGLASSES)."
            )

        self.root_dir = root_dir
        self.pairs = pairs
        self.occlusion = occlusion
        self.transform = transform

    @classmethod
    def create(
        cls,
        root_dir: str,
        pairs_file: str,
        occlusion: RofVariant,
        transform: Callable | None,
    ):
        raw_pairs = cls._parse_pairs_file(pairs_file)

        occlusion_suffix = occlusion.identity_suffix()

        neutral_dir = os.path.join(root_dir, RofVariant.NEUTRAL.subdirectory())
        occluded_dir = os.path.join(root_dir, occlusion.subdirectory())

        pairs = []
        for raw_pair in raw_pairs:
            neutral_image_path = os.path.join(
                neutral_dir, raw_pair.identity_neutral, raw_pair.neutral_filename
            )
            occluded_image_path = os.path.join(
                occluded_dir,
                f"{raw_pair.identity_occluded}{occlusion_suffix}",
                raw_pair.occluded_filename,
            )

            pairs.append(
                RofItem(
                    neutral_file_path=neutral_image_path,
                    occluded_file_path=occluded_image_path,
                    label=raw_pair.label(),
                )
            )

        return cls(root_dir, pairs, occlusion, transform)

    @classmethod
    def sunglasses(
        cls,
        transform: Callable | None,
        root_dir: str = "data/rof",
        pairs_file: str = "data/rof/pairs_sunglasses.csv",
    ):
        return cls.create(root_dir, pairs_file, RofVariant.SUNGLASSES, transform)

    @classmethod
    def masked(
        cls,
        transform: Callable | None,
        root_dir: str = "data/rof",
        pairs_file: str = "data/rof/pairs_masked.csv",
    ):
        return cls.create(root_dir, pairs_file, RofVariant.MASKED, transform)

    @staticmethod
    def _load_image(path: str):
        return Image.open(path).convert("RGB")

    @staticmethod
    def _parse_pairs_file(pairs_file: str) -> list[RofPair]:
        with open(pairs_file, "r") as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            return [
                RofPair(
                    identity_neutral=row[0],
                    identity_occluded=row[2],
                    neutral_filename=row[1],
                    occluded_filename=row[3],
                )
                for row in reader
            ]

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx) -> tuple[Any, Any, int]:
        """Returns a pair of images and their label (neutral image, occluded image, label)."""
        item = self.pairs[idx]
        img1 = self._load_image(item.neutral_file_path)
        img2 = self._load_image(item.occluded_file_path)

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, item.label

__getitem__(idx)

Returns a pair of images and their label (neutral image, occluded image, label).

Source code in src/dataset/rof.py
166
167
168
169
170
171
172
173
174
175
176
def __getitem__(self, idx) -> tuple[Any, Any, int]:
    """Returns a pair of images and their label (neutral image, occluded image, label)."""
    item = self.pairs[idx]
    img1 = self._load_image(item.neutral_file_path)
    img2 = self._load_image(item.occluded_file_path)

    if self.transform:
        img1 = self.transform(img1)
        img2 = self.transform(img2)

    return img1, img2, item.label

dummy

DummyFacePairsDataset

Bases: FacePairsDataset

Dummy in-memory dataset for testing

Source code in src/dataset/dummy.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DummyFacePairsDataset(FacePairsDataset):
    """Dummy in-memory dataset for testing"""

    def __init__(self, transform):
        self.transform = transform

        # 60 pairs: 30 match (same color), 30 mismatch (different colors)
        self.pairs = []
        for i in range(30):
            color = (i * 8, 128, 200)
            img = Image.new("RGB", (160, 160), color)
            self.pairs.append((img, img.copy(), 1))  # match

        for i in range(30):
            img1 = Image.new("RGB", (160, 160), (i * 8, 100, 150))
            img2 = Image.new("RGB", (160, 160), (255 - i * 8, 200, 50))
            self.pairs.append((img1, img2, -1))  # mismatch

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1, img2, label = self.pairs[idx]
        return self.transform(img1), self.transform(img2), label

face_pairs

FacePairsDataset

Bases: Dataset, Sized, ABC

Interface for use with Face Pairs Benchmark

Source code in src/dataset/face_pairs.py
 7
 8
 9
10
11
12
13
14
class FacePairsDataset(Dataset, Sized, ABC):
    """Interface for use with Face Pairs Benchmark"""

    @abstractmethod
    def __len__(self): ...

    @abstractmethod
    def __getitem__(self, idx) -> tuple[Any, Any, int]: ...

lfw

LFWDataset

Bases: FacePairsDataset

Torch Dataset for Labeled Faces in the Wild (LFW)

https://www.kaggle.com/datasets/jessicali9530/lfw-dataset

Source code in src/dataset/lfw.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class LFWDataset(FacePairsDataset):
    """Torch Dataset for Labeled Faces in the Wild (LFW)

    https://www.kaggle.com/datasets/jessicali9530/lfw-dataset
    """

    def __init__(
        self, root_dir: str, pairs: list[tuple], transform_1=None, transform_2=None
    ):
        self.root_dir = root_dir
        self.pairs = pairs
        self.transform_1 = transform_1
        self.transform_2 = transform_2

    @staticmethod
    def make_file_path(name: str, img_num: int) -> str:
        """Relative to root dir"""
        return f"{name}/{name}_{img_num:04d}.jpg"

    @staticmethod
    def load_image(path: str):
        return Image.open(path).convert("RGB")

    @classmethod
    def from_match_and_mismatch_pairs(
        cls,
        root_dir: str,
        match_pairs_file: str,
        mismatch_pairs_file: str,
        transform_1=None,
        transform_2=None,
    ):
        match_pairs = []
        mismatch_pairs = []

        with open(match_pairs_file, "r") as f:
            f.readline()  # skip header
            for line in f:
                name, img_num_1, img_num_2 = line.strip().split(",")
                img_path_1 = cls.make_file_path(name, int(img_num_1))
                img_path_2 = cls.make_file_path(name, int(img_num_2))
                match_pairs.append((img_path_1, img_path_2, 1))  # 1 for same person

        with open(mismatch_pairs_file, "r") as f:
            f.readline()  # skip header
            for line in f:
                name_1, img_num_1, name_2, img_num_2 = line.strip().split(",")
                img_path_1 = cls.make_file_path(name_1, int(img_num_1))
                img_path_2 = cls.make_file_path(name_2, int(img_num_2))
                mismatch_pairs.append(
                    (img_path_1, img_path_2, -1)
                )  # -1 for different people

        pairs = match_pairs + mismatch_pairs
        return cls(root_dir, pairs, transform_1, transform_2)

    @classmethod
    def test_set_from_pairs_file(
        cls, root_dir: str, pairs_file: str, transform_1=None, transform_2=None
    ):
        n_total_examples = 6000  # Total pairs in the file
        pairs = []

        with open(pairs_file, "r") as f:
            f.readline()  # skip header
            for _ in range(n_total_examples):
                line = f.readline().strip().split(",")

                if line[-1] == "":  # Match pair
                    name, img_num_1, img_num_2, _ = line
                    img_path_1 = cls.make_file_path(name, int(img_num_1))
                    img_path_2 = cls.make_file_path(name, int(img_num_2))
                    pairs.append((img_path_1, img_path_2, 1))  # 1 for same person
                else:  # Mismatch pair
                    name_1, img_num_1, name_2, img_num_2 = line
                    img_path_1 = cls.make_file_path(name_1, int(img_num_1))
                    img_path_2 = cls.make_file_path(name_2, int(img_num_2))
                    pairs.append(
                        (img_path_1, img_path_2, -1)
                    )  # -1 for different people

        return cls(root_dir, pairs, transform_1, transform_2)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img1_name, img2_name, label = self.pairs[idx]
        img1_path = os.path.join(self.root_dir, img1_name)
        img2_path = os.path.join(self.root_dir, img2_name)

        img1 = self.load_image(img1_path)
        img2 = self.load_image(img2_path)

        if self.transform_1:
            img1 = self.transform_1(img1)
        if self.transform_2:
            img2 = self.transform_2(img2)

        return img1, img2, label

make_file_path(name, img_num) staticmethod

Relative to root dir

Source code in src/dataset/lfw.py
22
23
24
25
@staticmethod
def make_file_path(name: str, img_num: int) -> str:
    """Relative to root dir"""
    return f"{name}/{name}_{img_num:04d}.jpg"

rococo

RococoDataset dataclass

Representation of the Rococo dataset

Only stores filenames and information on the directory structure, does not interact with the filesystem

The dataset is structured as follows:

<root_dir>
    ├── <faces_dir>
    │   ├── frame_1_face_1.jpg
    │   ├── frame_1_face_2.jpg
    │   └── ...
    └── <frames_dir>
        ├── frame_1_face_1_frame_2.jpg
        ├── frame_1_face_2_frame_3.jpg
        └── ...

Elements are pairs of (face, frame sequence), only those with at least 3 frames (not all faces). All faces are needed for search and comparison in the validation procedure.

Face files are matched with corresponding frame files based on filename prefix. The filenames are misleading - face is identified by frame_X_face_Y, frames corresponding to the same face are identified by frame_X_face_Y_frame_Z and differ only by Z.

Also supports a dataset with multiple sequences of frames for the same face. The filenames of frames are in the format frame_X_face_Y_sequence_Z_frame_W. If there are multiple sequences, there will be multiple elements with the same face (one element for each sequence).

Source code in src/dataset/rococo.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@dataclass
class RococoDataset:
    """Representation of the Rococo dataset

    Only stores filenames and information on the directory structure, does not interact with the filesystem

    The dataset is structured as follows:

    ```
    <root_dir>
        ├── <faces_dir>
        │   ├── frame_1_face_1.jpg
        │   ├── frame_1_face_2.jpg
        │   └── ...
        └── <frames_dir>
            ├── frame_1_face_1_frame_2.jpg
            ├── frame_1_face_2_frame_3.jpg
            └── ...
    ```

    Elements are pairs of (face, frame sequence), only those with at least 3 frames (not all faces).
    All faces are needed for search and comparison in the validation procedure.

    Face files are matched with corresponding frame files based on filename prefix.
    The filenames are misleading - face is identified by `frame_X_face_Y`,
    frames corresponding to the same face are identified by `frame_X_face_Y_frame_Z` and differ only by `Z`.

    Also supports a dataset with multiple sequences of frames for the same face.
    The filenames of frames are in the format `frame_X_face_Y_sequence_Z_frame_W`.
    If there are multiple sequences, there will be multiple elements with the same face (one element for each sequence).
    """

    elements: list[tuple[Face, FrameSequence]]
    all_faces: list[Face]
    root_dir: str
    faces_dir: str = "faces"
    frames_dir: str = "frames"

    @classmethod
    def from_directory(
        cls, root_dir: str, faces_dir: str = "faces", frames_dir: str = "frames"
    ) -> Self:
        faces_path = os.path.join(root_dir, faces_dir)
        frames_path = os.path.join(root_dir, frames_dir)

        face_filenames = sorted(list(os.listdir(faces_path)))
        frame_filenames = sorted(list(os.listdir(frames_path)))

        faces = {filename: Face.from_filename(filename) for filename in face_filenames}
        frames = {
            filename: Frame.from_filename(filename) for filename in frame_filenames
        }

        face_frames = cls.get_matching_frames_for_faces(faces, frames)

        elements = [
            (faces[face_filename], face_frames[face_filename])
            for face_filename in face_frames.keys()
        ]

        all_faces = cls.get_all_faces(faces_dict=faces, elements=elements)
        all_faces = sorted(all_faces, key=lambda face: face.filename)

        return cls(
            elements=elements,
            all_faces=all_faces,
            root_dir=root_dir,
            faces_dir=faces_dir,
            frames_dir=frames_dir,
        )

    @classmethod
    def from_directory_with_multiple_sequences(
        cls, root_dir: str, faces_dir: str = "faces", frames_dir: str = "frames"
    ) -> Self:
        """Create a RococoDataset with multiple sequences of frames for the same face.

        For loading Rococo-synth dataset

        The filenames of frames are in the format `frame_X_face_Y_sequence_Z_frame_W`.
        There will be multiple elements with the same face (one element for each sequence).
        """
        faces_path = os.path.join(root_dir, faces_dir)
        frames_path = os.path.join(root_dir, frames_dir)

        face_filenames = sorted(list(os.listdir(faces_path)))
        frame_filenames = sorted(list(os.listdir(frames_path)))

        faces = {filename: Face.from_filename(filename) for filename in face_filenames}
        frames = {
            filename: Frame.from_filename(filename) for filename in frame_filenames
        }

        face_sequences = cls.get_face_to_sequences_mapping(faces, frames)
        elements = [
            (faces[face_filename], sequence)
            for face_filename in face_sequences.keys()
            for sequence in face_sequences[face_filename].values()
        ]

        all_faces = cls.get_all_faces(faces_dict=faces, elements=elements)
        all_faces = sorted(all_faces, key=lambda face: face.filename)

        return cls(
            elements=elements,
            all_faces=all_faces,
            root_dir=root_dir,
            faces_dir=faces_dir,
            frames_dir=frames_dir,
        )

    def face_abspath(self, face: Face) -> str:
        return os.path.join(self.root_dir, self.faces_dir, face.filename)

    def frame_abspath(self, frame: Frame) -> str:
        return os.path.join(self.root_dir, self.frames_dir, frame.filename)

    def face_index(self, face: Face) -> int:
        """Return the position of given face in all_faces"""
        return self.all_faces.index(face)

    def __getitem__(self, index: int) -> tuple[Face, list[Frame]]:
        return self.elements[index]

    def __len__(self) -> int:
        return len(self.elements)

    @staticmethod
    def get_matching_frames_for_faces(
        faces: dict[str, Face], frames: dict[str, Frame], minimum_frames: int = 3
    ) -> dict[str, FrameSequence]:
        faces_with_enough_frames = [
            face
            for face in faces.values()
            if len(
                [
                    frame
                    for frame in frames.values()
                    if frame.face_number == face.face_number
                    and frame.frame_number_1 == face.frame_number_1
                ]
            )
            >= minimum_frames
        ]

        return {
            face.filename: [
                frame
                for frame in frames.values()
                if frame.face_number == face.face_number
                and frame.frame_number_1 == face.frame_number_1
            ]
            for face in faces_with_enough_frames
        }

    @staticmethod
    def get_face_to_sequences_mapping(
        faces: dict[str, Face], frames: dict[str, Frame]
    ) -> dict[str, dict[int, FrameSequence]]:
        """Return a mapping of face filenames to sequences of frames"""
        face_to_sequences = {}
        for frame in frames.values():
            assert frame.sequence is not None

            face_filename = Face.make_filename(frame.frame_number_1, frame.face_number)
            if face_filename not in face_to_sequences:
                face_to_sequences[face_filename] = {}

            if frame.sequence not in face_to_sequences[face_filename]:
                face_to_sequences[face_filename][frame.sequence] = []

            face_to_sequences[face_filename][frame.sequence].append(frame)

        assert len(face_to_sequences) == len(
            faces
        ), "Not all faces have matching sequences"
        assert all(
            len(sequence) >= 3
            for sequences_mapping in face_to_sequences.values()
            for sequence in sequences_mapping.values()
        ), "Not all sequences have enough frames"

        # Ensure that the sequences are sorted by frame number
        for sequences_mapping in face_to_sequences.values():
            for sequence in sequences_mapping.values():
                sequence.sort(key=lambda frame: frame.frame_number_2)

        return face_to_sequences

    @staticmethod
    def get_all_faces(
        faces_dict: dict[str, Face], elements: list[tuple[Face, FrameSequence]]
    ) -> list[Face]:
        """Return a list of faces for searching matches

        All faces present in the `faces` directory
        """
        return list(faces_dict.values())

face_index(face)

Return the position of given face in all_faces

Source code in src/dataset/rococo.py
209
210
211
def face_index(self, face: Face) -> int:
    """Return the position of given face in all_faces"""
    return self.all_faces.index(face)

from_directory_with_multiple_sequences(root_dir, faces_dir='faces', frames_dir='frames') classmethod

Create a RococoDataset with multiple sequences of frames for the same face.

For loading Rococo-synth dataset

The filenames of frames are in the format frame_X_face_Y_sequence_Z_frame_W. There will be multiple elements with the same face (one element for each sequence).

Source code in src/dataset/rococo.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@classmethod
def from_directory_with_multiple_sequences(
    cls, root_dir: str, faces_dir: str = "faces", frames_dir: str = "frames"
) -> Self:
    """Create a RococoDataset with multiple sequences of frames for the same face.

    For loading Rococo-synth dataset

    The filenames of frames are in the format `frame_X_face_Y_sequence_Z_frame_W`.
    There will be multiple elements with the same face (one element for each sequence).
    """
    faces_path = os.path.join(root_dir, faces_dir)
    frames_path = os.path.join(root_dir, frames_dir)

    face_filenames = sorted(list(os.listdir(faces_path)))
    frame_filenames = sorted(list(os.listdir(frames_path)))

    faces = {filename: Face.from_filename(filename) for filename in face_filenames}
    frames = {
        filename: Frame.from_filename(filename) for filename in frame_filenames
    }

    face_sequences = cls.get_face_to_sequences_mapping(faces, frames)
    elements = [
        (faces[face_filename], sequence)
        for face_filename in face_sequences.keys()
        for sequence in face_sequences[face_filename].values()
    ]

    all_faces = cls.get_all_faces(faces_dict=faces, elements=elements)
    all_faces = sorted(all_faces, key=lambda face: face.filename)

    return cls(
        elements=elements,
        all_faces=all_faces,
        root_dir=root_dir,
        faces_dir=faces_dir,
        frames_dir=frames_dir,
    )

get_all_faces(faces_dict, elements) staticmethod

Return a list of faces for searching matches

All faces present in the faces directory

Source code in src/dataset/rococo.py
281
282
283
284
285
286
287
288
289
@staticmethod
def get_all_faces(
    faces_dict: dict[str, Face], elements: list[tuple[Face, FrameSequence]]
) -> list[Face]:
    """Return a list of faces for searching matches

    All faces present in the `faces` directory
    """
    return list(faces_dict.values())

get_face_to_sequences_mapping(faces, frames) staticmethod

Return a mapping of face filenames to sequences of frames

Source code in src/dataset/rococo.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@staticmethod
def get_face_to_sequences_mapping(
    faces: dict[str, Face], frames: dict[str, Frame]
) -> dict[str, dict[int, FrameSequence]]:
    """Return a mapping of face filenames to sequences of frames"""
    face_to_sequences = {}
    for frame in frames.values():
        assert frame.sequence is not None

        face_filename = Face.make_filename(frame.frame_number_1, frame.face_number)
        if face_filename not in face_to_sequences:
            face_to_sequences[face_filename] = {}

        if frame.sequence not in face_to_sequences[face_filename]:
            face_to_sequences[face_filename][frame.sequence] = []

        face_to_sequences[face_filename][frame.sequence].append(frame)

    assert len(face_to_sequences) == len(
        faces
    ), "Not all faces have matching sequences"
    assert all(
        len(sequence) >= 3
        for sequences_mapping in face_to_sequences.values()
        for sequence in sequences_mapping.values()
    ), "Not all sequences have enough frames"

    # Ensure that the sequences are sorted by frame number
    for sequences_mapping in face_to_sequences.values():
        for sequence in sequences_mapping.values():
            sequence.sort(key=lambda frame: frame.frame_number_2)

    return face_to_sequences

RococoSmallDataset dataclass

Bases: RococoDataset

Rococo dataset with a narrowed set of all_faces

The set of all_faces (for searching matches) is narrowed to only those, for which a sequence of frames is present in the dataset.

len(all_faces) == len(elements)

Source code in src/dataset/rococo.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
class RococoSmallDataset(RococoDataset):
    """Rococo dataset with a narrowed set of all_faces

    The set of all_faces (for searching matches) is narrowed to only those, for which a sequence of frames
    is present in the dataset.

    `len(all_faces) == len(elements)`
    """

    @override
    @staticmethod
    def get_all_faces(
        faces_dict: dict[str, Face], elements: list[tuple[Face, list[Frame]]]
    ) -> list[Face]:
        """Return a list of faces for searching matches

        Only those faces are returned, for which a sequence of frames is present in the dataset.
        """
        return [face for face, _ in elements]

get_all_faces(faces_dict, elements) staticmethod

Return a list of faces for searching matches

Only those faces are returned, for which a sequence of frames is present in the dataset.

Source code in src/dataset/rococo.py
301
302
303
304
305
306
307
308
309
310
@override
@staticmethod
def get_all_faces(
    faces_dict: dict[str, Face], elements: list[tuple[Face, list[Frame]]]
) -> list[Face]:
    """Return a list of faces for searching matches

    Only those faces are returned, for which a sequence of frames is present in the dataset.
    """
    return [face for face, _ in elements]

rof

ROFDataset

Bases: FacePairsDataset

Torch dataset class for the Real World Occluded Faces dataset

https://github.com/ekremerakin/RealWorldOccludedFaces

Source code in src/dataset/rof.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class ROFDataset(FacePairsDataset):
    """Torch dataset class for the Real World Occluded Faces dataset

    https://github.com/ekremerakin/RealWorldOccludedFaces
    """

    def __init__(
        self,
        root_dir: str,
        pairs: list[RofItem],
        occlusion: RofVariant,
        transform: Callable | None,
    ):
        if occlusion == RofVariant.NEUTRAL:
            raise ValueError(
                "ROFDataset requires an occlusion variant (MASKED or SUNGLASSES)."
            )

        self.root_dir = root_dir
        self.pairs = pairs
        self.occlusion = occlusion
        self.transform = transform

    @classmethod
    def create(
        cls,
        root_dir: str,
        pairs_file: str,
        occlusion: RofVariant,
        transform: Callable | None,
    ):
        raw_pairs = cls._parse_pairs_file(pairs_file)

        occlusion_suffix = occlusion.identity_suffix()

        neutral_dir = os.path.join(root_dir, RofVariant.NEUTRAL.subdirectory())
        occluded_dir = os.path.join(root_dir, occlusion.subdirectory())

        pairs = []
        for raw_pair in raw_pairs:
            neutral_image_path = os.path.join(
                neutral_dir, raw_pair.identity_neutral, raw_pair.neutral_filename
            )
            occluded_image_path = os.path.join(
                occluded_dir,
                f"{raw_pair.identity_occluded}{occlusion_suffix}",
                raw_pair.occluded_filename,
            )

            pairs.append(
                RofItem(
                    neutral_file_path=neutral_image_path,
                    occluded_file_path=occluded_image_path,
                    label=raw_pair.label(),
                )
            )

        return cls(root_dir, pairs, occlusion, transform)

    @classmethod
    def sunglasses(
        cls,
        transform: Callable | None,
        root_dir: str = "data/rof",
        pairs_file: str = "data/rof/pairs_sunglasses.csv",
    ):
        return cls.create(root_dir, pairs_file, RofVariant.SUNGLASSES, transform)

    @classmethod
    def masked(
        cls,
        transform: Callable | None,
        root_dir: str = "data/rof",
        pairs_file: str = "data/rof/pairs_masked.csv",
    ):
        return cls.create(root_dir, pairs_file, RofVariant.MASKED, transform)

    @staticmethod
    def _load_image(path: str):
        return Image.open(path).convert("RGB")

    @staticmethod
    def _parse_pairs_file(pairs_file: str) -> list[RofPair]:
        with open(pairs_file, "r") as f:
            reader = csv.reader(f)
            next(reader)  # Skip header
            return [
                RofPair(
                    identity_neutral=row[0],
                    identity_occluded=row[2],
                    neutral_filename=row[1],
                    occluded_filename=row[3],
                )
                for row in reader
            ]

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx) -> tuple[Any, Any, int]:
        """Returns a pair of images and their label (neutral image, occluded image, label)."""
        item = self.pairs[idx]
        img1 = self._load_image(item.neutral_file_path)
        img2 = self._load_image(item.occluded_file_path)

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2, item.label

__getitem__(idx)

Returns a pair of images and their label (neutral image, occluded image, label).

Source code in src/dataset/rof.py
166
167
168
169
170
171
172
173
174
175
176
def __getitem__(self, idx) -> tuple[Any, Any, int]:
    """Returns a pair of images and their label (neutral image, occluded image, label)."""
    item = self.pairs[idx]
    img1 = self._load_image(item.neutral_file_path)
    img2 = self._load_image(item.occluded_file_path)

    if self.transform:
        img1 = self.transform(img1)
        img2 = self.transform(img2)

    return img1, img2, item.label

RofItem dataclass

Represents a single item in the ROFDataset

Label is 1 for matched pairs (neutral and occluded of the same person), -1 for mismatched pairs (neutral and occluded of different people).

File paths relative to the root directory.

Source code in src/dataset/rof.py
52
53
54
55
56
57
58
59
60
61
62
63
64
@dataclass(frozen=True)
class RofItem:
    """Represents a single item in the ROFDataset

    Label is 1 for matched pairs (neutral and occluded of the same person),
    -1 for mismatched pairs (neutral and occluded of different people).

    File paths relative to the root directory.
    """

    neutral_file_path: str
    occluded_file_path: str
    label: int

RofVariant

Bases: Enum

Source code in src/dataset/rof.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class RofVariant(Enum):
    NEUTRAL = "neutral"
    MASKED = "masked"
    SUNGLASSES = "sunglasses"

    def identity_suffix(self) -> str:
        """Returns the suffix for the identity name in the dataset."""
        match self:
            case RofVariant.NEUTRAL:
                return ""
            case RofVariant.MASKED:
                return "_wearing_mask"
            case RofVariant.SUNGLASSES:
                return "_wearing_sunglasses"

    def subdirectory(self) -> str:
        """Returns the subdirectory name for the variant."""
        match self:
            case RofVariant.NEUTRAL:
                return "neutral"
            case RofVariant.MASKED:
                return "masked"
            case RofVariant.SUNGLASSES:
                return "sunglasses"

identity_suffix()

Returns the suffix for the identity name in the dataset.

Source code in src/dataset/rof.py
17
18
19
20
21
22
23
24
25
def identity_suffix(self) -> str:
    """Returns the suffix for the identity name in the dataset."""
    match self:
        case RofVariant.NEUTRAL:
            return ""
        case RofVariant.MASKED:
            return "_wearing_mask"
        case RofVariant.SUNGLASSES:
            return "_wearing_sunglasses"

subdirectory()

Returns the subdirectory name for the variant.

Source code in src/dataset/rof.py
27
28
29
30
31
32
33
34
35
def subdirectory(self) -> str:
    """Returns the subdirectory name for the variant."""
    match self:
        case RofVariant.NEUTRAL:
            return "neutral"
        case RofVariant.MASKED:
            return "masked"
        case RofVariant.SUNGLASSES:
            return "sunglasses"