|
441 | 441 | "Env = Environment(params={\"dimensionality\": \"1D\", \"boundary_conditions\": \"periodic\"})\n",
|
442 | 442 | "\n",
|
443 | 443 | "# Put agent (who will move randomly under the ratinabox Ornstein Uhlenbeck random motion policy) inside the environement\n",
|
444 |
| - "Ag = Agent(Env, params={'dt':0.01})\n", |
| 444 | + "Ag = Agent(Env, params={'dt':0.02})\n", |
445 | 445 | "Ag.speed_mean = 0\n",
|
446 | 446 | "Ag.speed_std = 0.3\n",
|
447 | 447 | "\n",
|
|
461 | 461 | " params={\n",
|
462 | 462 | " \"n\": n_cells,\n",
|
463 | 463 | " \"name\": \"ConjunctiveCells_left\",\n",
|
| 464 | + " # nb. this tutorial is now quite old so the way that FeedForwardLayer --- define in the main codebase --- activations are set (passing \"activation_function\" at initialisation) no longer matches the way DendriticCompartment --- defined above --- activations are set (setting \"activation_params\" after initialisation). Sorry about this! TODO: update DendriticCompartment to be FeedForwardLayer subclass\n", |
| 465 | + " \"activation_function\": {\n", |
| 466 | + " \"activation\": \"relu\",\n", |
| 467 | + " \"threshold\": 1,\n", |
| 468 | + " }\n", |
464 | 469 | " },\n",
|
465 | 470 | ")\n",
|
466 | 471 | "\n",
|
|
469 | 474 | " params={\n",
|
470 | 475 | " \"n\": n_cells,\n",
|
471 | 476 | " \"name\": \"ConjunctiveCells_right\",\n",
|
| 477 | + " \"activation_function\": {\n", |
| 478 | + " \"activation\": \"relu\",\n", |
| 479 | + " \"threshold\": 1,\n", |
| 480 | + " }\n", |
472 | 481 | " },\n",
|
473 | 482 | ")\n",
|
474 | 483 | "\n",
|
|
497 | 506 | " [-1, 1]\n",
|
498 | 507 | ") # thus right velocity excites these cells and rigleftht velocity shuts them off\n",
|
499 | 508 | "ConjunctiveCells_left.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n",
|
500 |
| - "ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n", |
501 |
| - "ConjunctiveCells_left.activation_params = {\n", |
502 |
| - " \"activation\": \"relu\",\n", |
503 |
| - " \"threshold\": 1,\n", |
504 |
| - " \"width_x\": 2,\n", |
505 |
| - "}\n", |
506 |
| - "ConjunctiveCells_right.activation_params = {\n", |
507 |
| - " \"activation\": \"relu\",\n", |
508 |
| - " \"threshold\": 1,\n", |
509 |
| - " \"width_x\": 2,\n", |
510 |
| - "}" |
| 509 | + "ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)" |
511 | 510 | ]
|
512 | 511 | },
|
513 | 512 | {
|
|
516 | 515 | "source": [
|
517 | 516 | "### Train the network\n",
|
518 | 517 | "\n",
|
519 |
| - "Train it for 20 minutes" |
| 518 | + "Train it for 60 minutes" |
520 | 519 | ]
|
521 | 520 | },
|
522 | 521 | {
|
|
525 | 524 | "metadata": {},
|
526 | 525 | "outputs": [],
|
527 | 526 | "source": [
|
528 |
| - "for i in tqdm(range(int(10 * 60 / Ag.dt))):\n", |
| 527 | + "for i in tqdm(range(int(60 * 60 / Ag.dt))):\n", |
529 | 528 | " # update agent\n",
|
530 | 529 | " Ag.update()\n",
|
531 | 530 | " # update firing rates of all the cell layers\n",
|
|
556 | 555 | "source": [
|
557 | 556 | "fig, ax = RingAttractor.plot_loss()\n",
|
558 | 557 | "\n",
|
| 558 | + "save_plots = False\n", |
559 | 559 | "if save_plots == True:\n",
|
560 | 560 | " tpl.saveFigure(fig, \"PI_loss\")"
|
561 | 561 | ]
|
|
0 commit comments