Code Monkey home page Code Monkey logo

lo-shot's Introduction

Papers related to 'Less Than One'-Shot (LO-Shot) Learning

Papers found in this repo

Paper 1 - 'Less Than One'-Shot Learning: Learn N Classes from M<N Samples

Preprint - https://arxiv.org/abs/2009.08449

Published - In AAAI 2021 Proceedings

Code and appendix - Paper1 directory

TL;DR - Explore the decision landscapes generated by soft-label k-Nearest Neighbors classifiers in the 'less than one'-shot learning setting.

Press coverage - LO-Shot Learning has received significant press coverage.

Online demo - Binder

Paper 2 - Optimal 1-NN Prototypes for Pathological Geometries

Preprint - https://arxiv.org/abs/2011.00228

Published - In PeerJ Computer Science

Code - Paper2 directory

TL;DR - Design optimal 1-NN prototypes even in pathological cases where most prototype methods fail.

Paper 3 - One Line to Rule Them All: Generating LO-Shot Soft-Label Prototypes

Preprint - https://arxiv.org/abs/2102.07834

Published - In IJCNN 2021 Proceedings

Code - Paper3 directory

TL;DR - Represent your training dataset with fewer prototypes than even the number of classes found in the data.

Paper 4 - Can humans do less-than-one-shot learning?

Preprint - https://arxiv.org/abs/2202.04670

Published - In CogSci 2022 Proceedings

Code - LOSLP directory

TL;DR - Humans can also do less-than-one-shot learning.

Papers found in other repos

Paper - Soft-Label Dataset Distillation and Text Dataset Distillation

Preprint - https://arxiv.org/abs/1910.02551v3

Code - https://github.com/ilia10000/dataset-distillation

TL;DR - Experiments with soft-label dataset distillation (an algorithm for generating small synthetic datasets that train neural networks to the same performance as when training on the original data) provided the first evidence of LO-Shot Learning in neural networks.

lo-shot's People

Contributors

dependabot[bot] avatar ilia10000 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

lo-shot's Issues

Clarification of behaviour for recursive line generation

Hi,

While generating lines for 768 dimensional data with a large number of classes (1609), the following lines were gotten using find_lines_R_multiD using k=~800

[array([2, 2]) array([581, 780, 275]) array([26, 26]) array([7, 7])
 array([603, 158]) array([9, 9]) array([11, 11]) array([12, 12])
 array([1023,   13,  912]) array([590, 478])
 array([  15,  806,  171, 1536, 1392,  172,  888,  387]) array([16, 16])
 array([258, 258])
 array([  10,   62,  434, 1476, 1335,  672,  830,  400,  784,   19,   23,
        1205,  395,  889,  697,  676,  131, 1129])
 array([ 557, 1409,  219, 1183, 1444,  242])
 array([ 633,  783,  124,  517, 1046,   39,   22, 1161, 1202,   42,  534,
         714, 1070,  553,  491])
 array([ 937,  786, 1222]) array([597,  24]) array([29, 29])
 array([31, 31]) array([ 845,  973,   55,  361,  249, 1006,  169])
 array([579, 383]) array([34, 34]) array([36, 36]) array([37, 37])
 array([  93,  345,  530,  907,  441,   96,  389,  291,  663,  730, 1505,
        1336, 1345,  694])
 array([41, 41]) array([1447,  378]) array([  43,   43, 1162])
 array([44, 44]) array([1519,   45]) array([47, 47])
 array([ 904, 1072, 1346,  438,  447,  881,  167,  442,  198,  617,  168,
        1203,  921,   48,  945,  803, 1155, 1424])
 array([50, 50]) array([1272, 1313, 1187,  338]) array([53, 53])
 array([ 196,  854, 1290,  632,   54,  449]) array([744, 113])
 array([843, 785,  32, 364, 837]) array([59, 59]) array([799, 615])
 array([464, 203]) array([575, 409]) array([227,  38])
 array([1528,  723, 1445,  408,   68,  761,  283]) array([452, 321])
 array([1412,  386,  278,  506, 1466,  545]) array([77, 77])
 array([ 682,   78, 1538, 1556,  469]) array([585, 450])
 array([ 481,  382,  327,  920, 1218, 1270,  448,  767, 1554, 1201, 1382,
        1374, 1425,   82,  623, 1060,  148,  548,  958,  743])
 array([605, 399,  83, 511]) array([ 813,  343,  108,  966, 1565,  572])
 array([775, 751]) array([444, 153]) array([1240,   91,  844,  587,  302])
 array([718, 483]) array([490,  95]) array([630, 254,  71])
 array([691, 111]) array([571, 102]) array([101, 101]) array([103, 103])
 array([104, 104]) array([374, 374]) array([170, 170]) array([112, 112])
 array([1367,  115]) array([1471,  118,  432, 1354, 1125,  234, 1495])
 array([119, 119]) array([245, 812, 735, 120, 948]) array([121, 121])
 array([644, 558]) array([528, 528]) array([126, 126]) array([792, 792])
 array([ 81, 816, 507,  25,  30]) array([296, 296]) array([130, 130])
 array([ 514, 1434,  132]) array([133, 133]) array([608, 536])
 array([137, 137]) array([138, 138])
 array([ 638,  320,  906,  915,    1,   67, 1073,  527,  698,  642,  602,
         708,  294,  727,  873,  725,  419,  724,    6,  420,  567,   35,
        1366,  688,  776,  367,  140,  216,  415,  898,  531,  946, 1225,
         252,  957,  277,  394,    0,  263,   51,   52,  731, 1546,  446,
         641,  716,  900,   88,  304, 1330,  974, 1259,   49,  328,  815,
         348,  467, 1234,  455,  351,   80,  544, 1082, 1074,  217,  869,
        1059, 1478, 1261,  612, 1583,  625,  324,  695,  139,  681,  496,
         872,  679,   64,  692,  944,  188,  522,  281, 1062, 1545, 1574,
          28,  894, 1561, 1524,  860,  319,  354,  337,  839, 1145,  502,
         418,  668, 1567,  151,  346,  225,  919,  684,  410,  891,   18,
         927, 1050, 1208, 1220,  458,  505, 1110, 1590, 1094,  311, 1507,
        1178,  793,  308, 1362,  817, 1267, 1585,  713,  941,  707,  764,
         289,  908,  256, 1204, 1468, 1509,  466, 1598,  513,  800,  660,
         229,  629,  593,  685,  573, 1182,  542,  240, 1053,  192,  565,
         220,  350, 1496, 1464,  582, 1163,  290, 1104,  288,  570,  779,
        1106, 1340, 1258,  204, 1214,  190, 1426,  230,  214, 1431, 1298,
         122,  100, 1433, 1236, 1602,  954,  300, 1278,  849,   92,  495,
         982,  175,   69, 1482,  379,  622,  482, 1274, 1390,  989,  871,
         566,  251,  955, 1407, 1308,  895,  330, 1099,  257,  413, 1215,
         390,  634,  934,  835,  569,  516, 1085,  174,  365, 1479,  651,
         584, 1301,  696, 1530,  404, 1031,  236, 1173,  454,  286,  377,
        1241,  355, 1238,  986,  261,  564,  244,  671,  134, 1591, 1189,
        1132, 1030, 1075,  589,  653,  309,  690,  712])
 array([141, 141])
 array([ 762,  142,  601,  260,  159,  271,   86,  719,  358, 1310,  852])
 array([1408,  143])
 array([ 193,  808,  850, 1533,  842,  902,   94,  144,  269,  828,  421,
         221, 1436,  233,  347,  451,  381,  739,   66,  733,  195,  384,
         962])
 array([ 152,  145,   74,  604,  181, 1317, 1055,  248])
 array([1487,  238, 1143,  146,  189, 1275,  988, 1102,  429])
 array([ 182,   56,  788,  616, 1223, 1087, 1097,  885,  994,  425,  314,
        1322,  149,  117,  201,  431])
 array([ 675,  606,  594,  437,  747, 1532,  561, 1227,  185, 1558, 1522,
        1276,  778, 1153, 1315,  591, 1078])
 array([855, 239])
 array([ 678, 1216,  191,  840,  901,   60,  546,  480,   61,  334,  128,
         123,  529, 1320, 1321,  899, 1467, 1212, 1089,  753,  315, 1377,
          65,  155,  521, 1124])
 array([156, 156])
 array([ 453, 1361,  157,  515,  184,  703,  317, 1430,  576, 1356])
 array([160, 160]) array([161, 161]) array([ 209,  162,  423, 1432])
 array([ 646, 1526, 1008,  463,  715]) array([166, 166]) array([661, 661])
 array([722, 722]) array([285,  90])
 array([ 176,  807,  611, 1399,  975, 1175])
 array([178, 882, 884, 457, 187]) array([179, 179]) array([180, 180])
 array([440, 435]) array([1038,  183, 1262, 1531]) array([186, 186])
 array([298, 200]) array([459, 459]) array([235, 235])
 array([547, 401, 360]) array([ 877,  199,  554,  658, 1047, 1029])
 array([ 202,  202, 1380]) array([1112,  472,  639,  206])
 array([ 537,  215,  650,  373,  609,  376,  610, 1372,  532,  648, 1523,
         154,  657, 1489,  136, 1400,  293, 1460,  997, 1194,  107, 1357])
 array([1224,  211])
 array([ 323,  411,   99, 1364, 1457,  868,  693,  362,  771])
 array([1135, 1383,  114]) array([748, 476, 526]) array([223, 223])
 array([226, 226]) array([ 795,  231,  998, 1448, 1402,   79,  272,  498])
 array([232, 232]) array([237, 237]) array([1245,  802, 1101])
 array([243, 243]) array([600, 600]) array([942, 247, 588, 928, 680])
 array([250, 250]) array([ 656,  991, 1488, 1353, 1048, 1391, 1348,  523])
 array([255,  70]) array([259, 259])
 array([ 829,  823,  559,  598, 1580,  306, 1014, 1061,  331,  683,  880,
         177, 1256, 1548,  535, 1133,  706,  487,   76, 1011, 1517,  222,
         494, 1396,  369,  687, 1254, 1166,  810, 1088,  329, 1077, 1253,
         150, 1415, 1344, 1360, 1600,  325,  628,   27,  110,   98, 1325,
         287, 1035,   87,  832, 1116, 1307,  923,  393,   57,  673,  241,
          73, 1056, 1324, 1441, 1285,  782,  512,  677,  665,  949,  335,
         412,  282,  951])
 array([264, 264]) array([266, 266]) array([1393,  777]) array([268, 268])
 array([818, 818]) array([276, 276]) array([910, 709, 841])
 array([ 147,  279,  574, 1484,  210]) array([307, 307])
 array([ 556,  599,  652,  265,  717,  388, 1388,   17,  356,  774, 1499,
        1450, 1306, 1198, 1562, 1200,  500,  436,  274,    8, 1470,   89,
        1504,  280,  284, 1076,  607,    5,  127, 1271, 1027])
 array([1036,  292]) array([805, 805]) array([825, 825]) array([752, 736])
 array([301, 301]) array([826, 826]) array([310, 310]) array([313, 313])
 array([704, 342, 791, 643]) array([318, 318]) array([322, 322])
 array([ 352,  163,  749,  391,  772, 1206,  109,   63,  406, 1150])
 array([336, 336]) array([732, 732]) array([339, 339])
 array([896, 363, 543, 809, 341]) array([ 953, 1022,  763])
 array([797, 797])
 array([1327, 1057,  702,  647, 1492,  586,  848, 1370,  533,  759])
 array([ 939,  787, 1179, 1510,  398,  353, 1594, 1159,  766])
 array([621,  40]) array([357, 357]) array([1389,  701,  380])
 array([368, 368]) array([372, 372])
 array([ 164,  740,  375,  550,  492, 1001, 1474,  909,  486])
 array([1323, 1049, 1242,  857,  562]) array([385, 385])
 array([ 207,  801,  662,  838,  208,  964, 1233,  538,  911,  851,  552,
        1597,  965,  983,  853, 1032, 1015,   21,  359, 1114,  981,  595,
         416, 1054, 1572])
 array([ 396, 1020, 1421,  689,  541,  370])
 array([ 686, 1248,  212,  746,  253,  105,  194, 1064, 1456,  620,   97,
         822,  568,  861, 1009,  303,  397, 1304, 1287,  878])
 array([539, 539]) array([407, 407]) array([1117,  224]) array([781, 781])
 array([705, 461]) array([424, 424]) array([549, 549])
 array([1213,  426,  426, 1566]) array([430, 430])
 array([  46,  728,   72, 1148,  205,  267,  789,  197,  846,  627,  773,
        1581,  297, 1086,  299,  443, 1575, 1485, 1529, 1156])
 array([669, 666]) array([711, 475, 305]) array([456, 456])
 array([460, 460]) array([471, 471]) array([473, 473])
 array([ 619,  824,  654,  470,  738, 1294,  770,  580,  757, 1568,  968,
         729, 1527,  635])
 array([479, 479]) array([758, 405]) array([484, 484]) array([485, 485])
 array([1142,  497, 1542, 1595,  246,  433]) array([499, 499])
 array([ 344,  519,  734,  489,  930,  349, 1514, 1333,  504,  636,  428,
        1028, 1219,  984,  106, 1319,  741,  560, 1543,  555])
 array([509, 509]) array([1039,  510,  833]) array([520, 520])
 array([745, 745]) array([ 551, 1418,  804, 1394,  518])
 array([  58,  819,  796,  814, 1569,  943]) array([1113,  821])
 array([856, 856, 577]) array([659, 659]) array([640, 640])
 array([596, 596]) array([811, 887, 614, 631]) array([624, 624])
 array([649, 649]) array([667, 392]) array([1378, 1000, 1363,  769,  798])
 array([710, 710]) array([720, 720]) array([737, 737])
 array([1401,  754,  445, 1286,  858]) array([756, 756])
 array([1540,  836,  578, 1026])
 array([ 270,  863,  173,  890,  862, 1586, 1314,  503,  864,  403,  563,
        1557,   14,    4,  700])
 array([865, 865]) array([866, 866]) array([867, 867]) array([870, 870])
 array([874, 874]) array([875, 875]) array([876, 876]) array([879, 879])
 array([883, 883]) array([892, 892]) array([1334,  893]) array([897, 897])
 array([903, 903]) array([905, 905]) array([913, 913]) array([918, 918])
 array([922, 922]) array([924, 924]) array([925, 925]) array([926, 926])
 array([929, 929]) array([1341,  931]) array([932, 932]) array([933, 933])
 array([936, 936]) array([940, 940]) array([947, 947]) array([950, 950])
 array([956, 956]) array([985, 960]) array([961, 961]) array([963, 963])
 array([967, 967]) array([969, 592]) array([971, 971])
 array([1025,  972,  295]) array([976, 976, 699]) array([977, 977])
 array([978, 978]) array([ 979,  917,  422, 1051,  540]) array([980, 980])
 array([987, 987]) array([990, 990]) array([992, 992]) array([993, 993])
 array([995, 995]) array([996, 996]) array([999, 999]) array([1002, 1002])
 array([1003, 1003]) array([1004, 1004]) array([1005, 1005])
 array([1007, 1007]) array([1010, 1010]) array([1012, 1012])
 array([1013, 1013]) array([1016, 1016]) array([1018, 1018])
 array([1019, 1019]) array([1021, 1021]) array([1033, 1033])
 array([1034, 1034,  674]) array([1037, 1037]) array([1040, 1040])
 array([1041, 1041]) array([1042, 1042]) array([1043, 1043])
 array([1044, 1044]) array([1045, 1045]) array([1052, 1052])
 array([1071, 1058]) array([1063, 1063]) array([1066, 1066])
 array([1067, 1067]) array([1068, 1068]) array([1069, 1069])
 array([1079, 1079]) array([1080, 1080]) array([1081, 1081])
 array([1573, 1065, 1083]) array([1579, 1584, 1084])
 array([  84, 1090, 1090,   85]) array([1091, 1091]) array([1092, 1092])
 array([1093, 1093]) array([1095, 1095]) array([1103, 1103])
 array([1105, 1105]) array([1107, 1107]) array([1109, 1109])
 array([1111, 1111]) array([1115, 1115]) array([1118, 1118])
 array([1120, 1120]) array([1121, 1188, 1279, 1331, 1490, 1190,  218])
 array([1122, 1122]) array([1512, 1126,  721]) array([1127, 1172,  402])
 array([1128, 1128]) array([508, 165]) array([1130, 1130])
 array([1131, 1131]) array([1134, 1134]) array([1136,  645,  332])
 array([1137, 1137]) array([1138, 1138]) array([1139, 1477,  655,  914])
 array([1140, 1140])
 array([ 312, 1283,  794, 1459,  750,  366,  618,   33,  116, 1537, 1541,
        1193, 1358,  427,  670,  135, 1098, 1151,  952,  583, 1469, 1555,
         760,  465, 1601, 1462,  213, 1508,  228, 1096, 1384,  462, 1266,
        1141,  742,    3,  959, 1452, 1186, 1397,  439,  488])
 array([1144, 1144]) array([1146, 1146]) array([1147, 1147])
 array([1149, 1149]) array([1152, 1152]) array([1154, 1154])
 array([1157, 1157]) array([1589, 1158]) array([1160, 1160])
 array([1164, 1164]) array([1167, 1167]) array([1168, 1168])
 array([1169, 1169]) array([1170, 1170]) array([1463, 1171])
 array([1176, 1176]) array([1177, 1177]) array([ 501, 1449,  414,  273])
 array([1180, 1180]) array([1181, 1181]) array([1185, 1185])
 array([1191, 1191]) array([1192, 1192]) array([1195, 1195])
 array([1196, 1196]) array([1199, 1199]) array([1207, 1207])
 array([1209, 1209]) array([1210, 1210]) array([1211, 1211])
 array([1217, 1217]) array([1221, 1221])
 array([1292,  938,  831, 1226,  417]) array([1228, 1228])
 array([1229, 1229]) array([ 125, 1230,  637, 1252,  834])
 array([1231, 1231]) array([1232, 1232]) array([1235, 1235])
 array([1239, 1239]) array([1243, 1243]) array([1244, 1244])
 array([1246, 1246]) array([1247, 1247]) array([1249, 1249])
 array([1250, 1250]) array([1251, 1251]) array([1255, 1255])
 array([ 371, 1295, 1257,  326]) array([1260, 1260]) array([1263, 1263])
 array([1264, 1264]) array([1265, 1265]) array([1268, 1268])
 array([1269, 1269]) array([1273, 1273]) array([1277, 1277])
 array([1280, 1280]) array([1282, 1282]) array([1284, 1284])
 array([1288, 1288]) array([1289, 1289]) array([1291, 1291])
 array([1293, 1293]) array([1296,  477]) array([1297, 1297])
 array([1299, 1299]) array([1302, 1302]) array([1303, 1303])
 array([1305, 1305]) array([1309, 1309]) array([1311, 1311])
 array([1312, 1312]) array([1316, 1316]) array([1318, 1318])
 array([1498,  847]) array([1326, 1326]) array([1328, 1328])
 array([1329, 1329]) array([1332, 1332]) array([1337,  827])
 array([1338, 1338]) array([1339, 1339]) array([1342, 1342])
 array([1343, 1343]) array([1347, 1347]) array([1349,  493])
 array([1350, 1350]) array([1351, 1351]) array([1352, 1352])
 array([1359, 1359]) array([1365, 1365]) array([1368, 1368])
 array([1369, 1369]) array([1371, 1371]) array([1373, 1373])
 array([1375, 1375])
 array([ 768,  726, 1017,  886,  316,  790, 1281,  935,  340, 1119, 1237,
        1123, 1197,  765,  664, 1491,  468, 1355, 1376,  333,  613, 1483,
        1174,  525,  916, 1443, 1024,  970, 1184,  859, 1472, 1300, 1454,
         820])
 array([1379, 1379]) array([1381, 1381,   75]) array([1385, 1385])
 array([1386, 1386]) array([1387,  626]) array([1599, 1395])
 array([1398, 1398]) array([1403, 1403]) array([1404, 1404])
 array([1405,  755, 1165,  474, 1417,  524, 1535,   20,  129, 1473, 1100])
 array([1406, 1406]) array([1410, 1410]) array([1411, 1411])
 array([1413, 1413]) array([1414, 1414]) array([1416, 1416])
 array([1419, 1419]) array([1420, 1420]) array([1422, 1422])
 array([1423, 1423]) array([1427, 1427]) array([1428, 1428])
 array([1429, 1429]) array([1435, 1435]) array([1437, 1437])
 array([1438, 1438]) array([1439, 1439]) array([1440, 1440])
 array([1442, 1442]) array([1446, 1446]) array([1451, 1451])
 array([1453, 1453, 1108]) array([1455, 1455]) array([1458, 1458])
 array([1461, 1461]) array([1465, 1465]) array([1475, 1475])
 array([1480, 1480]) array([1481, 1481]) array([1486, 1486])
 array([1493, 1493]) array([1494, 1494]) array([1497, 1497])
 array([1500, 1500]) array([1501, 1501]) array([1502, 1502])
 array([1503, 1503]) array([1506, 1506]) array([1511, 1511])
 array([1513, 1513]) array([1515, 1515]) array([1516, 1516])
 array([1518, 1518]) array([1520, 1520]) array([1521, 1521])
 array([1525, 1525]) array([1534, 1534]) array([1539, 1539])
 array([1544, 1544]) array([1547, 1547]) array([1549, 1549])
 array([1550, 1550]) array([1551, 1551]) array([1552,  262])
 array([1553, 1553]) array([1559, 1559]) array([1560, 1560])
 array([1563, 1563]) array([1564, 1564]) array([1570, 1570])
 array([1571, 1571]) array([1576, 1576]) array([1577, 1577])
 array([1578, 1578]) array([1582, 1582]) array([1587, 1587])
 array([1588, 1588]) array([1592, 1592]) array([1593, 1593])
 array([1596, 1596])]

A lot of lines consist of a single point (which seems geometrically incorrect) even though they have two endpoints. Can you (a) let me know if this is expected and (b) if not, possible workarounds around this issue? Thanks!

Errors while computing lines using recursive regression

Hi,

While running the algorithm for generating lines using recursive regression, I noticed the following exceptions -

R[write to console]: Error in xtx_in %*% t(v) : non-conformable arguments

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py", line 268, in eval
    value, visible = ro.r("withVisible({%s\n})" % code)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py", line 438, in __call__
    res = self.eval(p)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 199, in __call__
    .__call__(*args, **kwargs))
  File "/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py", line 125, in __call__
    res = super(Function, self).__call__(*new_args, **new_kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py", line 45, in _
    cdata = function(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py", line 680, in __call__
    raise embedded.RRuntimeError(_rinterface._geterrmessage())
rpy2.rinterface_lib.embedded.RRuntimeError: Error in xtx_in %*% t(v) : non-conformable arguments

I am unable to find the root cause of this exception - it would be great if a workaround or a solution can be suggested!

Another exception that comes up commonly is

Traceback (most recent call last):
  File "<ipython-input-16-1a65a5bb4343>", line 66, in <module>
    correct_soft_label_KNN, soft_label_points_KNN = classify_with_soft_label_KNN(lines, test_classifications, labeled_centroids, labeled_test_data)
  File "<ipython-input-14-443ac67806c7>", line 8, in classify_with_soft_label_KNN
    distX, distY = get_line_prototypes(line, labeled_centroids[0])
  File "<ipython-input-7-e7dbf273c244>", line 137, in get_line_prototypes
    distY[0,line], distY[1,line] = x.value[0:n], x.value[n:]
TypeError: 'NoneType' object is not subscriptable

I suspect that cvxpy is unable to find a solution in this case - which returns x as None. Again, if a workaround or a best-practise can be suggested, that would be great. Thanks!

Unable to find lines using recursive_regression for higher dimensions

I was experimenting with the experiments, specifically trying to generate lines for data points with high dimensions. I was able to reproduce the error on this notebook. The snippet can be run under the Multi-dim Experiments section.

seeds=range(500)


num_samples=11 #np.array([10,20,30,40,50,60,70,80,90,100])*2
#points_surf = np.random.poisson(lam=7, size=(49,2))*20
#print(len(np.unique(points_surf, axis=1)))
num_classes=3
num_lines=3
#num_dims=2
for num_dims in range(2,11):
    fail_count=0
    tc_list =[]
    tp_list =[]
    knnc_list=[]
    nline_list=[]
    for seed in tqdm(seeds):
        # hardcode dimension count to 768
        clist,plist,knn_correct, true_num_lines = multiD_experiment(num_samples, num_classes, num_lines, seed, False, brute=False, max_diff=0.01, center_box=(-20,20), num_dims=768)
        total_points = sum(plist)
        total_correct = sum(clist)
        tp_list.append(total_points)
        tc_list.append(total_correct)
        knnc_list.append(knn_correct)
        nline_list.append(true_num_lines)
        #print("Correctly predicted: {0}/{1}".format(total_correct,total_points))
        #print("Vanilla kNN predicted: {0}/{1}".format(knn_correct,total_points))
        
        if len(tp_list)==100:
            break
    print("Dimension: {0}".format(num_dims))
    print(np.mean(tc_list), np.mean(knnc_list), np.mean(nline_list), fail_count )
    print(np.std(tc_list), np.std(knnc_list))

I got the following output, which seems to suggest this problem is not solvable mathematically.

  0%|          | 0/500 [00:00<?, ?it/s]            0          1          2  ...        766       767 My Hopes And Dreams
0   18.608422  11.431569   7.906970  ...   0.177640  8.986997                 1.0
1   18.728446  13.081384   4.447207  ...   1.386950  9.313738                 1.0
2    2.167781   7.601209   3.927630  ...  19.934411 -3.890110                 0.0
3    1.765849   7.463933   5.255089  ...  20.456780 -5.806031                 0.0
4    8.885765  15.784543  13.242956  ...   6.850871 -5.471370                 2.0
5   18.770369   9.909713   6.214410  ...  -1.898342  7.821227                 1.0
6   20.053602  11.852070   8.755898  ...  -1.006263  8.715530                 1.0
7    8.621864  16.963792  13.193678  ...   7.565738 -2.090369                 2.0
8    2.566521   8.031120   4.064921  ...  18.256428 -4.626897                 0.0
9    0.434333   9.727282   3.156674  ...  20.090864 -3.733480                 0.0
10   9.285816  15.687726  12.829242  ...   7.120174 -3.397527                 2.0

[11 rows x 769 columns]
[1] 3 2 3
[1] 2
R[write to console]: Error in solve.default(t(x) %*% x) : 
  system is computationally singular: reciprocal condition number = 6.69256e-22


Error in solve.default(t(x) %*% x) : 
  system is computationally singular: reciprocal condition number = 6.69256e-22
  0%|          | 0/500 [00:01<?, ?it/s]
---------------------------------------------------------------------------
RRuntimeError                             Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py in eval(self, code)
    267                 # Need the newline in case the last line in code is a comment.
--> 268                 value, visible = ro.r("withVisible({%s\n})" % code)
    269             except (ri.embedded.RRuntimeError, ValueError) as exception:

13 frames
/usr/local/lib/python3.7/dist-packages/rpy2/robjects/__init__.py in __call__(self, string)
    437         p = rinterface.parse(string)
--> 438         res = self.eval(p)
    439         return conversion.rpy2py(res)

/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py in __call__(self, *args, **kwargs)
    198         return (super(SignatureTranslatedFunction, self)
--> 199                 .__call__(*args, **kwargs))
    200 

/usr/local/lib/python3.7/dist-packages/rpy2/robjects/functions.py in __call__(self, *args, **kwargs)
    124                 new_kwargs[k] = conversion.py2rpy(v)
--> 125         res = super(Function, self).__call__(*new_args, **new_kwargs)
    126         res = conversion.rpy2py(res)

/usr/local/lib/python3.7/dist-packages/rpy2/rinterface_lib/conversion.py in _(*args, **kwargs)
     44     def _(*args, **kwargs):
---> 45         cdata = function(*args, **kwargs)
     46         # TODO: test cdata is of the expected CType

/usr/local/lib/python3.7/dist-packages/rpy2/rinterface.py in __call__(self, *args, **kwargs)
    679             if error_occured[0]:
--> 680                 raise embedded.RRuntimeError(_rinterface._geterrmessage())
    681         return res

RRuntimeError: Error in solve.default(t(x) %*% x) : 
  system is computationally singular: reciprocal condition number = 6.69256e-22


During handling of the above exception, another exception occurred:

RInterpreterError                         Traceback (most recent call last)
<ipython-input-27-1abdead09a56> in <module>()
     15     nline_list=[]
     16     for seed in tqdm(seeds):
---> 17         clist,plist,knn_correct, true_num_lines = multiD_experiment(num_samples, num_classes, num_lines, seed, False, brute=False, max_diff=0.01, center_box=(-20,20), num_dims=768)
     18         total_points = sum(plist)
     19         total_correct = sum(clist)

<ipython-input-19-91b63ce35f6a> in multiD_experiment(num_samples, num_classes, num_lines, random_state, visualize, brute, max_diff, center_box, num_dims)
     29         lines=[line_order(centroids, np.array(line)) for line in find_lines_brute(centroids,num_lines)]
     30     else:
---> 31         lines = [line_order_no_endpoints(centroids, np.array(line)) for line in find_lines_R_multiD(dat,centroids,dims=num_dims, k=num_lines, max_diff=max_diff)]
     32 
     33     if visualize:

<ipython-input-26-8869bbce1340> in find_lines_R_multiD(dat, centroids, dims, k, max_diff)
    387     print(df)
    388     #result1=[]
--> 389     get_ipython().magic('R -i df -i k -i max_diff -i dims -o result1 result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)')
    390     lines=[list(r) for r in result1]
    391     #print(lines)

/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py in magic(self, arg_s)
   2158         magic_name, _, magic_arg_s = arg_s.partition(' ')
   2159         magic_name = magic_name.lstrip(prefilter.ESC_MAGIC)
-> 2160         return self.run_line_magic(magic_name, magic_arg_s)
   2161 
   2162     #-------------------------------------------------------------------------

/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py in run_line_magic(self, magic_name, line)
   2079                 kwargs['local_ns'] = sys._getframe(stack_depth).f_locals
   2080             with self.builtin_trap:
-> 2081                 result = fn(*args,**kwargs)
   2082             return result
   2083 

<decorator-gen-119> in R(self, line, cell, local_ns)

/usr/local/lib/python3.7/dist-packages/IPython/core/magic.py in <lambda>(f, *a, **k)
    186     # but it's overkill for just that one bit of state.
    187     def magic_deco(arg):
--> 188         call = lambda f, *a, **k: f(*a, **k)
    189 
    190         if callable(arg):

/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py in R(self, line, cell, local_ns)
    781             if not e.stdout.endswith(e.err):
    782                 print(e.err)
--> 783             raise e
    784         finally:
    785             if self.device in ['png', 'svg']:

/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py in R(self, line, cell, local_ns)
    754             if line_mode:
    755                 for line in code.split(';'):
--> 756                     text_result, result, visible = self.eval(line)
    757                     text_output += text_result
    758                 if text_result:

/usr/local/lib/python3.7/dist-packages/rpy2/ipython/rmagic.py in eval(self, code)
    271                 warning_or_other_msg = self.flush()
    272                 raise RInterpreterError(code, str(exception),
--> 273                                         warning_or_other_msg)
    274             text_output = self.flush()
    275             return text_output, value, visible[0]

RInterpreterError: Failed to parse and evaluate line 'result1 <- recursive_reg(as.matrix(df[,-(dims+1)]), df[,dims+1]+1, k = k, max_diff = max_diff)'.
R error message: 'Error in solve.default(t(x) %*% x) : \n  system is computationally singular: reciprocal condition number = 6.69256e-22'

Please let me know how to proceed for such cases. Thanks!

Confirmation of expected behaviour of mathutils.geometry.intersect_point_line for higher dimensions

From the documentation of mathutils.geometry.intersect_point_line the point closest to the given point on a line should be returned.

it looks like the method does not work for higher dimensions - the point returned has at most 3 dimensions. A quick check can reproduce this behaviour -

from mathutils.geometry import intersect_point_line
import numpy as np

pointA = np.arange(50)
pointB = np.arange(100, 150)
pointC = np.arange(200, 250)
pointD = np.arange(300, 350)
centroids = [pointA, pointB, pointC, pointD]
print(centroids)
for centroid in centroids:
  print(intersect_point_line(centroid, centroids[0], centroids[-1]))
[array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]), array([100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
       113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
       126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
       139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149]), array([200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212,
       213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
       226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
       239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249]), array([300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312,
       313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
       326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338,
       339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349])]
(Vector((0.0, 1.0, 2.0)), 0.0)
(Vector((100.0, 101.0, 102.0)), 0.3333333432674408)
(Vector((200.0, 201.0, 202.0)), 0.6666666865348816)
(Vector((300.0, 301.0, 302.0)), 1.0)

The points are 50 dimensional, however, only a 3D point is returned. I am assuming that this is incorrect behaviour and we would need to implement a separate function for higher dimensions. Please let me know if this is correct. Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.