Skip to content

Commit 92775c7

Browse files
committed
Fix activation function bug causing this demo to break
1 parent 96296f0 commit 92775c7

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

demos/path_integration_example.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@
441441
"Env = Environment(params={\"dimensionality\": \"1D\", \"boundary_conditions\": \"periodic\"})\n",
442442
"\n",
443443
"# Put agent (who will move randomly under the ratinabox Ornstein Uhlenbeck random motion policy) inside the environement\n",
444-
"Ag = Agent(Env, params={'dt':0.01})\n",
444+
"Ag = Agent(Env, params={'dt':0.02})\n",
445445
"Ag.speed_mean = 0\n",
446446
"Ag.speed_std = 0.3\n",
447447
"\n",
@@ -461,6 +461,11 @@
461461
" params={\n",
462462
" \"n\": n_cells,\n",
463463
" \"name\": \"ConjunctiveCells_left\",\n",
464+
" # nb. this tutorial is now quite old so the way that FeedForwardLayer --- define in the main codebase --- activations are set (passing \"activation_function\" at initialisation) no longer matches the way DendriticCompartment --- defined above --- activations are set (setting \"activation_params\" after initialisation). Sorry about this! TODO: update DendriticCompartment to be FeedForwardLayer subclass\n",
465+
" \"activation_function\": {\n",
466+
" \"activation\": \"relu\",\n",
467+
" \"threshold\": 1,\n",
468+
" }\n",
464469
" },\n",
465470
")\n",
466471
"\n",
@@ -469,6 +474,10 @@
469474
" params={\n",
470475
" \"n\": n_cells,\n",
471476
" \"name\": \"ConjunctiveCells_right\",\n",
477+
" \"activation_function\": {\n",
478+
" \"activation\": \"relu\",\n",
479+
" \"threshold\": 1,\n",
480+
" }\n",
472481
" },\n",
473482
")\n",
474483
"\n",
@@ -497,17 +506,7 @@
497506
" [-1, 1]\n",
498507
") # thus right velocity excites these cells and rigleftht velocity shuts them off\n",
499508
"ConjunctiveCells_left.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n",
500-
"ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)\n",
501-
"ConjunctiveCells_left.activation_params = {\n",
502-
" \"activation\": \"relu\",\n",
503-
" \"threshold\": 1,\n",
504-
" \"width_x\": 2,\n",
505-
"}\n",
506-
"ConjunctiveCells_right.activation_params = {\n",
507-
" \"activation\": \"relu\",\n",
508-
" \"threshold\": 1,\n",
509-
" \"width_x\": 2,\n",
510-
"}"
509+
"ConjunctiveCells_right.inputs[\"RingAttractor\"][\"w\"] = np.identity(n_cells)"
511510
]
512511
},
513512
{
@@ -516,7 +515,7 @@
516515
"source": [
517516
"### Train the network\n",
518517
"\n",
519-
"Train it for 20 minutes"
518+
"Train it for 60 minutes"
520519
]
521520
},
522521
{
@@ -525,7 +524,7 @@
525524
"metadata": {},
526525
"outputs": [],
527526
"source": [
528-
"for i in tqdm(range(int(10 * 60 / Ag.dt))):\n",
527+
"for i in tqdm(range(int(60 * 60 / Ag.dt))):\n",
529528
" # update agent\n",
530529
" Ag.update()\n",
531530
" # update firing rates of all the cell layers\n",
@@ -556,6 +555,7 @@
556555
"source": [
557556
"fig, ax = RingAttractor.plot_loss()\n",
558557
"\n",
558+
"save_plots = False\n",
559559
"if save_plots == True:\n",
560560
" tpl.saveFigure(fig, \"PI_loss\")"
561561
]

0 commit comments

Comments
 (0)