Pair Matching (AKA Verification)

A binary classsification task of deciding whether a pair of inputs “matches” (belongs to the same class) or not

source

ThresholdSiamese

 ThresholdSiamese (*args, distance_metric=<function
                   normalized_squared_euclidean_distance>)

Combines a deep similarity metric with a threshold to produce a binary classifier.

from fastai_datasets.all import *
pairs = Pairs(Imagenette(160), .1)
dls = pairs.dls(after_item=Resize(128),
                after_batch=Normalize.from_stats(*imagenet_stats))

Training

We can construct the backbone from a pretrained classifier:

classifier = resnet34(weights=ResNet34_Weights.DEFAULT)
siamese = ThresholdSiamese(create_body(model=classifier, cut=-1))

As explained in Threshold.fit, we can fit the threshold directly:


source

ThresholdSiamese.fit_threshold

 ThresholdSiamese.fit_threshold (train_dl:fastai.data.load.DataLoader)

Picks a threshold that maximizes the accuracy on a dataloader

siamese.fit_threshold(dls.train)

learn = Learner(dls, siamese, metrics=accuracy)
learn.validate()
(#2) [0.5470428466796875,0.8622449040412903]

We can finetune the backbone in multiple ways:

Cross Entropy

Since ThresholdSiamese is a classifier, we can finetune it using the cross entropy loss, which will fit the backbone and the threshold together:

learn.fit(5, 1e-4)
epoch train_loss valid_loss accuracy time
0 0.507261 0.415415 0.900510 00:16
1 0.401674 0.354739 0.928571 00:17
2 0.337908 0.326078 0.920918 00:10
3 0.293121 0.304959 0.931122 00:09
4 0.259705 0.288149 0.933673 00:09

Contrastive Loss

Conversely, we can finetune the backbone directly, and then fit the threshold again:

classifier = resnet34(weights=ResNet34_Weights.DEFAULT)
siamese = DistanceSiamese(create_body(model=classifier, cut=-1))
learn = Learner(dls, siamese, ContrastiveLoss())
learn.fit(5, 1e-4)
epoch train_loss valid_loss time
0 0.394097 0.317473 00:09
1 0.272054 0.211472 00:09
2 0.199547 0.171662 00:09
3 0.152234 0.149710 00:09
4 0.118441 0.140471 00:09
siamese = ThresholdSiamese(siamese.backbone)
siamese.fit_threshold(dls.train)

learn = Learner(dls, siamese, metrics=accuracy)
learn.validate()
(#2) [0.36787182092666626,0.9081632494926453]

Visualizing Results

learn.show_results()

If we plot the distance histogram, we can also see the threshold:

siamese.plot_distance_histogram(dls.train)