Site icon R-bloggers

Training Neural Networks with MXNet

[This article was first published on Jakub Glinka's Blog, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
< svg style="display: none;">< defs id="MathJax_SVG_glyphs">< path stroke-width="1" id="MJMATHI-78" d="M52 289Q59 331 106 386T222 442Q257 442 286 424T329 379Q371 442 430 442Q467 442 494 420T522 361Q522 332 508 314T481 292T458 288Q439 288 427 299T415 328Q415 374 465 391Q454 404 425 404Q412 404 406 402Q368 386 350 336Q290 115 290 78Q290 50 306 38T341 26Q378 26 414 59T463 140Q466 150 469 151T485 153H489Q504 153 504 145Q504 144 502 134Q486 77 440 33T333 -11Q263 -11 227 52Q186 -10 133 -10H127Q78 -10 57 16T35 71Q35 103 54 123T99 143Q142 143 142 101Q142 81 130 66T107 46T94 41L91 40Q91 39 97 36T113 29T132 26Q168 26 194 71Q203 87 217 139T245 247T261 313Q266 340 266 352Q266 380 251 392T217 404Q177 404 142 372T93 290Q91 281 88 280T72 278H58Q52 284 52 289Z">< path stroke-width="1" id="MJMATHI-66" d="M118 -162Q120 -162 124 -164T135 -167T147 -168Q160 -168 171 -155T187 -126Q197 -99 221 27T267 267T289 382V385H242Q195 385 192 387Q188 390 188 397L195 425Q197 430 203 430T250 431Q298 431 298 432Q298 434 307 482T319 540Q356 705 465 705Q502 703 526 683T550 630Q550 594 529 578T487 561Q443 561 443 603Q443 622 454 636T478 657L487 662Q471 668 457 668Q445 668 434 658T419 630Q412 601 403 552T387 469T380 433Q380 431 435 431Q480 431 487 430T498 424Q499 420 496 407T491 391Q489 386 482 386T428 385H372L349 263Q301 15 282 -47Q255 -132 212 -173Q175 -205 139 -205Q107 -205 81 -186T55 -132Q55 -95 76 -78T118 -61Q162 -61 162 -103Q162 -122 151 -136T127 -157L118 -162Z">< path stroke-width="1" id="MJMATHI-6A" d="M297 596Q297 627 318 644T361 661Q378 661 389 651T403 623Q403 595 384 576T340 557Q322 557 310 567T297 596ZM288 376Q288 405 262 405Q240 405 220 393T185 362T161 325T144 293L137 279Q135 278 121 278H107Q101 284 101 286T105 299Q126 348 164 391T252 441Q253 441 260 441T272 442Q296 441 316 432Q341 418 354 401T367 348V332L318 133Q267 -67 264 -75Q246 -125 194 -164T75 -204Q25 -204 7 -183T-12 -137Q-12 -110 7 -91T53 -71Q70 -71 82 -81T95 -112Q95 -148 63 -167Q69 -168 77 -168Q111 -168 139 -140T182 -74L193 -32Q204 11 219 72T251 197T278 308T289 365Q289 372 288 376Z">< path stroke-width="1" id="MJMAIN-28" d="M94 250Q94 319 104 381T127 488T164 576T202 643T244 695T277 729T302 750H315H319Q333 750 333 741Q333 738 316 720T275 667T226 581T184 443T167 250T184 58T225 -81T274 -167T316 -220T333 -241Q333 -250 318 -250H315H302L274 -226Q180 -141 137 -14T94 250Z">< path stroke-width="1" id="MJMAIN-29" d="M60 749L64 750Q69 750 74 750H86L114 726Q208 641 251 514T294 250Q294 182 284 119T261 12T224 -76T186 -143T145 -194T113 -227T90 -246Q87 -249 86 -250H74Q66 -250 63 -250T58 -247T55 -238Q56 -237 66 -225Q221 -64 221 250T66 725Q56 737 55 738Q55 746 60 749Z">< path stroke-width="1" id="MJMAIN-3D" d="M56 347Q56 360 70 367H707Q722 359 722 347Q722 336 708 328L390 327H72Q56 332 56 347ZM56 153Q56 168 72 173H708Q722 163 722 153Q722 140 707 133H70Q56 140 56 153Z">< path stroke-width="1" id="MJMATHI-77" d="M580 385Q580 406 599 424T641 443Q659 443 674 425T690 368Q690 339 671 253Q656 197 644 161T609 80T554 12T482 -11Q438 -11 404 5T355 48Q354 47 352 44Q311 -11 252 -11Q226 -11 202 -5T155 14T118 53T104 116Q104 170 138 262T173 379Q173 380 173 381Q173 390 173 393T169 400T158 404H154Q131 404 112 385T82 344T65 302T57 280Q55 278 41 278H27Q21 284 21 287Q21 293 29 315T52 366T96 418T161 441Q204 441 227 416T250 358Q250 340 217 250T184 111Q184 65 205 46T258 26Q301 26 334 87L339 96V119Q339 122 339 128T340 136T341 143T342 152T345 165T348 182T354 206T362 238T373 281Q402 395 406 404Q419 431 449 431Q468 431 475 421T483 402Q483 389 454 274T422 142Q420 131 420 107V100Q420 85 423 71T442 42T487 26Q558 26 600 148Q609 171 620 213T632 273Q632 306 619 325T593 357T580 385Z">< path stroke-width="1" id="MJMAIN-22A4" d="M55 642T55 648T59 659T66 666T71 668H708Q723 660 723 648T708 628H409V15Q402 2 391 0Q387 0 384 1T379 3T375 6T373 9T371 13T369 16V628H71Q70 628 67 630T59 637Z">< path stroke-width="1" id="MJMAIN-2B" d="M56 237T56 250T70 270H369V420L370 570Q380 583 389 583Q402 583 409 568V270H707Q722 262 722 250T707 230H409V-68Q401 -82 391 -82H389H387Q375 -82 369 -68V230H70Q56 237 56 250Z">< path stroke-width="1" id="MJMATHI-62" d="M73 647Q73 657 77 670T89 683Q90 683 161 688T234 694Q246 694 246 685T212 542Q204 508 195 472T180 418L176 399Q176 396 182 402Q231 442 283 442Q345 442 383 396T422 280Q422 169 343 79T173 -11Q123 -11 82 27T40 150V159Q40 180 48 217T97 414Q147 611 147 623T109 637Q104 637 101 637H96Q86 637 83 637T76 640T73 647ZM336 325V331Q336 405 275 405Q258 405 240 397T207 376T181 352T163 330L157 322L136 236Q114 150 114 114Q114 66 138 42Q154 26 178 26Q211 26 245 58Q270 81 285 114T318 219Q336 291 336 325Z">< path stroke-width="1" id="MJMAIN-66" d="M273 0Q255 3 146 3Q43 3 34 0H26V46H42Q70 46 91 49Q99 52 103 60Q104 62 104 224V385H33V431H104V497L105 564L107 574Q126 639 171 668T266 704Q267 704 275 704T289 705Q330 702 351 679T372 627Q372 604 358 590T321 576T284 590T270 627Q270 647 288 667H284Q280 668 273 668Q245 668 223 647T189 592Q183 572 182 497V431H293V385H185V225Q185 63 186 61T189 57T194 54T199 51T206 49T213 48T222 47T231 47T241 46T251 46H282V0H273Z">< path stroke-width="1" id="MJMAIN-6F" d="M28 214Q28 309 93 378T250 448Q340 448 405 380T471 215Q471 120 407 55T250 -10Q153 -10 91 57T28 214ZM250 30Q372 30 372 193V225V250Q372 272 371 288T364 326T348 362T317 390T268 410Q263 411 252 411Q222 411 195 399Q152 377 139 338T126 246V226Q126 130 145 91Q177 30 250 30Z">< path stroke-width="1" id="MJMAIN-72" d="M36 46H50Q89 46 97 60V68Q97 77 97 91T98 122T98 161T98 203Q98 234 98 269T98 328L97 351Q94 370 83 376T38 385H20V408Q20 431 22 431L32 432Q42 433 60 434T96 436Q112 437 131 438T160 441T171 442H174V373Q213 441 271 441H277Q322 441 343 419T364 373Q364 352 351 337T313 322Q288 322 276 338T263 372Q263 381 265 388T270 400T273 405Q271 407 250 401Q234 393 226 386Q179 341 179 207V154Q179 141 179 127T179 101T180 81T180 66V61Q181 59 183 57T188 54T193 51T200 49T207 48T216 47T225 47T235 46T245 46H276V0H267Q249 3 140 3Q37 3 28 0H20V46H36Z">< path stroke-width="1" id="MJMAIN-31" d="M213 578L200 573Q186 568 160 563T102 556H83V602H102Q149 604 189 617T245 641T273 663Q275 666 285 666Q294 666 302 660V361L303 61Q310 54 315 52T339 48T401 46H427V0H416Q395 3 257 3Q121 3 100 0H88V46H114Q136 46 152 46T177 47T193 50T201 52T207 57T213 61V578Z">< path stroke-width="1" id="MJMAIN-2C" d="M78 35T78 60T94 103T137 121Q165 121 187 96T210 8Q210 -27 201 -60T180 -117T154 -158T130 -185T117 -194Q113 -194 104 -185T95 -172Q95 -168 106 -156T131 -126T157 -76T173 -3V9L172 8Q170 7 167 6T161 3T152 1T140 0Q113 0 96 17Z">< path stroke-width="1" id="MJMAIN-2E" d="M78 60Q78 84 95 102T138 120Q162 120 180 104T199 61Q199 36 182 18T139 0T96 17T78 60Z">< path stroke-width="1" id="MJMATHI-4A" d="M447 625Q447 637 354 637H329Q323 642 323 645T325 664Q329 677 335 683H352Q393 681 498 681Q541 681 568 681T605 682T619 682Q633 682 633 672Q633 670 630 658Q626 642 623 640T604 637Q552 637 545 623Q541 610 483 376Q420 128 419 127Q397 64 333 21T195 -22Q137 -22 97 8T57 88Q57 130 80 152T132 174Q177 174 182 130Q182 98 164 80T123 56Q115 54 115 53T122 44Q148 15 197 15Q235 15 271 47T324 130Q328 142 387 380T447 625Z">< path stroke-width="1" id="MJMATHI-3D5" d="M409 688Q413 694 421 694H429H442Q448 688 448 686Q448 679 418 563Q411 535 404 504T392 458L388 442Q388 441 397 441T429 435T477 418Q521 397 550 357T579 260T548 151T471 65T374 11T279 -10H275L251 -105Q245 -128 238 -160Q230 -192 227 -198T215 -205H209Q189 -205 189 -198Q189 -193 211 -103L234 -11Q234 -10 226 -10Q221 -10 206 -8T161 6T107 36T62 89T43 171Q43 231 76 284T157 370T254 422T342 441Q347 441 348 445L378 567Q409 686 409 688ZM122 150Q122 116 134 91T167 53T203 35T237 27H244L337 404Q333 404 326 403T297 395T255 379T211 350T170 304Q152 276 137 237Q122 191 122 150ZM500 282Q500 320 484 347T444 385T405 400T381 404H378L332 217L284 29Q284 27 285 27Q293 27 317 33T357 47Q400 66 431 100T475 170T494 234T500 282Z">< path stroke-width="1" id="MJMAIN-6D" d="M41 46H55Q94 46 102 60V68Q102 77 102 91T102 122T103 161T103 203Q103 234 103 269T102 328V351Q99 370 88 376T43 385H25V408Q25 431 27 431L37 432Q47 433 65 434T102 436Q119 437 138 438T167 441T178 442H181V402Q181 364 182 364T187 369T199 384T218 402T247 421T285 437Q305 442 336 442Q351 442 364 440T387 434T406 426T421 417T432 406T441 395T448 384T452 374T455 366L457 361L460 365Q463 369 466 373T475 384T488 397T503 410T523 422T546 432T572 439T603 442Q729 442 740 329Q741 322 741 190V104Q741 66 743 59T754 49Q775 46 803 46H819V0H811L788 1Q764 2 737 2T699 3Q596 3 587 0H579V46H595Q656 46 656 62Q657 64 657 200Q656 335 655 343Q649 371 635 385T611 402T585 404Q540 404 506 370Q479 343 472 315T464 232V168V108Q464 78 465 68T468 55T477 49Q498 46 526 46H542V0H534L510 1Q487 2 460 2T422 3Q319 3 310 0H302V46H318Q379 46 379 62Q380 64 380 200Q379 335 378 343Q372 371 358 385T334 402T308 404Q263 404 229 370Q202 343 195 315T187 232V168V108Q187 78 188 68T191 55T200 49Q221 46 249 46H265V0H257L234 1Q210 2 183 2T145 3Q42 3 33 0H25V46H41Z">< path stroke-width="1" id="MJMAIN-61" d="M137 305T115 305T78 320T63 359Q63 394 97 421T218 448Q291 448 336 416T396 340Q401 326 401 309T402 194V124Q402 76 407 58T428 40Q443 40 448 56T453 109V145H493V106Q492 66 490 59Q481 29 455 12T400 -6T353 12T329 54V58L327 55Q325 52 322 49T314 40T302 29T287 17T269 6T247 -2T221 -8T190 -11Q130 -11 82 20T34 107Q34 128 41 147T68 188T116 225T194 253T304 268H318V290Q318 324 312 340Q290 411 215 411Q197 411 181 410T156 406T148 403Q170 388 170 359Q170 334 154 320ZM126 106Q126 75 150 51T209 26Q247 26 276 49T315 109Q317 116 318 175Q318 233 317 233Q309 233 296 232T251 223T193 203T147 166T126 106Z">< path stroke-width="1" id="MJMAIN-78" d="M201 0Q189 3 102 3Q26 3 17 0H11V46H25Q48 47 67 52T96 61T121 78T139 96T160 122T180 150L226 210L168 288Q159 301 149 315T133 336T122 351T113 363T107 370T100 376T94 379T88 381T80 383Q74 383 44 385H16V431H23Q59 429 126 429Q219 429 229 431H237V385Q201 381 201 369Q201 367 211 353T239 315T268 274L272 270L297 304Q329 345 329 358Q329 364 327 369T322 376T317 380T310 384L307 385H302V431H309Q324 428 408 428Q487 428 493 431H499V385H492Q443 385 411 368Q394 360 377 341T312 257L296 236L358 151Q424 61 429 57T446 50Q464 46 499 46H516V0H510H502Q494 1 482 1T457 2T432 2T414 3Q403 3 377 3T327 1L304 0H295V46H298Q309 46 320 51T331 63Q331 65 291 120L250 175Q249 174 219 133T185 88Q181 83 181 74Q181 63 188 55T206 46Q208 46 208 23V0H201Z">< path stroke-width="1" id="MJMAIN-30" d="M96 585Q152 666 249 666Q297 666 345 640T423 548Q460 465 460 320Q460 165 417 83Q397 41 362 16T301 -15T250 -22Q224 -22 198 -16T137 16T82 83Q39 165 39 320Q39 494 96 585ZM321 597Q291 629 250 629Q208 629 178 597Q153 571 145 525T137 333Q137 175 145 125T181 46Q209 16 250 16Q290 16 318 46Q347 76 354 130T362 333Q362 478 354 524T321 597Z">< path stroke-width="1" id="MJMATHI-68" d="M137 683Q138 683 209 688T282 694Q294 694 294 685Q294 674 258 534Q220 386 220 383Q220 381 227 388Q288 442 357 442Q411 442 444 415T478 336Q478 285 440 178T402 50Q403 36 407 31T422 26Q450 26 474 56T513 138Q516 149 519 151T535 153Q555 153 555 145Q555 144 551 130Q535 71 500 33Q466 -10 419 -10H414Q367 -10 346 17T325 74Q325 90 361 192T398 345Q398 404 354 404H349Q266 404 205 306L198 293L164 158Q132 28 127 16Q114 -11 83 -11Q69 -11 59 -2T48 16Q48 30 121 320L195 616Q195 629 188 632T149 637H128Q122 643 122 645T124 664Q129 683 137 683Z">< path stroke-width="1" id="MJMATHI-6B" d="M121 647Q121 657 125 670T137 683Q138 683 209 688T282 694Q294 694 294 686Q294 679 244 477Q194 279 194 272Q213 282 223 291Q247 309 292 354T362 415Q402 442 438 442Q468 442 485 423T503 369Q503 344 496 327T477 302T456 291T438 288Q418 288 406 299T394 328Q394 353 410 369T442 390L458 393Q446 405 434 405H430Q398 402 367 380T294 316T228 255Q230 254 243 252T267 246T293 238T320 224T342 206T359 180T365 147Q365 130 360 106T354 66Q354 26 381 26Q429 26 459 145Q461 153 479 153H483Q499 153 499 144Q499 139 496 130Q455 -11 378 -11Q333 -11 305 15T277 90Q277 108 280 121T283 145Q283 167 269 183T234 206T200 217T182 220H180Q168 178 159 139T145 81T136 44T129 20T122 7T111 -2Q98 -11 83 -11Q66 -11 57 -1T48 16Q48 26 85 176T158 471L195 616Q196 629 188 632T149 637H144Q134 637 131 637T124 640T121 647Z">< path stroke-width="1" id="MJSZ2-2211" d="M60 948Q63 950 665 950H1267L1325 815Q1384 677 1388 669H1348L1341 683Q1320 724 1285 761Q1235 809 1174 838T1033 881T882 898T699 902H574H543H251L259 891Q722 258 724 252Q725 250 724 246Q721 243 460 -56L196 -356Q196 -357 407 -357Q459 -357 548 -357T676 -358Q812 -358 896 -353T1063 -332T1204 -283T1307 -196Q1328 -170 1348 -124H1388Q1388 -125 1381 -145T1356 -210T1325 -294L1267 -449L666 -450Q64 -450 61 -448Q55 -446 55 -439Q55 -437 57 -433L590 177Q590 178 557 222T452 366T322 544L56 909L55 924Q55 945 60 948Z">< path stroke-width="1" id="MJMAIN-2217" d="M229 286Q216 420 216 436Q216 454 240 464Q241 464 245 464T251 465Q263 464 273 456T283 436Q283 419 277 356T270 286L328 328Q384 369 389 372T399 375Q412 375 423 365T435 338Q435 325 425 315Q420 312 357 282T289 250L355 219L425 184Q434 175 434 161Q434 146 425 136T401 125Q393 125 383 131T328 171L270 213Q283 79 283 63Q283 53 276 44T250 35Q231 35 224 44T216 63Q216 80 222 143T229 213L171 171Q115 130 110 127Q106 124 100 124Q87 124 76 134T64 161Q64 166 64 169T67 175T72 181T81 188T94 195T113 204T138 215T170 230T210 250L74 315Q65 324 65 338Q65 353 74 363T98 374Q106 374 116 368T171 328L229 286Z">< path stroke-width="1" id="MJMATHI-76" d="M173 380Q173 405 154 405Q130 405 104 376T61 287Q60 286 59 284T58 281T56 279T53 278T49 278T41 278H27Q21 284 21 287Q21 294 29 316T53 368T97 419T160 441Q202 441 225 417T249 361Q249 344 246 335Q246 329 231 291T200 202T182 113Q182 86 187 69Q200 26 250 26Q287 26 319 60T369 139T398 222T409 277Q409 300 401 317T383 343T365 361T357 383Q357 405 376 424T417 443Q436 443 451 425T467 367Q467 340 455 284T418 159T347 40T241 -11Q177 -11 139 22Q102 54 102 117Q102 148 110 181T151 298Q173 362 173 380Z">< path stroke-width="1" id="MJMATHI-63" d="M34 159Q34 268 120 355T306 442Q362 442 394 418T427 355Q427 326 408 306T360 285Q341 285 330 295T319 325T330 359T352 380T366 386H367Q367 388 361 392T340 400T306 404Q276 404 249 390Q228 381 206 359Q162 315 142 235T121 119Q121 73 147 50Q169 26 205 26H209Q321 26 394 111Q403 121 406 121Q410 121 419 112T429 98T420 83T391 55T346 25T282 0T202 -11Q127 -11 81 37T34 159Z">< path stroke-width="1" id="MJMATHI-4B" d="M285 628Q285 635 228 637Q205 637 198 638T191 647Q191 649 193 661Q199 681 203 682Q205 683 214 683H219Q260 681 355 681Q389 681 418 681T463 682T483 682Q500 682 500 674Q500 669 497 660Q496 658 496 654T495 648T493 644T490 641T486 639T479 638T470 637T456 637Q416 636 405 634T387 623L306 305Q307 305 490 449T678 597Q692 611 692 620Q692 635 667 637Q651 637 651 648Q651 650 654 662T659 677Q662 682 676 682Q680 682 711 681T791 680Q814 680 839 681T869 682Q889 682 889 672Q889 650 881 642Q878 637 862 637Q787 632 726 586Q710 576 656 534T556 455L509 418L518 396Q527 374 546 329T581 244Q656 67 661 61Q663 59 666 57Q680 47 717 46H738Q744 38 744 37T741 19Q737 6 731 0H720Q680 3 625 3Q503 3 488 0H478Q472 6 472 9T474 27Q478 40 480 43T491 46H494Q544 46 544 71Q544 75 517 141T485 216L427 354L359 301L291 248L268 155Q245 63 245 58Q245 51 253 49T303 46H334Q340 37 340 35Q340 19 333 5Q328 0 317 0Q314 0 280 1T180 2Q118 2 85 2T49 1Q31 1 31 11Q31 13 34 25Q38 41 42 43T65 46Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Z">< path stroke-width="1" id="MJMAIN-65" d="M28 218Q28 273 48 318T98 391T163 433T229 448Q282 448 320 430T378 380T406 316T415 245Q415 238 408 231H126V216Q126 68 226 36Q246 30 270 30Q312 30 342 62Q359 79 369 104L379 128Q382 131 395 131H398Q415 131 415 121Q415 117 412 108Q393 53 349 21T250 -11Q155 -11 92 58T28 218ZM333 275Q322 403 238 411H236Q228 411 220 410T195 402T166 381T143 340T127 274V267H333V275Z">< path stroke-width="1" id="MJMAIN-70" d="M36 -148H50Q89 -148 97 -134V-126Q97 -119 97 -107T97 -77T98 -38T98 6T98 55T98 106Q98 140 98 177T98 243T98 296T97 335T97 351Q94 370 83 376T38 385H20V408Q20 431 22 431L32 432Q42 433 61 434T98 436Q115 437 135 438T165 441T176 442H179V416L180 390L188 397Q247 441 326 441Q407 441 464 377T522 216Q522 115 457 52T310 -11Q242 -11 190 33L182 40V-45V-101Q182 -128 184 -134T195 -145Q216 -148 244 -148H260V-194H252L228 -193Q205 -192 178 -192T140 -191Q37 -191 28 -194H20V-148H36ZM424 218Q424 292 390 347T305 402Q234 402 182 337V98Q222 26 294 26Q345 26 384 80T424 218Z">< path stroke-width="1" id="MJSZ1-2211" d="M61 748Q64 750 489 750H913L954 640Q965 609 976 579T993 533T999 516H979L959 517Q936 579 886 621T777 682Q724 700 655 705T436 710H319Q183 710 183 709Q186 706 348 484T511 259Q517 250 513 244L490 216Q466 188 420 134T330 27L149 -187Q149 -188 362 -188Q388 -188 436 -188T506 -189Q679 -189 778 -162T936 -43Q946 -27 959 6H999L913 -249L489 -250Q65 -250 62 -248Q56 -246 56 -239Q56 -234 118 -161Q186 -81 245 -11L428 206Q428 207 242 462L57 717L56 728Q56 744 61 748Z">< path stroke-width="1" id="MJMATHI-73" d="M131 289Q131 321 147 354T203 415T300 442Q362 442 390 415T419 355Q419 323 402 308T364 292Q351 292 340 300T328 326Q328 342 337 354T354 372T367 378Q368 378 368 379Q368 382 361 388T336 399T297 405Q249 405 227 379T204 326Q204 301 223 291T278 274T330 259Q396 230 396 163Q396 135 385 107T352 51T289 7T195 -10Q118 -10 86 19T53 87Q53 126 74 143T118 160Q133 160 146 151T160 120Q160 94 142 76T111 58Q109 57 108 57T107 55Q108 52 115 47T146 34T201 27Q237 27 263 38T301 66T318 97T323 122Q323 150 302 164T254 181T195 196T148 231Q131 256 131 289Z">< path stroke-width="1" id="MJMATHI-79" d="M21 287Q21 301 36 335T84 406T158 442Q199 442 224 419T250 355Q248 336 247 334Q247 331 231 288T198 191T182 105Q182 62 196 45T238 27Q261 27 281 38T312 61T339 94Q339 95 344 114T358 173T377 247Q415 397 419 404Q432 431 462 431Q475 431 483 424T494 412T496 403Q496 390 447 193T391 -23Q363 -106 294 -155T156 -205Q111 -205 77 -183T43 -117Q43 -95 50 -80T69 -58T89 -48T106 -45Q150 -45 150 -87Q150 -107 138 -122T115 -142T102 -147L99 -148Q101 -153 118 -160T152 -167H160Q177 -167 186 -165Q219 -156 247 -127T290 -65T313 -9T321 21L315 17Q309 13 296 6T270 -6Q250 -11 231 -11Q185 -11 150 11T104 82Q103 89 103 113Q103 170 138 262T173 379Q173 380 173 381Q173 390 173 393T169 400T158 404H154Q131 404 112 385T82 344T65 302T57 280Q55 278 41 278H27Q21 284 21 287Z">< path stroke-width="1" id="MJMAIN-2208" d="M84 250Q84 372 166 450T360 539Q361 539 377 539T419 540T469 540H568Q583 532 583 520Q583 511 570 501L466 500Q355 499 329 494Q280 482 242 458T183 409T147 354T129 306T124 272V270H568Q583 262 583 250T568 230H124V228Q124 207 134 177T167 112T231 48T328 7Q355 1 466 0H570Q583 -10 583 -20Q583 -32 568 -40H471Q464 -40 446 -40T417 -41Q262 -41 172 45Q84 127 84 250Z">< path stroke-width="1" id="MJMAIN-7B" d="M434 -231Q434 -244 428 -250H410Q281 -250 230 -184Q225 -177 222 -172T217 -161T213 -148T211 -133T210 -111T209 -84T209 -47T209 0Q209 21 209 53Q208 142 204 153Q203 154 203 155Q189 191 153 211T82 231Q71 231 68 234T65 250T68 266T82 269Q116 269 152 289T203 345Q208 356 208 377T209 529V579Q209 634 215 656T244 698Q270 724 324 740Q361 748 377 749Q379 749 390 749T408 750H428Q434 744 434 732Q434 719 431 716Q429 713 415 713Q362 710 332 689T296 647Q291 634 291 499V417Q291 370 288 353T271 314Q240 271 184 255L170 250L184 245Q202 239 220 230T262 196T290 137Q291 131 291 1Q291 -134 296 -147Q306 -174 339 -192T415 -213Q429 -213 431 -216Q434 -219 434 -231Z">< path stroke-width="1" id="MJMAIN-7D" d="M65 731Q65 745 68 747T88 750Q171 750 216 725T279 670Q288 649 289 635T291 501Q292 362 293 357Q306 312 345 291T417 269Q428 269 431 266T434 250T431 234T417 231Q380 231 345 210T298 157Q293 143 292 121T291 -28V-79Q291 -134 285 -156T256 -198Q202 -250 89 -250Q71 -250 68 -247T65 -230Q65 -224 65 -223T66 -218T69 -214T77 -213Q91 -213 108 -210T146 -200T183 -177T207 -139Q208 -134 209 3L210 139Q223 196 280 230Q315 247 330 250Q305 257 280 270Q225 304 212 352L210 362L209 498Q208 635 207 640Q195 680 154 696T77 713Q68 713 67 716T65 731Z">< path stroke-width="1" id="MJMATHI-4C" d="M228 637Q194 637 192 641Q191 643 191 649Q191 673 202 682Q204 683 217 683Q271 680 344 680Q485 680 506 683H518Q524 677 524 674T522 656Q517 641 513 637H475Q406 636 394 628Q387 624 380 600T313 336Q297 271 279 198T252 88L243 52Q243 48 252 48T311 46H328Q360 46 379 47T428 54T478 72T522 106T564 161Q580 191 594 228T611 270Q616 273 628 273H641Q647 264 647 262T627 203T583 83T557 9Q555 4 553 3T537 0T494 -1Q483 -1 418 -1T294 0H116Q32 0 32 10Q32 17 34 24Q39 43 44 45Q48 46 59 46H65Q92 46 125 49Q139 52 144 61Q147 65 216 339T285 628Q285 635 228 637Z">< path stroke-width="1" id="MJMAIN-2212" d="M84 237T84 250T98 270H679Q694 262 694 250T679 230H98Q84 237 84 250Z">< path stroke-width="1" id="MJMAIN-6C" d="M42 46H56Q95 46 103 60V68Q103 77 103 91T103 124T104 167T104 217T104 272T104 329Q104 366 104 407T104 482T104 542T103 586T103 603Q100 622 89 628T44 637H26V660Q26 683 28 683L38 684Q48 685 67 686T104 688Q121 689 141 690T171 693T182 694H185V379Q185 62 186 60Q190 52 198 49Q219 46 247 46H263V0H255L232 1Q209 2 183 2T145 3T107 3T57 1L34 0H26V46H42Z">< path stroke-width="1" id="MJMAIN-67" d="M329 409Q373 453 429 453Q459 453 472 434T485 396Q485 382 476 371T449 360Q416 360 412 390Q410 404 415 411Q415 412 416 414V415Q388 412 363 393Q355 388 355 386Q355 385 359 381T368 369T379 351T388 325T392 292Q392 230 343 187T222 143Q172 143 123 171Q112 153 112 133Q112 98 138 81Q147 75 155 75T227 73Q311 72 335 67Q396 58 431 26Q470 -13 470 -72Q470 -139 392 -175Q332 -206 250 -206Q167 -206 107 -175Q29 -140 29 -75Q29 -39 50 -15T92 18L103 24Q67 55 67 108Q67 155 96 193Q52 237 52 292Q52 355 102 398T223 442Q274 442 318 416L329 409ZM299 343Q294 371 273 387T221 404Q192 404 171 388T145 343Q142 326 142 292Q142 248 149 227T179 192Q196 182 222 182Q244 182 260 189T283 207T294 227T299 242Q302 258 302 292T299 343ZM403 -75Q403 -50 389 -34T348 -11T299 -2T245 0H218Q151 0 138 -6Q118 -15 107 -34T95 -74Q95 -84 101 -97T122 -127T170 -155T250 -167Q319 -167 361 -139T403 -75Z">< path stroke-width="1" id="MJAMS-52" d="M17 665Q17 672 28 683H221Q415 681 439 677Q461 673 481 667T516 654T544 639T566 623T584 607T597 592T607 578T614 565T618 554L621 548Q626 530 626 497Q626 447 613 419Q578 348 473 326L455 321Q462 310 473 292T517 226T578 141T637 72T686 35Q705 30 705 16Q705 7 693 -1H510Q503 6 404 159L306 310H268V183Q270 67 271 59Q274 42 291 38Q295 37 319 35Q344 35 353 28Q362 17 353 3L346 -1H28Q16 5 16 16Q16 35 55 35Q96 38 101 52Q106 60 106 341T101 632Q95 645 55 648Q17 648 17 665ZM241 35Q238 42 237 45T235 78T233 163T233 337V621L237 635L244 648H133Q136 641 137 638T139 603T141 517T141 341Q141 131 140 89T134 37Q133 36 133 35H241ZM457 496Q457 540 449 570T425 615T400 634T377 643Q374 643 339 648Q300 648 281 635Q271 628 270 610T268 481V346H284Q327 346 375 352Q421 364 439 392T457 496ZM492 537T492 496T488 427T478 389T469 371T464 361Q464 360 465 360Q469 360 497 370Q593 400 593 495Q593 592 477 630L457 637L461 626Q474 611 488 561Q492 537 492 496ZM464 243Q411 317 410 317Q404 317 401 315Q384 315 370 312H346L526 35H619L606 50Q553 109 464 243Z">< path stroke-width="1" id="MJMATHI-6E" d="M21 287Q22 293 24 303T36 341T56 388T89 425T135 442Q171 442 195 424T225 390T231 369Q231 367 232 367L243 378Q304 442 382 442Q436 442 469 415T503 336T465 179T427 52Q427 26 444 26Q450 26 453 27Q482 32 505 65T540 145Q542 153 560 153Q580 153 580 145Q580 144 576 130Q568 101 554 73T508 17T439 -10Q392 -10 371 17T350 73Q350 92 386 193T423 345Q423 404 379 404H374Q288 404 229 303L222 291L189 157Q156 26 151 16Q138 -11 108 -11Q95 -11 87 -5T76 7T74 17Q74 30 112 180T152 343Q153 348 153 366Q153 405 129 405Q91 405 66 305Q60 285 60 284Q58 278 41 278H27Q21 284 21 287Z">

Multilayer perceptron

Multilayer perceptron (MLP) is the simplest feed-forward neural network. It mitigates the constraints of original perceptron that was able to learn only linearly separable patterns from the data. It achieves this by introducing at least one hidden layer in order to learn representation of the data that would enable linear separation.

In the first layer MLP apply linear transformations to the data point < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="1.33ex" height="1.676ex" style="vertical-align: -0.338ex;" viewBox="0 -576.1 572.5 721.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-78">:

< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="33.627ex" height="3.509ex" style="vertical-align: -1.338ex;" viewBox="0 -934.9 14478.1 1510.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-66">< use transform="scale(0.707)" x="693" y="-213" xlink:href="#MJMATHI-6A">< use x="882" y="0" xlink:href="#MJMAIN-28">< use x="1271" y="0" xlink:href="#MJMATHI-78">< use x="1844" y="0" xlink:href="#MJMAIN-29">< use x="2511" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3567,0)">< use x="0" y="0" xlink:href="#MJMATHI-77">< use transform="scale(0.707)" x="1013" y="488" xlink:href="#MJMAIN-22A4">< use transform="scale(0.707)" x="1013" y="-430" xlink:href="#MJMATHI-6A">< use x="4934" y="0" xlink:href="#MJMATHI-78">< use x="5729" y="0" xlink:href="#MJMAIN-2B">< g transform="translate(6730,0)">< use x="0" y="0" xlink:href="#MJMATHI-62">< use transform="scale(0.707)" x="607" y="-213" xlink:href="#MJMATHI-6A">< g transform="translate(8051,0)">< use x="0" y="0" xlink:href="#MJMAIN-66">< use x="372" y="0" xlink:href="#MJMAIN-6F">< use x="873" y="0" xlink:href="#MJMAIN-72">< use x="9816" y="0" xlink:href="#MJMATHI-6A">< use x="10507" y="0" xlink:href="#MJMAIN-3D">< use x="11563" y="0" xlink:href="#MJMAIN-31">< use x="12063" y="0" xlink:href="#MJMAIN-2C">< use x="12509" y="0" xlink:href="#MJMAIN-2E">< use x="12954" y="0" xlink:href="#MJMAIN-2E">< use x="13399" y="0" xlink:href="#MJMAIN-2C">< use x="13844" y="0" xlink:href="#MJMATHI-4A">

the number of the transformations is the number of hidden nodes in the first hidden layer.

Next it applies non-linear transformation of outputs using so called activation function. Using linear function as a activation function would defeat the purpose of MLP as composition of linear transformations is still linear transformation.

The most often used activation function is so called rectifier:

< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="17.284ex" height="2.843ex" style="vertical-align: -0.838ex;" viewBox="0 -863.1 7441.7 1223.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use x="596" y="0" xlink:href="#MJMAIN-28">< use x="986" y="0" xlink:href="#MJMATHI-78">< use x="1558" y="0" xlink:href="#MJMAIN-29">< use x="2225" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3282,0)">< use xlink:href="#MJMAIN-6D">< use x="833" y="0" xlink:href="#MJMAIN-61">< use x="1334" y="0" xlink:href="#MJMAIN-78">< use x="5144" y="0" xlink:href="#MJMAIN-28">< use x="5534" y="0" xlink:href="#MJMAIN-30">< use x="6034" y="0" xlink:href="#MJMAIN-2C">< use x="6479" y="0" xlink:href="#MJMATHI-78">< use x="7052" y="0" xlink:href="#MJMAIN-29">

Finally the outputs of activation function are again combined using linear transformation:

< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.137ex" height="5.843ex" style="vertical-align: -3.338ex;" viewBox="0 -1078.4 12975.6 2515.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-6B">< use x="1045" y="0" xlink:href="#MJMAIN-28">< use x="1434" y="0" xlink:href="#MJMATHI-78">< use x="2007" y="0" xlink:href="#MJMAIN-29">< use x="2674" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(3730,0)">< use x="0" y="0" xlink:href="#MJSZ2-2211">< use transform="scale(0.707)" x="815" y="-1536" xlink:href="#MJMATHI-6A">< use x="5341" y="0" xlink:href="#MJMATHI-3D5">< use x="5938" y="0" xlink:href="#MJMAIN-28">< g transform="translate(6327,0)">< use x="0" y="0" xlink:href="#MJMATHI-66">< use transform="scale(0.707)" x="693" y="-213" xlink:href="#MJMATHI-6A">< use x="7210" y="0" xlink:href="#MJMAIN-28">< use x="7599" y="0" xlink:href="#MJMATHI-78">< use x="8172" y="0" xlink:href="#MJMAIN-29">< use x="8561" y="0" xlink:href="#MJMAIN-29">< use x="9173" y="0" xlink:href="#MJMAIN-2217">< g transform="translate(9896,0)">< use x="0" y="0" xlink:href="#MJMATHI-76">< use transform="scale(0.707)" x="686" y="499" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="686" y="-430" xlink:href="#MJMATHI-6A">< use x="11072" y="0" xlink:href="#MJMAIN-2B">< g transform="translate(12073,0)">< use x="0" y="0" xlink:href="#MJMATHI-63">< use transform="scale(0.707)" x="613" y="-213" xlink:href="#MJMATHI-6B">

At this point one can either repeat activation step and extend network with next activation layer or apply final transformation of the outputs to fit the algorithm objective. In case of classification problems most often used transformation is softmax function:

< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.361ex" height="6.509ex" style="vertical-align: -2.671ex;" viewBox="0 -1652.5 13072.2 2802.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="1065" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1454,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMAIN-31">< use x="2485" y="0" xlink:href="#MJMAIN-2C">< use x="2930" y="0" xlink:href="#MJMAIN-2E">< use x="3375" y="0" xlink:href="#MJMAIN-2E">< use x="3820" y="0" xlink:href="#MJMAIN-2E">< use x="4265" y="0" xlink:href="#MJMAIN-2C">< g transform="translate(4710,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-4B">< use x="6016" y="0" xlink:href="#MJMAIN-29">< use x="6683" y="0" xlink:href="#MJMAIN-3D">< g transform="translate(7462,0)">< g transform="translate(397,0)">< rect stroke="none" width="450" x="0" y="220">< g transform="translate(869,770)">< use xlink:href="#MJMAIN-65">< use x="444" y="0" xlink:href="#MJMAIN-78">< use x="973" y="0" xlink:href="#MJMAIN-70">< use x="1529" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1919,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-6B">< use x="2964" y="0" xlink:href="#MJMAIN-29">< g transform="translate(60,-771)">< use x="0" y="0" xlink:href="#MJSZ1-2211">< use transform="scale(0.707)" x="1494" y="-405" xlink:href="#MJMATHI-73">< g transform="translate(1655,0)">< use xlink:href="#MJMAIN-65">< use x="444" y="0" xlink:href="#MJMAIN-78">< use x="973" y="0" xlink:href="#MJMAIN-70">< use x="3184" y="0" xlink:href="#MJMAIN-28">< g transform="translate(3574,0)">< use x="0" y="0" xlink:href="#MJMATHI-68">< use transform="scale(0.707)" x="815" y="-213" xlink:href="#MJMATHI-73">< use x="4582" y="0" xlink:href="#MJMAIN-29">

which maps real valued vector to a vector of probabilities.

In case of classification problems the most often used loss function is cross-entropy between class label < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="14.719ex" height="2.843ex" style="vertical-align: -0.838ex;" viewBox="0 -863.1 6337.4 1223.9" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-79">< use x="775" y="0" xlink:href="#MJMAIN-2208">< use x="1720" y="0" xlink:href="#MJMAIN-7B">< use x="2221" y="0" xlink:href="#MJMAIN-31">< use x="2721" y="0" xlink:href="#MJMAIN-2C">< use x="3166" y="0" xlink:href="#MJMAIN-2E">< use x="3611" y="0" xlink:href="#MJMAIN-2E">< use x="4057" y="0" xlink:href="#MJMAIN-2E">< use x="4502" y="0" xlink:href="#MJMAIN-2C">< use x="4947" y="0" xlink:href="#MJMATHI-4B">< use x="5836" y="0" xlink:href="#MJMAIN-7D"> and probability returned by softmax function

< svg xmlns:xlink="http://www.w3.org/1999/xlink" width="30.725ex" height="7.343ex" style="vertical-align: -3.005ex;" viewBox="0 -1867.7 13228.6 3161.4" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJMATHI-4C">< use x="681" y="0" xlink:href="#MJMAIN-28">< g transform="translate(1071,0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="2136" y="0" xlink:href="#MJMAIN-2C">< use x="2581" y="0" xlink:href="#MJMATHI-79">< use x="3078" y="0" xlink:href="#MJMAIN-29">< use x="3746" y="0" xlink:href="#MJMAIN-3D">< use x="4802" y="0" xlink:href="#MJMAIN-2212">< g transform="translate(5747,0)">< use x="0" y="0" xlink:href="#MJSZ2-2211">< g transform="translate(85,-1110)">< use transform="scale(0.707)" x="0" y="0" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="521" y="0" xlink:href="#MJMAIN-3D">< use transform="scale(0.707)" x="1300" y="0" xlink:href="#MJMAIN-31">< use transform="scale(0.707)" x="576" y="1627" xlink:href="#MJMATHI-4B">< g transform="translate(7358,0)">< use x="0" y="0" xlink:href="#MJMAIN-31">< g transform="translate(500,-187)">< use transform="scale(0.707)" x="0" y="0" xlink:href="#MJMAIN-7B">< use transform="scale(0.707)" x="500" y="0" xlink:href="#MJMATHI-79">< use transform="scale(0.707)" x="998" y="0" xlink:href="#MJMAIN-3D">< use transform="scale(0.707)" x="1776" y="0" xlink:href="#MJMATHI-6B">< use transform="scale(0.707)" x="2298" y="0" xlink:href="#MJMAIN-7D">< g transform="translate(10104,0)">< use xlink:href="#MJMAIN-6C">< use x="278" y="0" xlink:href="#MJMAIN-6F">< use x="779" y="0" xlink:href="#MJMAIN-67">< use x="11384" y="0" xlink:href="#MJMAIN-28">< g transform="translate(11773,0)">< use x="0" y="0" xlink:href="#MJMATHI-3D5">< use transform="scale(0.707)" x="843" y="-213" xlink:href="#MJMATHI-6B">< use x="12839" y="0" xlink:href="#MJMAIN-29">

which is averaged over all training observations.

Universal Approximation Theorem

According to the theorem first proved by George Cybenko for sigmoid activation function: “feed-forward network with a single hidden layer containing a finite number of neurons (i.e., a multilayer perceptron), can approximate continuous functions on compact subsets of < svg xmlns:xlink="http://www.w3.org/1999/xlink" width="2.897ex" height="2.343ex" style="vertical-align: -0.338ex;" viewBox="0 -863.1 1247.1 1008.6" role="img" focusable="false">< g stroke="currentColor" fill="currentColor" stroke-width="0" transform="matrix(1 0 0 -1 0 0)">< use x="0" y="0" xlink:href="#MJAMS-52">< use transform="scale(0.707)" x="1021" y="581" xlink:href="#MJMATHI-6E">, under mild assumptions on the activation function.”

Lets put mlp to the test then. For this purpose I will use sprials dataset from mlbench package.

MXNet

MXNet is an open-source deep learning framework that allows you to define, train, and deploy deep neural networks on a wide array of devices, from cloud infrastructure to mobile devices and it allows to mix symbolic and imperative programming flavors. For example custom loss functions and accuracy measures.

Read more: http://mxnet.io

Network configuration

MXNet package expose so called symbolic API for R users. Its purpose is to create user friendly way of building neural networks abstracting out computational details to the MXNet specialized engine.

Most important symbols:

Below is the example of code that configures perceptron with one hidden layer.

  ########### Network configuration ########
 
  # variables
  act <- mx.symbol.Variable("data")
  
  # affine transformation
  fc <- mx.symbol.FullyConnected(act, num.hidden = 10)
  
  # non-linear activation 
  act <- mx.symbol.Activation(data = fc, act_type = "relu")

  # affine transformation
  fc <- mx.symbol.FullyConnected(act, num.hidden = 2)
  
  # softmax output and cross-
  mlp <- mx.symbol.SoftmaxOutput(fc)

Preparing data

  set.seed(2015)

  ############ sprials dataset ############

  s <- sample(x = c("train", "test"), 
              size = 1000, 
              prob = c(.8,.2),
              replace = TRUE)
  
  dta <- mlbench.spirals(n = 1000, cycles = 1.2, sd = .03)
  dta <- cbind(dta[["x"]], as.integer(dta[["classes"]]) - 1)
  colnames(dta) <- c("x","y","label")
  
  ######### train, validate, test ##########

  dta.train <- dta[s == "train",]
  dta.test <- dta[s == "test",]

Network training

Feed-forward networks are trained using iterative gradient descent type of algorithm. Additionally during single forward pass only subset of the data is used called batch. Process is repeated until all training examples are used. This is called an epoch. After every epoch MXNet returns training accuracy:

  ############# basic training #############

  mx.set.seed(2014)
  model <- mx.model.FeedForward.create(
            symbol = mlp,
            X = dta.train[, c("x", "y")], 
            y = dta.train[, c("label")],
            num.round = 5,
            array.layout = "rowmajor",
            learning.rate = 1,
            eval.metric = mx.metric.accuracy)
## Start training with 1 devices
## [1] Train-accuracy=0.506510416666667
## [2] Train-accuracy=0.5
## [3] Train-accuracy=0.5
## [4] Train-accuracy=0.5
## [5] Train-accuracy=0.5

Custom call-back

In order to stop process of training when the progress in accuracy is below certain level of tolerance we need to add custom callback to the feed forward procedure. It is called after every epoch to check if algorithm progresses. If not it will terminate optimization procedure and return results.

  ######## custom stopping criterion #######

  mx.callback.train.stop <- function(tol = 1e-3, 
                                     mean.n = 1e2, 
                                     period = 100, 
                                     min.iter = 100
                                     ) {
    function(iteration, nbatch, env, verbose = TRUE) {
      if (nbatch == 0 & !is.null(env$metric)) {
          continue <- TRUE
          acc.train <- env$metric$get(env$train.metric)$value
          if (is.null(env$acc.log)) {
            env$acc.log <- acc.train
          } else {
            if ((abs(acc.train - mean(tail(env$acc.log, mean.n))) < tol &
                abs(acc.train - max(env$acc.log)) < tol &
                iteration > min.iter) | 
                acc.train == 1) {
              cat("Training finished with final accuracy: ", 
                  round(acc.train * 100, 2), " %\n", sep = "")
              continue <- FALSE 
            }
            env$acc.log <- c(env$acc.log, acc.train)
          }
      }
      if (iteration %% period == 0) {
        cat("[", iteration,"]"," training accuracy: ", 
            round(acc.train * 100, 2), " %\n", sep = "") 
      }
      return(continue)
      }
   }

  ###### training with custom stopping #####

  mx.set.seed(2014)
  model <- mx.model.FeedForward.create(
          symbol = mlp,
          X = dta.train[, c("x", "y")], 
          y = dta.train[, c("label")],
          num.round = 2000,
          array.layout = "rowmajor",
          learning.rate = 1,
          epoch.end.callback = mx.callback.train.stop(),
          eval.metric = mx.metric.accuracy,
          verbose = FALSE
          )
## [100] training accuracy: 90.07 %
## [200] training accuracy: 98.88 %
## [300] training accuracy: 99.33 %
## Training finished with final accuracy: 99.44 %

Results

Learning curve

Evolution of decision boundary


Code for this post can be found here: https://github.com/jakubglinka/posts/tree/master/neural_networks_part1

To leave a comment for the author, please follow the link and comment on their blog: Jakub Glinka's Blog.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.