Skip to content

Commit ce7a0fc

Browse files
jtgrasbcmichelenstroferrgcoedtgaebemichaelcdevin
authored
Autograd -> Jax conversion (sandialabs#433)
* Update CONTRIBUTING.md to indicate PRs should be to the new `dev` branch * Update RELEASING.md to reflect new workflow with the `dev` branch * update docstrings (sandialabs#326) * damping naming and consistently change radiation damping (sandialabs#328) * issue 321 fd_to_td() bug (sandialabs#329) * bug bix : DC and Nyquist frequency should not be devided by two before ifft * Changed td_to_fd to scale single sided frequency components rather than TD signal * minor bug fix from issue332 sandialabs#332 * nodf -> ndof (sandialabs#334) * add DOI for Daniel's paper (sandialabs#336) * Lower tolerance for new test to fix CI failing occasionally * hyperlinks no longer have formatting, plus other small adjustments (sandialabs#348) * Merge to dev, not main (sandialabs#349) * Dev version of documentation site (sandialabs#347) * added initial file changes based on sphinx_multiversion docs and WEC-Sim implementation * removed sphinx-multiversion since it is no longer supported and made manual multiversion * now uses absolute paths, commented out linkcheck for debugging * fixed docstring errors in utilities module * updating files again that somehow got reverted * fixing path in conf.py * don't run tutorials (will revert later) * handle file moves correctly, fixed if statement to make other versions appear * fixed two bugs in versions template * reverted temp changes, changes latest to main * switched latest to main * main branch now in root directory of pages * fixed URLs with change from last commit * make other branches visible before building * switched main branch tag for more testing * fixed typo * switched dev branch to an existing branch * renamed main to latest, changed version.html file name to avoid confusion * added prints about moving files so Sphinx output isn't misleading * fixed typo with quotations * changed versions.html name back because that broke things I guess * modified contributing documentation to reflect changes * add logic to remove duplicate 'latest' branch * Fixed pathing when already on latest * remove typo * Troubleshooting complete, switching back to correct branches for deployment * Removed extra word in docstring * removed redundant function * fixed pathing so returns to same file (and fixes tutorial/API docs) * changed latest branch for demonstration * switched back latest branch for deployment * updated with new Capytaine docs URL * Add warnings when adding inertia and hydrostatic stiffness automatically (sandialabs#346) * CI workflow cleanup (sandialabs#352) * removed conda environment from workflows since newer capytaine/wavespectra work with Windows * fixed unnecessary capitalization * still create CI conda environment to fix Mac environment failures * added conda env fully back in, push workflow deploys docs, split PR workflow * conda environment activates again * mambaforge instead of miniforge * manual cache reset * reset to older version of setup-miniconda to troubleshoot * Updated workflows to newest Python version and changed references to supported versions (sandialabs#390) Co-authored-by: jtgrasb <[email protected]> * Revert to Python 3.12 (sandialabs#394) * Try specifying subversion * Test new cache * revert to 3.12 * Revert comment back to normal * use dev for docs and restrict sphinx (sandialabs#396) * Remove Sphinx version requirement (sandialabs#409) * v3.0.3 * v3.1 * Trying to convert tutorial 1 * Convert to jax progress * post-processing * clear outputs * wavebot tutorial running * wavebot tutorial running * Update to jax and numpy * Revert wavebot execution count * update pyproject.toml * Specify jax version for mac * try jaxlib * no jaxlib * add jax to environment manually * add jaxlib to env * ad jaxlib to pyproject * install jax manually for macos * conda init * install jax and jaxlib on macos * try arm64 * remove arm * make core optimization jittable * try pinning jax version * revert previous * install entirely w. pip * Add verbose outputs to testing * editable mode * no mamba * editable mode * update random inputs * remove cache environment * use pytest-cov * remove cache clear * Fix push.yml --------- Co-authored-by: Carlos A. Michelén Ströfer <[email protected]> Co-authored-by: Ryan Coe <[email protected]> Co-authored-by: Daniel Gaebele <[email protected]> Co-authored-by: Michael Devin <[email protected]> Co-authored-by: mcdevin <[email protected]>
1 parent 34d2d88 commit ce7a0fc

File tree

13 files changed

+238
-256
lines changed

13 files changed

+238
-256
lines changed

.github/workflows/pr.yml

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,52 +23,13 @@ jobs:
2323
# Checkout the WecOptTool repo
2424
- uses: actions/checkout@v3
2525

26-
# Cache the conda environment >>>
27-
# The cache key includes the OS, the date, the hash of the pyproject.toml file, and the cache number.
28-
# A new cache is created if:
29-
# - this is the first time today that this file is run (date changes)
30-
# - the content of pyproject.toml changes
31-
# - you manually change the value of the CACHE_NUMBER below
32-
# Else the existing cache is used.
26+
# Caching of environment has been removed for simplicity
3327
- name: Setup Miniforge
3428
uses: conda-incubator/setup-miniconda@v2
3529
with:
3630
miniforge-variant: Miniforge3
3731
miniforge-version: latest
3832
activate-environment: test-env
39-
use-mamba: true
40-
41-
# save the date to include in the cache key
42-
- name: Get Date
43-
id: get-date
44-
run: echo "DATE=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_ENV
45-
shell: bash
46-
47-
# create a conda yaml file
48-
- name: Create environment.yml file
49-
run: |
50-
echo "name: test-env" >> environment.yml;
51-
echo " " >> environment.yml
52-
echo "dependencies:" >> environment.yml
53-
echo " - python=${{ matrix.python-version }}" >> environment.yml
54-
echo " - capytaine" >> environment.yml
55-
echo " - wavespectra" >> environment.yml
56-
cat environment.yml
57-
58-
# use the cache if it exists
59-
- uses: actions/cache@v3
60-
env:
61-
CACHE_NUMBER: 0 # increase to reset cache manually
62-
with:
63-
path: ${{ env.CONDA }}/envs
64-
key: conda-${{ runner.os }}--${{ runner.arch }}--${{ env.DATE }}-${{ hashFiles('pyproject.toml') }}-${{ env.CACHE_NUMBER }}
65-
id: cache
66-
67-
# if cache key has changed, create new cache
68-
- name: Update environment
69-
run: mamba env update -n test-env -f environment.yml
70-
if: steps.cache.outputs.cache-hit != 'true'
71-
# <<< Cache the conda environment
7233

7334
# install libglu on ubuntu.
7435
- name: Install libglu
@@ -81,13 +42,13 @@ jobs:
8142
run: |
8243
conda activate test-env
8344
python3 -m pip install --upgrade pip
84-
pip3 install gmsh pygmsh coveralls pytest
45+
pip3 install gmsh pygmsh coveralls pytest pytest-cov
8546
pip3 install .
8647
8748
# run all tests & coverage
8849
- name: Run Test
8950
shell: bash -l {0}
90-
run: coverage run -m pytest
51+
run: pytest --cov=wecopttool --cov-report=xml
9152

9253
# upload coverage data
9354
- name: Upload coverage data to coveralls.io

.github/workflows/push.yml

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,52 +23,12 @@ jobs:
2323
# Checkout the WecOptTool repo
2424
- uses: actions/checkout@v3
2525

26-
# Cache the conda environment >>>
27-
# The cache key includes the OS, the date, the hash of the pyproject.toml file, and the cache number.
28-
# A new cache is created if:
29-
# - this is the first time today that this file is run (date changes)
30-
# - the content of pyproject.toml changes
31-
# - you manually change the value of the CACHE_NUMBER below
32-
# Else the existing cache is used.
3326
- name: Setup Miniforge
3427
uses: conda-incubator/setup-miniconda@v2
3528
with:
3629
miniforge-variant: Miniforge3
3730
miniforge-version: latest
3831
activate-environment: test-env
39-
use-mamba: true
40-
41-
# save the date to include in the cache key
42-
- name: Get Date
43-
id: get-date
44-
run: echo "DATE=$(/bin/date -u '+%Y%m%d')" >> $GITHUB_ENV
45-
shell: bash
46-
47-
# create a conda yaml file
48-
- name: Create environment.yml file
49-
run: |
50-
echo "name: test-env" >> environment.yml;
51-
echo " " >> environment.yml
52-
echo "dependencies:" >> environment.yml
53-
echo " - python=${{ matrix.python-version }}" >> environment.yml
54-
echo " - capytaine" >> environment.yml
55-
echo " - wavespectra" >> environment.yml
56-
cat environment.yml
57-
58-
# use the cache if it exists
59-
- uses: actions/cache@v3
60-
env:
61-
CACHE_NUMBER: 0 # increase to reset cache manually
62-
with:
63-
path: ${{ env.CONDA }}/envs
64-
key: conda-${{ runner.os }}--${{ runner.arch }}--${{ env.DATE }}-${{ hashFiles('pyproject.toml') }}-${{ env.CACHE_NUMBER }}
65-
id: cache
66-
67-
# if cache key has changed, create new cache
68-
- name: Update environment
69-
run: mamba env update -n test-env -f environment.yml
70-
if: steps.cache.outputs.cache-hit != 'true'
71-
# <<< Cache the conda environment
7232

7333
# install libglu on ubuntu.
7434
- name: Install libglu
@@ -81,13 +41,13 @@ jobs:
8141
run: |
8242
conda activate test-env
8343
python3 -m pip install --upgrade pip
84-
pip3 install gmsh pygmsh coveralls pytest
44+
pip3 install gmsh pygmsh coveralls pytest pytest-cov
8545
pip3 install .
8646
8747
# run all tests & coverage
8848
- name: Run Test
8949
shell: bash -l {0}
90-
run: coverage run -m pytest
50+
run: pytest --cov=wecopttool --cov-report=xml
9151

9252
# upload coverage data
9353
- name: Upload coverage data to coveralls.io

examples/tutorial_1_WaveBot.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
"metadata": {},
3131
"outputs": [],
3232
"source": [
33-
"import autograd.numpy as np\n",
33+
"import numpy as np\n",
34+
"import jax.numpy as jnp\n",
3435
"import capytaine as cpy\n",
3536
"from capytaine.io.meshio import load_from_meshio\n",
3637
"import matplotlib.pyplot as plt\n",
@@ -372,7 +373,7 @@
372373
"\n",
373374
"def const_f_pto(wec, x_wec, x_opt, wave): # Format for scipy.optimize.minimize\n",
374375
" f = pto.force(wec, x_wec, x_opt, wave, nsubsteps)\n",
375-
" return f_max - np.abs(f.flatten())\n",
376+
" return f_max - jnp.abs(f.flatten())\n",
376377
"\n",
377378
"ineq_cons = {'type': 'ineq',\n",
378379
" 'fun': const_f_pto,\n",
@@ -706,7 +707,7 @@
706707
"## Update PTO constraints and forcing\n",
707708
"def const_f_pto_2(wec, x_wec, x_opt, wave):\n",
708709
" f = pto_2.force_on_wec(wec, x_wec, x_opt, wave, nsubsteps)\n",
709-
" return f_max - np.abs(f.flatten())\n",
710+
" return f_max - jnp.abs(f.flatten())\n",
710711
"ineq_cons_2 = {'type': 'ineq', 'fun': const_f_pto_2}\n",
711712
"constraints_2 = [ineq_cons_2]\n",
712713
"f_add_2 = {'PTO': pto_2.force_on_wec}"
@@ -989,7 +990,7 @@
989990
"\n",
990991
" def const_f_pto(wec, x_wec, x_opt, wave):\n",
991992
" f = pto.force(wec, x_wec, x_opt, wave, nsubsteps)\n",
992-
" return f_max - np.abs(f.flatten())\n",
993+
" return f_max - jnp.abs(f.flatten())\n",
993994
"\n",
994995
" ineq_cons = {'type': 'ineq', 'fun': const_f_pto}\n",
995996
" constraints = [ineq_cons]\n",

examples/tutorial_2_AquaHarmonics.ipynb

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
"source": [
2525
"import capytaine as cpy\n",
2626
"from capytaine.io.meshio import load_from_meshio\n",
27-
"import autograd.numpy as np\n",
27+
"import numpy as np\n",
28+
"import jax.numpy as jnp\n",
2829
"import matplotlib.pyplot as plt\n",
2930
"from matplotlib import cm\n",
3031
"from scipy.optimize import brute\n",
@@ -376,10 +377,10 @@
376377
"source": [
377378
"def f_buoyancy(wec, x_wec, x_opt, wave, nsubsteps=1):\n",
378379
" \"\"\"Only the zeroth order component (doesn't include linear stiffness)\"\"\"\n",
379-
" return displacement * rho * g * np.ones([wec.ncomponents*nsubsteps, wec.ndof])\n",
380+
" return displacement * rho * g * jnp.ones([wec.ncomponents*nsubsteps, wec.ndof])\n",
380381
"\n",
381382
"def f_gravity(wec, x_wec, x_opt, wave, nsubsteps=1):\n",
382-
" return -1 * wec.inertia_matrix.item() * g * np.ones([wec.ncomponents*nsubsteps, wec.ndof])\n",
383+
" return -1 * wec.inertia_matrix.item() * g * jnp.ones([wec.ncomponents*nsubsteps, wec.ndof])\n",
383384
"\n",
384385
"def f_pretension_wec(wec, x_wec, x_opt, wave, nsubsteps=1):\n",
385386
" \"\"\"Pretension force as it acts on the WEC\"\"\"\n",
@@ -389,18 +390,18 @@
389390
"\n",
390391
"def f_pto_passive(wec, x_wec, x_opt, wave, nsubsteps=1):\n",
391392
" pos = wec.vec_to_dofmat(x_wec)\n",
392-
" vel = np.dot(wec.derivative_mat,pos)\n",
393-
" acc = np.dot(wec.derivative_mat, vel)\n",
393+
" vel = jnp.dot(wec.derivative_mat,pos)\n",
394+
" acc = jnp.dot(wec.derivative_mat, vel)\n",
394395
" time_matrix = wec.time_mat_nsubsteps(nsubsteps)\n",
395396
" spring = -(gear_ratios['spring']*airspring['gamma']*airspring['area']*\n",
396397
" airspring['press_init']/airspring['vol_init']) * pos\n",
397-
" f_spring = np.dot(time_matrix,spring)\n",
398+
" f_spring = jnp.dot(time_matrix,spring)\n",
398399
" fric = -(friction_pto + \n",
399400
" friction['Bpneumatic_spring_static1']*\n",
400401
" gear_ratios['spring']) * vel\n",
401-
" f_fric = np.dot(time_matrix,fric)\n",
402+
" f_fric = jnp.dot(time_matrix,fric)\n",
402403
" inertia = inertia_pto * acc\n",
403-
" f_inertia = np.dot(time_matrix,inertia)\n",
404+
" f_inertia = jnp.dot(time_matrix,inertia)\n",
404405
" return f_spring + f_fric + f_inertia\n",
405406
"\n",
406407
"def f_pto_line(wec, x_wec, x_opt, wave, nsubsteps=1):\n",
@@ -452,18 +453,18 @@
452453
"\n",
453454
"def const_peak_torque_pto(wec, x_wec, x_opt, wave):\n",
454455
" torque = pto.force(wec, x_wec, x_opt, wave, nsubsteps)\n",
455-
" return torque_peak_max - np.abs(torque.flatten())\n",
456+
" return torque_peak_max - jnp.abs(torque.flatten())\n",
456457
"\n",
457458
"def const_speed_pto(wec, x_wec, x_opt, wave):\n",
458459
" rot_vel = pto.velocity(wec, x_wec, x_opt, wave, nsubsteps)\n",
459-
" return rot_speed_max - np.abs(rot_vel.flatten())\n",
460+
" return rot_speed_max - jnp.abs(rot_vel.flatten())\n",
460461
"\n",
461462
"def const_power_pto(wec, x_wec, x_opt, wave):\n",
462463
" power_mech = (\n",
463464
" pto.velocity(wec, x_wec, x_opt, wave, nsubsteps) *\n",
464465
" pto.force(wec, x_wec, x_opt, wave, nsubsteps)\n",
465466
" )\n",
466-
" return power_max - np.abs(power_mech.flatten())\n",
467+
" return power_max - jnp.abs(power_mech.flatten())\n",
467468
"\n",
468469
"def constrain_min_tension(wec, x_wec, x_opt, wave):\n",
469470
" total_tension = -1*f_pto_line(wec, x_wec, x_opt, wave, nsubsteps)\n",

examples/tutorial_3_LUPA.ipynb

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
"import gmsh, pygmsh\n",
4141
"import capytaine as cpy\n",
4242
"from capytaine.io.meshio import load_from_meshio\n",
43-
"import autograd.numpy as np\n",
43+
"import numpy as np\n",
44+
"import jax.numpy as jnp\n",
4445
"import matplotlib.pyplot as plt\n",
4546
"import xarray as xr\n",
4647
"from scipy.optimize import brute\n",
@@ -766,8 +767,8 @@
766767
"# maximum stroke\n",
767768
"stroke_max = 0.5 # m\n",
768769
"def const_stroke_pto(wec, x_wec, x_opt, wave): \n",
769-
" pos = pto.position(wec, x_wec, x_opt, wave, nsubsteps)\n",
770-
" return stroke_max - np.abs(pos.flatten())\n",
770+
" pos = pto.position(wec, x_wec, x_opt, waves, nsubsteps)\n",
771+
" return stroke_max - jnp.abs(pos.flatten())\n",
771772
"\n",
772773
"## GENERATOR\n",
773774
"# peak torque\n",
@@ -776,20 +777,20 @@
776777
" \"\"\"Instantaneous torque must not exceed max torque Tmax - |T| >=0 \n",
777778
" \"\"\"\n",
778779
" torque = pto.force(wec, x_wec, x_opt, wave, nsubsteps) / gear_ratio(radius)\n",
779-
" return generator['max_torque'] - np.abs(torque.flatten())\n",
780+
" return generator['max_torque'] - jnp.abs(torque.flatten())\n",
780781
"\n",
781782
"# continuous torque\n",
782783
"def const_torque_pto(wec, x_wec, x_opt, wave, radius=default_radius): \n",
783784
" \"\"\"RMS torque must not exceed max continous torque \n",
784785
" Tmax_conti - Trms >=0 \"\"\"\n",
785786
" torque = pto.force(wec, x_wec, x_opt, wave, nsubsteps) / gear_ratio(radius)\n",
786-
" torque_rms = np.sqrt(np.mean(torque.flatten()**2))\n",
787+
" torque_rms = jnp.sqrt(jnp.mean(torque.flatten()**2))\n",
787788
" return generator['continuous_torque'] - torque_rms\n",
788789
"\n",
789790
"# max speed\n",
790791
"def const_speed_pto(wec, x_wec, x_opt, wave, radius=default_radius): \n",
791792
" rot_vel = pto.velocity(wec, x_wec, x_opt, wave, nsubsteps) * gear_ratio(radius)\n",
792-
" return generator['max_speed'] - np.abs(rot_vel.flatten())\n",
793+
" return generator['max_speed'] - jnp.abs(rot_vel.flatten())\n",
793794
"\n",
794795
"## Constraints\n",
795796
"constraints = [\n",
@@ -1263,7 +1264,7 @@
12631264
" stroke_max = 0.5 # m\n",
12641265
" def const_stroke_pto(wec, x_wec, x_opt, wave): \n",
12651266
" pos = pto.position(wec, x_wec, x_opt, wave, nsubsteps)\n",
1266-
" return stroke_max - np.abs(pos.flatten())\n",
1267+
" return stroke_max - jnp.abs(pos.flatten())\n",
12671268
"\n",
12681269
" ## GENERATOR\n",
12691270
" # peak torque\n",
@@ -1272,20 +1273,20 @@
12721273
" \"\"\"Instantaneous torque must not exceed max torque Tmax - |T| >=0 \n",
12731274
" \"\"\"\n",
12741275
" torque = pto.force(wec, x_wec, x_opt, wave, nsubsteps) / gear_ratio(radius)\n",
1275-
" return generator['max_torque'] - np.abs(torque.flatten())\n",
1276+
" return generator['max_torque'] - jnp.abs(torque.flatten())\n",
12761277
"\n",
12771278
" # continuous torque\n",
12781279
" def const_torque_pto(wec, x_wec, x_opt, wave): \n",
12791280
" \"\"\"RMS torque must not exceed max continous torque \n",
12801281
" Tmax_conti - Trms >=0 \"\"\"\n",
12811282
" torque = pto.force(wec, x_wec, x_opt, wave, nsubsteps) / gear_ratio(radius)\n",
1282-
" torque_rms = np.sqrt(np.mean(torque.flatten()**2))\n",
1283+
" torque_rms = jnp.sqrt(jnp.mean(torque.flatten()**2))\n",
12831284
" return generator['continuous_torque'] - torque_rms\n",
12841285
"\n",
12851286
" # max speed\n",
12861287
" def const_speed_pto(wec, x_wec, x_opt, wave): \n",
12871288
" rot_vel = pto.velocity(wec, x_wec, x_opt, wave, nsubsteps) * gear_ratio(radius)\n",
1288-
" return generator['max_speed'] - np.abs(rot_vel.flatten())\n",
1289+
" return generator['max_speed'] - jnp.abs(rot_vel.flatten())\n",
12891290
"\n",
12901291
" ## Constraints\n",
12911292
" constraints = [\n",

0 commit comments

Comments
 (0)