Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions Port_Weights_Assert.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gGWhoJ0axBQz"
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('./assert')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MuiAFtFUxIFD"
},
"outputs": [],
"source": [
"import torch\n",
"from model import E2E"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9VnaBzts2BK4"
},
"outputs": [],
"source": [
"import itertools\n",
"protocols = ['pa', 'la']\n",
"networks = ['attentive_filtering_network', 'dilated_resnet', 'senet34', 'senet50']\n",
"all_networks = list(itertools.product(protocols, networks))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5bt6169p4nbv"
},
"outputs": [],
"source": [
"\n",
"def port_weights(protocol, network):\n",
" models_dict = {'attentive_filtering_network': 5, 'dilated_resnet': 1, 'senet34':7, \n",
" 'senet50': 6}\n",
" model_params = {\n",
" 'MODEL_SELECT' : models_dict[network], # which model \n",
" 'NUM_SPOOF_CLASS' : 2, # x-class classification\n",
" 'FOCAL_GAMMA' : None, # gamma parameter for focal loss; if obj is not focal loss, set this to None \n",
" 'NUM_RESNET_BLOCK' : 5, # number of resnet blocks in ResNet \n",
" 'AFN_UPSAMPLE' : 'Bilinear', # upsampling method in AFNet: Conv or Bilinear\n",
" 'AFN_ACTIVATION' : 'sigmoid', # activation function in AFNet: sigmoid, softmaxF, softmaxT\n",
" 'NUM_HEADS' : 3, # number of heads for multi-head att in SAFNet \n",
" 'SAFN_HIDDEN' : 10, # hidden dim for SAFNet\n",
" 'SAFN_DIM' : 'T', # SAFNet attention dim: T or F\n",
" 'RNN_HIDDEN' : 128, # hidden dim for RNN\n",
" 'RNN_LAYERS' : 4, # number of hidden layers for RNN\n",
" 'RNN_BI': True, # bidirecitonal/unidirectional for RNN\n",
" 'DROPOUT_R' : 0.0, # dropout rate \n",
" }\n",
" model = E2E(**model_params)\n",
" pa_weights = torch.load(f'./ASSERT/pretrained/{protocol}/{network}', map_location='cpu', encoding='bytes')\n",
" # Convert the first level keys.\n",
" data_dict = dict(pa_weights)\n",
" for key in list(data_dict):\n",
" if type(key) is bytes:\n",
" data_dict[key.decode()] = data_dict[key]\n",
" data_dict.pop(key)\n",
" data_dict['state_dict'] = dict(data_dict['state_dict'])\n",
" for key in list(data_dict['state_dict']):\n",
" if type(key) is bytes:\n",
" data_dict['state_dict'][key.decode()] = data_dict['state_dict'][key]\n",
" data_dict['state_dict'].pop(key)\n",
" model.load_state_dict(data_dict['state_dict'])\n",
" torch.save(data_dict, f'./ASSERT/pretrained/{protocol}/{network}.py3.ckpt')\n",
" print(f\"Ported {network} - {protocol}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LFH4ZAwB7yAE",
"outputId": "3ab0a784-8d91-4a8e-ad54-ba2ec1c8ff3e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"attentive filtering network\n",
"Ported attentive_filtering_network - pa\n",
"resnet\n",
"Ported dilated_resnet - pa\n",
"squeeze-and-excitation network\n",
"Ported senet34 - pa\n",
"squeeze-and-excitation network\n",
"Ported senet50 - pa\n",
"attentive filtering network\n",
"Ported attentive_filtering_network - la\n",
"resnet\n",
"Ported dilated_resnet - la\n",
"squeeze-and-excitation network\n",
"Ported senet34 - la\n",
"squeeze-and-excitation network\n",
"Ported senet50 - la\n"
]
}
],
"source": [
"for _p, _n in all_networks:\n",
" port_weights(_p, _n)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "forensic_examiner_audio",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]"
},
"vscode": {
"interpreter": {
"hash": "e2b5310373df8c4f0bc118e06d390d9464bd5fe0a9f4e308bd14694ffbb1bd37"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
10 changes: 9 additions & 1 deletion assert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@ def E2E(MODEL_SELECT, NUM_SPOOF_CLASS, NUM_RESNET_BLOCK, AFN_UPSAMPLE, AFN_ACTIV
elif MODEL_SELECT == 5:
print('attentive filtering network')
model = attentive_filtering_network.SpoofSmallAFNet257_400(NUM_SPOOF_CLASS, AFN_UPSAMPLE, AFN_ACTIVATION, NUM_RESNET_BLOCK, FOCAL_LOSS)

elif MODEL_SELECT == 6:
print('squeeze-and-excitation network')
#model = senet.se_resnet18(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
#model = senet.se_resnet34(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
# model = senet.se_resnet34(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
model = senet.se_resnet50(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
#model = senet.se_resnet101(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
#model = senet.se_resnet152(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
elif MODEL_SELECT == 7:
print('squeeze-and-excitation network')
#model = senet.se_resnet18(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
model = senet.se_resnet34(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
# model = senet.se_resnet50(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
#model = senet.se_resnet101(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)
#model = senet.se_resnet152(num_classes=NUM_SPOOF_CLASS, focal_loss=FOCAL_LOSS)

return model

Expand Down
Loading