Skip to content

Common

get_facenet_pytorch_inception_resnet_v1(from_file=None, only_head=True, device=None, model_weights_key='model_state_dict')

Get pre-trained InceptionResnetV1 prepared for fine-tuning.

Optionally load weights from a file. By default all weights except the last linear layer are frozen. If only_head is False, all weights are trainable.

Source code in src/common/models.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_facenet_pytorch_inception_resnet_v1(
    from_file: str | None = None,
    only_head: bool = True,
    device: torch.device | None = None,
    model_weights_key: str = "model_state_dict",
) -> InceptionResnetV1:
    """Get pre-trained InceptionResnetV1 prepared for fine-tuning.

    Optionally load weights from a file.
    By default all weights except the last linear layer are frozen.
    If `only_head` is False, all weights are trainable.
    """

    model = InceptionResnetV1(pretrained="vggface2")

    if from_file is not None:
        state_dict = torch.load(from_file, weights_only=True)
        if model_weights_key in state_dict:
            state_dict = state_dict[model_weights_key]

        model.load_state_dict(state_dict)

    if only_head:
        for param in model.parameters():
            param.requires_grad = False

        for param in model.last_linear.parameters():
            param.requires_grad = True

    if device is not None:
        model.to(device)

    return model

get_facenet_pytorch_mtcnn_detector(image_size=160, margin=32, select_largest=True, device=None)

Load pretrained MTCNN detector.

Source code in src/common/models.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def get_facenet_pytorch_mtcnn_detector(
    image_size: int = 160,
    margin: int = 32,
    select_largest: bool = True,
    device: torch.device | None = None,
) -> MTCNN:
    """Load pretrained MTCNN detector."""

    return MTCNN(
        image_size=image_size,
        margin=margin,
        select_largest=select_largest,
        device=device,
    ).eval()

models

get_facenet_pytorch_inception_resnet_v1(from_file=None, only_head=True, device=None, model_weights_key='model_state_dict')

Get pre-trained InceptionResnetV1 prepared for fine-tuning.

Optionally load weights from a file. By default all weights except the last linear layer are frozen. If only_head is False, all weights are trainable.

Source code in src/common/models.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_facenet_pytorch_inception_resnet_v1(
    from_file: str | None = None,
    only_head: bool = True,
    device: torch.device | None = None,
    model_weights_key: str = "model_state_dict",
) -> InceptionResnetV1:
    """Get pre-trained InceptionResnetV1 prepared for fine-tuning.

    Optionally load weights from a file.
    By default all weights except the last linear layer are frozen.
    If `only_head` is False, all weights are trainable.
    """

    model = InceptionResnetV1(pretrained="vggface2")

    if from_file is not None:
        state_dict = torch.load(from_file, weights_only=True)
        if model_weights_key in state_dict:
            state_dict = state_dict[model_weights_key]

        model.load_state_dict(state_dict)

    if only_head:
        for param in model.parameters():
            param.requires_grad = False

        for param in model.last_linear.parameters():
            param.requires_grad = True

    if device is not None:
        model.to(device)

    return model

get_facenet_pytorch_mtcnn_detector(image_size=160, margin=32, select_largest=True, device=None)

Load pretrained MTCNN detector.

Source code in src/common/models.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def get_facenet_pytorch_mtcnn_detector(
    image_size: int = 160,
    margin: int = 32,
    select_largest: bool = True,
    device: torch.device | None = None,
) -> MTCNN:
    """Load pretrained MTCNN detector."""

    return MTCNN(
        image_size=image_size,
        margin=margin,
        select_largest=select_largest,
        device=device,
    ).eval()