diff --git a/__pycache__/utils.cpython-37.pyc b/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..43bb906 Binary files /dev/null and b/__pycache__/utils.cpython-37.pyc differ diff --git a/examples/.ipynb_checkpoints/.ipynb-checkpoint b/examples/.ipynb_checkpoints/.ipynb-checkpoint new file mode 100644 index 0000000..ab63616 --- /dev/null +++ b/examples/.ipynb_checkpoints/.ipynb-checkpoint @@ -0,0 +1,305 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import embeddings for checking. You can also import skip-thought and fasttext-idf" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import randsent as embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run bow embeddings for the task SST2. If you want to check embeddings for other tasks, change it in bow.py \"transfer_tasks\".\n", + "Available tasks: SST2, SST3, MRPC, ReadabilityCl, TagCl, PoemsCl, TREC, STS, SICK" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'SST2': {'devacc': 62.54,\n", + " 'acc': 62.42,\n", + " 'ndev': 18901,\n", + " 'ntest': 37804,\n", + " 'time': 265.33350253105164},\n", + " 'SST3': {'devacc': 86.57,\n", + " 'acc': 86.67,\n", + " 'ndev': 4831,\n", + " 'ntest': 9662,\n", + " 'time': 72.24070620536804},\n", + " 'MRPC': {'devacc': 99.77,\n", + " 'acc': 99.52,\n", + " 'f1': 99.74,\n", + " 'ndev': 35635,\n", + " 'ntest': 4400,\n", + " 'time': 375.37730598449707},\n", + " 'ReadabilityCl': {'devacc': 35.0,\n", + " 'acc': 32.68,\n", + " 'ndev': 640,\n", + " 'ntest': 1282,\n", + " 'time': 14.010723114013672},\n", + " 'TagCl': {'devacc': 32.98,\n", + " 'acc': 31.52,\n", + " 'ndev': 3605,\n", + " 'ntest': 7223,\n", + " 'time': 195.5709707736969},\n", + " 'PoemsCl': {'devacc': 37.89,\n", + " 'acc': 37.47,\n", + " 'ndev': 190,\n", + " 'ntest': 371,\n", + " 'time': 4.036132574081421},\n", + " 'ProzaCl': {'devacc': 43.01,\n", + " 'acc': 44.02,\n", + " 'ndev': 193,\n", + " 'ntest': 393,\n", + " 'time': 15.277348279953003},\n", + " 'TREC': {'devacc': 26.83,\n", + " 'acc': 20.8,\n", + " 'ndev': 5452,\n", + " 'ntest': 500,\n", + " 'time': 58.839884757995605},\n", + " 'STS': {'devpearson': 0.5242887821205533,\n", + " 'pearson': 0.5437650890037948,\n", + " 'spearman': 0.5355568012377294,\n", + " 'mse': 1.7308604560287553,\n", + " 'yhat': array([2.36370519, 2.00206342, 4.23734959, 2.00206342, 3.59654932,\n", + " 3.59654932, 3.59654932, 1.31822499, 2.95111737, 1.6604875 ,\n", + " 1.66242505, 2.44285593, 1.74046474, 2.35618017, 3.51656278,\n", + " 2.25289381, 2.32740771, 2.24848697, 2.32740771, 2.18231911,\n", + " 1.83410028, 1.75027232, 1.70023542, 1.63157094, 1.79205703,\n", + " 2.15417119, 2.43763485, 2.45906199, 1.6283864 , 3.30575675,\n", + " 1.90526416, 1.91028498, 3.14743564, 2.05880412, 2.93087074,\n", + " 2.69487728, 2.50648641, 2.60015661, 1.95352261, 1.98956804,\n", + " 2.62633436, 4.11341089, 2.14624925, 2.22687568, 2.84438144,\n", + " 3.49207652, 4.27593892, 2.84438144, 2.65373036, 2.95353766,\n", + " 2.87559857, 2.84438144, 3.20639532, 2.14981751, 2.98501824,\n", + " 2.31962286, 4.2768316 , 2.83724866, 3.04421297, 4.20906151,\n", + " 2.82118762, 2.29178635, 4.11341089, 3.51544331, 2.01756746,\n", + " 4.11341089, 3.78029433, 3.63733601, 4.19698806, 2.13879235,\n", + " 1.80974696, 2.61720604, 2.1259183 , 2.3694854 , 1.99420547,\n", + " 2.75644658, 1.75851517, 2.92803615, 2.8899997 , 3.0219249 ,\n", + " 2.31612899, 2.86885545, 2.38608417, 3.51199391, 1.50825688,\n", + " 3.67271728, 2.17267587, 2.58816803, 2.43841363, 3.77384172,\n", + " 2.2091377 , 2.00872354, 2.26974412, 2.3298081 , 2.17164023,\n", + " 2.76191505, 3.36655696, 2.57718437, 3.68441821, 2.54227494,\n", + " 3.12225927, 2.87281117, 2.61621274, 2.8193385 , 2.23211779,\n", + " 2.23289466, 2.80592144, 1.81147585, 2.9491739 , 3.11086212,\n", + " 2.55768208, 1.64975202, 2.41542841, 1.70755654, 2.15780174,\n", + " 2.07965625, 2.27954089, 2.26813636, 1.82425859, 1.9858315 ,\n", + " 2.05732191, 2.25564023, 2.01285644, 1.69766004, 1.88318016,\n", + " 2.7295166 , 2.10861523, 2.76711355, 2.76711355, 2.70359304,\n", + " 2.85731594, 2.19305353, 2.46330041, 2.65447506, 2.28927512,\n", + " 2.70116519, 3.01357137, 2.36597746, 3.02944084, 2.23293618,\n", + " 2.23409928, 2.48633213, 2.75423772, 2.23551204, 2.12712896,\n", + " 2.24741045, 2.20203449, 2.53940665, 2.06433245, 2.65563862,\n", + " 3.02739169, 2.14538316, 2.66246409, 2.63843243, 2.64534462,\n", + " 2.56239183, 2.81211761, 3.98729989, 3.72482352, 3.72482352,\n", + " 3.72482352, 3.00813903, 2.83357292, 3.72482352, 3.28048654,\n", + " 3.72482352, 3.72482352, 2.78214778, 3.66612506, 2.99105446,\n", + " 2.96428943, 2.73961376, 2.0406234 , 3.31337149, 2.53940665,\n", + " 2.98387995, 2.16538899, 3.23083933, 2.68703104, 3.02773364,\n", + " 2.75448835, 3.92846183, 3.38236713, 3.47968321, 4.27814279,\n", + " 2.42439519, 3.86452751, 3.72482352, 3.72482352, 4.04664071,\n", + " 4.48273602, 2.14408283, 2.32797889, 2.29960982, 2.49415862,\n", + " 3.03232156, 2.29149063, 2.63396502, 3.2307487 , 2.22032709,\n", + " 3.04666732, 2.81086401, 2.93390681, 2.49121863, 2.84724797,\n", + " 2.43735253, 2.82472951, 2.84566416, 3.01179767, 2.83164829,\n", + " 3.50378214, 3.24625112, 2.7758306 , 3.51910459, 2.07647324,\n", + " 2.83341413, 2.39531236, 2.44531281, 2.60321437, 2.94858754,\n", + " 2.70515415, 3.07154246, 3.13943927, 2.05664013, 2.72542439,\n", + " 3.11410937, 2.89040256, 2.73650738, 2.98638212, 2.80695424,\n", + " 2.25746051, 3.00147432, 1.93046722, 1.85738339, 1.70267584,\n", + " 2.50342129, 2.3310321 , 2.64650684, 2.99221873, 2.34387755,\n", + " 2.49883584, 2.30139482, 2.57759784, 2.46348407, 2.85524107,\n", + " 3.44659707, 2.42473114, 2.66312619, 3.02564657, 2.82930436,\n", + " 2.91551226, 2.58451452, 2.94109635, 2.61989807, 2.5930616 ,\n", + " 3.05420198, 3.451849 , 3.21142909, 2.21492794, 3.49986332,\n", + " 3.59887265, 3.22679251, 3.55176508, 2.98802177, 2.53047819,\n", + " 3.40512996, 3.05697569, 2.93673084, 3.13701302, 2.9208231 ,\n", + " 2.56659417, 2.50628538, 3.11546238, 2.96990098, 2.5325186 ,\n", + " 2.88549045, 3.44357218, 2.87825762, 3.26205348, 2.31881699,\n", + " 4.00454746, 2.71984204, 3.34254077, 2.53973467, 3.10250212,\n", + " 4.00454746, 3.13132996, 3.00340921, 3.66071282, 2.77529614,\n", + " 3.75051885, 3.44132653, 3.26902208, 3.56477434, 3.31063151,\n", + " 3.6623096 , 2.87013908, 3.41187104, 3.2842134 , 2.8617395 ,\n", + " 3.78551285, 4.19404367, 3.01149618, 2.4430571 , 2.14478287,\n", + " 2.21709113, 3.51182353, 2.60456154, 2.99000609, 2.90482547,\n", + " 2.28814654, 3.1679123 , 3.63593517, 3.38464183, 3.64171571,\n", + " 3.22763394, 3.30061626, 3.30140211, 3.11866726, 3.33301216,\n", + " 3.483347 , 2.96663485, 3.31580541, 3.56937192, 3.15122145,\n", + " 2.92735533, 3.23771875, 2.37467594, 2.81506037, 2.69658211,\n", + " 2.99577266, 2.78711434, 2.8945125 , 3.05251072, 2.44018829,\n", + " 3.36047474, 3.04371458, 2.19288102, 2.05381062, 2.57207188,\n", + " 2.39746353, 2.03845177, 2.64414475, 2.8985422 , 2.82809741,\n", + " 2.92065848, 2.69390277, 3.17356502, 3.06770568, 3.03609954,\n", + " 3.23974908, 3.20566231, 3.54400219, 2.67198119, 2.96750184,\n", + " 2.44082667, 2.93898277, 2.48542126, 2.79835435, 3.17309757,\n", + " 3.25209016, 3.07758925, 3.35874911, 3.46104913, 2.89044763,\n", + " 3.2226252 , 2.97076908, 3.45605969, 2.94142868, 3.25289205,\n", + " 2.94577806, 3.80483842, 2.83442539, 3.74944438, 3.58972311,\n", + " 3.55565017, 3.52070293, 2.81467678, 3.4934348 , 2.91913259,\n", + " 2.82321196, 3.03400458, 2.8697807 , 2.61369262, 3.06949562,\n", + " 2.57156015, 2.65144379, 3.16112471, 2.94936052, 3.10811144,\n", + " 3.27669651, 3.33749451, 2.97127902, 3.29690162, 3.14906306,\n", + " 3.78068858, 3.64421532, 2.84387947, 2.66373737, 2.43763862,\n", + " 3.1451649 , 2.72879936, 2.40679068, 3.12224179, 2.98630846,\n", + " 3.273885 , 3.05127962, 2.6901171 , 3.69571536, 3.32052726,\n", + " 3.32768044, 3.20715172, 3.2676289 , 3.01823581, 3.70303706,\n", + " 3.02115802, 2.66726685, 2.80152775, 2.95092472, 3.14252942,\n", + " 2.68350458, 2.56512773, 2.60552411, 3.19303863, 1.92858347,\n", + " 2.77228275, 2.77847243, 2.5711899 , 3.03368573, 2.39414079,\n", + " 2.00078022, 3.26473314, 2.85666998, 3.02423619, 2.99276656,\n", + " 1.86574623, 3.24524847, 3.30224398, 3.51981816, 2.56627555,\n", + " 3.35050094, 3.12341742, 2.86765604, 2.98225211, 3.54920612,\n", + " 2.89539596, 2.98298068, 2.59713513, 2.99765645, 3.23068462,\n", + " 3.17623057, 2.61168061, 2.98593915, 3.7802786 , 3.66432124,\n", + " 3.03015552, 2.70836815, 3.48517345, 3.35239874, 3.22211485,\n", + " 3.5293265 , 3.03387742, 3.70017151, 3.04253145, 2.55140119,\n", + " 3.06702048, 3.59344437, 2.59108468, 3.41746696, 3.07833336,\n", + " 3.12670837, 3.33886732, 2.57373958, 2.77859494, 3.81487393,\n", + " 3.26116236, 3.99766506, 3.14104449, 2.52387153, 3.00900869,\n", + " 2.26632847, 2.9732744 , 3.29688673, 2.97880374, 3.29114611,\n", + " 3.02736272, 2.83718502, 2.69804333, 2.79770155, 2.94849551,\n", + " 3.05144645, 2.30668522, 2.58388101, 3.04199901, 2.95844935,\n", + " 2.46034507, 2.89104179, 2.68307222, 3.19946846, 3.11602211,\n", + " 2.77863168, 2.802346 , 3.6964669 , 2.87921812, 2.881263 ,\n", + " 2.37192752, 2.92166112, 2.81005768, 3.07022922, 3.47576226,\n", + " 3.12359173, 3.00849605, 3.59357995, 3.21412423, 3.58475064,\n", + " 2.61012925, 2.26108191, 2.86243146, 2.65571226, 2.8039688 ,\n", + " 3.42760867, 3.51829935, 3.20991135, 2.75888825, 3.16482765,\n", + " 3.04940477, 3.00570773, 3.23316839, 2.72996912, 3.15913742,\n", + " 2.97063437, 3.64727248, 3.28573534, 2.95013794, 3.04750077,\n", + " 3.05467093, 3.18650276, 2.73353197, 3.16510623, 3.22810479,\n", + " 3.2076251 , 3.51000452, 2.9281648 , 3.38283958, 3.26525882,\n", + " 3.02802111, 2.70173231, 2.38292523, 2.65547086, 2.80916064,\n", + " 2.82226988, 2.74150643, 2.93249377, 3.18017906, 2.84560999,\n", + " 2.94691054, 2.96157923, 2.95331388, 2.44365502, 2.61237966,\n", + " 2.79360472, 2.62759449, 3.22557481, 3.06648905, 3.25854691,\n", + " 3.25621037, 3.1505248 , 3.48037563, 2.86669094, 3.11874464,\n", + " 2.96759461, 3.49079181, 3.19203307, 2.777756 , 2.7301416 ,\n", + " 3.18390058, 3.35905614, 3.54080842, 3.10998711, 3.25152593,\n", + " 2.54742119, 2.92509462, 3.45057984, 2.91290241, 3.03967786,\n", + " 2.88798039, 3.09531388, 3.03053573, 3.11177012, 3.45693224,\n", + " 2.53471738, 3.23157164, 2.55146906, 2.99651783, 3.28342923,\n", + " 2.49960039, 2.7075137 , 2.83867343, 2.95936275, 3.21161325,\n", + " 3.30597045, 2.96694229, 3.41132965, 2.96997678, 2.96485671,\n", + " 3.02755116, 3.40448305, 3.13064481, 3.33698334, 3.15948001,\n", + " 3.22888051, 3.39076798, 3.45171947, 3.03095654, 3.04291523,\n", + " 3.37283502, 2.83579427, 3.18845505, 2.6143499 , 2.94242331,\n", + " 2.29881597, 3.17878283, 3.64490526, 2.7112379 , 3.06746125,\n", + " 2.95047082, 2.74231353, 3.18234011, 2.89490026, 2.98618356,\n", + " 3.14492072, 3.09812109, 3.05065991, 2.91330285, 3.0292787 ,\n", + " 2.94859651, 3.379981 , 2.83184124, 3.28327069, 2.61416009,\n", + " 2.6702219 , 3.37569815, 3.10544221, 3.04858022, 3.13468383,\n", + " 2.81595495, 2.82637397, 2.88384629, 3.11267512, 3.01334456,\n", + " 2.93134135, 3.69883825, 3.11138546, 3.12244305, 3.213401 ,\n", + " 3.10063194, 3.26220817, 2.93112164, 3.07953711, 3.38831352,\n", + " 3.08955135, 3.2831784 , 3.00780753, 2.81428926, 3.12008262,\n", + " 2.69118118, 3.32126937, 3.03070855, 2.80641215, 2.84047465,\n", + " 3.12537859, 2.46977066, 3.32905451, 2.74720771, 3.11677095,\n", + " 3.18536312, 2.73780753, 3.15357313, 3.39436152, 3.33737008,\n", + " 3.48188969, 3.20216505, 3.64215494, 2.67310467, 3.46786668,\n", + " 3.02968626, 3.19657227, 2.30994617, 3.05887499, 3.15318695,\n", + " 2.84682521, 2.98572388, 3.08281109, 3.18894134, 3.54700109,\n", + " 3.03273243, 2.95468782, 3.09516688, 3.35482239, 2.75746833,\n", + " 3.20777377, 3.20184047, 3.34910021, 3.23446756, 2.90292806,\n", + " 2.38080544, 3.15942069, 3.28789218, 2.9231109 , 3.22827069,\n", + " 3.1514833 , 3.32214455, 2.93486338, 2.79330772, 3.02111109,\n", + " 2.93218711, 3.24844942, 2.90509383, 3.2384704 , 3.09319447,\n", + " 3.07797904, 3.45424227, 3.06210137, 3.47375872, 3.16706558,\n", + " 3.13825557, 3.20856016, 2.7097256 , 3.31463886, 3.17292167,\n", + " 3.02839066, 3.49152643, 3.40635147, 3.05746959, 3.21750846,\n", + " 2.94367695, 3.19057287, 3.22245819, 3.31454998, 2.8640388 ,\n", + " 3.89252803, 3.25268794, 3.0088939 , 3.21631937, 3.19481393,\n", + " 2.89930684, 3.16904083, 3.30156596, 3.4103966 , 2.63435557,\n", + " 3.31989469, 3.4422102 , 3.03194234, 3.01976077, 3.10166287,\n", + " 3.13653457, 3.37644516, 3.18862893, 2.57987028, 3.34295332,\n", + " 3.16091111, 3.40608641, 3.5167682 , 3.24129607, 3.30543648,\n", + " 3.11544525, 3.45893581, 3.3869574 , 3.2219131 , 3.0322293 ,\n", + " 2.82539359, 2.94139963, 3.28098132, 2.97146927, 3.31935797,\n", + " 3.19096029, 3.15377428, 3.72369686, 3.11100201, 3.25817911,\n", + " 3.32890571, 3.56832729, 3.34197157, 3.10624474, 2.84027675,\n", + " 3.2001786 , 3.37168305, 3.24964341, 3.05619156, 3.45994959,\n", + " 3.16372005, 3.08792445, 3.2969857 , 3.09615164, 3.6094668 ,\n", + " 3.48228108, 2.53915945, 3.52163506, 3.2962993 , 3.75129267,\n", + " 3.36231886, 2.49799371, 3.26869702, 3.38967431, 3.32217607,\n", + " 3.13128147, 3.46786305, 3.28618477, 2.82705589, 3.18026511,\n", + " 3.16543694, 3.55561133, 3.40027522, 3.10650304, 3.13304929,\n", + " 3.20397066, 3.30181689, 3.44012409, 3.39087135, 3.59150061,\n", + " 3.21720245, 3.34183837, 3.48396418, 3.19439971, 3.43802878,\n", + " 2.53245343, 3.39971318, 3.45624617, 3.32159916, 3.14860707,\n", + " 3.50829044, 3.38866973, 3.4352681 , 3.24563378, 3.35969868,\n", + " 3.37021024, 3.37119943, 3.34468816, 3.28526685, 3.22135253,\n", + " 2.99515533]),\n", + " 'ndev': 1469,\n", + " 'ntest': 841,\n", + " 'time': 72.52497482299805},\n", + " 'SICK': {'devpearson': 0.6286113217546703,\n", + " 'pearson': 0.704970327284487,\n", + " 'spearman': 0.6286594135873013,\n", + " 'mse': 0.5296568019620134,\n", + " 'yhat': array([1.02385654, 2.76356009, 1.09348593, ..., 4.25048931, 3.4249529 ,\n", + " 4.09359651]),\n", + " 'ndev': 756,\n", + " 'ntest': 4169,\n", + " 'time': 79.13439679145813}}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings.check()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/.ipynb_checkpoints/bow_example-checkpoint.ipynb b/examples/.ipynb_checkpoints/randsent_example-checkpoint.ipynb similarity index 100% rename from examples/.ipynb_checkpoints/bow_example-checkpoint.ipynb rename to examples/.ipynb_checkpoints/randsent_example-checkpoint.ipynb diff --git a/examples/bow_example.ipynb b/examples/bow_example.ipynb deleted file mode 100644 index 901cf39..0000000 --- a/examples/bow_example.ipynb +++ /dev/null @@ -1,94 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Import embeddings for checking. You can also import skip-thought and fasttext-idf" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import bow as embeddings" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run bow embeddings for the task SST2. If you want to check embeddings for other tasks, change it in bow.py \"transfer_tasks\".\n", - "Available tasks: SST2, SST3, MRPC, ReadabilityCl, TagCl, PoemsCl, TREC, STS, SICK" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2019-02-03 11:56:07,119 : ***** Transfer task :SST binary *****\n", - "\n", - "\n", - "2019-02-03 11:57:03,030 : Found 97700 words with word vectors, out of 602174 words\n", - "2019-02-03 11:57:03,047 : Computing embedding for train\n", - "2019-02-03 11:57:17,005 : Computed train embeddings\n", - "2019-02-03 11:57:17,005 : Computing embedding for dev\n", - "2019-02-03 11:57:18,585 : Computed dev embeddings\n", - "2019-02-03 11:57:18,585 : Computing embedding for test\n", - "2019-02-03 11:57:21,545 : Computed test embeddings\n", - "2019-02-03 11:57:21,545 : Training pytorch-MLP-nhid0-rmsprop-bs128 with standard validation..\n", - "2019-02-03 12:01:27,050 : [('reg:1e-05', 62.2), ('reg:0.0001', 62.22), ('reg:0.001', 62.5), ('reg:0.01', 61.71)]\n", - "2019-02-03 12:01:27,176 : Validation : best param found is reg = 0.001 with score 62.5\n", - "2019-02-03 12:01:27,203 : Evaluating...\n", - "2019-02-03 12:02:13,433 : \n", - "Dev acc : 62.5 Test acc : 62.3 for SST binary \n", - "\n" - ] - }, - { - "data": { - "text/plain": [ - "{'devacc': 62.5, 'acc': 62.3, 'ndev': 18901, 'ntest': 37804}" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "embeddings.check()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.8" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/models/__init__.py b/examples/models/__init__.py new file mode 100644 index 0000000..69a0791 --- /dev/null +++ b/examples/models/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from .esn import ESN +from .lstm import RandLSTM +from .borep import BOREP diff --git a/examples/models/borep.py b/examples/models/borep.py new file mode 100644 index 0000000..f4d7ec0 --- /dev/null +++ b/examples/models/borep.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + +from . import model_utils + +class BOREP(nn.Module): + + def __init__(self, params): + super(BOREP, self).__init__() + self.params = params + self.max_seq_len = params.max_seq_len + + self.projection = params.projection + self.proj = self.get_projection() + + self.position_enc = None + if params.pos_enc: + self.position_enc = torch.nn.Embedding(self.max_seq_len + 1, params.word_emb_dim, padding_idx=0) + self.position_enc.weight.data = model_utils.position_encoding_init(self.max_seq_len + 1, + params.word_emb_dim) + + if params.gpu: + self.cuda() + + def get_projection(self): + proj = nn.Linear(self.params.input_dim, self.params.output_dim) + if self.params.init == "orthogonal": + nn.init.orthogonal_(proj.weight) + elif self.params.init == "sparse": + nn.init.sparse_(proj.weight, sparsity=0.1) + elif self.params.init == "normal": + nn.init.normal_(proj.weight, std=0.1) + elif self.params.init == "uniform": + nn.init.uniform_(proj.weight, a=-0.1, b=0.1) + elif self.params.init == "kaiming": + nn.init.kaiming_uniform_(proj.weight) + elif self.params.init == "xavier": + nn.init.xavier_uniform_(proj.weight) + + nn.init.constant_(proj.bias, 0) + + if self.params.gpu: + proj = proj.cuda() + return proj + + def borep(self, x): + batch_sz, seq_len = x.size(1), x.size(0) + out = torch.FloatTensor(seq_len, batch_sz, self.params.output_dim).zero_() + for i in range(seq_len): + if self.projection: + emb = self.proj(x[i]) + else: + emb = x[i] + out[i] = emb + return out + + def forward(self, batch, se_params): + lengths, out, word_pos = model_utils.embed(batch, self.params, se_params, + position_enc=self.position_enc, to_reverse=0) + + out = self.borep(out) + out = model_utils.pool(out, lengths, self.params) + + if self.params.activation is not None: + out = self.params.activation(out) + + return out + + def encode(self, batch, params): + return self.forward(batch, params).cpu().detach().numpy() \ No newline at end of file diff --git a/examples/models/esn.py b/examples/models/esn.py new file mode 100644 index 0000000..5cfc546 --- /dev/null +++ b/examples/models/esn.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch.nn as nn +import torch + +import numpy as np + +from . import model_utils + +class ESN(nn.Module): + + def __init__(self, params): + super(ESN, self).__init__() + self.params = params + + self.bidirectional = params.bidirectional + self.max_seq_len = params.max_seq_len + + self.n_inputs = params.input_dim + self.n_reservoir = params.output_dim + self.spectral_radius = params.spectral_radius + self.leaky = params.leaky + self.concat_inp = params.concat_inp + self.stdv = params.stdv + self.sparsity = params.sparsity + + if self.concat_inp: + self.n_reservoir = self.n_reservoir - self.n_inputs + + self.W = nn.Parameter(torch.Tensor(self.n_reservoir , self.n_reservoir)) + self.Win = nn.Parameter(torch.Tensor(self.n_reservoir , self.n_inputs)) + self.W.data.uniform_(-0.5, 0.5) + self.Win.data.uniform_(-self.stdv, self.stdv) + + if self.bidirectional: + self.W_rev = nn.Parameter(torch.Tensor(self.n_reservoir, self.n_reservoir)) + self.Win_rev = nn.Parameter(torch.Tensor(self.n_reservoir, self.n_inputs)) + self.W_rev.data.uniform_(-0.5, 0.5) + self.Win_rev.data.uniform_(-self.stdv, self.stdv) + + if self.spectral_radius > 0: + radius = np.max(np.abs(np.linalg.eigvals(self.W.data))) + self.W.data = self.W.data * (self.spectral_radius / radius) + if self.bidirectional: + radius = np.max(np.abs(np.linalg.eigvals(self.W_rev.data))) + self.W_rev.data = self.W_rev.data * (self.spectral_radius / radius) + + self.input_layer = nn.Linear(self.n_inputs, self.n_reservoir, bias=False) + self.input_layer.weight = self.Win + self.recurrent_layer = nn.Linear(self.n_reservoir, self.n_reservoir, bias=False) + self.recurrent_layer.weight = self.W + + if self.bidirectional: + self.input_layer_rev = nn.Linear(self.n_inputs, self.n_reservoir, bias=True) + self.input_layer_rev.weight = self.Win_rev + self.recurrent_layer_rev = nn.Linear(self.n_reservoir, self.n_reservoir, bias=True) + self.recurrent_layer_rev.weight = self.W_rev + + if self.sparsity > 0: + self.sparse(self.recurrent_layer.weight.data, self.sparsity) + + if self.bidirectional: + self.sparse(self.recurrent_layer_rev.weight.data, self.sparsity) + + if params.gpu: + self.cuda() + + def sparse(self, tensor, sparsity): + rows, cols = tensor.shape + num_zeros = int(math.ceil(sparsity * rows)) + + with torch.no_grad(): + for col_idx in range(cols): + row_indices = torch.randperm(rows) + zero_indices = row_indices[:num_zeros] + tensor[zero_indices, col_idx] = 0 + + def esn(self, out, torev): + hidden_states = torch.zeros(out.size()[1], out.size()[0], self.n_reservoir) #SxBxD + curr_hid = torch.zeros(1, 1, self.n_reservoir) + if self.params.gpu: + curr_hid = curr_hid.cuda() + hidden_states = hidden_states.cuda() + + curr_hid.expand(1, out.size()[0], self.n_reservoir).contiguous() + for i in range(out.size()[1]): + curr_embs = out[:,i,:] + if not torev: + hid_i = self.input_layer(curr_embs) + self.recurrent_layer(curr_hid) + else: + hid_i = self.input_layer_rev(curr_embs) + self.recurrent_layer_rev(curr_hid) + if self.params.activation is not None: + hid_i = self.params.activation(hid_i) + hidden_states[i] = hid_i + if i > 1 and self.leaky > 0: + hidden_states[i] = (1 - self.leaky) * hidden_states[i] + (self.leaky) * hidden_states[i - 1] + curr_hid = hidden_states[i] + + if self.concat_inp: + out = torch.cat([hidden_states, out.transpose(1,0)], dim=2) + else: + out = hidden_states + + return out + + def forward(self, batch, se_params): + lengths, emb_fwd, _ = model_utils.embed(batch, self.params, se_params) + _, emb_rev, _ = model_utils.embed(batch, self.params, se_params, to_reverse=1) + + emb_fwd = emb_fwd.transpose(1, 0) + emb_rev = emb_rev.transpose(1, 0) + + if self.bidirectional: + out_fwd = self.esn(emb_fwd, False) + out_rev = self.esn(emb_rev, True) + out = torch.cat([out_fwd, out_rev], dim=2) + else: + out = self.esn(emb_fwd, False) + + out = model_utils.pool(out, lengths, self.params) + return out + + def encode(self, batch, params): + return self.forward(batch, params).cpu().detach().numpy() diff --git a/examples/models/lstm.py b/examples/models/lstm.py new file mode 100644 index 0000000..a50a9dc --- /dev/null +++ b/examples/models/lstm.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_packed_sequence as unpack +from torch.nn.utils.rnn import pack_padded_sequence as pack + +from . import model_utils + +class RandLSTM(nn.Module): + + def __init__(self, params): + super(RandLSTM, self).__init__() + self.params = params + + self.bidirectional = params.bidirectional + self.max_seq_len = params.max_seq_len + + self.e_hid_init = torch.zeros(1, 1, params.output_dim) + self.e_cell_init = torch.zeros(1, 1, params.output_dim) + + self.output_dim = params.output_dim + self.num_layers = params.num_layers + self.lm = nn.LSTM(params.input_dim, params.output_dim, num_layers=self.num_layers, + bidirectional=bool(params.bidirectional), batch_first=True) + + self.bidirectional += 1 + + if params.init != "none": + model_utils.param_init(self, params) + + if params.gpu: + self.e_hid_init = self.e_hid_init.cuda() + self.e_cell_init = self.e_cell_init.cuda() + self.cuda() + + def lstm(self, inputs, lengths): + bsz, max_len, _ = inputs.size() + in_embs = inputs + lens, indices = torch.sort(lengths, 0, True) + + e_hid_init = self.e_hid_init.expand(1*self.num_layers*self.bidirectional, bsz, self.output_dim).contiguous() + e_cell_init = self.e_cell_init.expand(1*self.num_layers*self.bidirectional, bsz, self.output_dim).contiguous() + all_hids, (enc_last_hid, _) = self.lm(pack(in_embs[indices], + lens.tolist(), batch_first=True), (e_hid_init, e_cell_init)) + _, _indices = torch.sort(indices, 0) + all_hids = unpack(all_hids, batch_first=True)[0][_indices] + + return all_hids + + def forward(self, batch, se_params): + lengths, out, _ = model_utils.embed(batch, self.params, se_params) + out = out.transpose(1, 0) + + out = self.lstm(out, lengths) + out = out.transpose(1,0) + + out = model_utils.pool(out, lengths, self.params) + + if self.params.activation is not None: + out = self.params.activation(out) + + return out + + def encode(self, batch, params): + return self.forward(batch, params).cpu().detach().numpy() diff --git a/examples/models/model_utils.py b/examples/models/model_utils.py new file mode 100644 index 0000000..6466a44 --- /dev/null +++ b/examples/models/model_utils.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + +import numpy as np + +def position_encoding_init(n_position, emb_dim): + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)] + if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)]) + position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim + position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim + return torch.from_numpy(position_enc).type(torch.FloatTensor) + +def sum_pool(x, lengths): + out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF + for i in range(x.size(1)): + out[i] = torch.sum(x[:lengths[i],i,:], 0) + return out + +def mean_pool(x, lengths): + out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF + for i in range(x.size(1)): + out[i] = torch.mean(x[:lengths[i],i,:], 0) + return out + +def max_pool(x, lengths): + out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF + for i in range(x.size(1)): + out[i,:] = torch.max(x[:lengths[i],i,:], 0)[0] + return out + +def min_pool(x, lengths): + out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF + for i in range(x.size(1)): + out[i] = torch.min(x[:lengths[i],i,:], 0)[0] + return out + +def hier_pool(x, lengths, n=5): + out = torch.FloatTensor(x.size(1), x.size(2)).zero_() # BxF + if x.size(0) <= n: return mean_pool(x, lengths) # BxF + for i in range(x.size(1)): + sliders = [] + if lengths[i] <= n: + out[i] = torch.mean(x[:lengths[i],i,:], 0) + continue + for j in range(lengths[i]-n): + win = torch.mean(x[j:j+n,i,:], 0, keepdim=True) # 1xN + sliders.append(win) + sliders = torch.cat(sliders, 0) + out[i] = torch.max(sliders, 0)[0] + return out + +def pool(out, lengths, params): + if params.pooling == "mean": + out = mean_pool(out, lengths) + elif params.pooling == "max": + out = max_pool(out, lengths) + elif params.pooling == "min": + out = min_pool(out, lengths) + elif params.pooling == "hier": + out = hier_pool(out, lengths) + elif params.pooling == "sum": + out = sum_pool(out, lengths) + else: + raise ValueError("No valid pooling operation specified!") + return out + +def param_init(model, opts): + if opts.init == "orthogonal": + for p in model.parameters(): + if p.dim() > 1: + nn.init.orthogonal_(p) + elif opts.init == "sparse": + for p in model.parameters(): + if p.dim() > 1: + nn.init.sparse_(p, sparsity=0.1) + elif opts.init == "normal": + for p in model.parameters(): + if p.dim() > 1: + nn.init.normal_(p) + elif opts.init == "uniform": + for p in model.parameters(): + if p.dim() > 1: + nn.init.uniform_(p, a=-0.1, b=0.1) + elif opts.init == "kaiming": + for p in model.parameters(): + if p.dim() > 1: + nn.init.kaiming_uniform_(p) + elif opts.init == "xavier": + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + +def embed(batch, params, se_params, position_enc=None, to_reverse=0): + input_seq = torch.LongTensor(params.max_seq_len, len(batch)).zero_() + word_pos = None + + if params.pos_enc: + word_pos = torch.LongTensor(params.max_seq_len, len(batch)).zero_() + + cur_max_seq_len = 0 + for i, l in enumerate(batch): + j = 0 + if to_reverse: + l.reverse() + for k, w in enumerate(l): + if k == params.max_seq_len: + break + input_seq[j][i] = se_params.word2id[w] + if params.pos_enc: + word_pos[j][i] = (k + 1) + j += 1 + if j > cur_max_seq_len: + cur_max_seq_len = j + + input_seq = input_seq[:cur_max_seq_len] + out = se_params.lut(input_seq) + if params.gpu: + out = out.cuda() + + if params.pos_enc: + word_pos = word_pos[:cur_max_seq_len] + if params.gpu: + word_pos = word_pos.cuda() + out += position_enc(word_pos) + + lengths = [len(i) if len(i) < params.max_seq_len else params.max_seq_len for i in batch] + lengths = torch.from_numpy(np.array(lengths)) + + return lengths, out, word_pos \ No newline at end of file diff --git a/examples/randsent.py b/examples/randsent.py new file mode 100644 index 0000000..287d75d --- /dev/null +++ b/examples/randsent.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import absolute_import, division, unicode_literals + +import argparse +import random +import os, sys +import io +import logging +from load_file_from_www import download_file_from_www + +import torch +import torch.nn as nn + +import numpy as np + +import utils +from models import BOREP, ESN, RandLSTM + +# Set PATHs +PATH_TO_SENTEVAL = os.path.join(os.path.dirname(__file__), '..') +PATH_TO_DATA = os.path.join(os.path.dirname(__file__), '..', 'data') +PATH_TO_VEC = os.path.join(os.path.dirname(__file__), 'fasttext', 'ft_native_300_ru_wiki_lenta_nltk_word_tokenize.vec') + +# import SentEval +sys.path.insert(0, PATH_TO_SENTEVAL) +import senteval + + +class params_embedding: + # Model parameters + model = 'lstm' # Type of model to use (either borep, esn, or lstm, default borep + task_type = 'downstream' # Type of task to try (either downstream or probing, default downstream) + n_folds = 10 # "Number of folds for cross-validation in SentEval (default 10) + se_batch_size = 16 # Batch size for embedding sentences in SentEval (default 16) + gpu = 1 # Whether to use GPU (default 0) + senteval_path = "./SentEval" # Path to SentEval (default ./SentEval). + word_emb_file = "./glove.840B.300d.txt" # Path to word embeddings file (default ./glove.840B.300d.txt) + word_emb_dim = 300 # Dimension of word embeddings (default 300) + + # Network parameters + input_dim = 300 # iutput feature dimensionality (default 300) + output_dim = 2048 # "Output feature dimensionality (default 4096) + max_seq_len = 96 # Sequence length (default 96) + bidirectional = 1 # Whether to be bidirectional (default 1). + init = 'none' # Type of initialization to use (either none, orthogonal, sparse, normal, uniform, kaiming, " + # "or xavier, default none). + activation = None # Activation function to apply to features (default none) + pooling = 'mean' # Type of pooling (either min, max, mean, hier, or sum, default max) + + # Embedding parameters + zero = 1 # whether to initialize word embeddings to zero (default 1) + pos_enc = 0 # Whether to do positional encoding (default 0) + pos_enc_concat = 0 # Whether to concat positional encoding to regular embedding (default 0) + random_word_embeddings = 0 # Whether to load pretrained embeddings (default 0) + + # Projection parameters + projection = 'same' # Type of projection (either none or same, default same) + + # ESN parameters + spectral_radius = 1 # Spectral radius for ESN (default 1.) + leaky = 0 # Fraction of previous state to leak for ESN (default 0) + concat_inp = 0 # Whether to concatenate input to hidden state for ESN (default 0) + stdv = 1 # Width of uniform interval to sample weights for ESN (default 1) + sparsity = 0 # Sparsity of recurrent weights for ESN (default 0) + # LSTM parameters + num_layers = 1 #Number of layers for random LSTM (default 1).", default=1) + +# Create dictionary +def create_dictionary(sentences, threshold=0): + words = {} + for s in sentences: + for word in s: + words[word] = words.get(word, 0) + 1 + + if threshold > 0: + newwords = {} + for word in words: + if words[word] >= threshold: + newwords[word] = words[word] + words = newwords + words[''] = 1e9 + 4 + words[''] = 1e9 + 3 + words['

'] = 1e9 + 2 + + sorted_words = sorted(words.items(), key=lambda x: -x[1]) # inverse sort + id2word = [] + word2id = {} + for i, (w, _) in enumerate(sorted_words): + id2word.append(w) + word2id[w] = i + + return id2word, word2id + + +# Get word vectors from vocabulary (glove, word2vec, fasttext ..) +def get_wordvec(path_to_vec, word2id): + word_vec = {} + + with io.open(path_to_vec, 'r', encoding='utf-8') as f: + # if word2vec or fasttext file : skip first line "next(f)" + next(f) + for line in f: + word, vec = line.split(' ', 1) + if word in word2id: + word_vec[word] = np.fromstring(vec, sep=' ') + + logging.info('Found {0} words with word vectors, out of \ + {1} words'.format(len(word_vec), len(word2id))) + return word_vec + + +# SentEval prepare and batcher +def prepare(params, samples): + _, params.word2id = create_dictionary(samples) + params.word_vec = get_wordvec(PATH_TO_VEC, params.word2id) + params.wvec_dim = 300 + return + + +def batcher(params, batch): + batch = [sent if sent != [] else ['.'] for sent in batch] + embeddings = [] + + for sent in batch: + sentvec = [] + for word in sent: + if word in params.word_vec: + sentvec.append(params.word_vec[word]) + if not sentvec: + vec = np.zeros(params.wvec_dim) + sentvec.append(vec) + sentvec = np.mean(sentvec, 0) + embeddings.append(sentvec) + + embeddings = np.vstack(embeddings) + return embeddings + +def check(): + seed = 42 + np.random.seed(seed) + torch.manual_seed(seed) + # Set params for network + params = params_embedding() + + if params.gpu: + torch.cuda.manual_seed(seed) + if params.model == 'lstm': + network = RandLSTM(params) + elif params.model == 'esn': + network = ESN(params) + else: + network = BOREP(params) + + # Set params for SentEval + params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10, + 'classifier': {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128, 'tenacity': 3, 'epoch_size': 2}} + + se = senteval.engine.SE(params_senteval, batcher, prepare) + transfer_tasks = ['SST2', 'SST3', 'MRPC', 'ReadabilityCl', 'TagCl', 'PoemsCl', 'ProzaCl', 'TREC', 'STS', 'SICK'] + results = se.eval(transfer_tasks) + return results + diff --git a/examples/randsent_example.ipynb b/examples/randsent_example.ipynb new file mode 100644 index 0000000..1d96ec2 --- /dev/null +++ b/examples/randsent_example.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Import embeddings for checking. You can also import skip-thought and fasttext-idf" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import randsent as embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run bow embeddings for the task SST2. If you want to check embeddings for other tasks, change it in bow.py \"transfer_tasks\".\n", + "Available tasks: SST2, SST3, MRPC, ReadabilityCl, TagCl, PoemsCl, TREC, STS, SICK" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'SST2': {'devacc': 62.54,\n", + " 'acc': 62.42,\n", + " 'ndev': 18901,\n", + " 'ntest': 37804,\n", + " 'time': 267.4362132549286},\n", + " 'SST3': {'devacc': 86.57,\n", + " 'acc': 86.67,\n", + " 'ndev': 4831,\n", + " 'ntest': 9662,\n", + " 'time': 72.43073320388794},\n", + " 'MRPC': {'devacc': 99.77,\n", + " 'acc': 99.52,\n", + " 'f1': 99.74,\n", + " 'ndev': 35635,\n", + " 'ntest': 4400,\n", + " 'time': 388.2232346534729},\n", + " 'ReadabilityCl': {'devacc': 35.0,\n", + " 'acc': 32.68,\n", + " 'ndev': 640,\n", + " 'ntest': 1282,\n", + " 'time': 15.202638149261475},\n", + " 'TagCl': {'devacc': 32.98,\n", + " 'acc': 31.52,\n", + " 'ndev': 3605,\n", + " 'ntest': 7223,\n", + " 'time': 203.70333123207092},\n", + " 'PoemsCl': {'devacc': 37.89,\n", + " 'acc': 37.47,\n", + " 'ndev': 190,\n", + " 'ntest': 371,\n", + " 'time': 4.4349260330200195},\n", + " 'ProzaCl': {'devacc': 43.01,\n", + " 'acc': 44.02,\n", + " 'ndev': 193,\n", + " 'ntest': 393,\n", + " 'time': 18.596612215042114},\n", + " 'TREC': {'devacc': 26.83,\n", + " 'acc': 20.8,\n", + " 'ndev': 5452,\n", + " 'ntest': 500,\n", + " 'time': 66.55714130401611},\n", + " 'STS': {'devpearson': 0.5242887821205533,\n", + " 'pearson': 0.5437650890037948,\n", + " 'spearman': 0.5355568012377294,\n", + " 'mse': 1.7308604560287553,\n", + " 'yhat': array([2.36370519, 2.00206342, 4.23734959, 2.00206342, 3.59654932,\n", + " 3.59654932, 3.59654932, 1.31822499, 2.95111737, 1.6604875 ,\n", + " 1.66242505, 2.44285593, 1.74046474, 2.35618017, 3.51656278,\n", + " 2.25289381, 2.32740771, 2.24848697, 2.32740771, 2.18231911,\n", + " 1.83410028, 1.75027232, 1.70023542, 1.63157094, 1.79205703,\n", + " 2.15417119, 2.43763485, 2.45906199, 1.6283864 , 3.30575675,\n", + " 1.90526416, 1.91028498, 3.14743564, 2.05880412, 2.93087074,\n", + " 2.69487728, 2.50648641, 2.60015661, 1.95352261, 1.98956804,\n", + " 2.62633436, 4.11341089, 2.14624925, 2.22687568, 2.84438144,\n", + " 3.49207652, 4.27593892, 2.84438144, 2.65373036, 2.95353766,\n", + " 2.87559857, 2.84438144, 3.20639532, 2.14981751, 2.98501824,\n", + " 2.31962286, 4.2768316 , 2.83724866, 3.04421297, 4.20906151,\n", + " 2.82118762, 2.29178635, 4.11341089, 3.51544331, 2.01756746,\n", + " 4.11341089, 3.78029433, 3.63733601, 4.19698806, 2.13879235,\n", + " 1.80974696, 2.61720604, 2.1259183 , 2.3694854 , 1.99420547,\n", + " 2.75644658, 1.75851517, 2.92803615, 2.8899997 , 3.0219249 ,\n", + " 2.31612899, 2.86885545, 2.38608417, 3.51199391, 1.50825688,\n", + " 3.67271728, 2.17267587, 2.58816803, 2.43841363, 3.77384172,\n", + " 2.2091377 , 2.00872354, 2.26974412, 2.3298081 , 2.17164023,\n", + " 2.76191505, 3.36655696, 2.57718437, 3.68441821, 2.54227494,\n", + " 3.12225927, 2.87281117, 2.61621274, 2.8193385 , 2.23211779,\n", + " 2.23289466, 2.80592144, 1.81147585, 2.9491739 , 3.11086212,\n", + " 2.55768208, 1.64975202, 2.41542841, 1.70755654, 2.15780174,\n", + " 2.07965625, 2.27954089, 2.26813636, 1.82425859, 1.9858315 ,\n", + " 2.05732191, 2.25564023, 2.01285644, 1.69766004, 1.88318016,\n", + " 2.7295166 , 2.10861523, 2.76711355, 2.76711355, 2.70359304,\n", + " 2.85731594, 2.19305353, 2.46330041, 2.65447506, 2.28927512,\n", + " 2.70116519, 3.01357137, 2.36597746, 3.02944084, 2.23293618,\n", + " 2.23409928, 2.48633213, 2.75423772, 2.23551204, 2.12712896,\n", + " 2.24741045, 2.20203449, 2.53940665, 2.06433245, 2.65563862,\n", + " 3.02739169, 2.14538316, 2.66246409, 2.63843243, 2.64534462,\n", + " 2.56239183, 2.81211761, 3.98729989, 3.72482352, 3.72482352,\n", + " 3.72482352, 3.00813903, 2.83357292, 3.72482352, 3.28048654,\n", + " 3.72482352, 3.72482352, 2.78214778, 3.66612506, 2.99105446,\n", + " 2.96428943, 2.73961376, 2.0406234 , 3.31337149, 2.53940665,\n", + " 2.98387995, 2.16538899, 3.23083933, 2.68703104, 3.02773364,\n", + " 2.75448835, 3.92846183, 3.38236713, 3.47968321, 4.27814279,\n", + " 2.42439519, 3.86452751, 3.72482352, 3.72482352, 4.04664071,\n", + " 4.48273602, 2.14408283, 2.32797889, 2.29960982, 2.49415862,\n", + " 3.03232156, 2.29149063, 2.63396502, 3.2307487 , 2.22032709,\n", + " 3.04666732, 2.81086401, 2.93390681, 2.49121863, 2.84724797,\n", + " 2.43735253, 2.82472951, 2.84566416, 3.01179767, 2.83164829,\n", + " 3.50378214, 3.24625112, 2.7758306 , 3.51910459, 2.07647324,\n", + " 2.83341413, 2.39531236, 2.44531281, 2.60321437, 2.94858754,\n", + " 2.70515415, 3.07154246, 3.13943927, 2.05664013, 2.72542439,\n", + " 3.11410937, 2.89040256, 2.73650738, 2.98638212, 2.80695424,\n", + " 2.25746051, 3.00147432, 1.93046722, 1.85738339, 1.70267584,\n", + " 2.50342129, 2.3310321 , 2.64650684, 2.99221873, 2.34387755,\n", + " 2.49883584, 2.30139482, 2.57759784, 2.46348407, 2.85524107,\n", + " 3.44659707, 2.42473114, 2.66312619, 3.02564657, 2.82930436,\n", + " 2.91551226, 2.58451452, 2.94109635, 2.61989807, 2.5930616 ,\n", + " 3.05420198, 3.451849 , 3.21142909, 2.21492794, 3.49986332,\n", + " 3.59887265, 3.22679251, 3.55176508, 2.98802177, 2.53047819,\n", + " 3.40512996, 3.05697569, 2.93673084, 3.13701302, 2.9208231 ,\n", + " 2.56659417, 2.50628538, 3.11546238, 2.96990098, 2.5325186 ,\n", + " 2.88549045, 3.44357218, 2.87825762, 3.26205348, 2.31881699,\n", + " 4.00454746, 2.71984204, 3.34254077, 2.53973467, 3.10250212,\n", + " 4.00454746, 3.13132996, 3.00340921, 3.66071282, 2.77529614,\n", + " 3.75051885, 3.44132653, 3.26902208, 3.56477434, 3.31063151,\n", + " 3.6623096 , 2.87013908, 3.41187104, 3.2842134 , 2.8617395 ,\n", + " 3.78551285, 4.19404367, 3.01149618, 2.4430571 , 2.14478287,\n", + " 2.21709113, 3.51182353, 2.60456154, 2.99000609, 2.90482547,\n", + " 2.28814654, 3.1679123 , 3.63593517, 3.38464183, 3.64171571,\n", + " 3.22763394, 3.30061626, 3.30140211, 3.11866726, 3.33301216,\n", + " 3.483347 , 2.96663485, 3.31580541, 3.56937192, 3.15122145,\n", + " 2.92735533, 3.23771875, 2.37467594, 2.81506037, 2.69658211,\n", + " 2.99577266, 2.78711434, 2.8945125 , 3.05251072, 2.44018829,\n", + " 3.36047474, 3.04371458, 2.19288102, 2.05381062, 2.57207188,\n", + " 2.39746353, 2.03845177, 2.64414475, 2.8985422 , 2.82809741,\n", + " 2.92065848, 2.69390277, 3.17356502, 3.06770568, 3.03609954,\n", + " 3.23974908, 3.20566231, 3.54400219, 2.67198119, 2.96750184,\n", + " 2.44082667, 2.93898277, 2.48542126, 2.79835435, 3.17309757,\n", + " 3.25209016, 3.07758925, 3.35874911, 3.46104913, 2.89044763,\n", + " 3.2226252 , 2.97076908, 3.45605969, 2.94142868, 3.25289205,\n", + " 2.94577806, 3.80483842, 2.83442539, 3.74944438, 3.58972311,\n", + " 3.55565017, 3.52070293, 2.81467678, 3.4934348 , 2.91913259,\n", + " 2.82321196, 3.03400458, 2.8697807 , 2.61369262, 3.06949562,\n", + " 2.57156015, 2.65144379, 3.16112471, 2.94936052, 3.10811144,\n", + " 3.27669651, 3.33749451, 2.97127902, 3.29690162, 3.14906306,\n", + " 3.78068858, 3.64421532, 2.84387947, 2.66373737, 2.43763862,\n", + " 3.1451649 , 2.72879936, 2.40679068, 3.12224179, 2.98630846,\n", + " 3.273885 , 3.05127962, 2.6901171 , 3.69571536, 3.32052726,\n", + " 3.32768044, 3.20715172, 3.2676289 , 3.01823581, 3.70303706,\n", + " 3.02115802, 2.66726685, 2.80152775, 2.95092472, 3.14252942,\n", + " 2.68350458, 2.56512773, 2.60552411, 3.19303863, 1.92858347,\n", + " 2.77228275, 2.77847243, 2.5711899 , 3.03368573, 2.39414079,\n", + " 2.00078022, 3.26473314, 2.85666998, 3.02423619, 2.99276656,\n", + " 1.86574623, 3.24524847, 3.30224398, 3.51981816, 2.56627555,\n", + " 3.35050094, 3.12341742, 2.86765604, 2.98225211, 3.54920612,\n", + " 2.89539596, 2.98298068, 2.59713513, 2.99765645, 3.23068462,\n", + " 3.17623057, 2.61168061, 2.98593915, 3.7802786 , 3.66432124,\n", + " 3.03015552, 2.70836815, 3.48517345, 3.35239874, 3.22211485,\n", + " 3.5293265 , 3.03387742, 3.70017151, 3.04253145, 2.55140119,\n", + " 3.06702048, 3.59344437, 2.59108468, 3.41746696, 3.07833336,\n", + " 3.12670837, 3.33886732, 2.57373958, 2.77859494, 3.81487393,\n", + " 3.26116236, 3.99766506, 3.14104449, 2.52387153, 3.00900869,\n", + " 2.26632847, 2.9732744 , 3.29688673, 2.97880374, 3.29114611,\n", + " 3.02736272, 2.83718502, 2.69804333, 2.79770155, 2.94849551,\n", + " 3.05144645, 2.30668522, 2.58388101, 3.04199901, 2.95844935,\n", + " 2.46034507, 2.89104179, 2.68307222, 3.19946846, 3.11602211,\n", + " 2.77863168, 2.802346 , 3.6964669 , 2.87921812, 2.881263 ,\n", + " 2.37192752, 2.92166112, 2.81005768, 3.07022922, 3.47576226,\n", + " 3.12359173, 3.00849605, 3.59357995, 3.21412423, 3.58475064,\n", + " 2.61012925, 2.26108191, 2.86243146, 2.65571226, 2.8039688 ,\n", + " 3.42760867, 3.51829935, 3.20991135, 2.75888825, 3.16482765,\n", + " 3.04940477, 3.00570773, 3.23316839, 2.72996912, 3.15913742,\n", + " 2.97063437, 3.64727248, 3.28573534, 2.95013794, 3.04750077,\n", + " 3.05467093, 3.18650276, 2.73353197, 3.16510623, 3.22810479,\n", + " 3.2076251 , 3.51000452, 2.9281648 , 3.38283958, 3.26525882,\n", + " 3.02802111, 2.70173231, 2.38292523, 2.65547086, 2.80916064,\n", + " 2.82226988, 2.74150643, 2.93249377, 3.18017906, 2.84560999,\n", + " 2.94691054, 2.96157923, 2.95331388, 2.44365502, 2.61237966,\n", + " 2.79360472, 2.62759449, 3.22557481, 3.06648905, 3.25854691,\n", + " 3.25621037, 3.1505248 , 3.48037563, 2.86669094, 3.11874464,\n", + " 2.96759461, 3.49079181, 3.19203307, 2.777756 , 2.7301416 ,\n", + " 3.18390058, 3.35905614, 3.54080842, 3.10998711, 3.25152593,\n", + " 2.54742119, 2.92509462, 3.45057984, 2.91290241, 3.03967786,\n", + " 2.88798039, 3.09531388, 3.03053573, 3.11177012, 3.45693224,\n", + " 2.53471738, 3.23157164, 2.55146906, 2.99651783, 3.28342923,\n", + " 2.49960039, 2.7075137 , 2.83867343, 2.95936275, 3.21161325,\n", + " 3.30597045, 2.96694229, 3.41132965, 2.96997678, 2.96485671,\n", + " 3.02755116, 3.40448305, 3.13064481, 3.33698334, 3.15948001,\n", + " 3.22888051, 3.39076798, 3.45171947, 3.03095654, 3.04291523,\n", + " 3.37283502, 2.83579427, 3.18845505, 2.6143499 , 2.94242331,\n", + " 2.29881597, 3.17878283, 3.64490526, 2.7112379 , 3.06746125,\n", + " 2.95047082, 2.74231353, 3.18234011, 2.89490026, 2.98618356,\n", + " 3.14492072, 3.09812109, 3.05065991, 2.91330285, 3.0292787 ,\n", + " 2.94859651, 3.379981 , 2.83184124, 3.28327069, 2.61416009,\n", + " 2.6702219 , 3.37569815, 3.10544221, 3.04858022, 3.13468383,\n", + " 2.81595495, 2.82637397, 2.88384629, 3.11267512, 3.01334456,\n", + " 2.93134135, 3.69883825, 3.11138546, 3.12244305, 3.213401 ,\n", + " 3.10063194, 3.26220817, 2.93112164, 3.07953711, 3.38831352,\n", + " 3.08955135, 3.2831784 , 3.00780753, 2.81428926, 3.12008262,\n", + " 2.69118118, 3.32126937, 3.03070855, 2.80641215, 2.84047465,\n", + " 3.12537859, 2.46977066, 3.32905451, 2.74720771, 3.11677095,\n", + " 3.18536312, 2.73780753, 3.15357313, 3.39436152, 3.33737008,\n", + " 3.48188969, 3.20216505, 3.64215494, 2.67310467, 3.46786668,\n", + " 3.02968626, 3.19657227, 2.30994617, 3.05887499, 3.15318695,\n", + " 2.84682521, 2.98572388, 3.08281109, 3.18894134, 3.54700109,\n", + " 3.03273243, 2.95468782, 3.09516688, 3.35482239, 2.75746833,\n", + " 3.20777377, 3.20184047, 3.34910021, 3.23446756, 2.90292806,\n", + " 2.38080544, 3.15942069, 3.28789218, 2.9231109 , 3.22827069,\n", + " 3.1514833 , 3.32214455, 2.93486338, 2.79330772, 3.02111109,\n", + " 2.93218711, 3.24844942, 2.90509383, 3.2384704 , 3.09319447,\n", + " 3.07797904, 3.45424227, 3.06210137, 3.47375872, 3.16706558,\n", + " 3.13825557, 3.20856016, 2.7097256 , 3.31463886, 3.17292167,\n", + " 3.02839066, 3.49152643, 3.40635147, 3.05746959, 3.21750846,\n", + " 2.94367695, 3.19057287, 3.22245819, 3.31454998, 2.8640388 ,\n", + " 3.89252803, 3.25268794, 3.0088939 , 3.21631937, 3.19481393,\n", + " 2.89930684, 3.16904083, 3.30156596, 3.4103966 , 2.63435557,\n", + " 3.31989469, 3.4422102 , 3.03194234, 3.01976077, 3.10166287,\n", + " 3.13653457, 3.37644516, 3.18862893, 2.57987028, 3.34295332,\n", + " 3.16091111, 3.40608641, 3.5167682 , 3.24129607, 3.30543648,\n", + " 3.11544525, 3.45893581, 3.3869574 , 3.2219131 , 3.0322293 ,\n", + " 2.82539359, 2.94139963, 3.28098132, 2.97146927, 3.31935797,\n", + " 3.19096029, 3.15377428, 3.72369686, 3.11100201, 3.25817911,\n", + " 3.32890571, 3.56832729, 3.34197157, 3.10624474, 2.84027675,\n", + " 3.2001786 , 3.37168305, 3.24964341, 3.05619156, 3.45994959,\n", + " 3.16372005, 3.08792445, 3.2969857 , 3.09615164, 3.6094668 ,\n", + " 3.48228108, 2.53915945, 3.52163506, 3.2962993 , 3.75129267,\n", + " 3.36231886, 2.49799371, 3.26869702, 3.38967431, 3.32217607,\n", + " 3.13128147, 3.46786305, 3.28618477, 2.82705589, 3.18026511,\n", + " 3.16543694, 3.55561133, 3.40027522, 3.10650304, 3.13304929,\n", + " 3.20397066, 3.30181689, 3.44012409, 3.39087135, 3.59150061,\n", + " 3.21720245, 3.34183837, 3.48396418, 3.19439971, 3.43802878,\n", + " 2.53245343, 3.39971318, 3.45624617, 3.32159916, 3.14860707,\n", + " 3.50829044, 3.38866973, 3.4352681 , 3.24563378, 3.35969868,\n", + " 3.37021024, 3.37119943, 3.34468816, 3.28526685, 3.22135253,\n", + " 2.99515533]),\n", + " 'ndev': 1469,\n", + " 'ntest': 841,\n", + " 'time': 80.568594455719},\n", + " 'SICK': {'devpearson': 0.6286113217546703,\n", + " 'pearson': 0.704970327284487,\n", + " 'spearman': 0.6286594135873013,\n", + " 'mse': 0.5296568019620134,\n", + " 'yhat': array([1.02385654, 2.76356009, 1.09348593, ..., 4.25048931, 3.4249529 ,\n", + " 4.09359651]),\n", + " 'ndev': 756,\n", + " 'ntest': 4169,\n", + " 'time': 82.33699464797974}}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings.check() #lstm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/untitled b/examples/untitled new file mode 100644 index 0000000..e69de29 diff --git a/examples/utils.py b/examples/utils.py new file mode 100644 index 0000000..9b9c6dc --- /dev/null +++ b/examples/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + +def load_vecs(params, word2id, zero=True): + embs = nn.Embedding(len(word2id), params.word_emb_dim, padding_idx=word2id['

']) + if zero: + nn.init.constant_(embs.weight, 0.) + + matches = 0 + n_words = len(word2id) + embedding_vocab = [] + with open(params.word_emb_file) as f: + for line in f: + word = line.split(' ', 1)[0] + embedding_vocab.append(word) + embedding_vocab = set(embedding_vocab) + + word_map = {} + for word in word2id: + if word in embedding_vocab: + word_map[word] = word2id[word] + else: + new_word = word.lower().capitalize() + if new_word in embedding_vocab: + word_map[new_word] = word2id[word] + else: + new_word = word.lower() + if new_word in embedding_vocab: + word_map[new_word] = word2id[word] + + with open(params.word_emb_file) as f: + for line in f: + word = line.split(' ', 1)[0] + if word != '

': + if word in word_map: + glove_vect = torch.FloatTensor(list(map(float, line.split(' ', 1)[1].split(' ')))) + embs.weight.data[word_map[word]][:300].copy_(torch.FloatTensor(glove_vect)) + + matches += 1 + if matches == n_words: break + return embs + +def init_word_embeds(emb, opts): + if opts.init == "none": + emb.weight = nn.Embedding(emb.weight.size()[0], emb.weight.size()[1]).weight + if opts.init == "orthogonal": + nn.init.orthogonal_(emb.weight) + elif opts.init == "sparse": + nn.init.sparse_(emb.weight, sparsity=0.1) + elif opts.init == "normal": + nn.init.normal_(emb.weight, std=0.1) + elif opts.init == "uniform": + nn.init.uniform_(emb.weight, a=-0.1, b=0.1) + elif opts.init == "kaiming": + nn.init.kaiming_uniform_(emb.weight) + elif opts.init == "xavier": + nn.init.xavier_uniform_(emb.weight) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..9b9c6dc --- /dev/null +++ b/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn + +def load_vecs(params, word2id, zero=True): + embs = nn.Embedding(len(word2id), params.word_emb_dim, padding_idx=word2id['

']) + if zero: + nn.init.constant_(embs.weight, 0.) + + matches = 0 + n_words = len(word2id) + embedding_vocab = [] + with open(params.word_emb_file) as f: + for line in f: + word = line.split(' ', 1)[0] + embedding_vocab.append(word) + embedding_vocab = set(embedding_vocab) + + word_map = {} + for word in word2id: + if word in embedding_vocab: + word_map[word] = word2id[word] + else: + new_word = word.lower().capitalize() + if new_word in embedding_vocab: + word_map[new_word] = word2id[word] + else: + new_word = word.lower() + if new_word in embedding_vocab: + word_map[new_word] = word2id[word] + + with open(params.word_emb_file) as f: + for line in f: + word = line.split(' ', 1)[0] + if word != '

': + if word in word_map: + glove_vect = torch.FloatTensor(list(map(float, line.split(' ', 1)[1].split(' ')))) + embs.weight.data[word_map[word]][:300].copy_(torch.FloatTensor(glove_vect)) + + matches += 1 + if matches == n_words: break + return embs + +def init_word_embeds(emb, opts): + if opts.init == "none": + emb.weight = nn.Embedding(emb.weight.size()[0], emb.weight.size()[1]).weight + if opts.init == "orthogonal": + nn.init.orthogonal_(emb.weight) + elif opts.init == "sparse": + nn.init.sparse_(emb.weight, sparsity=0.1) + elif opts.init == "normal": + nn.init.normal_(emb.weight, std=0.1) + elif opts.init == "uniform": + nn.init.uniform_(emb.weight, a=-0.1, b=0.1) + elif opts.init == "kaiming": + nn.init.kaiming_uniform_(emb.weight) + elif opts.init == "xavier": + nn.init.xavier_uniform_(emb.weight)