diff --git a/.github/workflows/pr-test.yaml b/.github/workflows/pr-test.yaml new file mode 100644 index 00000000..3106cd9b --- /dev/null +++ b/.github/workflows/pr-test.yaml @@ -0,0 +1,95 @@ +name: MediSwarm PR Validation + +on: + schedule: + - cron: '0 5 * * 0' + pull_request: + branches: + - main + - dev + +permissions: + contents: read + +jobs: + validate-swarm: + runs-on: self-hosted + timeout-minutes: 45 + + env: + DATADIR: /mnt/sda1/Odelia_challange/ODELIA_Challenge_unilateral/ + SCRATCHDIR: /mnt/scratch + SITE_NAME: UKA + PYTHONUNBUFFERED: 1 + + steps: + - name: Checkout repository (with submodules) + uses: actions/checkout@v3 + with: + submodules: true + fetch-depth: 0 + + - name: Get Docker image version + id: get_version + run: | + VERSION=$(./getVersionNumber.sh) + echo "VERSION=$VERSION" + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Build Docker image and startup kits for test project + run: ./buildDockerImageAndStartupKits.sh -p tests/provision/dummy_project_for_testing.yml -c tests/local_vpn/client_configs + + - name: Show workspace path for test project + run: | + echo "WORKSPACE_PATH: ${{ env.WORKSPACE_PATH }}" + find workspace -maxdepth 1 -type d -name "odelia_*_dummy_project_for_testing" || echo "No workspace found" + + - name: Run integration test checking documentation on github + continue-on-error: false + run: | + ./runIntegrationTests.sh check_files_on_github + + - name: Run controller unit tests + continue-on-error: false + run: | + ./runIntegrationTests.sh run_unit_tests_controller + + - name: Run dummy training standalone + continue-on-error: false + run: | + ./runIntegrationTests.sh run_dummy_training_standalone + + - name: Run dummy training in simulation mode + continue-on-error: false + run: | + ./runIntegrationTests.sh run_dummy_training_simulation_mode + + - name: Run dummy training in proof-of-concept mode + continue-on-error: false + run: | + ./runIntegrationTests.sh run_dummy_training_poc_mode + + - name: Run 3DCNN training in simulation mode + continue-on-error: false + run: | + ./runIntegrationTests.sh run_3dcnn_simulation_mode + + - name: Run integration test creating startup kits + continue-on-error: false + run: | + ./runIntegrationTests.sh create_startup_kits + + - name: Run intergration test listing licenses + continue-on-error: false + run: | + ./runIntegrationTests.sh run_list_licenses + + - name: Run integration test Docker GPU preflight check + continue-on-error: false + run: | + ./runIntegrationTests.sh run_docker_gpu_preflight_check + + - name: Run integration test Data access preflight check + continue-on-error: false + run: | + ./runIntegrationTests.sh run_data_access_preflight_check diff --git a/.github/workflows/update-apt-versions.yml b/.github/workflows/update-apt-versions.yml index baeeea8f..8e4eedc8 100644 --- a/.github/workflows/update-apt-versions.yml +++ b/.github/workflows/update-apt-versions.yml @@ -1,31 +1,24 @@ -name: Auto Update APT Versions +name: Auto Update APT Versions (Self-hosted) on: schedule: - # Every day at 05:00 UTC - - cron: '0 5 * * *' + # run eveyday at 04:00 UTC + - cron: '0 4 * * *' workflow_dispatch: jobs: update-apt: - name: Update APT Package Versions in Dockerfile - runs-on: ubuntu-latest + runs-on: self-hosted + timeout-minutes: 60 steps: - name: Checkout repository (with submodules) uses: actions/checkout@v3 with: submodules: true + fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - - - name: Install dependencies - run: sudo apt-get update && sudo apt-get install -y git apt-utils - - - name: Configure Git for CI + - name: Set up Git run: | git config --global user.email "ci@github.com" git config --global user.name "GitHub CI" @@ -33,17 +26,15 @@ jobs: - name: Create and switch to apt-update branch run: | git checkout -b ci/apt-update || git switch ci/apt-update - - - name: Make update script executable - run: chmod +x scripts/ci/update_apt_versions.sh - - name: Run APT update script - run: scripts/ci/update_apt_versions.sh + run: | + chmod +x scripts/ci/update_apt_versions.sh + scripts/ci/update_apt_versions.sh - name: Show git diff for debugging - run: git diff + run: git diff || true - - name: Push ci/apt-update to origin + - name: Push apt-update branch if: env.NO_CHANGES == 'false' run: git push origin ci/apt-update --force @@ -53,8 +44,10 @@ jobs: with: commit-message: "chore: update apt versions in Dockerfile_ODELIA" branch: ci/apt-update + branch-suffix: timestamp title: "chore: Update APT versions in Dockerfile" body: | This PR automatically updates APT package version numbers in `Dockerfile_ODELIA` based on a rebuild and inspection of installation logs. base: main + delete-branch: false \ No newline at end of file diff --git a/.gitignore b/.gitignore index 57af2b10..86efbffa 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,6 @@ provision # Ignore provisioned files /workspace/ + +# Ignore directory for caching pre-trained models +docker_config/torch_home_cache diff --git a/README.md b/README.md index cd8837fb..2eb29561 100644 --- a/README.md +++ b/README.md @@ -1,244 +1,43 @@ -# Introduction -MediSwarm is an open-source project dedicated to advancing medical deep learning through swarm intelligence, leveraging the NVFlare platform. Developed in collaboration with the Odelia consortium, this repository aims to create a decentralized and collaborative framework for medical research and applications. - -## Key Features -- **Swarm Learning:** Utilizes swarm intelligence principles to improve model performance and adaptability. -- **NVFlare Integration:** Built on NVFlare, providing robust and scalable federated learning capabilities. -- **Data Privacy:** Ensures data security and compliance with privacy regulations by keeping data local to each institution. -- **Collaborative Research:** Facilitates collaboration among medical researchers and institutions for enhanced outcomes. -- **Extensible Framework:** Designed to support various medical applications and easily integrate with existing workflows. - -## Prerequisites -### Hardware recommendations -* 64 GB of RAM (32 GB is the absolute minimum) -* 16 CPU cores (8 is the absolute minimum) -* an NVIDIA GPU with 48 GB of RAM (24 GB is the minimum) -* 8 TB of Storage (4 TB is the absolute minimum) - -We demonstrate that the system can run on lightweight hardware like this. For less than 10k EUR, you can configure systems from suppliers like Lambda, Dell Precision, and Dell Alienware. - -### Operating System -* Ubuntu 20.04 LTS - -### Software -* Docker -* openvpn -* git - -### Cloning the repository - ```bash - git clone https://github.com/KatherLab/MediSwarm.git --recurse-submodules - ``` -* The last argument is necessary because we are using a git submodule for the (ODELIA fork of NVFlare)[https://github.com/KatherLab/NVFlare_MediSwarm] -* If you have cloned it without this argument, use `git submodule update --init --recursive` - -### VPN -A VPN is necessary so that the swarm nodes can communicate with each other securely across firewalls. For that purpose, -1. Install OpenVPN - ```bash - sudo apt-get install openvpn - ``` -2. If you have a graphical user interface(GUI), follow this guide to connect to the VPN: [VPN setup guide(GUI).pdf](assets/VPN%20setup%20guide%28GUI%29.pdf) -3. If you have a command line interface(CLI), follow this guide to connect to the VPN: [VPN setup guide(CLI).md](assets/VPN%20setup%20guide%28CLI%29.md) - -# Usage for Swarm Participants -## Setup -1. Make sure your compute node satisfies the specification and has the necessary software installed. -2. Clone the repository and connect the client node to the VPN as described above. -3. TODO anything else? - -## Prepare Dataset -1. TODO which data is expected in which folder structure + table structure - -## Prepare Training Participation -1. Extract startup kit provided by swarm operator - -## Run Pre-Flight Check -1. Directories - ```bash - export SITE_NAME= # TODO should be defined above, also needed for dataset location - export DATADIR= - export SCRATCHDIR= - ``` -2. From the directory where you unpacked the startup kit, - ```bash - cd $SITE_NAME/startup - ``` -3. Verify that your Docker/GPU setup is working - ```bash - ./docker.sh --scratch_dir $SCRATCHDIR --GPU device=0 --dummy_training - ``` - * This will pull the Docker image, which might take a while. - * If you have multiple GPUs and 0 is busy, use a different one. - * The “training” itself should take less than minute and does not yield a meaningful classification performance. -4. Verify that your local data can be accessed and the model can be trained locally - ```bash - ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --preflight_check - ``` - * Training time depends on the size of the local dataset - -## Start Swarm Node -1. From the directory where you unpacked the startup kit - ```bash - cd $SITE_NAME/startup # skip this if you just ran the pre-flight check - ``` -2. Start the client - ```bash - ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --start_client - ``` -3. Console output is captured in `nohup.out`, which may have been created by the root user in the container, so make it readable: - ```bash - sudo chmod a+r nohup.out - ``` -4. Output files - * TODO describe - -## Run Local Training -1. From the directory where you unpacked the startup kit - ```bash - cd $SITE_NAME/startup - ``` -2. Start local training - ```bash - /docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU all --local_training - ``` - * TODO update when handling of the number of epochs has been implemented -3. Output files - * TODO describe - -# Usage for MediSwarm and Application Code Developers -## Versioning of ODELIA Docker Images -If needed, update the version number in file (odelia_image.version)[odelia_image.version]. It will be used automatically for the Docker image and startup kits. - -## Build the Docker Image and Startup Kits -The Docker image contains all dependencies for administrative purposes (dashboard, command-line provisioning, admin console, server) as well as for running the 3DCNN pipeline under the pytorch-lightning framework. -The project description specifies the swarm nodes etc. to be used for a swarm training. - ```bash - cd MediSwarm - ./buildDockerImageAndStartupKits.sh -p application/provision/ - ``` - -1. Make sure you have no uncommitted changes. -2. If package versions are still not available, you may have to check what the current version is and update the `Dockerfile` accordingly. Version numbers are hard-coded to avoid issues due to silently different versions being installed. -3. After successful build (and after verifying that everything works as expected, i.e., local tests, building startup kits, running local trainings in the startup kit), you can manually push the image to DockerHub, provided you have the necessary rights. Make sure you are not re-using a version number for this purpose. - -## Running Local Tests - ```bash - ./runTestsInDocker.sh - ``` - -You should see -1. several expected errors and warnings printed from unit tests that should succeed overall, and a coverage report -2. output of a successful simulation run with two nodes -3. output of a successful proof-of-concept run run with two nodes -4. output of a set of startup kits being generated -5. output of a dummy training run using one of the startup kits - -Optionally, uncomment running NVFlare unit tests in `_runTestsInsideDocker.sh`. - -## Distributing Startup Kits -Distribute the startup kits to the clients. - -## Running the Application -1. **CIFAR-10 example:** - See [cifar10/README.md](application/jobs/cifar10/README.md) -2. **Minimal PyTorch CNN example:** - See [application/jobs/minimal_training_pytorch_cnn/README.md](application/jobs/minimal_training_pytorch_cnn/README.md) -3. **3D CNN for classifying breast tumors:** - See [3dcnn_ptl/README.md](application/jobs/3dcnn_ptl/README.md) - -## Contributing Application Code -1. Take a look at application/jobs/minimal_training_pytorch_cnn for a minimal example how pytorch code can be adapted to work with NVFlare -2. Take a look at application/jobs/3dcnn_ptl for a more relastic example of pytorch code that can run in the swarm -3. Use the local tests to check if the code is swarm-ready -4. TODO more detailed instructions - -# Usage for Swarm Operators -## Setting up a Swarm -Production mode is designed for secure, real-world deployments. It supports both local and remote setups, whether on-premise or in the cloud. For more details, refer to the [NVFLARE Production Mode](https://nvflare.readthedocs.io/en/2.4.1/real_world_fl.html). - -To set up production mode, follow these steps: - -## Edit `/etc/hosts` -Ensure that your `/etc/hosts` file includes the correct host mappings. All hosts need to be able to communicate to the server node. - -For example, add the following line (replace `` with the server's actual IP address): - -```plaintext - dl3.tud.de dl3 -``` - -## Create Startup Kits -### Via Script (recommended) -1. Use, e.g., the file `application/provision/project_MEVIS_test.yml`, adapt as needed (network protocol etc.) -2. Call `buildStartupKits.sh /path/to/project_configuration.yml` to build the startup kits -3. Startup kits are generated to `workspace//prod_00/` -4. Deploy startup kits to the respective server/clients - -### Via the Dashboard (not recommended) -```bash -docker run -d --rm \ - --ipc=host -p 8443:8443 \ - --name=odelia_swarm_admin \ - -v /var/run/docker.sock:/var/run/docker.sock \ - \ - /bin/bash -c "nvflare dashboard --start --local --cred :" -``` -using some credentials chosen for the swarm admin account. - -Access the dashboard in a web browser at `https://localhost:8443` log in with these credentials, and configure the project: -1. enter project short name, name, description -2. enter docker download link: jefftud/nvflare-pt-dev:3dcnn -3. if needed, enter dates -4. click save -5. Server Configuration > Server (DNS name): -6. click make project public - -#### Register client per site -Access the dashboard at `https://:8443`. - -1. register a user -2. enter organziation (corresponding to the site) -3. enter role (e.g., org admin) -4. add a site (note: must not contain spaces, best use alphanumerical name) -5. specify number of GPUs and their memory - -#### Approve clients and finish configuration -Access the dashboard at `https://localhost:8443` log in with the admin credentials. -1. Users Dashboard > approve client user -2. Client Sites > approve client sites -3. Project Home > freeze project - -## Download startup kits -After setting up the project admin configuration, server and clients can download their startup kits. Store the passwords somewhere, they are only displayed once (or you can download them again). - -## Starting a Swarm Training -1. Connect the *server* host to the VPN as described above. -2. Start the *server* startup kit using the respective `startup/docker.sh` script with the option to start the server -3. Provide the *client* startup kits to the swarm participants (be aware that email providers or other channels may prevent encrypted archives) -4. Make sure the participants have started their clients via the respective startup kits, see below -5. Start the *admin* startup kit using the respective `startup/docker.sh` script to start the admin console -6. Deploy a job by `submit_job ` - - -# License -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. - -# Maintainers -[Jeff](https://github.com/Ultimate-Storm) -[Ole Schwen](mailto:ole.schwen@mevis.fraunhofer.de) -[Steffen Renisch](mailto:steffen.renisch@mevis.fraunhofer.de) - -# Contributing -Feel free to dive in! [Open an issue](https://github.com/KatherLab/MediSwarm/issues) or submit pull requests. - -# Credits -This project utilizes platforms and resources from the following repositories: - -- **[NVFLARE](https://github.com/NVIDIA/NVFlare)**: NVFLARE (NVIDIA Federated Learning Application Runtime Environment) is an open-source framework that provides a robust and scalable platform for federated learning applications. We have integrated NVFLARE to efficiently handle the federated learning aspects of our project. - -Special thanks to the contributors and maintainers of these repositories for their valuable work and support. - ---- - -For more details about NVFLARE and its features, please visit the [NVFLARE GitHub repository](https://github.com/NVIDIA/NVFlare). +# MediSwarm + +An open-source platform advancing medical AI via privacy-preserving swarm learning, based on NVFlare and developed with +the ODELIA consortium. + +[![PR Tests](https://github.com/KatherLab/MediSwarm/actions/workflows/pr-test.yaml/badge.svg)](https://github.com/KatherLab/MediSwarm/actions/workflows/pr-test.yaml) +[![Build](https://github.com/KatherLab/MediSwarm/actions/workflows/update-apt-versions.yml/badge.svg)](https://github.com/KatherLab/MediSwarm/actions/workflows/update-apt-versions.yml) + +## Quick Start for Your Role + +Choose your role and follow the instructions: + +- [Swarm Participant (Medical Site / Data Scientist)](assets/readme/README.participant.md) +- [Developer (Docker, Code, Pipeline)](assets/readme/README.developer.md) +- [Swarm Operator (Provisioning, VPN, Server)](assets/readme/README.operator.md) + +## Overview + +MediSwarm enables: + +- **Privacy-preserving training** of deep learning models on distributed medical datasets +- **Decentralized collaboration** between institutions +- **Dockerized, reproducible** experiments built on NVFlare + +## License + +MIT — see [LICENSE](LICENSE). + +## Maintainers + +- [Jeff](https://github.com/Ultimate-Storm) +- [Ole Schwen](mailto:ole.schwen@mevis.fraunhofer.de) +- [Steffen Renisch](mailto:steffen.renisch@mevis.fraunhofer.de) + +## Contributing + +Contributions welcome! [Open an issue](https://github.com/KatherLab/MediSwarm/issues) or submit a PR. + +## Credits + +Built on: + +- [NVFLARE](https://github.com/NVIDIA/NVFlare) diff --git a/_buildStartupKits.sh b/_buildStartupKits.sh index bf3fec04..47f064c8 100755 --- a/_buildStartupKits.sh +++ b/_buildStartupKits.sh @@ -1,15 +1,33 @@ #!/usr/bin/env bash -if [ "$#" -ne 2 ]; then - echo "Usage: _buildStartupKits.sh SWARM_PROJECT.yml VERSION_STRING" +set -euo pipefail + +if [ "$#" -lt 3 ]; then + echo "Usage: _buildStartupKits.sh SWARM_PROJECT.yml VERSION_STRING CONTAINER_NAME [VPN_CREDENTIALS_DIR]" exit 1 fi PROJECT_YML=$1 VERSION=$2 +CONTAINER_NAME=$3 +MOUNT_VPN_CREDENTIALS_DIR="" +if [ "$#" -eq 4 ]; then + MOUNT_VPN_CREDENTIALS_DIR="-v $4:/vpn_credentials/" +fi sed -i 's#__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__#'$VERSION'#' $PROJECT_YML -docker run --rm -it -u $(id -u):$(id -g) -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v ./:/workspace/ -w /workspace/ jefftud/odelia:$VERSION /bin/bash -c "nvflare provision -p $PROJECT_YML && ./_generateStartupKitArchives.sh $PROJECT_YML $VERSION" +ARGUMENTS="$PROJECT_YML $VERSION" + +echo "Building startup kits: $ARGUMENTS" +docker run --rm \ + -u $(id -u):$(id -g) \ + -v /etc/passwd:/etc/passwd \ + -v /etc/group:/etc/group \ + -v ./:/workspace/ \ + $MOUNT_VPN_CREDENTIALS_DIR \ + -w /workspace/ \ + $CONTAINER_NAME \ + /bin/bash -c "nvflare provision -p $PROJECT_YML && ./_generateStartupKitArchives.sh $ARGUMENTS"|| { echo "Docker run failed"; exit 1; } sed -i 's#'$VERSION'#__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__#' $PROJECT_YML diff --git a/_generateStartupKitArchives.sh b/_generateStartupKitArchives.sh index c1055153..e76161fc 100755 --- a/_generateStartupKitArchives.sh +++ b/_generateStartupKitArchives.sh @@ -1,11 +1,20 @@ #!/usr/bin/env bash +set -e + OUTPUT_FOLDER=workspace/`grep "^name: " $1 | sed 's/name: //'` TARGET_FOLDER=`ls -d $OUTPUT_FOLDER/prod_* | tail -n 1` LONG_VERSION=$2 cd $TARGET_FOLDER + for startupkit in `ls .`; do + VPN_CREDENTIALS_FILE=/vpn_credentials/${startupkit}_client.ovpn + if [[ -f $VPN_CREDENTIALS_FILE ]]; then + cp $VPN_CREDENTIALS_FILE ${startupkit}/startup/vpn_client.ovpn + else + echo "$VPN_CREDENTIALS_FILE does not exist, omitting VPN credentials for ${startupkit} in startup kit" + fi zip -rq ${startupkit}_$LONG_VERSION.zip $startupkit echo "Generated startup kit $TARGET_FOLDER/${startupkit}_$LONG_VERSION.zip" done diff --git a/_runTestsInsideDocker.sh b/_runTestsInsideDocker.sh deleted file mode 100755 index 794d7320..00000000 --- a/_runTestsInsideDocker.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env bash - -# run unit tests of ODELIA swarm learning and report coverage -export MPLCONFIGDIR=/tmp -cd /MediSwarm/tests/unit_tests/controller -PYTHONPATH=/MediSwarm/controller/controller python3 -m coverage run --source=/MediSwarm/controller/controller -m unittest discover -coverage report -m -rm .coverage - -# uncomment to run NVFlare's unit tests (takes about 2 minutes and will install python packages in the container) -# cd /MediSwarm/docker_config/NVFlare -# ./runtest.sh -c -r -# coverage report -m -# cd .. - -# run standalone version of minimal example -cd /MediSwarm/application/jobs/minimal_training_pytorch_cnn/app/custom/ -export TRAINING_MODE="local_training" -./main.py - -# run simulation mode for minimal example -cd /MediSwarm -export TRAINING_MODE="swarm" -nvflare simulator -w /tmp/minimal_training_pytorch_cnn -n 2 -t 2 application/jobs/minimal_training_pytorch_cnn -c simulated_node_0,simulated_node_1 - -# run proof-of-concept mode for minimal example -cd /MediSwarm -export TRAINING_MODE="swarm" -nvflare poc prepare -c poc_client_0 poc_client_1 -nvflare poc prepare-jobs-dir -j application/jobs/ -nvflare poc start -ex admin@nvidia.com -sleep 15 -echo "Will submit job now after sleeping 15 seconds to allow the background process to complete" -nvflare job submit -j application/jobs/minimal_training_pytorch_cnn -sleep 60 -echo "Will shut down now after sleeping 60 seconds to allow the background process to complete" -sleep 2 -nvflare poc stop diff --git a/application/jobs/3dcnn_ptl/app/custom/data/augmentation/augmentations_3d.py b/application/jobs/3dcnn_ptl/app/custom/data/augmentation/augmentations_3d.py deleted file mode 100644 index cc206cf0..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/augmentation/augmentations_3d.py +++ /dev/null @@ -1,244 +0,0 @@ -import torchio as tio -from typing import Tuple, Union, Optional, Dict -from numbers import Number -import nibabel as nib -import numpy as np -import torch -from torchio.typing import TypeRangeFloat -from torchio.transforms.transform import TypeMaskingMethod -from torchio import Subject, Image - - -class SubjectToTensor: - """Transforms TorchIO Subjects into a Python dict and changes axes order from TorchIO to Torch.""" - - def __call__(self, subject: Subject) -> Dict[str, torch.Tensor]: - """Transforms the given subject. - - Args: - subject (Subject): The subject to be transformed. - - Returns: - Dict[str, torch.Tensor]: A dictionary with transformed subject data. - """ - return {key: val.data.swapaxes(1, -1) if isinstance(val, Image) else val for key, val in subject.items()} - - -class ImageToTensor: - """Transforms TorchIO Image into a Numpy/Torch Tensor and changes axes order from TorchIO [B, C, W, H, D] to Torch [B, C, D, H, W].""" - - def __call__(self, image: Image) -> torch.Tensor: - """Transforms the given image. - - Args: - image (Image): The image to be transformed. - - Returns: - torch.Tensor: The transformed image tensor. - """ - return image.data.swapaxes(1, -1) - - -def parse_per_channel(per_channel: Union[bool, list], channels: int) -> list: - """Parses the per_channel argument. - - Args: - per_channel (Union[bool, list]): Whether to apply per channel. - channels (int): The number of channels. - - Returns: - list: A list of channel tuples. - """ - if isinstance(per_channel, bool): - if per_channel: - return [(ch,) for ch in range(channels)] - else: - return [tuple(ch for ch in range(channels))] - else: - return per_channel - - -class ZNormalization(tio.ZNormalization): - """Add option 'per_channel' to apply znorm for each channel independently and percentiles to clip values first.""" - - def __init__( - self, - percentiles: TypeRangeFloat = (0, 100), - per_channel: Union[bool, list] = True, - masking_method: TypeMaskingMethod = None, - **kwargs - ): - super().__init__(masking_method=masking_method, **kwargs) - self.percentiles = percentiles - self.per_channel = per_channel - - def apply_normalization( - self, - subject: Subject, - image_name: str, - mask: torch.Tensor, - ) -> None: - """Applies normalization to the given subject. - - Args: - subject (Subject): The subject to normalize. - image_name (str): The name of the image to normalize. - mask (torch.Tensor): The mask tensor. - """ - image = subject[image_name] - per_channel = parse_per_channel(self.per_channel, image.shape[0]) - - image.set_data(torch.cat([ - self._znorm(image.data[chs,], mask[chs,], image_name, image.path) - for chs in per_channel]) - ) - - def _znorm(self, image_data: torch.Tensor, mask: torch.Tensor, image_name: str, image_path: str) -> torch.Tensor: - """Applies z-normalization to the given image data. - - Args: - image_data (torch.Tensor): The image data to normalize. - mask (torch.Tensor): The mask tensor. - image_name (str): The name of the image. - image_path (str): The path of the image. - - Returns: - torch.Tensor: The normalized image data. - - Raises: - RuntimeError: If standard deviation is 0 for masked values. - """ - cutoff = torch.quantile(image_data.masked_select(mask).float(), torch.tensor(self.percentiles) / 100.0) - torch.clamp(image_data, *cutoff.to(image_data.dtype).tolist(), out=image_data) - - standardized = self.znorm(image_data, mask) - if standardized is None: - message = ( - 'Standard deviation is 0 for masked values' - f' in image "{image_name}" ({image_path})' - ) - raise RuntimeError(message) - return standardized - - -class RescaleIntensity(tio.RescaleIntensity): - """Add option 'per_channel' to apply rescale for each channel independently.""" - - def __init__( - self, - out_min_max: TypeRangeFloat = (0, 1), - percentiles: TypeRangeFloat = (0, 100), - masking_method: TypeMaskingMethod = None, - in_min_max: Optional[Tuple[float, float]] = None, - per_channel: Union[bool, list] = True, - # Bool or List of tuples containing channel indices that should be normalized together - **kwargs - ): - super().__init__(out_min_max, percentiles, masking_method, in_min_max, **kwargs) - self.per_channel = per_channel - - def apply_normalization( - self, - subject: Subject, - image_name: str, - mask: torch.Tensor, - ) -> None: - """Applies normalization to the given subject. - - Args: - subject (Subject): The subject to normalize. - image_name (str): The name of the image to normalize. - mask (torch.Tensor): The mask tensor. - """ - image = subject[image_name] - per_channel = parse_per_channel(self.per_channel, image.shape[0]) - - image.set_data(torch.cat([ - self.rescale(image.data[chs,], mask[chs,], image_name) - for chs in per_channel]) - ) - - -class Pad(tio.Pad): - """Fixed version of TorchIO Pad. - - Pads with zeros for LabelMaps independent of padding mode (e.g., don't pad with mean). - Pads with global (not per axis) 'maximum', 'mean', 'median', 'minimum' if any of these padding modes were selected. - """ - - def apply_transform(self, subject: Subject) -> Subject: - """Applies padding to the given subject. - - Args: - subject (Subject): The subject to pad. - - Returns: - Subject: The padded subject. - """ - assert self.bounds_parameters is not None - low = self.bounds_parameters[::2] - for image in self.get_images(subject): - new_origin = nib.affines.apply_affine(image.affine, -np.array(low)) - new_affine = image.affine.copy() - new_affine[:3, 3] = new_origin - kwargs: Dict[str, Union[str, float]] - if isinstance(self.padding_mode, Number): - kwargs = { - 'mode': 'constant', - 'constant_values': self.padding_mode, - } - elif isinstance(image, tio.LabelMap): # FIX - kwargs = { - 'mode': 'constant', - 'constant_values': 0, - } - else: - if self.padding_mode in ['maximum', 'mean', 'median', 'minimum']: - if self.padding_mode == 'maximum': - constant_values = image.data.min() - elif self.padding_mode == 'mean': - constant_values = image.data.to(torch.float).mean().to(image.data.dtype) - elif self.padding_mode == 'median': - constant_values = image.data.median() - elif self.padding_mode == 'minimum': - constant_values = image.data.min() - kwargs = { - 'mode': 'constant', - 'constant_values': constant_values, - } - else: - kwargs = {'mode': self.padding_mode} - pad_params = self.bounds_parameters - paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] - padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] # noqa: E501 - image.set_data(torch.as_tensor(padded)) - image.affine = new_affine - return subject - - -class CropOrPad(tio.CropOrPad): - """Fixed version of TorchIO CropOrPad. - - Pads with zeros for LabelMaps independent of padding mode (e.g., don't pad with mean). - Pads with global (not per axis) 'maximum', 'mean', 'median', 'minimum' if any of these padding modes were selected. - """ - - def apply_transform(self, subject: Subject) -> Subject: - """Applies cropping or padding to the given subject. - - Args: - subject (Subject): The subject to crop or pad. - - Returns: - Subject: The cropped or padded subject. - """ - subject.check_consistent_space() - padding_params, cropping_params = self.compute_crop_or_pad(subject) - padding_kwargs = {'padding_mode': self.padding_mode} - if padding_params is not None: - pad = Pad(padding_params, **padding_kwargs) - subject = pad(subject) # type: ignore[assignment] - if cropping_params is not None: - crop = tio.Crop(cropping_params) - subject = crop(subject) # type: ignore[assignment] - return subject diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datamodules/datamodule.py b/application/jobs/3dcnn_ptl/app/custom/data/datamodules/datamodule.py deleted file mode 100644 index b8c8cb44..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datamodules/datamodule.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytorch_lightning as pl -import torch -from torch.utils.data.dataloader import DataLoader -import torch.multiprocessing as mp -from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler - - -class DataModule(pl.LightningDataModule): - """ - LightningDataModule for handling dataset loading and batching. - - Attributes: - ds_train (object): Training dataset. - ds_val (object): Validation dataset. - ds_test (object): Test dataset. - batch_size (int): Batch size for dataloaders. - num_workers (int): Number of workers for data loading. - seed (int): Random seed for reproducibility. - pin_memory (bool): If True, pin memory for faster data transfer to GPU. - weights (list): Weights for the weighted random sampler. - """ - - def __init__( - self, - ds_train: object = None, - ds_val: object = None, - ds_test: object = None, - batch_size: int = 1, - num_workers: int = mp.cpu_count(), - seed: int = 0, - pin_memory: bool = False, - weights: list = None - ): - """ - Initializes the DataModule with datasets and parameters. - - Args: - ds_train (object, optional): Training dataset. Defaults to None. - ds_val (object, optional): Validation dataset. Defaults to None. - ds_test (object, optional): Test dataset. Defaults to None. - batch_size (int, optional): Batch size. Defaults to 1. - num_workers (int, optional): Number of workers. Defaults to mp.cpu_count(). - seed (int, optional): Random seed. Defaults to 0. - pin_memory (bool, optional): Pin memory. Defaults to False. - weights (list, optional): Weights for sampling. Defaults to None. - """ - super().__init__() - self.hyperparameters = {**locals()} - self.hyperparameters.pop('__class__') - self.hyperparameters.pop('self') - - self.ds_train = ds_train - self.ds_val = ds_val - self.ds_test = ds_test - - self.batch_size = batch_size - self.num_workers = num_workers - self.seed = seed - self.pin_memory = pin_memory - self.weights = weights - - def train_dataloader(self) -> DataLoader: - """ - Returns the training dataloader. - - Returns: - DataLoader: DataLoader for the training dataset. - - Raises: - AssertionError: If the training dataset is not initialized. - """ - generator = torch.Generator() - generator.manual_seed(self.seed) - - if self.ds_train is not None: - if self.weights is not None: - sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator) - else: - sampler = RandomSampler(self.ds_train, replacement=False, generator=generator) - return DataLoader( - self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers, - sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory - ) - - raise AssertionError("A training set was not initialized.") - - def val_dataloader(self) -> DataLoader: - """ - Returns the validation dataloader. - - Returns: - DataLoader: DataLoader for the validation dataset. - - Raises: - AssertionError: If the validation dataset is not initialized. - """ - generator = torch.Generator() - generator.manual_seed(self.seed) - if self.ds_val is not None: - return DataLoader( - self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - generator=generator, drop_last=False, pin_memory=self.pin_memory - ) - - raise AssertionError("A validation set was not initialized.") - - def test_dataloader(self) -> DataLoader: - """ - Returns the test dataloader. - - Returns: - DataLoader: DataLoader for the test dataset. - - Raises: - AssertionError: If the test dataset is not initialized. - """ - generator = torch.Generator() - generator.manual_seed(self.seed) - if self.ds_test is not None: - return DataLoader( - self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - generator=generator, drop_last=False, pin_memory=self.pin_memory - ) - - raise AssertionError("A test dataset was not initialized.") diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/__init__.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/__init__.py deleted file mode 100644 index e34b2577..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -This package initializes the necessary modules and classes for the project. -""" - -from .dataset_3d import SimpleDataset3D -from .dataset_3d_collab import DUKE_Dataset3D_collab -from .dataset_3d_duke import DUKE_Dataset3D -from .dataset_3d_duke_external import DUKE_Dataset3D_external - -__all__ = [name for name in dir() if not name.startswith('_')] diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d.py deleted file mode 100644 index 234a6d8b..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d.py +++ /dev/null @@ -1,124 +0,0 @@ -from pathlib import Path -import torch.utils.data as data -import torchio as tio -from data.augmentation.augmentations_3d import ImageToTensor, RescaleIntensity, ZNormalization - -class SimpleDataset3D(data.Dataset): - """ - A simple dataset class for 3D medical images using TorchIO for preprocessing and augmentation. - - Args: - path_root (str): Root directory of the dataset. - item_pointers (list, optional): List of file paths. Defaults to []. - crawler_glob (str, optional): Glob pattern for crawling files. Defaults to '*.nii.gz'. - transform (callable, optional): Transformations to apply to the data. Defaults to None. - image_resize (tuple, optional): Desired output image size. Defaults to None. - flip (bool, optional): Whether to apply random flipping. Defaults to False. - image_crop (tuple, optional): Desired crop size. Defaults to None. - norm (str, optional): Normalization method. Defaults to 'znorm_clip'. - to_tensor (bool, optional): Whether to convert images to tensor. Defaults to True. - """ - - def __init__( - self, - path_root, - item_pointers=[], - crawler_glob='*.nii.gz', - transform=None, - image_resize=None, - flip=False, - image_crop=None, - norm='znorm_clip', - to_tensor=True, - ): - super().__init__() - self.path_root = Path(path_root) - self.crawler_glob = crawler_glob - - if transform is None: - self.transform = tio.Compose([ - tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x), - tio.RandomFlip((0, 1, 2)) if flip else tio.Lambda(lambda x: x), - tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x), - self.get_norm(norm), - ImageToTensor() if to_tensor else tio.Lambda(lambda x: x) # [C, W, H, D] -> [C, D, H, W] - ]) - else: - self.transform = transform - - if item_pointers: - self.item_pointers = item_pointers - else: - self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_glob) - - def __len__(self): - """Returns the number of items in the dataset.""" - return len(self.item_pointers) - - def __getitem__(self, index): - """ - Retrieves an item from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: A dictionary with 'uid' and 'source' keys. - """ - rel_path_item = self.item_pointers[index] - path_item = self.path_root / rel_path_item - img = self.load_item(path_item) - return {'uid': str(rel_path_item), 'source': self.transform(img)} - - def load_item(self, path_item): - """ - Loads an image from the given path. - - Args: - path_item (Path): Path to the image file. - - Returns: - tio.ScalarImage: Loaded image. - """ - return tio.ScalarImage(path_item) - - @classmethod - def run_item_crawler(cls, path_root, crawler_glob, **kwargs): - """ - Crawls the directory to find items matching the glob pattern. - - Args: - path_root (Path): Root directory to start crawling. - crawler_glob (str): Glob pattern to match files. - - Returns: - list: List of relative file paths. - """ - return [path.relative_to(path_root) for path in Path(path_root).rglob(f'{crawler_glob}')] - - @staticmethod - def get_norm(norm): - """ - Returns the normalization transform based on the provided norm string. - - Args: - norm (str): Normalization method name. - - Returns: - tio.Transform: The normalization transform. - """ - if norm is None: - return tio.Lambda(lambda x: x) - elif isinstance(norm, str): - if norm == 'min-max': - return RescaleIntensity((-1, 1), per_channel=True, masking_method=lambda x: x > 0) - elif norm == 'min-max_clip': - return RescaleIntensity((-1, 1), per_channel=True, percentiles=(0.5, 99.5), masking_method=lambda x: x > 0) - elif norm == 'znorm': - return ZNormalization(per_channel=True, masking_method=lambda x: x > 0) - elif norm == 'znorm_clip': - return ZNormalization(per_channel=True, percentiles=(0.5, 99.5), masking_method=lambda x: x > 0) - else: - raise ValueError("Unknown normalization") - else: - return norm diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_collab.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_collab.py deleted file mode 100755 index c867aec7..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_collab.py +++ /dev/null @@ -1,106 +0,0 @@ -from pathlib import Path -import pandas as pd -from data.datasets import SimpleDataset3D - - -class DUKE_Dataset3D_collab(SimpleDataset3D): - """ - DUKE Collaboration Dataset for 3D medical images, extending SimpleDataset3D. - - Args: - path_root (str): Root directory of the dataset. - item_pointers (list, optional): List of file paths. Defaults to None. - crawler_glob (str, optional): Glob pattern for crawling files. Defaults to '*.nii.gz'. - transform (callable, optional): Transformations to apply to the data. Defaults to None. - image_resize (tuple, optional): Desired output image size. Defaults to None. - flip (bool, optional): Whether to apply random flipping. Defaults to False. - image_crop (tuple, optional): Desired crop size. Defaults to None. - norm (str, optional): Normalization method. Defaults to 'znorm_clip'. - to_tensor (bool, optional): Whether to convert images to tensor. Defaults to True. - """ - - def __init__( - self, - path_root, - item_pointers=None, - crawler_glob='*.nii.gz', - transform=None, - image_resize=None, - flip=False, - image_crop=None, - norm='znorm_clip', - to_tensor=True - ): - if item_pointers is None: - item_pointers = [] - super().__init__(path_root, item_pointers, crawler_glob, transform, image_resize, flip, image_crop, norm, - to_tensor) - df = pd.read_csv(self.path_root.parent / 'datasheet.csv') - - df = df[[df.columns[0], df.columns[1]]] # Only pick relevant columns: Patient ID, Tumor Side, Bilateral - existing_folders = {folder.name for folder in Path(path_root).iterdir() if folder.is_dir()} - self.df = df[df['PATIENT'].isin(existing_folders)] - self.df = self.df.set_index('PATIENT', drop=True) - self.item_pointers = self.df.index[self.df.index.isin(self.item_pointers)].tolist() - - def __getitem__(self, index): - """ - Retrieves an item from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: A dictionary with 'uid', 'source', and 'target' keys. - """ - uid = self.item_pointers[index] - item_dir = self.path_root / uid - nii_gz_files = list(item_dir.glob('**/*.nii.gz')) - file_name = 'SUB_4.nii.gz' - - if len(nii_gz_files) > 1: - sub_4_path = item_dir / file_name - if sub_4_path in nii_gz_files: - path_item = sub_4_path - else: - path_item = nii_gz_files[0] - elif nii_gz_files: - path_item = nii_gz_files[0] - else: - raise FileNotFoundError(f"No .nii.gz files found in {item_dir}") - - img = self.load_item(path_item) - target = self.df.loc[uid]['Malign'] - return {'uid': uid, 'source': self.transform(img), 'target': target} - - @classmethod - def run_item_crawler(cls, path_root, crawler_ext, **kwargs): - """ - Crawls the directory to find items matching the glob pattern. - - Args: - path_root (Path): Root directory to start crawling. - crawler_ext (str): Extension to match files. - - Returns: - list: List of relative file paths. - """ - return [path.relative_to(path_root).name for path in Path(path_root).iterdir() if path.is_dir()] - - def get_labels(self): - """ - Gets the labels for the dataset items. - - Returns: - array: Array of labels. - """ - return self.df['Malign'].values - - def __len__(self): - """ - Returns the number of items in the dataset. - - Returns: - int: Number of items. - """ - return len(self.item_pointers) diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke.py deleted file mode 100644 index ab8a4338..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke.py +++ /dev/null @@ -1,105 +0,0 @@ -from pathlib import Path -import pandas as pd -from data.datasets import SimpleDataset3D - - -class DUKE_Dataset3D(SimpleDataset3D): - """ - DUKE Dataset for 3D medical images, extending SimpleDataset3D. - - Args: - path_root (str): Root directory of the dataset. - item_pointers (list, optional): List of file paths. Defaults to []. - crawler_glob (str, optional): Glob pattern for crawling files. Defaults to '*.nii.gz'. - transform (callable, optional): Transformations to apply to the data. Defaults to None. - image_resize (tuple, optional): Desired output image size. Defaults to None. - flip (bool, optional): Whether to apply random flipping. Defaults to False. - image_crop (tuple, optional): Desired crop size. Defaults to None. - norm (str, optional): Normalization method. Defaults to 'znorm_clip'. - to_tensor (bool, optional): Whether to convert images to tensor. Defaults to True. - sequence (str, optional): Sequence type to use for loading images. Defaults to 'sub'. - """ - - def __init__( - self, - path_root, - item_pointers=[], - crawler_glob='*.nii.gz', - transform=None, - image_resize=None, - flip=False, - image_crop=None, - norm='znorm_clip', - to_tensor=True, - sequence='sub' - ): - super().__init__(path_root, item_pointers, crawler_glob, transform, image_resize, flip, image_crop, norm, - to_tensor) - df = pd.read_excel(self.path_root.parent / 'Clinical_and_Other_Features.xlsx', header=[0, 1, 2]) - df = df[df[df.columns[38]] == 0] # check if cancer is bilateral=1, unilateral=0 or NC - df = df[[df.columns[0], df.columns[36], - df.columns[38]]] # Only pick relevant columns: Patient ID, Tumor Side, Bilateral - df.columns = ['PatientID', 'Location', 'Bilateral'] # Simplify columns as: Patient ID, Tumor Side - dfs = [] - existing_folders = {folder.name for folder in Path(path_root).iterdir() if folder.is_dir()} - - for side in ["left", 'right']: - dfs.append(pd.DataFrame({ - 'PatientID': df["PatientID"].str.split('_').str[2] + f"_{side}", - 'Malign': df[["Location", "Bilateral"]].apply(lambda ds: (ds[0] == side[0].upper()) | (ds[1] == 1), - axis=1) - })) - - self.df = df[df['PatientID'].isin(existing_folders)] - self.df = self.df.set_index('PatientID', drop=True) - self.df = pd.concat(dfs, ignore_index=True).set_index('PatientID', drop=True) - self.item_pointers = self.df.index[self.df.index.isin(self.item_pointers)].tolist() - self.sequence = sequence - - def __getitem__(self, index): - """ - Retrieves an item from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: A dictionary with 'uid', 'source', and 'target' keys. - """ - uid = self.item_pointers[index] - path_item = [self.path_root / uid / name for name in [f'{self.sequence}.nii.gz']] - img = self.load_item(path_item) - target = self.df.loc[uid]['Malign'] - return {'uid': uid, 'source': self.transform(img), 'target': target} - - @classmethod - def run_item_crawler(cls, path_root, crawler_ext, **kwargs): - """ - Crawls the directory to find items matching the glob pattern. - - Args: - path_root (Path): Root directory to start crawling. - crawler_ext (str): Extension to match files. - - Returns: - list: List of relative file paths. - """ - return [path.relative_to(path_root).name for path in Path(path_root).iterdir() if path.is_dir()] - - def get_labels(self): - """ - Gets the labels for the dataset items. - - Returns: - list: List of labels. - """ - return self.df.loc[self.item_pointers, 'Malign'].tolist() - - def __len__(self): - """ - Returns the number of items in the dataset. - - Returns: - int: Number of items. - """ - return len(self.item_pointers) diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke_external.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke_external.py deleted file mode 100755 index 3dc2d1ad..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/dataset_3d_duke_external.py +++ /dev/null @@ -1,69 +0,0 @@ -from pathlib import Path -import pandas as pd -from data.datasets import SimpleDataset3D - -class DUKE_Dataset3D_external(SimpleDataset3D): - """ - DUKE External Dataset for 3D medical images, extending SimpleDataset3D. - - Args: - path_root (str): Root directory of the dataset. - item_pointers (list, optional): List of file paths. Defaults to None. - crawler_glob (str, optional): Glob pattern for crawling files. Defaults to '*.nii.gz'. - transform (callable, optional): Transformations to apply to the data. Defaults to None. - image_resize (tuple, optional): Desired output image size. Defaults to None. - flip (bool, optional): Whether to apply random flipping. Defaults to False. - image_crop (tuple, optional): Desired crop size. Defaults to None. - norm (str, optional): Normalization method. Defaults to 'znorm_clip'. - to_tensor (bool, optional): Whether to convert images to tensor. Defaults to True. - """ - - def __init__( - self, - path_root, - item_pointers=None, - crawler_glob='*.nii.gz', - transform=None, - image_resize=None, - flip=False, - image_crop=None, - norm='znorm_clip', - to_tensor=True - ): - if item_pointers is None: - item_pointers = [] - super().__init__(path_root, item_pointers, crawler_glob, transform, image_resize, flip, image_crop, norm, to_tensor) - df = pd.read_csv(self.path_root.parent / 'segmentation_metadata_unilateral.csv') - df = df[[df.columns[0], df.columns[5]]] # Only pick relevant columns: Patient ID, Tumor Side, Bilateral - self.df = df.set_index('PATIENT', drop=True) - self.item_pointers = self.df.index[self.df.index.isin(self.item_pointers)].tolist() - - def __getitem__(self, index): - """ - Retrieves an item from the dataset. - - Args: - index (int): Index of the item to retrieve. - - Returns: - dict: A dictionary with 'uid', 'source', and 'target' keys. - """ - uid = self.item_pointers[index] - path_item = [self.path_root / uid / name for name in ['Sub.nii.gz']] - img = self.load_item(path_item) - target = self.df.loc[uid]['Malign'] - return {'uid': uid, 'source': self.transform(img), 'target': target} - - @classmethod - def run_item_crawler(cls, path_root, crawler_ext, **kwargs): - """ - Crawls the directory to find items matching the glob pattern. - - Args: - path_root (Path): Root directory to start crawling. - crawler_ext (str): Extension to match files. - - Returns: - list: List of relative file paths. - """ - return [path.relative_to(path_root).name for path in Path(path_root).iterdir() if path.is_dir()] diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datasets/simple_dataset_3d.py b/application/jobs/3dcnn_ptl/app/custom/data/datasets/simple_dataset_3d.py deleted file mode 100644 index 99a80047..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/data/datasets/simple_dataset_3d.py +++ /dev/null @@ -1,138 +0,0 @@ -from pathlib import Path -import torch.utils.data as data -import torchio as tio - -from data.augmentation.augmentations_3d import ImageToTensor, RescaleIntensity, ZNormalization - - -class SimpleDataset3D(data.Dataset): - """ - A simple 3D dataset class that handles loading, transforming, and normalizing 3D medical images. - - Attributes: - path_root (Path): The root directory of the dataset. - crawler_glob (str): Glob pattern for crawling the dataset directory. - transform (callable): Transformation to be applied to the images. - item_pointers (list): List of relative paths to the dataset items. - """ - - def __init__( - self, - path_root: str, - item_pointers: list = [], - crawler_glob: str = '*.nii.gz', - transform: tio.transforms.Transform = None, - image_resize: tuple = None, - flip: bool = False, - image_crop: tuple = None, - norm: str = 'znorm_clip', - to_tensor: bool = True, - ): - """ - Initializes the dataset with the given parameters. - - Args: - path_root (str): The root directory of the dataset. - item_pointers (list, optional): List of item pointers. Defaults to []. - crawler_glob (str, optional): Glob pattern for crawling the dataset directory. Defaults to '*.nii.gz'. - transform (callable, optional): Transformation to be applied to the images. Defaults to None. - image_resize (tuple, optional): Size to resize images to. Defaults to None. - flip (bool, optional): Whether to apply random flipping. Defaults to False. - image_crop (tuple, optional): Size to crop or pad images to. Defaults to None. - norm (str, optional): Normalization method. Defaults to 'znorm_clip'. - to_tensor (bool, optional): Whether to convert images to tensors. Defaults to True. - """ - super().__init__() - self.path_root = Path(path_root) - self.crawler_glob = crawler_glob - - if transform is None: - self.transform = tio.Compose([ - tio.Resize(image_resize) if image_resize is not None else tio.Lambda(lambda x: x), - tio.RandomFlip((0, 1, 2)) if flip else tio.Lambda(lambda x: x), - tio.CropOrPad(image_crop) if image_crop is not None else tio.Lambda(lambda x: x), - self.get_norm(norm), - ImageToTensor() if to_tensor else tio.Lambda(lambda x: x) # [C, W, H, D] -> [C, D, H, W] - ]) - else: - self.transform = transform - - if len(item_pointers): - self.item_pointers = item_pointers - else: - self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_glob) - - def __len__(self) -> int: - """Returns the number of items in the dataset. - - Returns: - int: Number of items in the dataset. - """ - return len(self.item_pointers) - - def __getitem__(self, index: int) -> dict: - """Gets the item at the given index. - - Args: - index (int): Index of the item. - - Returns: - dict: A dictionary with 'uid' and 'source' keys. - """ - rel_path_item = self.item_pointers[index] - path_item = self.path_root / rel_path_item - img = self.load_item(path_item) - return {'uid': str(rel_path_item), 'source': self.transform(img)} - - def load_item(self, path_item: Path) -> tio.ScalarImage: - """Loads the item from the given path. - - Args: - path_item (Path): Path to the item. - - Returns: - tio.ScalarImage: The loaded image. - """ - return tio.ScalarImage(path_item) - - @classmethod - def run_item_crawler(cls, path_root: Path, crawler_glob: str, **kwargs) -> list: - """Crawls the dataset directory and returns a list of item pointers. - - Args: - path_root (Path): Root directory of the dataset. - crawler_glob (str): Glob pattern for crawling the dataset directory. - - Returns: - list: List of relative paths to the dataset items. - """ - return [path.relative_to(path_root) for path in Path(path_root).rglob(f'{crawler_glob}')] - - @staticmethod - def get_norm(norm: str) -> tio.transforms.Transform: - """Gets the normalization transform based on the given norm type. - - Args: - norm (str): Normalization method. - - Returns: - tio.transforms.Transform: The normalization transform. - - Raises: - ValueError: If the normalization method is unknown. - """ - if norm is None: - return tio.Lambda(lambda x: x) - elif isinstance(norm, str): - if norm == 'min-max': - return RescaleIntensity((-1, 1), per_channel=True, masking_method=lambda x: x > 0) - elif norm == 'min-max_clip': - return RescaleIntensity((-1, 1), per_channel=True, percentiles=(0.5, 99.5), masking_method=lambda x: x > 0) - elif norm == 'znorm': - return ZNormalization(per_channel=True, masking_method=lambda x: x > 0) - elif norm == 'znorm_clip': - return ZNormalization(per_channel=True, percentiles=(0.5, 99.5), masking_method=lambda x: x > 0) - else: - raise ValueError("Unknown normalization") - else: - return norm diff --git a/application/jobs/3dcnn_ptl/app/custom/env_config.py b/application/jobs/3dcnn_ptl/app/custom/env_config.py deleted file mode 100755 index c108f3c3..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/env_config.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -from datetime import datetime - - -def load_environment_variables(): - """Load environment variables and return them as a dictionary.""" - return { - 'task_data_name': os.getenv('DATA_FOLDER', 'DUKE'), - 'scratch_dir': os.getenv('SCRATCH_DIR', '/scratch/'), - 'data_dir': os.getenv('DATA_DIR', '/data/'), - 'max_epochs': int(os.getenv('MAX_EPOCHS', 100)), - 'min_peers': int(os.getenv('MIN_PEERS', 2)), - 'max_peers': int(os.getenv('MAX_PEERS', 7)), - 'local_compare_flag': os.getenv('LOCAL_COMPARE_FLAG', 'False').lower() == 'true', - 'use_adaptive_sync': os.getenv('USE_ADAPTIVE_SYNC', 'False').lower() == 'true', - 'sync_frequency': int(os.getenv('SYNC_FREQUENCY', 1024)), - 'model_name': os.getenv('MODEL_NAME', 'ResNet50'), - 'prediction_flag': os.getenv('PREDICT_FLAG', 'ext') - } - -def load_prediction_modules(prediction_flag): - """Dynamically load prediction modules based on the prediction flag.""" - from predict import predict - return predict, prediction_flag - -def prepare_dataset(task_data_name, data_dir, site_name): - - - """Prepare the dataset based on task data name.""" - print('task_data_name: ', task_data_name) - print("Current Directory ", os.getcwd()) - - # Check if data_dir contains only DUKE_ext - try: - available_dirs = next(os.walk(data_dir))[1] # List directories directly under data_dir - except StopIteration: - print(f"No directories found under data_dir: {data_dir}") - raise ValueError("No directories found under data_dir") - if 'DUKE_ext' in available_dirs: - print("Only DUKE_ext directory found under data_dir. Setting task_data_name to DUKE_ext.") - task_data_name = "DUKE_ext" - - dataset_class = None - if task_data_name == "multi_ext": - from data.datasets import DUKE_Dataset3D_collab as dataset_class - elif task_data_name == "DUKE_ext": - from data.datasets import DUKE_Dataset3D_external as dataset_class - elif task_data_name == "DUKE": - from data.datasets import DUKE_Dataset3D as dataset_class - else: - print(f"Invalid task data name specified: {task_data_name}") - - - if dataset_class: - return dataset_class(flip=True, path_root=os.path.join(data_dir, site_name)), task_data_name - else: - raise ValueError("Invalid task data name specified") - -def generate_run_directory(scratch_dir, task_data_name, model_name, local_compare_flag): - """Generate the directory path for the current run.""" - current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") - mode = 'local_compare' if local_compare_flag else 'swarm_learning' - # make dir if not exist - if not os.path.exists(scratch_dir): - os.makedirs(scratch_dir) - return os.path.join(scratch_dir, f"{current_time}_{task_data_name}_{model_name}_{mode}") - -def cal_weightage(train_size): - estimated_full_dataset_size = 808 # exact training size of Duke 80% dataset, which is the largest across multiple nodes - weightage = int(100 * train_size / estimated_full_dataset_size) - if weightage > 100: - weightage = 100 - return weightage - -def cal_max_epochs(preset_max_epochs, weightage): - return int(preset_max_epochs / (weightage / 100)) \ No newline at end of file diff --git a/application/jobs/3dcnn_ptl/app/custom/model_selector.py b/application/jobs/3dcnn_ptl/app/custom/model_selector.py deleted file mode 100755 index f51bf0d3..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/model_selector.py +++ /dev/null @@ -1,80 +0,0 @@ -from models import ResNet, VisionTransformer, EfficientNet, DenseNet121, UNet3D - - -def select_model(model_name: str): - """ - Selects and returns a model based on the provided model name. - - Args: - model_name (str): The name of the model to select. - - Returns: - nn.Module: The selected model. - - Raises: - ValueError: If an invalid model name is provided. - """ - print('Using model:', model_name) - - # Define ResNet layer configurations - resnet_layers = { - 'ResNet18': [2, 2, 2, 2], - 'ResNet34': [3, 4, 6, 3], - 'ResNet50': [3, 4, 6, 3], - 'ResNet101': [3, 4, 23, 3], - 'ResNet152': [3, 8, 36, 3], - } - - try: - if model_name in resnet_layers: - layers = resnet_layers[model_name] - model = ResNet(in_ch=1, out_ch=1, spatial_dims=3, layers=layers) - elif model_name in ['efficientnet_l1', 'efficientnet_l2', 'efficientnet_b4', 'efficientnet_b7']: - model = EfficientNet(model_name=model_name, in_ch=1, out_ch=1, spatial_dims=3) - elif model_name.startswith('EfficientNet3D'): - # Define EfficientNet3D configurations based on model_name - blocks_args_str = { - 'EfficientNet3Db0': [ - "r1_k3_s11_e1_i32_o16_se0.25", - "r2_k3_s22_e6_i16_o24_se0.25", - "r2_k5_s22_e6_i24_o40_se0.25", - "r3_k3_s22_e6_i40_o80_se0.25", - "r3_k5_s11_e6_i80_o112_se0.25", - "r4_k5_s22_e6_i112_o192_se0.25", - "r1_k3_s11_e6_i192_o320_se0.25" - ], - 'EfficientNet3Db4': [ - "r1_k3_s11_e1_i48_o24_se0.25", - "r3_k3_s22_e6_i24_o32_se0.25", - "r3_k5_s22_e6_i32_o56_se0.25", - "r4_k3_s22_e6_i56_o112_se0.25", - "r4_k5_s11_e6_i112_o160_se0.25", - "r5_k5_s22_e6_i160_o272_se0.25", - "r2_k3_s11_e6_i272_o448_se0.25" - ], - 'EfficientNet3Db7': [ - "r1_k3_s11_e1_i32_o32_se0.25", - "r4_k3_s22_e6_i32_o48_se0.25", - "r4_k5_s22_e6_i48_o80_se0.25", - "r4_k3_s22_e6_i80_o160_se0.25", - "r6_k5_s11_e6_i160_o256_se0.25", - "r6_k5_s22_e6_i256_o384_se0.25", - "r3_k3_s11_e6_i384_o640_se0.25" - ], - }[model_name[-2:]] # Extract b0, b4, b7 from model_name - model = EfficientNet(in_ch=1, out_ch=1, spatial_dims=3) - elif model_name == 'DenseNet121': - model = DenseNet121(in_ch=1, out_ch=1, spatial_dims=3) - elif model_name == 'UNet3D': - model = UNet3D(in_ch=1, out_ch=1, spatial_dims=3) - elif model_name == 'VisionTransformer': - model = VisionTransformer(in_ch=1, out_ch=1, spatial_dims=3) - else: - raise ValueError("Invalid network model specified") - - return model - except KeyError as e: - raise ValueError(f"Model configuration for {model_name} not found: {e}") - except Exception as e: - raise RuntimeError(f"Error while creating the model {model_name}: {e}") - diff --git a/application/jobs/3dcnn_ptl/app/custom/models/__init__.py b/application/jobs/3dcnn_ptl/app/custom/models/__init__.py deleted file mode 100644 index b2d2c392..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -This package initializes the necessary modules and classes for the project. - -Modules: - base_model: Contains basic models including VeryBasicModel, BasicModel, and BasicClassifier. - resnet: Contains the ResNet model implementation. -""" - -from .base_model import VeryBasicModel, BasicModel, BasicClassifier -from .resnet import ResNet -from .densenet import DenseNet121 -from .efficientNet import EfficientNet -from .uNet3D import UNet3D -from .vit import VisionTransformer - -__all__ = ['VeryBasicModel', 'BasicModel', 'BasicClassifier', 'ResNet', 'DenseNet121', 'EfficientNet', 'UNet3D', 'VisionTransformer'] diff --git a/application/jobs/3dcnn_ptl/app/custom/models/base_model.py b/application/jobs/3dcnn_ptl/app/custom/models/base_model.py deleted file mode 100644 index 2c998c83..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/base_model.py +++ /dev/null @@ -1,269 +0,0 @@ -from typing import List, Union -from pathlib import Path -import json -import torch -import torch.nn as nn -import torch.nn.functional as F -import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.types import EPOCH_OUTPUT -from torchmetrics import AUROC, Accuracy - - -class VeryBasicModel(pl.LightningModule): - """ - A very basic model class extending LightningModule with basic functionality. - - Attributes: - _step_train (int): Counter for training steps. - _step_val (int): Counter for validation steps. - _step_test (int): Counter for test steps. - """ - - def __init__(self): - super().__init__() - self.save_hyperparameters() - self._step_train = -1 - self._step_val = -1 - self._step_test = -1 - - def forward(self, x_in): - """Forward pass. Must be implemented by subclasses.""" - raise NotImplementedError - - def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): - """Step function for training, validation, and testing. Must be implemented by subclasses.""" - raise NotImplementedError - - def _epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]], state: str): - """Epoch end function.""" - return - - def training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): - self._step_train += 1 - return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx) - - def validation_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): - self._step_val += 1 - return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx) - - def test_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): - self._step_test += 1 - return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx) - - def training_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "train") - return super().training_epoch_end(outputs) - - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "val") - return super().validation_epoch_end(outputs) - - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "test") - return super().test_epoch_end(outputs) - - @classmethod - def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): - """Saves the best model checkpoint path. - - Args: - path_checkpoint_dir (str): Directory to save the checkpoint. - best_model_path (str): Path to the best model. - """ - with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: - json.dump({'best_model_epoch': Path(best_model_path).name}, f) - - @classmethod - def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs): - """Gets the best model checkpoint path. - - Args: - path_checkpoint_dir (str): Directory containing the checkpoint. - version (int, optional): Version of the checkpoint. Defaults to 0. - - Returns: - Path: Path to the best checkpoint. - """ - path_version = 'lightning_logs/version_' + str(version) - with open(Path(path_checkpoint_dir) / path_version / 'best_checkpoint.json', 'r') as f: - path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) - return Path(path_checkpoint_dir) / path_rel_best_checkpoint - - @classmethod - def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs): - """Loads the best model checkpoint. - - Args: - path_checkpoint_dir (str): Directory containing the checkpoint. - version (int, optional): Version of the checkpoint. Defaults to 0. - - Returns: - LightningModule: The loaded model. - """ - path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version) - return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) - - def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): - """Loads pretrained weights from a checkpoint. - - Args: - checkpoint_path (str): Path to the checkpoint. - map_location (str, optional): Device to map the checkpoint. Defaults to None. - - Returns: - LightningModule: The model with loaded weights. - """ - if checkpoint_path.is_dir(): - checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) - - with pl_legacy_patch(): - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - return self.load_weights(checkpoint["state_dict"], **kwargs) - - def load_weights(self, pretrained_weights, strict=True, **kwargs): - """Loads weights into the model. - - Args: - pretrained_weights (dict): Pretrained weights. - strict (bool, optional): Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module’s `state_dict` function. Defaults to True. - - Returns: - LightningModule: The model with loaded weights. - """ - filter_fn = kwargs.get('filter', lambda key: key in pretrained_weights) - init_weights = self.state_dict() - pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter_fn(key)} - init_weights.update(pretrained_weights) - self.load_state_dict(init_weights, strict=strict) - return self - - -class BasicModel(VeryBasicModel): - """ - A basic model class with optimizer and learning rate scheduler configurations. - - Attributes: - optimizer (Optimizer): The optimizer to use. - optimizer_kwargs (dict): Keyword arguments for the optimizer. - lr_scheduler (Scheduler): The learning rate scheduler to use. - lr_scheduler_kwargs (dict): Keyword arguments for the learning rate scheduler. - """ - - def __init__( - self, - optimizer=torch.optim.AdamW, - optimizer_kwargs={'lr': 1e-3, 'weight_decay': 1e-2}, - lr_scheduler=None, - lr_scheduler_kwargs={}, - ): - super().__init__() - self.save_hyperparameters() - self.optimizer = optimizer - self.optimizer_kwargs = optimizer_kwargs - self.lr_scheduler = lr_scheduler - self.lr_scheduler_kwargs = lr_scheduler_kwargs - - def configure_optimizers(self): - """Configures the optimizers and learning rate schedulers. - - Returns: - list: List containing the optimizer and optionally the learning rate scheduler. - """ - optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) - if self.lr_scheduler is not None: - lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) - return [optimizer], [lr_scheduler] - else: - return [optimizer] - - -class BasicClassifier(BasicModel): - """ - A basic classifier model with loss function and metrics. - - Attributes: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int): Number of spatial dimensions. - loss (Loss): The loss function. - loss_kwargs (dict): Keyword arguments for the loss function. - auc_roc (ModuleDict): Dictionary of AUROC metrics. - acc (ModuleDict): Dictionary of Accuracy metrics. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int, - loss=torch.nn.CrossEntropyLoss, - loss_kwargs={}, - optimizer=torch.optim.AdamW, - optimizer_kwargs={'lr': 1e-3, 'weight_decay': 1e-2}, - lr_scheduler=None, - lr_scheduler_kwargs={}, - aucroc_kwargs={"task": "binary"}, - acc_kwargs={"task": "binary"} - ): - super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs) - self.in_ch = in_ch - self.out_ch = out_ch - self.spatial_dims = spatial_dims - self.loss = loss(**loss_kwargs) - self.loss_kwargs = loss_kwargs - - self.auc_roc = nn.ModuleDict({state: AUROC(**aucroc_kwargs) for state in ["train_", "val_", "test_"]}) - self.acc = nn.ModuleDict({state: Accuracy(**acc_kwargs) for state in ["train_", "val_", "test_"]}) - - def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): - """Step function for training, validation, and testing. - - Args: - batch (dict): Input batch. - batch_idx (int): Batch index. - state (str): State of the model ('train', 'val', 'test'). - step (int): Current step. - optimizer_idx (int): Index of the optimizer. - - Returns: - Tensor: Loss value. - """ - source, target = batch['source'], batch['target'] - target = target[:, None].float() - batch_size = source.shape[0] - - # Run Model - pred = self(source) - - # Compute Loss - logging_dict = {} - logging_dict['loss'] = self.loss(pred, target) - - # Compute Metrics - with torch.no_grad(): - self.acc[state + "_"].update(pred, target) - self.auc_roc[state + "_"].update(pred, target) - - # Log Scalars - for metric_name, metric_val in logging_dict.items(): - self.log(f"{state}/{metric_name}", metric_val.cpu() if hasattr(metric_val, 'cpu') else metric_val, - batch_size=batch_size, on_step=True, on_epoch=True) - - return logging_dict['loss'] - - def _epoch_end(self, outputs, state): - """Epoch end function. - - Args: - outputs (Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]): Outputs of the epoch. - state (str): State of the model ('train', 'val', 'test'). - """ - batch_size = len(outputs) - for name, value in [("ACC", self.acc[state + "_"]), ("AUC_ROC", self.auc_roc[state + "_"])]: - self.log(f"{state}/{name}", value.compute().cpu(), batch_size=batch_size, on_step=False, on_epoch=True) - value.reset() diff --git a/application/jobs/3dcnn_ptl/app/custom/models/densenet.py b/application/jobs/3dcnn_ptl/app/custom/models/densenet.py deleted file mode 100755 index 6ef6fa3d..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/densenet.py +++ /dev/null @@ -1,59 +0,0 @@ -from .base_model import BasicClassifier -import monai.networks.nets as nets -import torch -import torch.nn.functional as F - -class DenseNet121(BasicClassifier): - """ - DenseNet121 model for classification tasks. - - Attributes: - model (nn.Module): The DenseNet model from MONAI. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the DenseNet121 model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - self.model = nets.DenseNet264(spatial_dims=spatial_dims, in_channels=in_ch, out_channels=out_ch) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the DenseNet121 model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - pred_hor = self.model(x_in) - return pred_hor diff --git a/application/jobs/3dcnn_ptl/app/custom/models/efficientNet.py b/application/jobs/3dcnn_ptl/app/custom/models/efficientNet.py deleted file mode 100755 index 33c20c0d..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/efficientNet.py +++ /dev/null @@ -1,243 +0,0 @@ -from .base_model import BasicClassifier -import torch -import torch.nn.functional as F -import timm -import monai.networks.nets as nets - -class EfficientNet(BasicClassifier): - """ - EfficientNet model for 2D classification tasks. - - Attributes: - model (nn.Module): The EfficientNet model from TIMM. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - model_name: str = 'efficientnet_l2', - pretrained: bool = False, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the EfficientNet model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - model_name (str, optional): Name of the EfficientNet model. Defaults to 'efficientnet_l2'. - pretrained (bool, optional): Whether to use pretrained weights. Defaults to False. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=in_ch, num_classes=out_ch) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the EfficientNet model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - batch_size, _, num_slices, height, width = x_in.shape - x_in = x_in.view(batch_size * num_slices, 1, height, width) # Reshape to [batch_size * num_slices, 1, height, width] - - pred_hor = self.model(x_in) # Process each slice with EfficientNet - - # Reshape the output back to [batch_size, num_slices, out_ch] - out_ch = pred_hor.shape[1] - pred_hor = pred_hor.view(batch_size, num_slices, out_ch) - - # Combine the results from each slice (e.g., by averaging or max-pooling) - combined_pred = torch.mean(pred_hor, dim=1) - - return combined_pred - - -class EfficientNet3D(BasicClassifier): - """ - EfficientNet model for 3D classification tasks. - - Attributes: - model (nn.Module): The EfficientNet model from MONAI. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - blocks_args_str: list = None, - width_coefficient: float = 1.0, - depth_coefficient: float = 1.0, - dropout_rate: float = 0.2, - image_size: int = 224, - norm: tuple = ('batch', {'eps': 0.001, 'momentum': 0.01}), - drop_connect_rate: float = 0.2, - depth_divisor: int = 8, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the EfficientNet3D model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - blocks_args_str (list, optional): List of block arguments. Defaults to None. - width_coefficient (float, optional): Width coefficient for EfficientNet. Defaults to 1.0. - depth_coefficient (float, optional): Depth coefficient for EfficientNet. Defaults to 1.0. - dropout_rate (float, optional): Dropout rate. Defaults to 0.2. - image_size (int, optional): Image size. Defaults to 224. - norm (tuple, optional): Normalization configuration. Defaults to ('batch', {'eps': 0.001, 'momentum': 0.01}). - drop_connect_rate (float, optional): Drop connect rate. Defaults to 0.2. - depth_divisor (int, optional): Depth divisor. Defaults to 8. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - if blocks_args_str is None: - blocks_args_str = [ - "r1_k3_s11_e1_i32_o16_se0.25", - "r2_k3_s22_e6_i16_o24_se0.25", - "r2_k5_s22_e6_i24_o40_se0.25", - "r3_k3_s22_e6_i40_o80_se0.25", - "r3_k5_s11_e6_i80_o112_se0.25", - "r4_k5_s22_e6_i112_o192_se0.25", - "r1_k3_s11_e6_i192_o320_se0.25"] - self.model = nets.EfficientNet(blocks_args_str, spatial_dims, in_ch, out_ch, - width_coefficient, depth_coefficient, dropout_rate, - image_size, norm, drop_connect_rate, depth_divisor) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the EfficientNet3D model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - pred_hor = self.model(x_in) - return pred_hor - - -class EfficientNet3Db7(BasicClassifier): - """ - EfficientNetB7 model for 3D classification tasks. - - Attributes: - model (nn.Module): The EfficientNetB7 model from MONAI. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - blocks_args_str: list = None, - width_coefficient: float = 1.0, - depth_coefficient: float = 1.0, - dropout_rate: float = 0.2, - image_size: int = 224, - norm: tuple = ('batch', {'eps': 0.001, 'momentum': 0.01}), - drop_connect_rate: float = 0.2, - depth_divisor: int = 8, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the EfficientNet3Db7 model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - blocks_args_str (list, optional): List of block arguments. Defaults to None. - width_coefficient (float, optional): Width coefficient for EfficientNet. Defaults to 1.0. - depth_coefficient (float, optional): Depth coefficient for EfficientNet. Defaults to 1.0. - dropout_rate (float, optional): Dropout rate. Defaults to 0.2. - image_size (int, optional): Image size. Defaults to 224. - norm (tuple, optional): Normalization configuration. Defaults to ('batch', {'eps': 0.001, 'momentum': 0.01}). - drop_connect_rate (float, optional): Drop connect rate. Defaults to 0.2. - depth_divisor (int, optional): Depth divisor. Defaults to 8. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - if blocks_args_str is None: - blocks_args_str = [ - "r1_k3_s11_e1_i32_o32_se0.25", - "r4_k3_s22_e6_i32_o48_se0.25", - "r4_k5_s22_e6_i48_o80_se0.25", - "r4_k3_s22_e6_i80_o160_se0.25", - "r6_k5_s11_e6_i160_o256_se0.25", - "r6_k5_s22_e6_i256_o384_se0.25", - "r3_k3_s11_e6_i384_o640_se0.25", - ] - - self.model = nets.EfficientNet(blocks_args_str, spatial_dims, in_ch, out_ch, - width_coefficient, depth_coefficient, dropout_rate, - image_size, norm, drop_connect_rate, depth_divisor) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the EfficientNet3Db7 model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - pred_hor = self.model(x_in) - return pred_hor diff --git a/application/jobs/3dcnn_ptl/app/custom/models/resnet.py b/application/jobs/3dcnn_ptl/app/custom/models/resnet.py deleted file mode 100644 index fd94a674..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/resnet.py +++ /dev/null @@ -1,68 +0,0 @@ -from models.base_model import BasicClassifier -import monai.networks.nets as nets -import torch - - -class ResNet(BasicClassifier): - """ - ResNet model for classification tasks. - - Attributes: - model (nn.Module): The ResNet model from MONAI. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - block: str = 'basic', - layers: list = [3, 4, 6, 3], - block_inplanes: list = [64, 128, 256, 512], - feed_forward: bool = True, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the ResNet model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - block (str, optional): Block type for ResNet. Defaults to 'basic'. - layers (list, optional): List of layer configurations. Defaults to [3, 4, 6, 3]. - block_inplanes (list, optional): List of block in-plane sizes. Defaults to [64, 128, 256, 512]. - feed_forward (bool, optional): Whether to use feed forward. Defaults to True. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - self.model = nets.ResNet( - block, layers, block_inplanes, spatial_dims, in_ch, 7, 1, False, 'B', 1.0, out_ch, feed_forward, True - ) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the ResNet model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - return self.model(x_in) diff --git a/application/jobs/3dcnn_ptl/app/custom/models/uNet3D.py b/application/jobs/3dcnn_ptl/app/custom/models/uNet3D.py deleted file mode 100755 index aba13554..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/uNet3D.py +++ /dev/null @@ -1,147 +0,0 @@ -from .base_model import BasicClassifier -import monai.networks.nets as nets -import torch - -class UNet3D(BasicClassifier): - """ - UNet3D model for 3D segmentation tasks. - - Attributes: - model (nn.Module): The UNet3D model from MONAI. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - channels: tuple = (16, 32, 64, 128, 256), - strides: tuple = (2, 2, 2, 2), - num_res_units: int = 2, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the UNet3D model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - channels (tuple, optional): Tuple of channel sizes. Defaults to (16, 32, 64, 128, 256). - strides (tuple, optional): Tuple of stride sizes. Defaults to (2, 2, 2, 2). - num_res_units (int, optional): Number of residual units. Defaults to 2. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - self.model = nets.UNet( - dimensions=spatial_dims, - in_channels=in_ch, - out_channels=out_ch, - channels=channels, - strides=strides, - num_res_units=num_res_units - ) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the UNet3D model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - pred_hor = self.model(x_in) - return pred_hor - - def _generate_predictions(self, source: torch.Tensor) -> torch.Tensor: - """ - Generates predictions for the given input tensor. - - Args: - source (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Predicted tensor. - """ - return self.forward(source) - - def _step(self, batch: dict, batch_idx: int, phase: str, optimizer_idx: int = 0) -> torch.Tensor: - """ - Performs a step in the training or validation phase. - - Args: - batch (dict): Input batch. - batch_idx (int): Batch index. - phase (str): Current phase ('train' or 'val'). - optimizer_idx (int, optional): Index of the optimizer. Defaults to 0. - - Returns: - torch.Tensor: Loss value. - """ - source, target = batch['source'], batch['target'] - - if phase == "train": - pred = self._generate_predictions(source) - elif phase == "val": - pred = self._generate_predictions(source) - else: - raise ValueError(f"Invalid phase: {phase}") - - target = target.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(pred).float() # Cast target to float - loss = self.loss(pred, target) - - logging_dict = {f"{phase}_loss": loss} - - if phase == "val": - logging_dict["y_true"] = target - logging_dict["y_pred"] = pred - - logging_dict = {k: v.mean() for k, v in logging_dict.items()} # Add this line before logging - self.log_dict(logging_dict, on_step=(phase == "train"), on_epoch=True, prog_bar=True, logger=True) - - return loss - - def validation_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: - """ - Performs a step in the validation phase. - - Args: - batch (dict): Input batch. - batch_idx (int): Batch index. - optimizer_idx (int, optional): Index of the optimizer. Defaults to 0. - - Returns: - torch.Tensor: Loss value. - """ - return self._step(batch, batch_idx, "val", optimizer_idx) - - def training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor: - """ - Performs a step in the training phase. - - Args: - batch (dict): Input batch. - batch_idx (int): Batch index. - optimizer_idx (int, optional): Index of the optimizer. Defaults to 0. - - Returns: - torch.Tensor: Loss value. - """ - return self._step(batch, batch_idx, "train", optimizer_idx) diff --git a/application/jobs/3dcnn_ptl/app/custom/models/vit.py b/application/jobs/3dcnn_ptl/app/custom/models/vit.py deleted file mode 100755 index 92bba936..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/models/vit.py +++ /dev/null @@ -1,128 +0,0 @@ -from .base_model import BasicClassifier -import torch -import torch.nn as nn -import torch.nn.functional as F -import timm -from timm.models.vision_transformer import VisionTransformer as TimmVisionTransformer - -class VisionTransformer(BasicClassifier): - """ - VisionTransformer model for 3D classification tasks. - - Attributes: - model (nn.Module): The VisionTransformer3D model. - """ - - def __init__( - self, - in_ch: int, - out_ch: int, - spatial_dims: int = 3, - model_name: str = 'vit_base_patch16_224', - pretrained: bool = False, - loss=torch.nn.BCEWithLogitsLoss, - loss_kwargs: dict = {}, - optimizer=torch.optim.AdamW, - optimizer_kwargs: dict = {'lr': 1e-4}, - lr_scheduler=None, - lr_scheduler_kwargs: dict = {}, - aucroc_kwargs: dict = {"task": "binary"}, - acc_kwargs: dict = {"task": "binary"} - ): - """ - Initializes the VisionTransformer model with the given parameters. - - Args: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int, optional): Number of spatial dimensions. Defaults to 3. - model_name (str, optional): Name of the VisionTransformer model. Defaults to 'vit_base_patch16_224'. - pretrained (bool, optional): Whether to use pretrained weights. Defaults to False. - loss (callable, optional): Loss function. Defaults to torch.nn.BCEWithLogitsLoss. - loss_kwargs (dict, optional): Keyword arguments for the loss function. Defaults to {}. - optimizer (Optimizer, optional): Optimizer. Defaults to torch.optim.AdamW. - optimizer_kwargs (dict, optional): Keyword arguments for the optimizer. Defaults to {'lr': 1e-4}. - lr_scheduler (Scheduler, optional): Learning rate scheduler. Defaults to None. - lr_scheduler_kwargs (dict, optional): Keyword arguments for the learning rate scheduler. Defaults to {}. - aucroc_kwargs (dict, optional): Keyword arguments for AUROC. Defaults to {"task": "binary"}. - acc_kwargs (dict, optional): Keyword arguments for Accuracy. Defaults to {"task": "binary"}. - """ - super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler, - lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs) - self.model = VisionTransformer3D(model_name, pretrained=pretrained, in_chans=in_ch, num_classes=out_ch) - - def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Forward pass of the VisionTransformer model. - - Args: - x_in (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - print(x_in.shape) - pred_hor = self.model(x_in) - print(pred_hor.shape) - pred_hor = self.model(x_in) - return pred_hor - -class PatchEmbed3D(nn.Module): - """ - 3D Patch Embedding module for Vision Transformer. - - Attributes: - proj (nn.Module): The convolutional projection layer. - """ - - def __init__(self, in_chans: int, embed_dim: int, patch_size: tuple): - """ - Initializes the PatchEmbed3D module. - - Args: - in_chans (int): Number of input channels. - embed_dim (int): Embedding dimension. - patch_size (tuple): Size of the patches. - """ - super().__init__() - self.proj = nn.Sequential( - nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size), - nn.Flatten(2) - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the PatchEmbed3D module. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Output tensor. - """ - B, C, D, H, W = x.shape - x = self.proj(x) - x = x.transpose(1, 2) - return x - -class VisionTransformer3D(TimmVisionTransformer): - """ - 3D Vision Transformer model extending TimmVisionTransformer. - - Attributes: - patch_embed (nn.Module): The 3D Patch Embedding module. - """ - - def __init__(self, *args, **kwargs): - """ - Initializes the VisionTransformer3D model with the given parameters. - - Args: - *args: Positional arguments for the TimmVisionTransformer. - **kwargs: Keyword arguments for the TimmVisionTransformer. - """ - super().__init__(*args, **kwargs) - in_chans = kwargs.get("in_chans", 3) - embed_dim = kwargs.get("embed_dim", 768) - patch_size = kwargs.get("patch_size", (2, 16, 16)) - self.patch_embed = PatchEmbed3D(in_chans, embed_dim, patch_size) diff --git a/application/jobs/3dcnn_ptl/app/custom/predict.py b/application/jobs/3dcnn_ptl/app/custom/predict.py deleted file mode 100755 index 3249bc69..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/predict.py +++ /dev/null @@ -1,232 +0,0 @@ -#!/usr/bin/env python3 - -import torch -import numpy as np -from pathlib import Path -import logging -from tqdm import tqdm -from sklearn.metrics import confusion_matrix, f1_score, precision_recall_curve, average_precision_score -import matplotlib.pyplot as plt -import seaborn as sns -import pandas as pd -from data.datasets import DUKE_Dataset3D, DUKE_Dataset3D_external, DUKE_Dataset3D_collab -from data.datamodules import DataModule -from utils.roc_curve import plot_roc_curve, cm2acc, cm2x -from models import ResNet, VisionTransformer, EfficientNet, DenseNet121, UNet3D - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -def predict(model_dir, test_data_dir, model_name, last_flag, prediction_flag, cohort_flag='aachen'): - """ - Predicts and evaluates the model on the test dataset. - - Args: - model_dir (str): Directory containing the model. - test_data_dir (str): Directory containing the test data. - model_name (str): Name of the model to use. - last_flag (bool): Whether to use the last checkpoint or the best checkpoint. - prediction_flag (str): Flag to indicate which dataset to use ('ext', 'internal', 'collab'). - cohort_flag (str, optional): Cohort flag for the output directory name. Defaults to 'aachen'. - """ - try: - path_run = Path(model_dir) - path_out = Path(path_run, f"{prediction_flag}_{cohort_flag}") - logger.info(f"Output path: {path_out.absolute()}") - path_out.mkdir(parents=True, exist_ok=True) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - fontdict = {'fontsize': 10, 'fontweight': 'bold'} - - # Load Data - if prediction_flag == 'ext': - ds = DUKE_Dataset3D_external(flip=False, path_root=test_data_dir) - elif prediction_flag == 'internal': - ds = DUKE_Dataset3D(flip=False, path_root=test_data_dir) - elif prediction_flag == 'collab': - ds = DUKE_Dataset3D_collab(flip=False, path_root=test_data_dir) - else: - raise ValueError("Invalid prediction_flag specified") - - logger.info(f"Number of test samples: {len(ds)}") - dm = DataModule(ds_test=ds, batch_size=1) - - # Initialize Model - model = initialize_model(model_name, path_run, last_flag) - model.to(device) - model.eval() - - results = {'uid': [], 'GT': [], 'NN': [], 'NN_pred': []} - threshold = 0.5 - - for batch in tqdm(dm.test_dataloader()): - source, target = batch['source'], batch['target'] - - # Run Model - pred = model(source.to(device)).cpu() - pred_proba = torch.sigmoid(pred).squeeze() - pred_binary = (pred_proba > threshold).long() - - results['GT'].extend(target.tolist()) - results['NN'].extend(pred_binary.tolist() if isinstance(pred_binary.tolist(), list) else [pred_binary.tolist()]) - results['NN_pred'].extend(pred_proba.tolist() if isinstance(pred_proba.tolist(), list) else [pred_proba.tolist()]) - results['uid'].extend(batch['uid']) - - df = pd.DataFrame(results) - save_results(df, path_out, last_flag) - evaluate_results(df, path_out, last_flag, fontdict) - - del model - torch.cuda.empty_cache() - except Exception as e: - logger.error(f"Error in predict function: {e}") - raise - -def initialize_model(model_name, path_run, last_flag): - """ - Initializes the model based on the provided model name. - - Args: - model_name (str): Name of the model to initialize. - path_run (Path): Path to the model directory. - last_flag (bool): Whether to use the last checkpoint or the best checkpoint. - - Returns: - nn.Module: The initialized model. - """ - try: - layers = None - if model_name in ['ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152']: - layers = {'ResNet18': [2, 2, 2, 2], 'ResNet34': [3, 4, 6, 3], 'ResNet50': [3, 4, 6, 3], 'ResNet101': [3, 4, 23, 3], 'ResNet152': [3, 8, 36, 3]}[model_name] - if last_flag: - return ResNet.load_last_checkpoint(path_run, version=0, layers=layers) - return ResNet.load_best_checkpoint(path_run, version=0, layers=layers) - - if model_name in ['efficientnet_l1', 'efficientnet_l2', 'efficientnet_b4', 'efficientnet_b7']: - if last_flag: - return EfficientNet.load_last_checkpoint(path_run, version=0, model_name=model_name) - return EfficientNet.load_best_checkpoint(path_run, version=0, model_name=model_name) - - if model_name.startswith('EfficientNet3D'): - blocks_args_str = { - 'EfficientNet3Db0': ["r1_k3_s11_e1_i32_o16_se0.25", "r2_k3_s22_e6_i16_o24_se0.25", "r2_k5_s22_e6_i24_o40_se0.25", "r3_k3_s22_e6_i40_o80_se0.25", "r3_k5_s11_e6_i80_o112_se0.25", "r4_k5_s22_e6_i112_o192_se0.25", "r1_k3_s11_e6_i192_o320_se0.25"], - 'EfficientNet3Db4': ["r1_k3_s11_e1_i48_o24_se0.25", "r3_k3_s22_e6_i24_o32_se0.25", "r3_k5_s22_e6_i32_o56_se0.25", "r4_k3_s22_e6_i56_o112_se0.25", "r4_k5_s11_e6_i112_o160_se0.25", "r5_k5_s22_e6_i160_o272_se0.25", "r2_k3_s11_e6_i272_o448_se0.25"], - 'EfficientNet3Db7': ["r1_k3_s11_e1_i32_o32_se0.25", "r4_k3_s22_e6_i32_o48_se0.25", "r4_k5_s22_e6_i48_o80_se0.25", "r4_k3_s22_e6_i80_o160_se0.25", "r6_k5_s11_e6_i160_o256_se0.25", "r6_k5_s22_e6_i256_o384_se0.25", "r3_k3_s11_e6_i384_o640_se0.25"] - }[model_name] - if last_flag: - return EfficientNet3D.load_last_checkpoint(path_run, version=0, blocks_args_str=blocks_args_str) - return EfficientNet3D.load_best_checkpoint(path_run, version=0, blocks_args_str=blocks_args_str) - - if model_name == 'DenseNet121': - if last_flag: - return DenseNet121.load_last_checkpoint(path_run, version=0) - return DenseNet121.load_best_checkpoint(path_run, version=0) - - if model_name == 'UNet3D': - if last_flag: - return UNet3D.load_last_checkpoint(path_run, version=0) - return UNet3D.load_best_checkpoint(path_run, version=0) - - raise ValueError("Invalid network model specified") - except Exception as e: - logger.error(f"Error in initialize_model function: {e}") - raise - -def save_results(df, path_out, last_flag): - """ - Saves the prediction results to a CSV file. - - Args: - df (pd.DataFrame): DataFrame containing the results. - path_out (Path): Path to the output directory. - last_flag (bool): Whether to save results for the last checkpoint or the best checkpoint. - """ - try: - file_name = 'results_last.csv' if last_flag else 'results.csv' - df.to_csv(path_out / file_name, index=False) - except Exception as e: - logger.error(f"Error in save_results function: {e}") - raise - -def evaluate_results(df, path_out, last_flag, fontdict): - """ - Evaluates the prediction results and saves metrics and plots. - - Args: - df (pd.DataFrame): DataFrame containing the results. - path_out (Path): Path to the output directory. - last_flag (bool): Whether to save results for the last checkpoint or the best checkpoint. - fontdict (dict): Font dictionary for plot titles and labels. - """ - try: - f1 = f1_score(df['GT'], df['NN']) - logger.info(f"F1 Score: {f1:.2f}") - - cm = confusion_matrix(df['GT'], df['NN']) - tn, fp, fn, tp = cm.ravel() - n = len(df) - logger.info(f"Confusion Matrix: TN {tn} ({tn / n * 100:.2f}%), FP {fp} ({fp / n * 100:.2f}%), FN {fn} ({fn / n * 100:.2f}%), TP {tp} ({tp / n * 100:.2f}%)") - - fig, axis = plt.subplots(ncols=1, nrows=1, figsize=(6, 6)) - y_pred_lab = np.asarray(df['NN_pred']) - y_true_lab = np.asarray(df['GT']) - tprs, fprs, auc_val, thrs, opt_idx, cm = plot_roc_curve(y_true_lab, y_pred_lab, axis, fontdict=fontdict) - fig.tight_layout() - file_name = 'roc_last.png' if last_flag else 'roc.png' - fig.savefig(path_out / file_name, dpi=300) - - precision, recall, _ = precision_recall_curve(y_true_lab, y_pred_lab) - ap = average_precision_score(y_true_lab, y_pred_lab) - - ppv = tp / (tp + fp) - npv = tn / (tn + fn) - - acc = cm2acc(cm) - _, _, sens, spec = cm2x(cm) - df_cm = pd.DataFrame(data=cm, columns=['False', 'True'], index=['False', 'True']) - fig, axis = plt.subplots(1, 1, figsize=(4, 4)) - sns.heatmap(df_cm, ax=axis, cbar=False, fmt='d', annot=True) - axis.set_title(f'Confusion Matrix ACC={acc:.2f}', fontdict=fontdict) - axis.set_xlabel('Prediction', fontdict=fontdict) - axis.set_ylabel('True', fontdict=fontdict) - fig.tight_layout() - file_name = 'confusion_matrix_last.png' if last_flag else 'confusion_matrix.png' - fig.savefig(path_out / file_name, dpi=300) - - logger.info(f"Malign Objects: {np.sum(y_true_lab)}") - logger.info(f"Confusion Matrix {cm}") - logger.info(f"Sensitivity {sens:.2f}") - logger.info(f"Specificity {spec:.2f}") - - with open(path_out / 'metrics.txt', 'w') as f: - f.write(f"AUC: {auc_val:.2f}\n") - f.write(f"F1 Score: {f1:.2f}\n") - f.write(f"Sensitivity: {sens:.2f}\n") - f.write(f"Specificity: {spec:.2f}\n") - f.write(f"PPV: {ppv:.2f}\n") - f.write(f"NPV: {npv:.2f}\n") - f.write(f"ACC: {acc:.2f}\n") - f.write(f"AP: {ap:.2f}\n") - - print(f"AUC: {auc_val:.2f}") - print(f"F1 Score: {f1:.2f}") - print(f"Sensitivity: {sens:.2f}") - print(f"Specificity: {spec:.2f}") - print(f"PPV: {ppv:.2f}") - print(f"NPV: {npv:.2f}") - print(f"ACC: {acc:.2f}") - print(f"AP: {ap:.2f}") - except Exception as e: - logger.error(f"Error in evaluate_results function: {e}") - raise - -if __name__ == "__main__": - wouter_data_path = "/mnt/sda1/swarm-learning/wouter_data/preprocessed_re/" - athens_data_path = "/mnt/sda1/swarm-learning/athens_data/preprocessed_athens/" - predict( - model_dir=Path('/mnt/sda1/odelia_paper_trained_results/2023_07_04_180000_DUKE_ext_ResNet101_swarm_learning'), - test_data_dir=athens_data_path, - model_name='ResNet101', - last_flag=False, - prediction_flag='collab', - cohort_flag='athens' - ) diff --git a/application/jobs/3dcnn_ptl/app/custom/threedcnn_ptl.py b/application/jobs/3dcnn_ptl/app/custom/threedcnn_ptl.py deleted file mode 100644 index 4a298c76..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/threedcnn_ptl.py +++ /dev/null @@ -1,162 +0,0 @@ -from sklearn.model_selection import train_test_split -from torch.utils.data import DataLoader, Subset -from collections import Counter -import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger -from data.datamodules import DataModule -from model_selector import select_model -from env_config import load_environment_variables, load_prediction_modules, prepare_dataset, generate_run_directory - -import os -import logging - - -def get_num_epochs_per_round(site_name: str) -> int: - #TODO: Set max_epochs based on the data set size - NUM_EPOCHS_FOR_SITE = { "TUD_1": 2, - "TUD_2": 4, - "TUD_3": 8, - "MEVIS_1": 2, - "MEVIS_2": 4, - "UKA": 2, - } - - if site_name in NUM_EPOCHS_FOR_SITE.keys(): - MAX_EPOCHS = NUM_EPOCHS_FOR_SITE[site_name] - else: - MAX_EPOCHS = 5 - - print(f"Site name: {site_name}") - print(f"Max epochs set to: {MAX_EPOCHS}") - - return MAX_EPOCHS - - -def set_up_logging(): - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - return logger - - -def set_up_data_module(env_vars, logger, site_name: str): - ds, task_data_name = prepare_dataset(env_vars['task_data_name'], env_vars['data_dir'], site_name=site_name) - - labels = ds.get_labels() - - # Generate indices and perform stratified split - indices = list(range(len(ds))) - train_indices, val_indices = train_test_split(indices, test_size=0.2, stratify=labels, random_state=42) - - # Create training and validation subsets - ds_train = Subset(ds, train_indices) - ds_val = Subset(ds, val_indices) - - # Extract training labels using the train_indices - train_labels = [labels[i] for i in train_indices] - label_counts = Counter(train_labels) - - # Calculate the total number of samples in the training set - total_samples = len(train_labels) - - # Print the percentage of the training set for each label - for label, count in label_counts.items(): - percentage = (count / total_samples) * 100 - logger.info(f"Label '{label}': {percentage:.2f}% of the training set, Exact count: {count}") - - logger.info(f"Total number of different labels in the training set: {len(label_counts)}") - - ads_val_data = DataLoader(ds_val, batch_size=2, shuffle=False) - logger.info(f'ads_val_data type: {type(ads_val_data)}') - - train_size = len(ds_train) - val_size = len(ds_val) - logger.info(f'Train size: {train_size}') - logger.info(f'Val size: {val_size}') - - dm = DataModule( - ds_train=ds_train, - ds_val=ds_val, - batch_size=1, - num_workers=16, - pin_memory=True, - ) - - return dm - - -def create_run_directory(env_vars): - path_run_dir = generate_run_directory(env_vars['scratch_dir'], env_vars['task_data_name'], env_vars['model_name'], env_vars['local_compare_flag']) - return path_run_dir - - -def prepare_training(logger, max_epochs:int , site_name: str): - try: - env_vars = load_environment_variables() - path_run_dir = create_run_directory(env_vars) - if not torch.cuda.is_available(): - raise(RuntimeError("This example does not work without GPU")) - accelerator = 'gpu' - logger.info(f"Using {accelerator} for training") - - data_module = set_up_data_module(env_vars, logger, site_name) - - # max_epochs = env_vars['max_epochs'] - # cal_max_epochs = cal_max_epochs(max_epochs, cal_weightage(train_size)) - # logger.info(f"Max epochs set to: {cal_max_epochs}") - - # Initialize the model - model_name = env_vars['model_name'] - model = select_model(model_name) - logger.info(f"Using model: {model_name}") - - to_monitor = "val/AUC_ROC" - min_max = "max" - log_every_n_steps = 1 - - checkpointing = ModelCheckpoint( - dirpath=str(path_run_dir), - monitor=to_monitor, - save_last=True, - save_top_k=2, - mode=min_max, - ) - - trainer = Trainer( - accelerator=accelerator, - precision=16, - default_root_dir=str(path_run_dir), - callbacks=[checkpointing], - enable_checkpointing=True, - check_val_every_n_epoch=1, - log_every_n_steps=log_every_n_steps, - max_epochs=max_epochs, - num_sanity_val_steps=2, - logger=TensorBoardLogger(save_dir=path_run_dir) - ) - - except Exception as e: - logger.error(f"Error in set_up_training: {e}") - raise - - return data_module, model, checkpointing, trainer, path_run_dir, env_vars - - -def validate_and_train(logger, data_module, model, trainer) -> None: - logger.info("--- Validate global model ---") - trainer.validate(model, datamodule=data_module) - - logger.info("--- Train new model ---") - trainer.fit(model, datamodule=data_module) - - -def finalize_training(logger, model, checkpointing, trainer, path_run_dir, env_vars) -> None: - model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) - predict, prediction_flag = load_prediction_modules(env_vars['prediction_flag']) - test_data_path = os.path.join(env_vars['data_dir'], env_vars['task_data_name'], 'test') - if os.path.exists(test_data_path): - predict(path_run_dir, test_data_path, env_vars['model_name'], last_flag=False, prediction_flag=prediction_flag) - else: - logger.info('No test data found, not running evaluation') - logger.info('Training completed successfully') diff --git a/application/jobs/3dcnn_ptl/app/custom/utils/roc_curve.py b/application/jobs/3dcnn_ptl/app/custom/utils/roc_curve.py deleted file mode 100644 index 2cfd4a9d..00000000 --- a/application/jobs/3dcnn_ptl/app/custom/utils/roc_curve.py +++ /dev/null @@ -1,132 +0,0 @@ -import numpy as np -from sklearn.metrics import roc_curve, auc, confusion_matrix -import matplotlib - -def plot_roc_curve(y_true, y_score, axis, bootstrapping=1000, drop_intermediate=False, fontdict={}): - """ - Plots the ROC curve with bootstrapping. - - Args: - y_true (array-like): True binary labels. - y_score (array-like): Target scores. - axis (matplotlib.axes.Axes): Matplotlib axis object. - bootstrapping (int, optional): Number of bootstrap samples. Defaults to 1000. - drop_intermediate (bool, optional): Whether to drop some intermediate thresholds. Defaults to False. - fontdict (dict, optional): Dictionary of font properties. Defaults to {}. - - Returns: - tuple: tprs, fprs, auc_val, thrs, opt_idx, conf_matrix - """ - # ----------- Bootstrapping ------------ - tprs, aucs, thrs = [], [], [] - mean_fpr = np.linspace(0, 1, 100) - rand_idxs = np.random.randint(0, len(y_true), size=(bootstrapping, len(y_true))) # Note: with replacement - for rand_idx in rand_idxs: - y_true_set = y_true[rand_idx] - y_score_set = y_score[rand_idx] - fpr, tpr, thresholds = roc_curve(y_true_set, y_score_set, drop_intermediate=drop_intermediate) - tpr_interp = np.interp(mean_fpr, fpr, tpr) # must be interpolated to gain constant/equal fpr positions - tprs.append(tpr_interp) - aucs.append(auc(fpr, tpr)) - optimal_idx = np.argmax(tpr - fpr) - thrs.append(thresholds[optimal_idx]) - - mean_tpr = np.mean(tprs, axis=0) - mean_tpr[-1] = 1.0 - std_tpr = np.std(tprs, axis=0, ddof=1) - tprs_upper = np.minimum(mean_tpr + std_tpr, 1) - tprs_lower = np.maximum(mean_tpr - std_tpr, 0) - - # ------ Averaged based on bootstrapping ------ - mean_auc = np.mean(aucs) - std_auc = np.std(aucs, ddof=1) - - # --------- Specific Case ------------- - fprs, tprs, thrs = roc_curve(y_true, y_score, drop_intermediate=drop_intermediate) - auc_val = auc(fprs, tprs) - opt_idx = np.argmax(tprs - fprs) - opt_tpr = tprs[opt_idx] - opt_fpr = fprs[opt_idx] - - y_scores_bin = y_score >= thrs[opt_idx] # WARNING: Must be >= not > - conf_matrix = confusion_matrix(y_true, y_scores_bin) # [[TN, FP], [FN, TP]] - - axis.plot(fprs, tprs, color='b', label=rf"ROC (AUC = {auc_val:.2f} $\pm$ {std_auc:.2f})", lw=2, alpha=.8) - axis.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, label=r'$\pm$ 1 std. dev.') - axis.hlines(y=opt_tpr, xmin=0.0, xmax=opt_fpr, color='g', linestyle='--') - axis.vlines(x=opt_fpr, ymin=0.0, ymax=opt_tpr, color='g', linestyle='--') - axis.plot(opt_fpr, opt_tpr, color='g', marker='o') - axis.plot([0, 1], [0, 1], linestyle='--', color='k') - axis.set_xlim([0.0, 1.0]) - axis.set_ylim([0.0, 1.0]) - - axis.legend(loc='lower right') - axis.set_xlabel('1 - Specificity', fontdict=fontdict) - axis.set_ylabel('Sensitivity', fontdict=fontdict) - - axis.grid(color='#dddddd') - axis.set_axisbelow(True) - axis.tick_params(colors='#dddddd', which='both') - for xtick in axis.get_xticklabels(): - xtick.set_color('k') - for ytick in axis.get_yticklabels(): - ytick.set_color('k') - for child in axis.get_children(): - if isinstance(child, matplotlib.spines.Spine): - child.set_color('#dddddd') - - return tprs, fprs, auc_val, thrs, opt_idx, conf_matrix - -def cm2acc(cm): - """ - Calculates accuracy from the confusion matrix. - - Args: - cm (array-like): Confusion matrix [[TN, FP], [FN, TP]]. - - Returns: - float: Accuracy. - """ - tn, fp, fn, tp = cm.ravel() - return (tn + tp) / (tn + tp + fn + fp) - -def safe_div(x, y): - """ - Safely divides two numbers, returning NaN if the denominator is zero. - - Args: - x (float): Numerator. - y (float): Denominator. - - Returns: - float: Result of division or NaN if denominator is zero. - """ - if y == 0: - return float('nan') - return x / y - -def cm2x(cm): - """ - Calculates various metrics from the confusion matrix. - - Args: - cm (array-like): Confusion matrix [[TN, FP], [FN, TP]]. - - Returns: - tuple: (ppv, npv, tpr, tnr) - ppv (float): Positive predictive value. - npv (float): Negative predictive value. - tpr (float): True positive rate (sensitivity, recall). - tnr (float): True negative rate (specificity). - """ - tn, fp, fn, tp = cm.ravel() - pp = tp + fp # predicted positive - pn = fn + tn # predicted negative - p = tp + fn # actual positive - n = fp + tn # actual negative - - ppv = safe_div(tp, pp) # positive predictive value - npv = safe_div(tn, pn) # negative predictive value - tpr = safe_div(tp, p) # true positive rate (sensitivity, recall) - tnr = safe_div(tn, n) # true negative rate (specificity) - return ppv, npv, tpr, tnr diff --git a/application/jobs/3dcnn_ptl/README.md b/application/jobs/ODELIA_ternary_classification/README.md similarity index 90% rename from application/jobs/3dcnn_ptl/README.md rename to application/jobs/ODELIA_ternary_classification/README.md index 25b686ff..d10de042 100644 --- a/application/jobs/3dcnn_ptl/README.md +++ b/application/jobs/ODELIA_ternary_classification/README.md @@ -32,7 +32,7 @@ docker run -it --rm \ Before running a swarm dummy training, first make sure the code works in non-swarm mode. ```bash -cd application/jobs/3dcnn_ptl/app/custom/ +cd application/jobs/ODELIA_ternary_classification/app/custom/ export TRAINING_MODE="local_training" export SITE_NAME= export NUM_EPOCHS=1 @@ -45,10 +45,10 @@ cd /workspace The FL Simulator is a lightweight tool that uses threads to simulate multiple clients. It is useful for quick local testing and debugging. Run the following command to start the simulator: ```bash -nvflare simulator -w /tmp/3dcnn_ptl -n 2 -t 2 application/jobs/3dcnn_ptl -c simulated_node_0,simulated_node_1 +nvflare simulator -w /tmp/ODELIA_ternary_classification -n 2 -t 2 application/jobs/ODELIA_ternary_classification -c simulated_node_0,simulated_node_1 ``` -* `-w /tmp/3dcnn_ptl`: Specifies the working directory. +* `-w /tmp/ODELIA_ternary_classification`: Specifies the working directory. * `-n 2`: Sets the number of clients. * `-t 2`: Specifies the number of threads. * `-c simulated_node_0,simulated_node_1`: Names the two simulated nodes. diff --git a/application/jobs/3dcnn_ptl/app/config/config_fed_client.conf b/application/jobs/ODELIA_ternary_classification/app/config/config_fed_client.conf similarity index 96% rename from application/jobs/3dcnn_ptl/app/config/config_fed_client.conf rename to application/jobs/ODELIA_ternary_classification/app/config/config_fed_client.conf index 1b0fa456..42ac6ebb 100644 --- a/application/jobs/3dcnn_ptl/app/config/config_fed_client.conf +++ b/application/jobs/ODELIA_ternary_classification/app/config/config_fed_client.conf @@ -82,11 +82,12 @@ path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" args { model { - path = "models.resnet.ResNet" - args { - in_ch = 1 - out_ch = 1 - } + path = "models.mst.MST" + args { + n_input_channels = 1 + num_classes = 3 + spatial_dims = 3 + } } } } diff --git a/application/jobs/3dcnn_ptl/app/config/config_fed_server.conf b/application/jobs/ODELIA_ternary_classification/app/config/config_fed_server.conf similarity index 96% rename from application/jobs/3dcnn_ptl/app/config/config_fed_server.conf rename to application/jobs/ODELIA_ternary_classification/app/config/config_fed_server.conf index d408e3f0..fe11655d 100644 --- a/application/jobs/3dcnn_ptl/app/config/config_fed_server.conf +++ b/application/jobs/ODELIA_ternary_classification/app/config/config_fed_server.conf @@ -16,7 +16,7 @@ workflows = [ path = "controller.SwarmServerController" args { # can also set aggregation clients and train clients, see class for all available args - num_rounds = 30 + num_rounds = 20 start_task_timeout = 360000 progress_timeout = 360000 end_workflow_timeout = 360000 diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/data/augmentation/augmentations_3d.py b/application/jobs/ODELIA_ternary_classification/app/custom/data/augmentation/augmentations_3d.py new file mode 100644 index 00000000..043f93c0 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/data/augmentation/augmentations_3d.py @@ -0,0 +1,124 @@ +from typing import Union, Optional, Sequence + +import torchio as tio +from torchio.typing import TypeRangeFloat, TypeTripletInt +from torchio.transforms.transform import TypeMaskingMethod +from torchio import Subject, Image + +import torch +import numpy as np + + +class ImageOrSubjectToTensor(object): + """Converts a torchio Image or Subject to a tensor format by swapping axes.""" + + def __call__(self, input: Union[Image, Subject]): + if isinstance(input, Subject): + return {key: val.data.swapaxes(1, -1) if isinstance(val, Image) else val for key, val in input.items()} + else: + return input.data.swapaxes(1, -1) + + +def parse_per_channel(per_channel, channels): + if isinstance(per_channel, bool): + if per_channel == True: + return [(ch,) for ch in range(channels)] + else: + return [tuple(ch for ch in range(channels))] + else: + return per_channel + + +class ZNormalization(tio.ZNormalization): + """Z-Normalization with support for per-channel and per-slice options, and percentile-based clipping.""" + + def __init__( + self, + percentiles: TypeRangeFloat = (0, 100), + per_channel=True, + per_slice=False, + masking_method: TypeMaskingMethod = None, + **kwargs + ): + super().__init__(masking_method=masking_method, **kwargs) + self.percentiles = percentiles + self.per_channel = per_channel + self.per_slice = per_slice + + def apply_normalization(self, subject: Subject, image_name: str, mask: torch.Tensor) -> None: + image = subject[image_name] + per_channel = parse_per_channel(self.per_channel, image.shape[0]) + per_slice = parse_per_channel(self.per_slice, image.shape[-1]) + + image.set_data( + torch.cat([ + torch.cat([ + self._znorm(image.data[chs,][:, :, :, sl, ], mask[chs,][:, :, :, sl, ], image_name, image.path) + for sl in per_slice], dim=-1) + for chs in per_channel]) + ) + + def _znorm(self, image_data, mask, image_name, image_path): + cutoff = torch.quantile(image_data.masked_select(mask).float(), torch.tensor(self.percentiles) / 100.0) + torch.clamp(image_data, *cutoff.to(image_data.dtype).tolist(), out=image_data) + standardized = self.znorm(image_data, mask) + if standardized is None: + raise RuntimeError( + f'Standard deviation is 0 for masked values in image "{image_name}" ({image_path})' + ) + return standardized + + +class CropOrPad(tio.CropOrPad): + """Crop or pad a subject with optional random center logic for padding.""" + + def __init__( + self, + target_shape: Union[int, TypeTripletInt, None] = None, + padding_mode: Union[str, float] = 0, + mask_name: Optional[str] = None, + labels: Optional[Sequence[int]] = None, + random_center=False, + **kwargs + ): + super().__init__( + target_shape=target_shape, + padding_mode=padding_mode, + mask_name=mask_name, + labels=labels, + **kwargs + ) + self.random_center = random_center + + def _get_six_bounds_parameters(self, parameters: np.ndarray): + result = [] + for number in parameters: + if self.random_center: + ini = np.random.randint(low=0, high=number + 1) + else: + ini = int(np.ceil(number / 2)) + fin = number - ini + result.extend([ini, fin]) + return tuple(result) + + def apply_transform(self, subject: tio.Subject) -> tio.Subject: + subject.check_consistent_space() + padding_params, cropping_params = self.compute_crop_or_pad(subject) + padding_kwargs = {'padding_mode': self.padding_mode} + + if padding_params is not None: + if self.random_center: + random_padding_params = [] + for i in range(0, len(padding_params), 2): + s = padding_params[i] + padding_params[i + 1] + r = np.random.randint(0, s + 1) + random_padding_params.extend([r, s - r]) + padding_params = random_padding_params + pad = tio.Pad(padding_params, **padding_kwargs) + subject = pad(subject) + + if cropping_params is not None: + crop = tio.Crop(cropping_params) + subject = crop(subject) + + return subject diff --git a/application/jobs/3dcnn_ptl/app/custom/data/datamodules/__init__.py b/application/jobs/ODELIA_ternary_classification/app/custom/data/datamodules/__init__.py similarity index 100% rename from application/jobs/3dcnn_ptl/app/custom/data/datamodules/__init__.py rename to application/jobs/ODELIA_ternary_classification/app/custom/data/datamodules/__init__.py diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/data/datamodules/datamodule.py b/application/jobs/ODELIA_ternary_classification/app/custom/data/datamodules/datamodule.py new file mode 100644 index 00000000..6bbb42ef --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/data/datamodules/datamodule.py @@ -0,0 +1,99 @@ +import pytorch_lightning as pl +import torch +from torch.utils.data.dataloader import DataLoader +import torch.multiprocessing as mp +from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler + + +class DataModule(pl.LightningDataModule): + """Flexible LightningDataModule with weighted or random sampling support.""" + + def __init__( + self, + ds_train: object = None, + ds_val: object = None, + ds_test: object = None, + batch_size: int = 1, + batch_size_val: int = None, + batch_size_test: int = None, + num_train_samples: int = None, + num_workers: int = mp.cpu_count(), + seed: int = 0, + pin_memory: bool = False, + weights: list = None + ): + super().__init__() + self.hyperparameters = {**locals()} + self.hyperparameters.pop('__class__') + self.hyperparameters.pop('self') + + self.ds_train = ds_train + self.ds_val = ds_val + self.ds_test = ds_test + + self.batch_size = batch_size + self.batch_size_val = batch_size if batch_size_val is None else batch_size_val + self.batch_size_test = batch_size if batch_size_test is None else batch_size_test + self.num_train_samples = num_train_samples + self.num_workers = num_workers + self.seed = seed + self.pin_memory = pin_memory + self.weights = weights + + def train_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + + if self.ds_train is not None: + if self.weights is not None: + num_samples = len(self.weights) if self.num_train_samples is None else self.num_train_samples + sampler = WeightedRandomSampler(self.weights, num_samples=num_samples, generator=generator) + else: + num_samples = len(self.ds_train) if self.num_train_samples is None else self.num_train_samples + sampler = RandomSampler(self.ds_train, num_samples=num_samples, replacement=False, generator=generator) + + return DataLoader( + self.ds_train, + batch_size=self.batch_size, + num_workers=self.num_workers, + sampler=sampler, + generator=generator, + drop_last=True, + pin_memory=self.pin_memory + ) + + raise AssertionError("A training set was not initialized.") + + def val_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + + if self.ds_val is not None: + return DataLoader( + self.ds_val, + batch_size=self.batch_size_val, + num_workers=self.num_workers, + shuffle=False, + generator=generator, + drop_last=False, + pin_memory=self.pin_memory + ) + + raise AssertionError("A validation set was not initialized.") + + def test_dataloader(self): + generator = torch.Generator() + generator.manual_seed(self.seed) + + if self.ds_test is not None: + return DataLoader( + self.ds_test, + batch_size=self.batch_size_test, + num_workers=self.num_workers, + shuffle=False, + generator=generator, + drop_last=False, + pin_memory=self.pin_memory + ) + + raise AssertionError("A test test set was not initialized.") diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/__init__.py b/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/__init__.py new file mode 100644 index 00000000..a2c7d909 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/__init__.py @@ -0,0 +1,11 @@ +""" +This package initializes the necessary modules and classes for the project. +""" + +# from .dataset_3d import SimpleDataset3D +# from .dataset_3d_collab import DUKE_Dataset3D_collab +# from .dataset_3d_duke import DUKE_Dataset3D +# from .dataset_3d_duke_external import DUKE_Dataset3D_external +from .dataset_3d_odelia import ODELIA_Dataset3D + +# __all__ = [name for name in dir() if not name.startswith('_')] diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/dataset_3d_odelia.py b/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/dataset_3d_odelia.py new file mode 100644 index 00000000..eba4aa12 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/data/datasets/dataset_3d_odelia.py @@ -0,0 +1,151 @@ +from pathlib import Path +import pandas as pd +import torch.utils.data as data +import torchio as tio +import torch +import numpy as np +from sklearn.preprocessing import OneHotEncoder + +from data.augmentation.augmentations_3d import ImageOrSubjectToTensor, ZNormalization, CropOrPad + + +class ODELIA_Dataset3D(data.Dataset): + PATH_ROOT = Path('/data') + ALL_INSTITUTIONS = ['CAM', 'MHA', 'RSH', 'UKA', 'UMCU', 'VHIO', 'RUMC', 'USZ'] + DATA_DIR = { + "original": "data", + "unilateral": "data_unilateral" + } + META_DIR = { + "original": "metadata", + "unilateral": "metadata_unilateral" + } + CLASS_LABELS = { + 'original': { + 'Lesion_Left': ['No', 'Benign', 'Malignant'], + 'Lesion_Right': ['No', 'Benign', 'Malignant'], + }, + 'unilateral': { + 'Lesion': ['No', 'Benign', 'Malignant'], + } + } + + def __init__( + self, + path_root=None, + institutions=None, + fold=0, + labels=None, # None = all labels or list of labels + config=None, # original, unilateral + split=None, + fraction=None, + transform=None, + random_flip=False, + random_rotate=False, + random_inverse=False, + noise=False, + to_tensor=True, + + ): + self.path_root = Path(self.PATH_ROOT if path_root is None else path_root) + self.split = split + self.config = config + self.class_labels = self.CLASS_LABELS[config] + self.meta_dir = self.META_DIR[config] + self.data_dir = self.DATA_DIR[config] + self.labels = list(self.class_labels.keys()) if labels is None else labels + self.class_labels_num = [len(self.class_labels[l]) for l in self.labels] # For CORN Loss -1 + + if (institutions is None) or (institutions == "ODELIA"): + institutions = self.ALL_INSTITUTIONS + elif isinstance(institutions, str): + institutions = [institutions] + self.institutions = institutions + + flip_axes = (0, 1) if config == "original" else (0, 1, 2) # Do not flip horizontal axis 2, otherwise labels incorrect + if transform is None: + self.transform = tio.Compose([ + tio.ToCanonical() if config == "original" else tio.Lambda(lambda x: x), + tio.Resample((0.7, 0.7, 3)) if config == "original" else tio.Lambda(lambda x: x), + + tio.Flip((1, 0)), # Just for viewing, otherwise upside down + CropOrPad((448, 448, 32), random_center=random_rotate) if config == "original" else CropOrPad( + (224, 224, 32), random_center=random_rotate), + + ZNormalization(per_channel=True, per_slice=False, + masking_method=lambda x: (x > x.min()) & (x < x.max()), percentiles=(0.5, 99.5)), + + tio.OneOf([ + # tio.Lambda(lambda x: x.moveaxis(1, 2) if torch.rand((1,),)[0]<0.5 else x ) if random_rotate else tio.Lambda(lambda x: x), # WARNING: 1,2 if Subject, 2, 3 if tensor + tio.RandomAffine(scales=0, degrees=(0, 0, 0, 0, 0, 90), translation=0, isotropic=True, + default_pad_value='minimum') if random_rotate else tio.Lambda(lambda x: x), + tio.RandomFlip(flip_axes) if random_flip else tio.Lambda(lambda x: x), # WARNING: Padding mask + ]), + tio.Lambda(lambda x: -x if torch.rand((1,), )[0] < 0.5 else x, + types_to_apply=[tio.INTENSITY]) if random_inverse else tio.Lambda(lambda x: x), + tio.RandomNoise(std=(0.0, 0.25)) if noise else tio.Lambda(lambda x: x), + + ImageOrSubjectToTensor() if to_tensor else tio.Lambda(lambda x: x) + ]) + else: + self.transform = transform + + # Get split + dfs = [] + for institution in self.institutions: + path_metadata = self.path_root / institution / self.meta_dir + df = self.load_split(path_metadata / 'split.csv', fold=fold, split=split, fraction=fraction) + df['Institution'] = institution + + # Verify files exist + # uids = self.run_item_crawler(self.path_root/institution/'data_unilateral') + # df = df[df['UID'].isin(uids)] + + # Merge with annotations + df_anno = pd.read_csv(path_metadata / 'annotation.csv', dtype={'UID': str, 'PatientID': str}) + df = df.merge(df_anno, on='UID', how='inner') + + dfs.append(df) + df = pd.concat(dfs).reset_index(drop=True) + + self.item_pointers = df.index.tolist() + self.df = df + + def __len__(self): + return len(self.item_pointers) + + def load_img(self, path_img): + return tio.ScalarImage(path_img) + + def load_map(self, path_img): + return tio.LabelMap(path_img) + + def __getitem__(self, index): + idx = self.item_pointers[index] + item = self.df.loc[idx] + uid = item['UID'] + institution = item['Institution'] + + target = np.stack(item[self.labels].values) + + path_folder = self.path_root / institution / self.data_dir / uid + # img = self.load_img([path_folder/f'{name}.nii.gz' for name in [ 'Pre', 'Sub_1', 'T2']]) + img = self.load_img(path_folder / 'Sub_1.nii.gz') + img = self.transform(img) + + return {'uid': uid, 'source': img, 'target': target} + + @classmethod + def load_split(cls, filepath_or_buffer=None, fold=0, split=None, fraction=None): + # WARNING: PatientID must be read as string otherwise leading zeros are cut off + df = pd.read_csv(filepath_or_buffer, dtype={'UID': str}) + df = df[df['Fold'] == fold] + if split is not None: + df = df[df['Split'] == split] + if fraction is not None: + df = df.sample(frac=fraction, random_state=0).reset_index() + return df + + @classmethod + def run_item_crawler(cls, path_root, **kwargs): + return [path.relative_to(path_root).name for path in Path(path_root).iterdir() if path.is_dir()] diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/env_config.py b/application/jobs/ODELIA_ternary_classification/app/custom/env_config.py new file mode 100755 index 00000000..93efb091 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/env_config.py @@ -0,0 +1,75 @@ +import os +from datetime import datetime +from pathlib import Path + + +def load_environment_variables(): + return { + 'site_name': os.environ['SITE_NAME'], + 'task_data_name': os.environ.get('DATA_FOLDER', 'Odelia'), + 'scratch_dir': os.environ['SCRATCH_DIR'], + 'data_dir': os.environ['DATA_DIR'], + 'max_epochs': int(os.environ.get('MAX_EPOCHS', 100)), + 'min_peers': int(os.environ.get('MIN_PEERS', 2)), + 'max_peers': int(os.environ.get('MAX_PEERS', 10)), + 'local_compare_flag': os.environ.get('LOCAL_COMPARE_FLAG', 'False').lower() == 'true', + 'use_adaptive_sync': os.environ.get('USE_ADAPTIVE_SYNC', 'False').lower() == 'true', + 'sync_frequency': int(os.environ.get('SYNC_FREQUENCY', 1024)), + 'model_name': os.environ.get('MODEL_NAME', 'ResNet101'), + 'prediction_flag': os.environ.get('PREDICT_FLAG', 'ext'), + 'mediswarm_version': os.environ.get('MEDISWARM_VERSION', 'unset'), + } + + +def load_prediction_modules(prediction_flag): + from predict import predict + return predict, prediction_flag + + +def prepare_odelia_dataset(): + # parser removed, now read from environment + institution = os.environ.get('INSTITUTION', os.environ['SITE_NAME']) # TODO think about how this should be handled + model = os.environ.get('MODEL_NAME', 'MST') + config = os.environ.get('CONFIG', 'unilateral') + + current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") + run_name = f'{model}_{config}_{current_time}' + path_run_dir = Path.cwd() / 'runs' / institution / run_name + path_run_dir.mkdir(parents=True, exist_ok=True) + + from data.datasets import ODELIA_Dataset3D + ds_train = ODELIA_Dataset3D(institutions=institution, split='train', config=config, + random_flip=True, random_rotate=True, random_inverse=False, noise=True) + ds_val = ODELIA_Dataset3D(institutions=institution, split='val', config=config) + + print(f"Total samples loaded: {len(ds_train)} (train) + {len(ds_val)} (val)") + print(f"Train set: {len(ds_train)}, Val set: {len(ds_val)}") + # print(f"Labels in val: {[sample['label'] for sample in ds_val]}") + + return ds_train, ds_val, path_run_dir, run_name + + +def generate_run_directory(scratch_dir, task_data_name, model_name, local_compare_flag): + current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") + mode = 'local_compare' if local_compare_flag else 'swarm_learning' + if not os.path.exists(scratch_dir): + os.makedirs(scratch_dir) + return os.path.join(scratch_dir, f"{current_time}_{task_data_name}_{model_name}_{mode}") + + +# TODO: Implement dynamic weightage calculation based on actual dataset size +def cal_weightage(train_size): + """ + Placeholder function for calculating training weightage. + Currently unused. + """ + pass # To be implemented + + +# TODO: Implement max epochs adjustment logic based on weightage +def cal_max_epochs(preset_max_epochs, weightage): + """ + Placeholder function for dynamically adjusting max epochs. + Currently unused. + """ + pass # To be implemented diff --git a/application/jobs/3dcnn_ptl/app/custom/main.py b/application/jobs/ODELIA_ternary_classification/app/custom/main.py similarity index 71% rename from application/jobs/3dcnn_ptl/app/custom/main.py rename to application/jobs/ODELIA_ternary_classification/app/custom/main.py index 747cc353..b86d6665 100755 --- a/application/jobs/3dcnn_ptl/app/custom/main.py +++ b/application/jobs/ODELIA_ternary_classification/app/custom/main.py @@ -1,28 +1,35 @@ #!/usr/bin/env python3 import os +import torch import nvflare.client.lightning as flare import nvflare.client as flare_util -import torch import threedcnn_ptl TRAINING_MODE = os.getenv("TRAINING_MODE") TM_PREFLIGHT_CHECK = "preflight_check" -TM_LOCAL_TRAINING="local_training" +TM_LOCAL_TRAINING = "local_training" TM_SWARM = "swarm" +if not TRAINING_MODE: + raise ValueError("TRAINING_MODE environment variable must be set") if TRAINING_MODE == TM_SWARM: flare_util.init() - SITE_NAME=flare.get_site_name() + SITE_NAME = flare.get_site_name() NUM_EPOCHS = threedcnn_ptl.get_num_epochs_per_round(SITE_NAME) elif TRAINING_MODE in [TM_PREFLIGHT_CHECK, TM_LOCAL_TRAINING]: - SITE_NAME=os.getenv("SITE_NAME") - NUM_EPOCHS = int(os.getenv("NUM_EPOCHS")) + SITE_NAME = os.getenv("SITE_NAME") + if not SITE_NAME: + raise ValueError("SITE_NAME environment variable must be set for local training") + try: + NUM_EPOCHS = int(os.getenv("NUM_EPOCHS", "1")) + except ValueError: + raise ValueError("NUM_EPOCHS must be an integer") else: - raise Exception(f"Illegal TRAINING_MODE {TRAINING_MODE}") + raise ValueError(f"Unsupported TRAINING_MODE: {TRAINING_MODE}") def main(): @@ -30,8 +37,11 @@ def main(): Main function for training and evaluating the model using NVFlare and PyTorch Lightning. """ logger = threedcnn_ptl.set_up_logging() + try: - data_module, model, checkpointing, trainer, path_run_dir, env_vars = threedcnn_ptl.prepare_training(logger, NUM_EPOCHS, SITE_NAME) + data_module, model, checkpointing, trainer, path_run_dir, env_vars = threedcnn_ptl.prepare_training( + logger, NUM_EPOCHS, SITE_NAME + ) if TRAINING_MODE == TM_SWARM: flare.patch(trainer) # Patch trainer to enable swarm learning @@ -55,5 +65,6 @@ def main(): logger.error(f"Error in main function: {e}") raise + if __name__ == "__main__": main() diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/models/__init__.py b/application/jobs/ODELIA_ternary_classification/app/custom/models/__init__.py new file mode 100644 index 00000000..be8e0a1b --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/models/__init__.py @@ -0,0 +1,9 @@ +""" +This package initializes the necessary modules and classes for the project. +""" + +from .base_model import VeryBasicModel, BasicModel, BasicClassifier +from .resnet import ResNet +from .mst import MST + +__all__ = ['VeryBasicModel', 'BasicModel', 'BasicClassifier', 'ResNet', 'MST'] diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/models/base_model.py b/application/jobs/ODELIA_ternary_classification/app/custom/models/base_model.py new file mode 100644 index 00000000..a01b5b03 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/models/base_model.py @@ -0,0 +1,179 @@ +from pathlib import Path +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from torchmetrics import AUROC, Accuracy + + +class VeryBasicModel(pl.LightningModule): + """Base LightningModule with training, validation, and test hooks stubbed out.""" + + def __init__(self, save_hyperparameters=True): + super().__init__() + if save_hyperparameters: + self.save_hyperparameters() + self._step_train = -1 + self._step_val = -1 + self._step_test = -1 + + def forward(self, x, cond=None): + raise NotImplementedError + + def _step(self, batch: dict, batch_idx: int, state: str, step: int): + raise NotImplementedError + + def _epoch_end(self, state: str): + return + + def training_step(self, batch: dict, batch_idx: int): + self._step_train += 1 + return self._step(batch, batch_idx, "train", self._step_train) + + def validation_step(self, batch: dict, batch_idx: int): + self._step_val += 1 + return self._step(batch, batch_idx, "val", self._step_val) + + def test_step(self, batch: dict, batch_idx: int): + self._step_test += 1 + return self._step(batch, batch_idx, "test", self._step_test) + + def on_train_epoch_end(self) -> None: + self._epoch_end("train") + + def on_validation_epoch_end(self) -> None: + self._epoch_end("val") + + def on_test_epoch_end(self) -> None: + self._epoch_end("test") + + @classmethod + def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): + with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: + json.dump({'best_model_epoch': Path(best_model_path).name}, f) + + @classmethod + def _get_best_checkpoint_path(cls, path_checkpoint_dir, **kwargs): + with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'r') as f: + path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) + return Path(path_checkpoint_dir) / path_rel_best_checkpoint + + @classmethod + def load_best_checkpoint(cls, path_checkpoint_dir, **kwargs): + path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir) + return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) + + def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): + if checkpoint_path.is_dir(): + checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) + + checkpoint = torch.load(checkpoint_path, map_location=map_location) + return self.load_weights(checkpoint["state_dict"], **kwargs) + + def load_weights(self, pretrained_weights, strict=True, **kwargs): + filter = kwargs.get('filter', lambda key: key in pretrained_weights) + init_weights = self.state_dict() + pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter(key)} + init_weights.update(pretrained_weights) + self.load_state_dict(init_weights, strict=strict) + return self + + +class BasicModel(VeryBasicModel): + """Extension of VeryBasicModel that includes optimizer and scheduler configuration.""" + + def __init__( + self, + optimizer=torch.optim.Adam, + optimizer_kwargs={'lr': 1e-3, 'weight_decay': 1e-2}, + lr_scheduler=None, + lr_scheduler_kwargs={}, + save_hyperparameters=True + ): + super().__init__(save_hyperparameters=save_hyperparameters) + if save_hyperparameters: + self.save_hyperparameters() + self.optimizer = optimizer + self.optimizer_kwargs = optimizer_kwargs + self.lr_scheduler = lr_scheduler + self.lr_scheduler_kwargs = lr_scheduler_kwargs + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) + if self.lr_scheduler is not None: + lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) + return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch", "frequency": 1}] + return [optimizer] + + +class BasicClassifier(BasicModel): + """Generic classifier with dynamic metric and loss configuration based on task type.""" + + def __init__( + self, + in_ch, + out_ch, + spatial_dims, + loss_kwargs={}, + optimizer=torch.optim.AdamW, + optimizer_kwargs={'lr': 1e-4, 'weight_decay': 1e-2}, + lr_scheduler=None, + lr_scheduler_kwargs={}, + aucroc_kwargs={}, + acc_kwargs={}, + save_hyperparameters=True + ): + super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs) + self.in_ch = in_ch + self.out_ch = out_ch + self.spatial_dims = spatial_dims + + loss = torch.nn.CrossEntropyLoss + + self.loss = loss(**loss_kwargs) + self.loss_kwargs = loss_kwargs + + aucroc_kwargs.update({"task": "multiclass", 'num_classes': out_ch}) + acc_kwargs.update({"task": "multiclass", 'num_classes': out_ch}) + + self.auc_roc = nn.ModuleDict({state: AUROC(**aucroc_kwargs) for state in ["train_", "val_", "test_"]}) + self.acc = nn.ModuleDict({state: Accuracy(**acc_kwargs) for state in ["train_", "val_", "test_"]}) + + def _step(self, batch: dict, batch_idx: int, state: str, step: int): + source = batch['source'] + target = batch['target'] + batch_size = source.shape[0] + self.batch_size = batch_size + + pred = self(source) + loss_val = self.compute_loss(pred, target) + target_squeezed = torch.squeeze(target, 1) # TODO Why is this necessary and is it the right thing to do? + self.acc[state + "_"].update(pred, target_squeezed) + self.auc_roc[state + "_"].update(pred, target_squeezed) + + self.log(f"{state}/loss", loss_val, batch_size=batch_size, on_step=True, on_epoch=True) + return loss_val + + def _epoch_end(self, state): + acc_value = self.acc[state + "_"].compute() + auc_roc_value = self.auc_roc[state + "_"].compute() + self.log(f"{state}/ACC", acc_value, batch_size=self.batch_size, on_step=False, on_epoch=True) + self.log(f"{state}/AUC_ROC", auc_roc_value, batch_size=self.batch_size, on_step=False, on_epoch=True) + # For ModelCheckpoint, also log as "val/AUC_ROC" if state == "val" + if state == "val": + self.log("val/AUC_ROC", auc_roc_value, batch_size=self.batch_size, on_step=False, on_epoch=True) + # print some debug information + print(f"Epoch {self.current_epoch} - {state} ACC: {acc_value:.4f}, AUC_ROC: {auc_roc_value:.4f}") + self.acc[state + "_"].reset() + self.auc_roc[state + "_"].reset() + + def compute_loss(self, pred, target): + target_squeezed = torch.squeeze(target, 1) # TODO Why is this necessary and is it the right thing to do? + return self.loss(pred, target_squeezed) + + def logits2labels(self, logits): + return torch.argmax(logits, dim=1, keepdim=True) + + def logits2probabilities(self, logits): + return F.softmax(logits, dim=1) diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/models/mst.py b/application/jobs/ODELIA_ternary_classification/app/custom/models/mst.py new file mode 100644 index 00000000..540441ae --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/models/mst.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from einops import rearrange +from x_transformers import Encoder + +from .base_model import BasicClassifier + + +class TransformerEncoder(Encoder): + """Override the default forward to match input formatting.""" + + def forward(self, x, mask=None, src_key_padding_mask=None): + src_key_padding_mask = ~src_key_padding_mask if src_key_padding_mask is not None else None + mask = ~mask if mask is not None else None + return super().forward(x=x, context=None, mask=src_key_padding_mask, context_mask=None, attn_mask=mask) + + +class _MST(nn.Module): + """Multi-slice transformer for 3D volume input classification or regression.""" + + def __init__( + self, + out_ch=1, + backbone_type="dinov2", + model_size=None, + slice_fusion_type="transformer" + ): + super().__init__() + self.backbone_type = backbone_type + self.slice_fusion_type = slice_fusion_type + + if backbone_type == "dinov2": + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True + self.backbone = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_size}14') + self.backbone.mask_token = None + emb_ch = self.backbone.num_features + else: + raise ValueError("Unknown backbone_type") + + self.emb_ch = emb_ch + + if slice_fusion_type == "transformer": + self.slice_fusion = TransformerEncoder( + dim=emb_ch, + heads=12 if emb_ch % 12 == 0 else 8, + ff_mult=1, + attn_dropout=0.0, + pre_norm=True, + depth=1, + attn_flash=True, + ff_no_bias=True, + rotary_pos_emb=True, + ) + self.cls_token = nn.Parameter(torch.randn(1, 1, emb_ch)) + elif slice_fusion_type in ["average", "none"]: + self.slice_fusion = None + else: + raise ValueError("Unknown slice_fusion_type") + + self.linear = nn.Linear(emb_ch, out_ch) + + def forward(self, x): + B, *_ = x.shape + x = rearrange(x, 'b c d h w -> (b c d) h w') + x = x[:, None].repeat(1, 3, 1, 1) # Gray to RGB + + x = self.backbone(x) # (B * D, E) + x = rearrange(x, '(b d) e -> b d e', b=B) + + if self.slice_fusion_type == 'none': + return x + elif self.slice_fusion_type == 'transformer': + x = torch.cat([x, self.cls_token.repeat(B, 1, 1)], dim=1) + x = self.slice_fusion(x) + elif self.slice_fusion_type == 'average': + x = x.mean(dim=1, keepdim=True) + + x = self.linear(x[:, -1]) + return x + + +class MST(BasicClassifier): + """MST-based classifier using ViT or ResNet as backbone.""" + + def __init__( + self, + n_input_channels: int, + num_classes: int, + spatial_dims: int, + backbone_type="dinov2", + model_size="s", + slice_fusion_type="transformer", + optimizer_kwargs={'lr': 1e-6}, + **kwargs + ): + super().__init__(n_input_channels, num_classes, spatial_dims, optimizer_kwargs=optimizer_kwargs, **kwargs) + self.mst = _MST(out_ch=num_classes, backbone_type=backbone_type, model_size=model_size, + slice_fusion_type=slice_fusion_type) + + def forward(self, x): + return self.mst(x) diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/models/resnet.py b/application/jobs/ODELIA_ternary_classification/app/custom/models/resnet.py new file mode 100644 index 00000000..49503b26 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/models/resnet.py @@ -0,0 +1,79 @@ +from models import BasicClassifier +import monai.networks.nets as nets +import torch.nn as nn +from einops import rearrange + + +class _ResNet(nn.Module): + """Wrapper for MONAI ResNet models supporting 3D/2D input.""" + + def __init__(self, n_input_channels: int, num_classes: int, spatial_dims: int, resnet_variant: int): + super().__init__() + Model = { + 10: nets.resnet10, + 18: nets.resnet18, + 34: nets.resnet34, + 50: nets.resnet50, + 101: nets.resnet101, + 152: nets.resnet152 + }.get(resnet_variant) + if Model is None: + raise ValueError(f"Unsupported ResNet model number: {resnet_variant}") + + shortcut_type = { + 10: 'B', + 18: 'A', + 34: 'A', + 50: 'B', + 101: 'B', + 152: 'B', + }.get(resnet_variant) + + bias_downsample = { + 10: False, + 18: True, + 34: True, + 50: False, + 101: False, + 152: False, + }.get(resnet_variant) + + num_channels = { + 10: 512, + 18: 512, + 34: 512, + 50: 2048, + 101: 2048, + 152: 2048, + }.get(resnet_variant) + + self.model = Model(n_input_channels=n_input_channels, spatial_dims=spatial_dims, num_classes=num_classes, + feed_forward=False, shortcut_type=shortcut_type, bias_downsample=bias_downsample, pretrained=True) + self.model.fc = nn.Linear(num_channels, + num_classes) + + def forward(self, x): + return self.model(x) + + +class ResNet(BasicClassifier): + """ResNet-based classifier using MONAI backbones.""" + + def __init__(self, n_input_channels: int, num_classes: int, spatial_dims: int, resnet_variant: int, **kwargs): + super().__init__(n_input_channels, num_classes, spatial_dims, **kwargs) + self.model = _ResNet(n_input_channels, num_classes, spatial_dims, resnet_variant) + + def forward(self, x): + return self.model(x) + + +''' +class ResNetRegression(BasicRegression): + """ResNet-based regression model using MONAI backbones.""" + def __init__(self, n_input_channels: int, num_classes: int , spatial_dims: int, resnet_variant: str, **kwargs): + super().__init__(n_input_channels, num_classes, spatial_dims, **kwargs) + self.model = _ResNet(n_input_channels, num_classes, resnet_variant) + + def forward(self, x): + return self.model(x) +''' diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/threedcnn_ptl.py b/application/jobs/ODELIA_ternary_classification/app/custom/threedcnn_ptl.py new file mode 100644 index 00000000..ad291652 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/threedcnn_ptl.py @@ -0,0 +1,156 @@ +from sklearn.model_selection import train_test_split +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from data.datamodules import DataModule +from models import ResNet, MST +from env_config import load_environment_variables, prepare_odelia_dataset, generate_run_directory +import torch.multiprocessing as mp + +import logging + + +def get_num_epochs_per_round(site_name: str) -> int: + NUM_EPOCHS_FOR_SITE = { + "TUD_1": 2, "TUD_2": 4, "TUD_3": 8, + "MEVIS_1": 2, "MEVIS_2": 4, + } + max_epochs = NUM_EPOCHS_FOR_SITE.get(site_name, 5) + print(f"Site name: {site_name}") + print(f"Max epochs set to: {max_epochs}") + return max_epochs + + +def set_up_logging(): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def set_up_data_module(logger): + torch.set_float32_matmul_precision('high') + ds_train, ds_val, path_run_dir, run_name = prepare_odelia_dataset() + num_classes = sum(ds_train.class_labels_num) + logger.info(f"Dataset path: {ds_train}") + logger.info(f"Run directory: {path_run_dir}") + logger.info(f"Run name: {run_name}") + # logger.info(f"Number of classes: {num_classes}") # number of possible classes, not number of classes present, thus misleading + logger.info(f"Length of train dataset: {len(ds_train)}") + logger.info(f"Length of val dataset: {len(ds_val)}") + + dm = DataModule( + ds_train=ds_train, + ds_val=ds_val, + ds_test=ds_val, + batch_size=1, + pin_memory=True, + weights=None, + num_workers=mp.cpu_count(), + ) + + # # Log label distribution + # distribution = dm.get_train_label_distribution(lambda sample: sample['label']) + # logger.info(f"Total samples in training set: {distribution['total']}") + # for label, pct in distribution['percentages'].items(): + # logger.info(f"Label '{label}': {pct:.2f}% of training set, Count: {distribution['counts'][label]}") + # logger.info(f"Number of unique labels: {len(distribution['counts'])}") + + loss_kwargs = {} + + return dm, path_run_dir, run_name, num_classes, loss_kwargs + + +def create_run_directory(env_vars): + return generate_run_directory( + env_vars['scratch_dir'], + env_vars['task_data_name'], + env_vars['model_name'], + env_vars['local_compare_flag'] + ) + + +def prepare_training(logger, max_epochs: int, site_name: str): + try: + env_vars = load_environment_variables() + data_module, path_run_dir, run_name, num_classes, loss_kwargs = set_up_data_module(logger) + + if not torch.cuda.is_available(): + raise RuntimeError("This example requires a GPU") + + logger.info(f"Running code version {env_vars['mediswarm_version']}") + logger.info(f"Using GPU for training") + + model_name = env_vars['model_name'] + + model = None + if model_name in ['ResNet10', 'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152']: + resnet_variant = int(model_name[6:]) + model = ResNet(n_input_channels=1, + num_classes=num_classes, + spatial_dims=3, + resnet_variant=resnet_variant, + loss_kwargs=loss_kwargs) + elif model_name == 'MST': + model = MST(n_input_channels=1, + num_classes=num_classes, + spatial_dims=3, + loss_kwargs=loss_kwargs) + + logger.info(f"Using model: {model_name}") + + to_monitor = "val/ACC" + min_max = "max" + log_every_n_steps = 50 + + ''' + early_stopping = EarlyStopping( + monitor=to_monitor, + min_delta=0.0, + patience=25, + mode=min_max + ) + ''' + checkpointing = ModelCheckpoint( + dirpath=str(path_run_dir), + monitor=to_monitor, + save_last=True, + save_top_k=1, + mode=min_max, + ) + + trainer = Trainer( + accelerator='gpu', + accumulate_grad_batches=1, + precision='16-mixed', + default_root_dir=str(path_run_dir), + callbacks=[checkpointing], + enable_checkpointing=True, + check_val_every_n_epoch=1, + log_every_n_steps=log_every_n_steps, + max_epochs=max_epochs, + num_sanity_val_steps=2, + logger=TensorBoardLogger(save_dir=path_run_dir) + ) + + except Exception as e: + logger.error(f"Error in prepare_training: {e}") + raise + + return data_module, model, checkpointing, trainer, path_run_dir, env_vars + + +def validate_and_train(logger, data_module, model, trainer) -> None: + logger.info("--- Validate global model ---") + trainer.validate(model, datamodule=data_module) + + logger.info("--- Train new model ---") + trainer.fit(model, datamodule=data_module) + + +def finalize_training(logger, model, checkpointing, trainer, path_run_dir, env_vars) -> None: + model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) + + logger.info('Prediction currently not implemented.') + + logger.info('Training completed successfully.') diff --git a/application/jobs/ODELIA_ternary_classification/app/custom/utils/roc_curve.py b/application/jobs/ODELIA_ternary_classification/app/custom/utils/roc_curve.py new file mode 100644 index 00000000..0ac2f4f2 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/custom/utils/roc_curve.py @@ -0,0 +1,238 @@ +import numpy as np +import matplotlib +from sklearn.metrics import roc_curve, auc, confusion_matrix + + +def auc_bootstrapping(y_true, y_score, bootstrapping=1000, drop_intermediate=False): + """Perform bootstrapping to compute variability of ROC curve and AUC. + + Args: + y_true (np.ndarray): True binary labels. + y_score (np.ndarray): Predicted scores or probabilities. + bootstrapping (int): Number of bootstrap samples. + drop_intermediate (bool): Whether to drop some thresholds for faster computation. + + Returns: + Tuple[list, list, list, np.ndarray]: + - List of interpolated TPRs, + - List of AUCs, + - List of optimal thresholds, + - Mean FPR values used for interpolation. + """ + tprs, aucs, thrs = [], [], [] + mean_fpr = np.linspace(0, 1, 100) + rng = np.random.default_rng(seed) + + # Generate bootstrap samples with replacement + rand_idxs = rng.integers(0, len(y_true), size=(bootstrapping, len(y_true))) + + for rand_idx in rand_idxs: + y_true_set = y_true[rand_idx] + y_score_set = y_score[rand_idx] + + # Compute ROC for the sample + fpr, tpr, thresholds = roc_curve(y_true_set, y_score_set, drop_intermediate=drop_intermediate) + + # Interpolate TPRs to a common FPR scale + tpr_interp = np.interp(mean_fpr, fpr, tpr) + tprs.append(tpr_interp) + aucs.append(auc(fpr, tpr)) + + # Identify optimal threshold (Youden's J statistic) + optimal_idx = np.argmax(tpr - fpr) + thrs.append(thresholds[optimal_idx]) + + return tprs, aucs, thrs, mean_fpr + + +def plot_roc_curve(y_true, y_score, axis, bootstrapping=1000, drop_intermediate=False, fontdict={}, + name='ROC', color='b', show_wp=True): + """Plot ROC curve with bootstrapped AUC and shaded confidence interval. + + Args: + y_true (np.ndarray): True binary labels. + y_score (np.ndarray): Predicted probabilities or scores. + axis (matplotlib.axes.Axes): Axis to plot on. + bootstrapping (int): Number of bootstrap samples. + drop_intermediate (bool): Drop thresholds for faster computation. + fontdict (dict): Font styling dictionary. + name (str): Curve label. + color (str): Line color. + show_wp (bool): Show working point (optimal threshold marker). + + Returns: + Tuple[np.ndarray, np.ndarray, float, np.ndarray, int]: + FPR, TPR, AUC, thresholds, and index of optimal threshold. + """ + # Bootstrapping + tprs, aucs, thrs, mean_fpr = auc_bootstrapping(y_true, y_score, bootstrapping, drop_intermediate) + + mean_tpr = np.nanmean(tprs, axis=0) + mean_tpr[-1] = 1.0 # Ensure proper endpoint + std_tpr = np.nanstd(tprs, axis=0, ddof=1) + tprs_upper = np.minimum(mean_tpr + std_tpr, 1) + tprs_lower = np.maximum(mean_tpr - std_tpr, 0) + + mean_auc = np.nanmean(aucs) + std_auc = np.nanstd(aucs, ddof=1) + + # Compute actual ROC + fprs, tprs_, thrs_ = roc_curve(y_true, y_score, drop_intermediate=drop_intermediate) + auc_val = auc(fprs, tprs_) + opt_idx = np.argmax(tprs_ - fprs) + opt_tpr = tprs_[opt_idx] + opt_fpr = fprs[opt_idx] + + # Plot ROC + axis.plot(fprs, tprs_, color=color, label=rf"{name} (AUC={auc_val:.2f}$\pm${std_auc:.2f})", lw=2, alpha=.8) + axis.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, label=r'$\pm$ 1 std. dev.') + + if show_wp: + axis.hlines(y=opt_tpr, xmin=0.0, xmax=opt_fpr, color=color, linestyle='--') + axis.vlines(x=opt_fpr, ymin=0.0, ymax=opt_tpr, color=color, linestyle='--') + axis.plot(opt_fpr, opt_tpr, color=color, marker='o') + + axis.plot([0, 1], [0, 1], linestyle='--', color='k') + axis.set_xlim([0.0, 1.0]) + axis.set_ylim([0.0, 1.0]) + + axis.legend(loc='lower right') + axis.set_xlabel('1 - Specificity', fontdict=fontdict) + axis.set_ylabel('Sensitivity', fontdict=fontdict) + + # Aesthetic tweaks + axis.grid(color='#dddddd') + axis.set_axisbelow(True) + axis.tick_params(colors='#dddddd', which='both') + for xtick in axis.get_xticklabels(): + xtick.set_color('k') + for ytick in axis.get_yticklabels(): + ytick.set_color('k') + for child in axis.get_children(): + if isinstance(child, matplotlib.spines.Spine): + child.set_color('#dddddd') + + return fprs, tprs_, auc_val, thrs_, opt_idx + + +def cm2acc(cm): + """Calculate accuracy from a 2x2 confusion matrix.""" + tn, fp, fn, tp = cm.ravel() + return (tn + tp) / (tn + tp + fn + fp) + + +def safe_div(x, y): + """Safely divide x by y, return NaN if y is zero.""" + return float('nan') if y == 0 else x / y + + +def specificity_at_fixed_sensitivity(y_true, y_scores, tpr, thresholds, sensitivity_target=0.90): + """Calculate specificity at a given sensitivity level. + + Args: + y_true (np.ndarray): Ground truth labels. + y_scores (np.ndarray): Predicted scores. + tpr (np.ndarray): True positive rates from ROC. + thresholds (np.ndarray): Thresholds from ROC. + sensitivity_target (float): Desired sensitivity level. + + Returns: + float: Specificity at the closest sensitivity. + """ + idx = np.argmin(np.abs(tpr - sensitivity_target)) + chosen_threshold = thresholds[idx] + y_pred = (y_scores >= chosen_threshold).astype(int) + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return tn / (tn + fp) + + +def sensitivity_at_fixed_specificity(y_true, y_scores, fpr, thresholds, specificity_target=0.90): + """Calculate sensitivity at a given specificity level. + + Args: + y_true (np.ndarray): Ground truth labels. + y_scores (np.ndarray): Predicted scores. + fpr (np.ndarray): False positive rates from ROC. + thresholds (np.ndarray): Thresholds from ROC. + specificity_target (float): Desired specificity level. + + Returns: + float: Sensitivity at the closest specificity. + """ + specificity = 1 - fpr + idx = np.argmin(np.abs(specificity - specificity_target)) + chosen_threshold = thresholds[idx] + y_pred = (y_scores >= chosen_threshold).astype(int) + tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() + return tp / (tp + fn) + + +def cm2x(cm, average='macro', pos_label=1): + """Compute PPV, NPV, Sensitivity (TPR), and Specificity (TNR) from confusion matrix. + + Args: + cm (np.ndarray): Confusion matrix. + average (str): 'binary', 'micro', 'macro', or 'weighted'. + pos_label (int): Class considered positive in binary mode. + + Returns: + dict: Dictionary with PPV, NPV, TPR, and TNR. + """ + num_classes = cm.shape[0] + metrics_per_class = {} + + if average == 'micro': + TP = np.sum([cm[i, i] for i in range(num_classes)]) + FP = np.sum([cm[:, i].sum() - cm[i, i] for i in range(num_classes)]) + FN = np.sum([cm[i, :].sum() - cm[i, i] for i in range(num_classes)]) + TN = np.sum([cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i]) for i in range(num_classes)]) + + return { + "PPV": safe_div(TP, TP + FP), + "NPV": safe_div(TN, TN + FN), + "TPR": safe_div(TP, TP + FN), + "TNR": safe_div(TN, TN + FP), + } + + for i in range(num_classes): + TP = cm[i, i] + FP = cm[:, i].sum() - TP + FN = cm[i, :].sum() - TP + TN = cm.sum() - (TP + FP + FN) + + metrics_per_class[i] = { + "PPV": safe_div(TP, TP + FP), + "NPV": safe_div(TN, TN + FN), + "TPR": safe_div(TP, TP + FN), + "TNR": safe_div(TN, TN + FP), + } + + if average == 'binary': + if pos_label not in metrics_per_class: + raise ValueError(f"pos_label={pos_label} not in class labels: {list(metrics_per_class.keys())}") + return metrics_per_class[pos_label] + + ppv_vals = [metrics_per_class[i]["PPV"] for i in range(num_classes)] + npv_vals = [metrics_per_class[i]["NPV"] for i in range(num_classes)] + tpr_vals = [metrics_per_class[i]["TPR"] for i in range(num_classes)] + tnr_vals = [metrics_per_class[i]["TNR"] for i in range(num_classes)] + + if average == 'macro': + return { + "PPV": np.mean(ppv_vals), + "NPV": np.mean(npv_vals), + "TPR": np.mean(tpr_vals), + "TNR": np.mean(tnr_vals), + } + + if average == 'weighted': + support = cm.sum(axis=1) + weights = support / support.sum() + return { + "PPV": np.sum(weights * np.array(ppv_vals)), + "NPV": np.sum(weights * np.array(npv_vals)), + "TPR": np.sum(weights * np.array(tpr_vals)), + "TNR": np.sum(weights * np.array(tnr_vals)), + } + + raise ValueError("Invalid average method. Choose from {'binary', 'micro', 'macro', 'weighted'}.") diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/README.md b/application/jobs/ODELIA_ternary_classification/app/scripts/README.md new file mode 100644 index 00000000..caf02444 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/README.md @@ -0,0 +1,90 @@ +# Preprocessing scripts for ODELIA - Breast MRI Classification + +## Step 1: Download [DUKE](https://sites.duke.edu/mazurowski/resources/breast-cancer-mri-dataset/) Dataset + +* Create a folder `DUKE` with a subfolder `data_raw` +* [Download](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70226903) files form The Cancer Imaging + Archive (TCIA) into `data_raw` +* Make sure to download the dataset in the "classical" structure (PatientID - StudyInstanceUID - SeriesInstanceUID) +* Place all tables in a folder "metadata" +* The folder structure should look like: + ```bash + DUKE + ├── data_raw + │ ├── Breast_MRI_001 + │ │ ├── 1.3.6.1.4.1.14519 + | | | ├── 1.3.6.1.4.1.14519.5.2.1.10 + | | | ├── 1.3.6.1.4.1.14519.5.2.1.17 + │ ├── Breast_MRI_002 + │ | ├── ... + ├── metadata + | ├── Breast-Cancer-MRI-filepath_filename-mapping.xlsx + | ├── Clinical_and_Other_Features.xlsx + ``` + +## Step 2: Prepare Data ([DUKE](https://sites.duke.edu/mazurowski/resources/breast-cancer-mri-dataset/)) + +* Specify the path to the parent folder as `path_root=...` and `dataset=DUKE` in the following scripts +* Run [step1_dicom2nifti.py](preprocessing/duke/step1_dicom2nifti.py) - It will + store DICOM files as NIFTI files in a new folder `data` +* Run [scripts/preprocessing/step2_compute_sub.py](preprocessing/step2_compute_sub.py) - computes the + subtraction image +* Run [scripts/preprocessing/step3_unilateral.py](preprocessing/step3_unilateral.py) - splits breasts into left + and right side and resamples to uniform shape. The result is stored in a new folder `data_unilateral` +* Run [scripts/preprocessing/duke/step4_create_split.py](preprocessing/duke/step4_create_split.py) - creates a + stratified five-fold split and stores the result in `metadata/split.csv` + +
+ +## Step 3: Prepare Data ([ODELIA](https://odelia.ai/)) + +* Create a folder with the initials of your institution e.g. `ABC` +* Place your DICOM files in a subfolder `data_raw` +* Create a folder `metadata` with the following file inside: + * Challenge: `annotation.xlsx` + * Local Training: `ODELIA annotation scheme-2.0.xlsx` +* Overwrite [scripts/preprocessing/odelia/step1_dicom2nifti.py](preprocessing/odelia/step1_dicom2nifti.py). It + should create a subfolder `data` and subfolders with files named as `T2.nii.gz`, `Pre.nii.gz`, `Post_1.nii.gz`, + `Post_2.nii.gz`, etc. + The subfolder should be labeled as follows: + * Challenge: Folders must have the same name as the entries in the `ID` column of the `annotation.xlsx` file. + * Local Training: Folders must have the same name as the entries in the `StudyInstanceUID` column of the + `ODELIA annotation scheme-2.0.xlsx` file. +* Run [scripts/preprocessing/step2_compute_sub.py](preprocessing/step2_compute_sub.py) - computes the + subtraction image +* Run [scripts/preprocessing/step3_unilateral.py](preprocessing/step3_unilateral.py) - splits breasts into left + and right side and resamples to uniform shape. The result is stored in a new folder `data_unilateral` +* To create a five-fold stratified split and store the result in `metadata/split.csv`, run the following script: + * Local Training: [scripts/preprocessing/odelia/step4_create_split.py](preprocessing/odelia/step4_create_split.py) + +* The final folder structure should look like: + ```bash + ABC + ├── data_raw + ├── data + │ ├── ID_001 + │ │ ├── Pre.nii.gz + | | ├── Post_1.nii.gz + | | ├── Post_2.nii.gz + │ ├── ID_002 + │ | ├── ... + ├── data_unilateral + │ ├── ID_001_left + │ ├── ID_001_right + ├── metadata + | ├── annotation.xlsx + | ├── split.csv + ``` + +
+ +## Step 4: Run Training + +* Specify path to downloaded folder as `PATH_ROOT=` + in [dataset_3d_odelia.py](../custom/data/datasets/dataset_3d_odelia.py) +* Run Script: [main_train.py](main_train.py) + +## Step 5: Predict & Evaluate Performance + +* Run Script: [main_predict.py](main_predict.py) +* Set `path_run` to root directory of latest model diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/.gitignore b/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/.gitignore new file mode 100644 index 00000000..47d769c1 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/.gitignore @@ -0,0 +1 @@ +synthetic_dataset \ No newline at end of file diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/create_synthetic_dataset.py b/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/create_synthetic_dataset.py new file mode 100755 index 00000000..52e8ab4f --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/create_synthetic_dataset.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +import csv +from itertools import product +import numpy as np +import os +import pathlib +import shutil +import sys +import SimpleITK as sitk +from tqdm import tqdm + +np.random.seed(1) + +size = (32, 256, 256) +num_images_per_site = 15 +sites = ('client_A', 'client_B') # this must match the swarm project definition +metadata_folder = 'metadata_unilateral' +data_folder = 'data_unilateral' +other_unused_folders = ('data_raw', 'data') +folders = other_unused_folders + (metadata_folder, data_folder) +some_age = 42 * 365 +num_folds = 5 + + +def create_folder_structure(output_folder) -> None: + shutil.rmtree(output_folder, ignore_errors=True) + os.makedirs(output_folder, exist_ok=True) + for i, site in enumerate(sites): + os.mkdir(output_folder / site) + for folder in folders: + os.mkdir(output_folder / site / folder) + + +def get_image(i: int, j: int, lesion_class: int): + # create three different types of images depending on the class + array = np.random.randint(-10, 10, size=size, dtype=np.int16) + if lesion_class == 0: + array[:, i, j] = -50 + elif lesion_class == 1: + array[:, i, j] = 200 + else: + array[:size[2] // 2, i, j] = 200 + array[size[2] // 2:, i, j] = 50 + image = sitk.GetImageFromArray(array) + return image + + +def save_table(output_folder, site: str, table_data: dict) -> None: + def write_split_csv(output_folder, site: str, table_data: dict) -> None: + with open(output_folder / site / metadata_folder / 'split.csv', 'w') as output_csv: + split_fields = ('UID', 'Fold', 'Split') + writer = csv.DictWriter(output_csv, fieldnames=split_fields) + writer.writeheader() + for linedata in table_data: + writer.writerow({sf: linedata[sf] for sf in split_fields}) + + def _get_annotation_data(table_data: dict, annotation_fields: tuple) -> list: + annotation_data = [{af: linedata[af] for af in annotation_fields} for linedata in table_data] + entries = list({tuple(d.items()) for d in annotation_data}) + entries.sort() + annotation_data = [dict(t) for t in entries] + return annotation_data + + def write_annotation_csv(output_folder, site: str, table_data: dict) -> None: + with open(output_folder / site / metadata_folder / 'annotation.csv', 'w') as output_csv: + annotation_fields = ('UID', 'PatientID', 'Age', 'Lesion') + writer = csv.DictWriter(output_csv, fieldnames=annotation_fields) + writer.writeheader() + + annotation_data = _get_annotation_data(table_data, annotation_fields) + for linedata in annotation_data: + writer.writerow(linedata) + + write_split_csv(output_folder, site, table_data) + write_annotation_csv(output_folder, site, table_data) + + +def get_split(fold: int, num: int) -> str: + # mimic 60/20/20 split that slightly differs between folds + index = ((fold + num) % num_images_per_site) / num_images_per_site + if index < 0.6: + return 'train' + elif index < 0.8: + return 'val' + else: + return 'test' + + +if __name__ == '__main__': + if len(sys.argv) != 2: + print('usage: create_synthetic_dataset.py ') + exit(1) + + output_folder = pathlib.Path(sys.argv[1]) + create_folder_structure(output_folder) + + for i, site in enumerate(sites): + table_data = [] + for j in tqdm(range(num_images_per_site), f'Generating synthetic images for {site}'): + lesion_class = j % 3 + image = get_image(i, j, lesion_class) + for side in ('left', 'right'): + patientid = f'ID_{j:03d}' + uid = f'{patientid}_{side}' + side_folder = output_folder / site / data_folder / uid + os.mkdir(side_folder) + # sitk.WriteImage(image, side_folder/'Pre.nii.gz') + sitk.WriteImage(image, side_folder / 'Sub_1.nii.gz') + # sitk.WriteImage(image, side_folder/'T2.nii.gz') + for f in range(num_folds): + table_data.append( + {'UID': uid, 'PatientID': patientid, 'Lesion': lesion_class, 'Age': some_age + i + j, 'Fold': f, + 'Split': get_split(j, f)}) + + save_table(output_folder, site, table_data) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/main_predict.py b/application/jobs/ODELIA_ternary_classification/app/scripts/main_predict.py new file mode 100644 index 00000000..30d2f45e --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/main_predict.py @@ -0,0 +1,185 @@ +import argparse +from pathlib import Path +import logging +from tqdm import tqdm +import torch +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd +import ast +import torch.nn.functional as F +import torch.multiprocessing as mp +from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score, roc_auc_score, roc_curve + +from odelia.data.datasets import ODELIA_Dataset3D +from odelia.data.datamodules import DataModule +from odelia.models import MST, ResNet, MSTRegression, ResNetRegression +from odelia.utils.roc_curve import cm2x, plot_roc_curve, sensitivity_at_fixed_specificity, \ + specificity_at_fixed_sensitivity + + +def one_hot(y, num_classes): + return np.eye(num_classes, dtype=int)[y] + + +def evaluate(gt, nn, nn_prob, label, label_vals, path_out): + plt.rcParams.update({'font.size': 12}) + fontdict = {'fontsize': 12, 'fontweight': 'bold'} + colors = ['b', 'g', 'r'] + y_prob = np.asarray(nn_prob) + y_pred = np.asarray(nn) + y_true = np.asarray(gt) + labels = list(range(len(label_vals))) + + fig, axes = plt.subplots(ncols=2, figsize=(12, 6)) + + # ------------------------------- ROC-AUC --------------------------------- + y_true_hot = one_hot(y_true, len(label_vals)) + y_prob = np.stack([1 - y_prob, y_prob], axis=1) if binary else y_prob # Convert to one-hot + # fig, axis = plt.subplots(ncols=1, nrows=1, figsize=(6,6)) + axis = axes[0] + results = {'AUC': [], 'Sensitivity': [], 'Specificity': []} + for i in range(len(label_vals)): + if binary and i == 0: + continue + y_true_i = y_true_hot[:, i] + y_prob_i = y_prob[:, i] + fprs, tprs, auc_val, thrs, opt_idx = plot_roc_curve(y_true_i, y_prob_i, axis, color=colors[i], + name=f"AUC {label_vals[i]} {label} ", fontdict=fontdict) + # fprs, tprs, thrs = roc_curve(y_true_hot[:,i], y_prob[:, i], drop_intermediate=False) + sensitivity = sensitivity_at_fixed_specificity(y_true_i, y_prob_i, fprs, thrs, 0.9) + specificity = specificity_at_fixed_sensitivity(y_true_i, y_prob_i, tprs, thrs, 0.9) + print( + f"{label_vals[i]} {label}: AUC {auc_val:.2f} Sensitivity {sensitivity:.2f} Specificity: {specificity:.2f}") + results['AUC'].append(auc_val) + results['Sensitivity'].append(sensitivity) + results['Specificity'].append(specificity) + print( + f"{label}: AUC {np.mean(results['AUC']):.2f} Sensitivity {np.mean(results['Sensitivity']):.2f} Specificity: {np.mean(results['Specificity']):.2f}") + # fig.tight_layout() + # fig.savefig(path_out/f'roc_{label}.png', dpi=300) + + # -------------------------- Confusion Matrix ------------------------- + cm = confusion_matrix(y_true, y_pred, labels=labels) + acc = accuracy_score(y_true, y_pred) + metrics = cm2x(cm, "macro") + + print(f"Accuracy: {acc:.2f}") + print(f"Sensitivity: {metrics['TPR']:.2f}") + print(f"Specificity {metrics['TNR']:.2f}") + + df_cm = pd.DataFrame(data=cm, columns=label_vals, index=label_vals) + # fig, axis = plt.subplots(1, 1, figsize=(4,4)) + axis = axes[1] + sns.heatmap(df_cm, ax=axis, cbar=False, cmap="Blues", fmt='d', annot=True) + axis.set_title(f'{label}', fontdict=fontdict) # CM = [[TN, FP], [FN, TP]] + axis.set_xlabel('Neural Network', fontdict=fontdict) + axis.set_ylabel('Radiologist', fontdict=fontdict) + # fig.tight_layout() + # fig.savefig(path_out/f'confusion_matrix_{label}.png', dpi=300) + + fig.tight_layout() + fig.subplots_adjust(wspace=0.4) + fig.savefig(path_out / f'roc_conf_{label}.png', dpi=300) + + # -------------------------- Agreement ------------------------- + # kappa = cohen_kappa_score(y_true, y_pred, weights="linear") + # print(label, "Kappa", kappa) + + +if __name__ == "__main__": + # ------------ Get Arguments ---------------- + parser = argparse.ArgumentParser() + parser.add_argument('--path_run', + default='runs/ODELIA/MST_binary_unilateral_2025_05_13_170027/epoch=22-step=188922.ckpt', + type=str) + parser.add_argument('--test_institution', default='ODELIA', type=str) + args = parser.parse_args() + batch_size = 4 + + # ------------ Settings/Defaults ---------------- + path_run = Path(args.path_run) + train_institution = path_run.parent.parent.name + run_name = path_run.parent.name + path_out = Path().cwd() / 'results' / train_institution / run_name / args.test_institution + path_out.mkdir(parents=True, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ------------ Logging -------------------- + logger = logging.getLogger(__name__) + logging.basicConfig(level=logging.INFO) + + # ------------ Load Data ---------------- + split = None if args.test_institution == 'RUMC' else 'test' # Use all samples if testing on RUMC + binary = run_name.split('_')[1] == "binary" + config = run_name.split('_')[2] + ds_test = ODELIA_Dataset3D(split=split, institutions=args.test_institution, binary=binary, config=config) + + dm = DataModule( + ds_test=ds_test, + batch_size=batch_size, + num_workers=mp.cpu_count(), + # pin_memory=True, + ) + + # ------------ Initialize Model ------------ + model = run_name.split('_')[0] + model_map = { + 'ResNet': ResNet if binary else ResNetRegression, + 'MST': MST if binary else MSTRegression + } + MODEL = model_map.get(model, None) + model = MODEL.load_from_checkpoint(path_run) + model.to(device) + model.eval() + + # ------------ Predict ---------------- + results = [] + for batch in tqdm(dm.test_dataloader()): + uid, source, target = batch['uid'], batch['source'], batch['target'] + + with torch.no_grad(): + logits = model(source.to(device)).cpu() + + # Transfer logits to integer + pred_prob = model.logits2probabilities(logits) + pred = model.logits2labels(logits) + + for b in range(pred.size(0)): + results.append({ + 'UID': uid[b], + 'GT': target[b].tolist(), + 'NN': pred[b].tolist(), + 'NN_prob': pred_prob[b].tolist(), + }) + + # ------------ Save Results ---------------- + df = pd.DataFrame(results) + df.to_csv(path_out / 'results.csv', index=False) + + # ------------ Evaluate ---------------- + df = pd.read_csv(path_out / 'results.csv') + df['GT'] = df['GT'].apply(ast.literal_eval) + df['NN'] = df['NN'].apply(ast.literal_eval) + df['NN_prob'] = df['NN_prob'].apply(ast.literal_eval) + + gt = np.stack(df['GT'].values) + nn = np.stack(df['NN'].values) + nn_prob = np.stack(df['NN_prob'].values) + labels = ODELIA_Dataset3D.CLASS_LABELS[config] # {'Malignant Lesion': ['No', 'Yes']} if binary else + for i in range(gt.shape[1]): + label = list(labels.keys())[i] + label_vals = labels[label] + evaluate(gt[:, i], nn[:, i], nn_prob[:, i], label, label_vals, path_out) + + # If original(bilateral), evaluate for left and right together + if config == 'original': + gt = gt.reshape(-1, 1) + nn = nn.reshape(-1, 1) + nn_prob = nn_prob.reshape(-1, 1) + labels = ODELIA_Dataset3D.CLASS_LABELS['unilateral'] + for i in range(gt.shape[1]): + label = list(labels.keys())[i] + label_vals = labels[label] + evaluate(gt[:, i], nn[:, i], nn_prob[:, i], label, label_vals, path_out) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/main_train.py b/application/jobs/ODELIA_ternary_classification/app/scripts/main_train.py new file mode 100644 index 00000000..a2580f33 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/main_train.py @@ -0,0 +1,114 @@ +from pathlib import Path +from datetime import datetime + +import torch +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +import torch.multiprocessing as mp +from odelia.data.datasets import ODELIA_Dataset3D +from odelia.data.datamodules import DataModule +from odelia.models import ResNet, MST, ResNetRegression, MSTRegression +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--institution', default='ODELIA', type=str) + parser.add_argument('--model', type=str, default='MST', choices=['ResNet', 'MST']) + parser.add_argument('--task', type=str, default="binary", choices=['binary', + 'ordinal']) # binary: malignant lesion yes/no, ordinal: no lesion, benign, malignant + parser.add_argument('--config', type=str, default="unilateral", choices=['original', 'unilateral']) + args = parser.parse_args() + binary = args.task == 'binary' + + # ------------ Settings/Defaults ---------------- + current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") + run_name = f'{args.model}_{args.task}_{args.config}_{current_time}' + path_run_dir = Path.cwd() / 'runs' / args.institution / run_name + path_run_dir.mkdir(parents=True, exist_ok=True) + accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' + torch.set_float32_matmul_precision('high') + + # ------------ Load Data ---------------- + ds_train = ODELIA_Dataset3D(institutions=args.institution, split='train', binary=binary, config=args.config, + random_flip=True, random_rotate=True, random_inverse=False, noise=True) + ds_val = ODELIA_Dataset3D(institutions=args.institution, split='val', binary=binary, config=args.config) + + samples = len(ds_train) + len(ds_val) + batch_size = 1 + accumulate_grad_batches = 1 + steps_per_epoch = samples / batch_size / accumulate_grad_batches + + # class_counts = ds_train.df["Lesion"].value_counts() + # class_weights = 1 / class_counts / len(class_counts) + # weights = ds_train.df["Lesion"].map(lambda x: class_weights[x]).values + + dm = DataModule( + ds_train=ds_train, + ds_val=ds_val, + ds_test=ds_val, + batch_size=batch_size, + pin_memory=True, + weights=None, # weights, + num_workers=mp.cpu_count(), + ) + + # ------------ Initialize Model ------------ + loss_kwargs = {} + out_ch = len(ds_train.labels) + if not binary: + out_ch = sum(ds_train.class_labels_num) + loss_kwargs = {'class_labels_num': ds_train.class_labels_num} + + model_map = { + 'ResNet': ResNet if binary else ResNetRegression, + 'MST': MST if binary else MSTRegression + } + MODEL = model_map.get(args.model, None) + model = MODEL( + in_ch=1, + out_ch=out_ch, + loss_kwargs=loss_kwargs + ) + + # Load pretrained model + # model = ResNet.load_from_checkpoint('runs/DUKE/2024_11_14_132823/epoch=41-step=17514.ckpt') + + # -------------- Training Initialization --------------- + to_monitor = "val/AUC_ROC" if binary else "val/MAE" + min_max = "max" if binary else "min" + log_every_n_steps = 50 + logger = WandbLogger(project='ODELIA', group=args.institution, name=run_name, log_model=False) + + early_stopping = EarlyStopping( + monitor=to_monitor, + min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement + patience=25, # number of checks with no improvement + mode=min_max + ) + checkpointing = ModelCheckpoint( + dirpath=str(path_run_dir), # dirpath + monitor=to_monitor, + # every_n_train_steps=log_every_n_steps, + save_last=True, + save_top_k=1, + mode=min_max, + ) + trainer = Trainer( + accelerator=accelerator, + accumulate_grad_batches=accumulate_grad_batches, + precision='16-mixed', + default_root_dir=str(path_run_dir), + callbacks=[checkpointing, early_stopping], + enable_checkpointing=True, + check_val_every_n_epoch=1, + log_every_n_steps=log_every_n_steps, + max_epochs=1000, + num_sanity_val_steps=2, + logger=logger + ) + # ---------------- Execute Training ---------------- + trainer.fit(model, datamodule=dm) + + # ------------- Save path to best model ------------- + model.save_best_checkpoint(path_run_dir, checkpointing.best_model_path) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step1_dicom2nifti.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step1_dicom2nifti.py new file mode 100644 index 00000000..5dcf8c4c --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step1_dicom2nifti.py @@ -0,0 +1,134 @@ +from pathlib import Path +import logging +import pandas as pd +from multiprocessing import Pool + +import pydicom +import pydicom.datadict +import pydicom.dataelem +import pydicom.sequence +import pydicom.valuerep +from tqdm import tqdm +import SimpleITK as sitk + +# Logging +# path_log_file = path_root/'preprocessing.log' +logger = logging.getLogger(__name__) + + +# s_handler = logging.StreamHandler(sys.stdout) +# f_handler = logging.FileHandler(path_log_file, 'w') +# logging.basicConfig(level=logging.DEBUG, +# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +# handlers=[s_handler, f_handler]) + + +def maybe_convert(x): + if isinstance(x, pydicom.sequence.Sequence): + # return [maybe_convert(item) for item in x] + return None # Don't store this type of data + elif isinstance(x, pydicom.dataset.Dataset): + # return dataset2dict(x) + return None # Don't store this type of data + elif isinstance(x, pydicom.multival.MultiValue): + return list(x) + elif isinstance(x, pydicom.valuerep.PersonName): + return str(x) + else: + return x + + +def dataset2dict(ds, exclude=['PixelData', '']): + return {keyword: value for key in ds.keys() + if ((keyword := ds[key].keyword) not in exclude) and ((value := maybe_convert(ds[key].value)) is not None)} + + +def series2nifti(series_info): + seq_name, path_series = series_info + path_series = path_root_data / Path(path_series) + if not path_series.is_dir(): + logger.warning(f"Directory not found: {path_series}:") + return + + try: + # Read DICOM + dicom_names = reader.GetGDCMSeriesFileNames(str(path_series)) + reader.SetFileNames(dicom_names) + img_nii = reader.Execute() + + # Read Metadata + ds = pydicom.dcmread(next(path_series.glob('*.dcm'), None), stop_before_pixels=True) + metadata = dataset2dict(ds) + + # Create output folder + path_out_dir = path_root_out_data / path_series.parts[-3] + path_out_dir.mkdir(exist_ok=True, parents=True) + + # Write + filename = seq_name + logger.info(f"Writing file: {filename}:") + path_file = path_out_dir / f'{seq_name}.nii.gz' + sitk.WriteImage(img_nii, path_file) + + metadata['_path_file'] = str(path_file.relative_to(path_root_out_data)) + return metadata + + + except Exception as e: + logger.warning(f"Error in: {path_series}") + logger.warning(str(e)) + + +if __name__ == "__main__": + # Setting + path_root = Path('/home/gustav/Documents/datasets/') + path_root_dataset = path_root / 'DUKE' + + path_root_data = path_root_dataset / 'data_raw/' + path_root_metadata = path_root_dataset / 'metadata' + + path_root_out_data = path_root_dataset / 'data' + path_root_out_data.mkdir(parents=True, exist_ok=True) + + # Init reader + reader = sitk.ImageSeriesReader() + + # Note: Contains path to every single dicom file + # WARNING: reading this .xlsx file takes some time + df_path2name = pd.read_excel(path_root_metadata / 'Breast-Cancer-MRI-filepath_filename-mapping.xlsx') + df_path2name = df_path2name[df_path2name.columns[:4]].copy() + seq_paths = df_path2name['original_path_and_filename'].str.split('/') + df_path2name['PatientID'] = seq_paths.apply(lambda x: int(x[1].rsplit('_', 1)[1])) + df_path2name['SequenceName'] = seq_paths.apply(lambda x: x[2]) + df_path2name['classic_path'] = df_path2name['classic_path'].str.rsplit('/', n=1).str[0] # remove xx.dcm + df_path2name['classic_path'] = df_path2name['classic_path'].str.split('/', n=1).str[ + 1] # remove Duke-Breast-Cancer-MRI/ + df_path2name = df_path2name.drop_duplicates(subset=['PatientID', 'SequenceName'], keep='first') + df_path2name['SequenceName'] = df_path2name['SequenceName'].str.capitalize() # Just convention + df_path2name.to_csv(path_root_metadata / 'Breast-Cancer-MRI-filepath_filename-mapping.csv', index=False) + + df_path2name = pd.read_csv(path_root_metadata / 'Breast-Cancer-MRI-filepath_filename-mapping.csv') + series = list(zip(df_path2name['SequenceName'], + df_path2name['classic_path'])) # NOTE: Only working with TCIA download strategy 'classic_path' + + # Validate + print("Number Series: ", len(series), "of 5034 (5034+127=5161) ") + + # Option 1: Multi-CPU + metadata_list = [] + with Pool() as pool: + for meta in tqdm(pool.imap_unordered(series2nifti, series), total=len(series)): + metadata_list.append(meta) + + # Option 2: Single-CPU (if you need a coffee break) + # metadata_list = [] + # for series_info in tqdm(series): + # meta = series2nifti(series_info) + # metadata_list.append(meta) + + df = pd.DataFrame(metadata_list) + df.to_csv(path_root_metadata / 'metadata.csv', index=False) + + # Check export + num_series = len([path for path in path_root_out_data.rglob('*.nii.gz')]) + print("Number Series: ", num_series, "of 5034 (5034+127=5161) ") diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step4_create_split.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step4_create_split.py new file mode 100644 index 00000000..c6645455 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/duke/step4_create_split.py @@ -0,0 +1,41 @@ +from pathlib import Path +import numpy as np +import pandas as pd + +from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold + +path_root = Path('/home/gustav/Documents/datasets/ODELIA/') +path_root_dataset = path_root / 'DUKE' +path_root_metadata = path_root_dataset / 'metadata' + +df = pd.read_excel(path_root_metadata / 'Clinical_and_Other_Features.xlsx', header=[0, 1, 2]) +df = df[df[df.columns[38]] != 'NC'] # check if cancer is bilateral=1, unilateral=0 or NC +df = df[ + [df.columns[0], df.columns[36], df.columns[38]]] # Only pick relevant columns: Patient ID, Tumor Side, Bilateral +df.columns = ['PatientID', 'Location', 'Bilateral'] # Simplify columns as: Patient ID, Tumor Side +dfs = [] +for side in ["left", 'right']: + dfs.append(pd.DataFrame({ + 'PatientID': df["PatientID"].str.split('_').str[2], + 'UID': df["PatientID"] + f"_{side}", + 'Class': df[["Location", "Bilateral"]].apply(lambda ds: int((ds[0] == side[0].upper()) | (ds[1] == 1)), + axis=1)})) +df = pd.concat(dfs, ignore_index=True) + +df = df.reset_index(drop=True) +splits = [] +sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) # StratifiedGroupKFold +sgkf2 = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) +for fold_i, (train_val_idx, test_idx) in enumerate(sgkf.split(df['UID'], df['Class'], groups=df['PatientID'])): + df_split = df.copy() + df_split['Fold'] = fold_i + df_trainval = df_split.loc[train_val_idx] + train_idx, val_idx = list(sgkf2.split(df_trainval['UID'], df_trainval['Class'], groups=df_trainval['PatientID']))[0] + train_idx, val_idx = df_trainval.iloc[train_idx].index, df_trainval.iloc[val_idx].index + df_split.loc[train_idx, 'Split'] = 'train' + df_split.loc[val_idx, 'Split'] = 'val' + df_split.loc[test_idx, 'Split'] = 'test' + splits.append(df_split) +df_splits = pd.concat(splits) + +df_splits.to_csv(path_root_metadata / 'split.csv', index=False) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step1_dicom2nifti.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step1_dicom2nifti.py new file mode 100644 index 00000000..8e5831bd --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step1_dicom2nifti.py @@ -0,0 +1,3 @@ +# ------------------ +# Add your code here +# ------------------- diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step4_create_split.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step4_create_split.py new file mode 100644 index 00000000..122832a9 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/odelia/step4_create_split.py @@ -0,0 +1,80 @@ +from pathlib import Path +import pandas as pd +from multiprocessing import Pool +from tqdm import tqdm + +from sklearn.model_selection import StratifiedGroupKFold + + +def create_split(df, uid_col='UID', label_col='Label', group_col='PatientID'): + df = df.reset_index(drop=True) + splits = [] + sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) # StratifiedGroupKFold + sgkf2 = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) + for fold_i, (train_val_idx, test_idx) in enumerate(sgkf.split(df[uid_col], df[label_col], groups=df[group_col])): + df_split = df.copy() + df_split['Fold'] = fold_i + df_trainval = df_split.iloc[train_val_idx] + train_idx, val_idx = \ + list(sgkf2.split(df_trainval[uid_col], df_trainval[label_col], groups=df_trainval[group_col]))[0] + train_idx, val_idx = df_trainval.iloc[train_idx].index, df_trainval.iloc[val_idx].index + df_split.loc[train_idx, 'Split'] = 'train' + df_split.loc[val_idx, 'Split'] = 'val' + df_split.loc[test_idx, 'Split'] = 'test' + splits.append(df_split) + df_splits = pd.concat(splits) + return df_splits + + +if __name__ == "__main__": + for dataset in ['UKA']: # 'CAM', 'MHA', 'RSH', 'RUMC', 'UKA', 'UMCU' + print(f"----------------- {dataset} ---------------") + + path_root = Path('/home/homesOnMaster/gfranzes/Documents/datasets/ODELIA/') / dataset + path_root_metadata = path_root / 'metadata' + + df = pd.read_excel(path_root_metadata / 'ODELIA annotation scheme-2.0.xlsx', dtype={'Patient ID': str}) + df = df[11:].reset_index(drop=True) # Remove rows with annotation hints + df = df.rename(columns={'Patient ID': 'PatientID', 'Type of Lesion': 'Lesion'}) + assert not df[['PatientID', 'StudyInstanceUID', 'Lesion']].isna().any().any(), "Missing values detected" + + # Define class mapping + class_mapping = { + 'No lesion': 0, + 'Benign lesion': 1, + 'DCIS': 2, + 'Proliferative with atypia': 2, + 'Invasive Cancer (no special type)': 2, # TODO should invasive cancer be separate class? + 'Invasive Cancer (lobular carcinoma)': 2, + 'Invasive Cancer (all other)': 2, + 'not provided': pd.NA + } + + df_left = df[df['Side'] == "left"] + df_left = df_left[['PatientID', 'StudyInstanceUID', 'Side', 'Lesion']] + df_left.insert(0, 'UID', df_left['StudyInstanceUID'].astype(str) + '_' + df_left['Side']) + + df_left['Class'] = df_left['Lesion'].map(class_mapping) + df_left = df_left.dropna(subset='Class').reset_index(drop=True) # TODO: Should the entire study be removed? + df_left = df_left.loc[df_left.groupby('StudyInstanceUID')['Class'].idxmax()] + + df_right = df[df['Side'] == "right"] + df_right = df_right[['PatientID', 'StudyInstanceUID', 'Side', 'Lesion']] + df_right.insert(0, 'UID', df_right['StudyInstanceUID'].astype(str) + '_' + df_right['Side']) + + df_right['Class'] = df_right['Lesion'].map(class_mapping) + df_right = df_right.dropna(subset='Class').reset_index(drop=True) # TODO: Should the entire study be removed? + df_right = df_right.loc[df_right.groupby('StudyInstanceUID')['Class'].idxmax()] + + # ------------------- Merge left and right ---------------------- + df = pd.concat([df_left, df_right]).reset_index(drop=True) + df['Class'] = df['Class'].astype(int) + + print("Patients", df['PatientID'].nunique()) + print("Studies", df['StudyInstanceUID'].nunique()) + print("Breasts", df['UID'].nunique()) + for class_name, count in df['Class'].value_counts().sort_index().items(): + print(f"Lesion Type {class_name}: {count}") + + df_splits = create_split(df, uid_col='UID', label_col='Class', group_col='PatientID') + df_splits.to_csv(path_root_metadata / 'split.csv', index=False) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step2_compute_sub.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step2_compute_sub.py new file mode 100644 index 00000000..ac490b41 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step2_compute_sub.py @@ -0,0 +1,37 @@ +from pathlib import Path +import SimpleITK as sitk +import numpy as np +from multiprocessing import Pool +from tqdm import tqdm + + +def process(path_patient): + # Compute subtraction image + # Note: if dtype not specified, data is read as uint16 -> subtraction wrong + dyn0_nii = sitk.ReadImage(str(path_patient / 'Pre.nii.gz'), sitk.sitkInt16) + dyn1_nii = sitk.ReadImage(str(path_patient / 'Post_1.nii.gz'), sitk.sitkInt16) + dyn0 = sitk.GetArrayFromImage(dyn0_nii) + dyn1 = sitk.GetArrayFromImage(dyn1_nii) + sub = dyn1 - dyn0 + sub = sub - sub.min() # Note: negative values causes overflow when using uint + sub = sub.astype(np.uint16) + sub_nii = sitk.GetImageFromArray(sub) + sub_nii.CopyInformation(dyn0_nii) + sitk.WriteImage(sub_nii, str(path_patient / 'Sub.nii.gz')) + + +if __name__ == "__main__": + path_root = Path('/home/gustav/Documents/datasets/ODELIA/') + for dataset in ['DUKE', ]: # 'CAM', 'MHA', 'RSH', 'RUMC', 'UKA', 'UMCU', 'DUKE' + path_data = path_root / dataset / 'data' + + files = path_data.iterdir() + + # Option 1: Multi-CPU + with Pool() as pool: + for _ in tqdm(pool.imap_unordered(process, files)): + pass + + # Option 2: Single-CPU + # for path_dir in tqdm(files): + # process(path_dir) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step3_unilateral.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step3_unilateral.py new file mode 100644 index 00000000..926c5347 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/step3_unilateral.py @@ -0,0 +1,89 @@ +from pathlib import Path +import torchio as tio +import torch +import numpy as np +from multiprocessing import Pool +from tqdm import tqdm + + +def crop_breast_height(image, margin_top=10): + "Crop height to 256 and try to cover breast based on intensity localization" + # threshold = int(image.data.float().quantile(0.9)) + threshold = int(np.quantile(image.data.float(), 0.9)) + foreground = image.data > threshold + fg_rows = foreground[0].sum(axis=(0, 2)) + top = min(max(512 - int(torch.argwhere(fg_rows).max()) - margin_top, 0), 256) + bottom = 256 - top + return tio.Crop((0, 0, bottom, top, 0, 0)) + + +def preprocess(path_dir): + # -------- Settings -------------- + ref_img = tio.ScalarImage(path_dir / 'Pre.nii.gz') + ref_img = tio.ToCanonical()(ref_img) + + # Spacing + target_spacing = (0.7, 0.7, 3) + ref_img = tio.Resample(target_spacing)(ref_img) + + # Crop + target_shape = (512, 512, 32) + + padding_constant = ref_img.data.min().item() # Ugly workaround: padding_mode='minimum' calculates the minimum per axis, not globally + transform = tio.Compose([ + tio.Resample(ref_img), # Resample to reference image to ensure that origin, direction, etc, fit + tio.CropOrPad(target_shape, padding_mode=padding_constant), + ]) + crop_height = crop_breast_height(transform(ref_img)) + split_side = { + 'right': tio.Crop((256, 0, 0, 0, 0, 0)), + 'left': tio.Crop((0, 256, 0, 0, 0, 0)), + } + + for n, path_img in enumerate(path_dir.glob('*.nii.gz')): + # Read image + img = tio.ScalarImage(path_img) + + # Preprocess (eg. Crop/Pad) + padding_constant = img.data.min().item() + transform = tio.Compose([ + tio.Resample(ref_img), + tio.CropOrPad(target_shape, padding_mode=padding_constant), + ]) + img = transform(img) + + # Crop bottom and top so that height is 256 and breast is preserved + img = crop_height(img) + + # Split left and right side + for side in ['left', 'right']: + # Create output directory + path_out_dir = path_root_out_data / f"{path_dir.relative_to(path_root_in_data)}_{side}" + path_out_dir.mkdir(exist_ok=True, parents=True) + + # Crop left/right side + img_side = split_side[side](img) + + # Save + img_side.save(path_out_dir / path_img.name) + + +if __name__ == "__main__": + for dataset in ['DUKE', ]: # 'CAM', 'MHA', 'RSH', 'RUMC', 'UKA', 'UMCU', 'DUKE' + + path_root = Path('/home/gustav/Documents/datasets/ODELIA/') / dataset + path_root_in_data = path_root / 'data' + + path_root_out_data = path_root / 'data_unilateral' + path_root_out_data.mkdir(parents=True, exist_ok=True) + + path_patients = list(path_root_in_data.iterdir()) # Convert the iterator to a list + + # Option 1: Multi-CPU + with Pool() as pool: + for _ in tqdm(pool.imap_unordered(preprocess, path_patients)): + pass + + # Option 2: Single-CPU + # for path_dir in tqdm(path_patients): + # preprocess(path_dir) diff --git a/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/uka/step4_create_split.py b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/uka/step4_create_split.py new file mode 100644 index 00000000..b94b6cb4 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/scripts/preprocessing/uka/step4_create_split.py @@ -0,0 +1,44 @@ +from pathlib import Path +import numpy as np +import pandas as pd + +from sklearn.model_selection import StratifiedGroupKFold, StratifiedKFold + +path_root = Path('/home/gustav/Documents/datasets/ODELIA/') +path_root_dataset = path_root / 'UKA_all' +path_root_metadata = path_root_dataset / 'metadata' + +df = pd.read_excel(path_root_metadata / 'annotation_regex.xlsx') +assert len(df[df.duplicated(subset='UID', keep=False)]) == 0, "Duplicates exist" + +df['DCISoderKarzinom'] = df[df.columns[-1]] | df[df.columns[-2]] +print(f"Text available for {len(df)} cases") + +# Include only examinations were MR image is available +uids = [path.name for path in (path_root_dataset / 'data_unilateral').iterdir()] +print(f"Image available for {len(uids)} cases") + +# Merge +df = df[df['UID'].isin(uids)].reset_index(drop=True) +print(f"Text and Image available for {len(df)} cases") + +# label = df.columns[6] +for label in df.columns[6:]: + print(label) + df = df.reset_index(drop=True) + splits = [] + sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) # StratifiedGroupKFold + sgkf2 = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0) + for fold_i, (train_val_idx, test_idx) in enumerate(sgkf.split(df['UID'], df[label], groups=df['PNR'])): + df_split = df.copy() + df_split['Fold'] = fold_i + df_trainval = df_split.loc[train_val_idx] + train_idx, val_idx = list(sgkf2.split(df_trainval['UID'], df_trainval[label], groups=df_trainval['PNR']))[0] + train_idx, val_idx = df_trainval.iloc[train_idx].index, df_trainval.iloc[val_idx].index + df_split.loc[train_idx, 'Split'] = 'train' + df_split.loc[val_idx, 'Split'] = 'val' + df_split.loc[test_idx, 'Split'] = 'test' + splits.append(df_split) + df_splits = pd.concat(splits) + + df_splits.to_csv(path_root_metadata / f'split_regex_{label}.csv', index=False) diff --git a/application/jobs/ODELIA_ternary_classification/app/tests/data/test_dataset_odelia.py b/application/jobs/ODELIA_ternary_classification/app/tests/data/test_dataset_odelia.py new file mode 100644 index 00000000..9c906dc9 --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/tests/data/test_dataset_odelia.py @@ -0,0 +1,43 @@ +from odelia.data.datasets import ODELIA_Dataset3D +import torch +from pathlib import Path +from torchvision.utils import save_image + + +def tensor2image(tensor, batch=0): + return (tensor if tensor.ndim < 5 else torch.swapaxes(tensor[batch], 0, 1).reshape(-1, *tensor.shape[-2:])[:, None]) + + +all_institutions = ODELIA_Dataset3D.ALL_INSTITUTIONS +for institution in all_institutions: + ds = ODELIA_Dataset3D( + institutions=institution, + random_flip=True, + random_rotate=True, + # random_inverse=True, + # noise=True + binary=False, + config='unilateral', + ) + + print(f" ------------- Dataset {institution} ------------") + df = ds.df + print("Number of exams: ", len(df)) + print("Number of patients: ", df['PatientID'].nunique()) + + for label in ds.labels: + print(f"Label {label}") + print(df[label].value_counts()) + + # ----------------- Print some examples ---------------- + item = ds[20] + uid = item["uid"] + img = item['source'] + label = item['target'] + + print("UID", uid, "Image Shape", list(img.shape), "Label", label) + + path_out = Path.cwd() / 'results/test' + path_out.mkdir(parents=True, exist_ok=True) + img = tensor2image(img[None]) + save_image(img, path_out / f'test_{institution}.png', normalize=True) diff --git a/application/jobs/ODELIA_ternary_classification/app/tests/model/test_model_step.py b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_model_step.py new file mode 100644 index 00000000..b0d7aa7a --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_model_step.py @@ -0,0 +1,50 @@ +import torch +from tqdm import tqdm + +from odelia.models import MST, MSTRegression +from odelia.models import ResNet, ResNetRegression +from odelia.data.datasets import ODELIA_Dataset3D +from odelia.data.datamodules import DataModule + +config = "unilateral" # original or unilateral +task = "ordinal" # binary or ordinal +model = "MST" # ResNet or MST +label = None + +binary = task == "binary" +ds_train = ODELIA_Dataset3D(split='train', institutions='ODELIA', binary=binary, config=config, labels=label) + +device = torch.device(f'cuda:5') + +loss_kwargs = {} +out_ch = len(ds_train.labels) +if task == "ordinal": + out_ch = sum(ds_train.class_labels_num) + loss_kwargs = {'class_labels_num': ds_train.class_labels_num} + +if label is not None: + class_counts = ds_train.df[label].value_counts() + class_weights = 1 / class_counts / len(class_counts) + weights = ds_train.df[label].map(lambda x: class_weights[x]).values + +model_map = { + 'ResNet': ResNet if binary else ResNetRegression, + 'MST': MST if binary else MSTRegression +} +MODEL = model_map.get(model, None) +model = MODEL( + in_ch=1, + out_ch=out_ch, + loss_kwargs=loss_kwargs +) + +model.to(device) +model.eval() + +dm = DataModule(ds_train=ds_train, batch_size=3, num_workers=0) +dl = dm.train_dataloader() + +for idx, batch in tqdm(enumerate(iter(dl))): + batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + loss = model._step(batch, batch_idx=idx, state="train", step=idx * dm.batch_size) + print("loss", loss) diff --git a/application/jobs/ODELIA_ternary_classification/app/tests/model/test_mst.py b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_mst.py new file mode 100644 index 00000000..f83b86da --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_mst.py @@ -0,0 +1,10 @@ +import torch +from odelia.models import MST, MSTRegression + +input = torch.randn((1, 1, 32, 224, 224)) +model = MST(in_ch=1, out_ch=2, spatial_dims=3) +model = MSTRegression(in_ch=1, out_ch=2 + 3, spatial_dims=3, loss_kwargs={"class_labels_num": [2, 3]}) + +pred = model(input) +print(pred.shape) +print(pred) diff --git a/application/jobs/ODELIA_ternary_classification/app/tests/model/test_resnet.py b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_resnet.py new file mode 100644 index 00000000..0ad9846e --- /dev/null +++ b/application/jobs/ODELIA_ternary_classification/app/tests/model/test_resnet.py @@ -0,0 +1,10 @@ +import torch +from odelia.models import ResNet, ResNetRegression + +input = torch.randn((1, 1, 32, 224, 224)) +model = ResNet(in_ch=1, out_ch=2, spatial_dims=3, model=18) +model = ResNetRegression(in_ch=1, out_ch=2 + 3, spatial_dims=3, loss_kwargs={"class_labels_num": [2, 3]}) + +pred = model(input) +print(pred.shape) +print(pred) diff --git a/application/jobs/3dcnn_ptl/meta.conf b/application/jobs/ODELIA_ternary_classification/meta.conf similarity index 71% rename from application/jobs/3dcnn_ptl/meta.conf rename to application/jobs/ODELIA_ternary_classification/meta.conf index 2675ccfd..15fc1b67 100644 --- a/application/jobs/3dcnn_ptl/meta.conf +++ b/application/jobs/ODELIA_ternary_classification/meta.conf @@ -1,4 +1,4 @@ -name = "3dcnn_ptl" +name = "ODELIA_ternary_classification" resource_spec {} deploy_map { app = [ diff --git a/application/jobs/minimal_training_pytorch_cnn/app/custom/models/base_model.py b/application/jobs/minimal_training_pytorch_cnn/app/custom/models/base_model.py index bf1d3519..3d933bb4 100644 --- a/application/jobs/minimal_training_pytorch_cnn/app/custom/models/base_model.py +++ b/application/jobs/minimal_training_pytorch_cnn/app/custom/models/base_model.py @@ -1,25 +1,16 @@ -from typing import List, Union +from typing import List, Union, Any from pathlib import Path import json import torch import torch.nn as nn import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.types import EPOCH_OUTPUT from torchmetrics import AUROC, Accuracy class VeryBasicModel(pl.LightningModule): """ A very basic model class extending LightningModule with basic functionality. - - Attributes: - _step_train (int): Counter for training steps. - _step_val (int): Counter for validation steps. - _step_test (int): Counter for test steps. """ - def __init__(self): super().__init__() self.save_hyperparameters() @@ -28,112 +19,58 @@ def __init__(self): self._step_test = -1 def forward(self, x_in): - """Forward pass. Must be implemented by subclasses.""" raise NotImplementedError def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): - """Step function for training, validation, and testing. Must be implemented by subclasses.""" raise NotImplementedError - def _epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]], state: str): - """Epoch end function.""" + def _epoch_end(self, outputs: List[Any], state: str): return - def training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): + def training_step(self, batch, batch_idx): self._step_train += 1 - return self._step(batch, batch_idx, "train", self._step_train, optimizer_idx) + return self._step(batch, batch_idx, "train", self._step_train, 0) - def validation_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): + def validation_step(self, batch: dict, batch_idx: int) -> Any: self._step_val += 1 - return self._step(batch, batch_idx, "val", self._step_val, optimizer_idx) + return self._step(batch, batch_idx, "val", self._step_val, 0) - def test_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0): + def test_step(self, batch: dict, batch_idx: int) -> Any: self._step_test += 1 - return self._step(batch, batch_idx, "test", self._step_test, optimizer_idx) + return self._step(batch, batch_idx, "test", self._step_test, 0) - def training_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "train") - return super().training_epoch_end(outputs) + def on_train_epoch_end(self): + self._epoch_end([], "train") - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "val") - return super().validation_epoch_end(outputs) + def on_validation_epoch_end(self): + self._epoch_end([], "val") - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - self._epoch_end(outputs, "test") - return super().test_epoch_end(outputs) + def on_test_epoch_end(self): + self._epoch_end([], "test") @classmethod def save_best_checkpoint(cls, path_checkpoint_dir, best_model_path): - """Saves the best model checkpoint path. - - Args: - path_checkpoint_dir (str): Directory to save the checkpoint. - best_model_path (str): Path to the best model. - """ with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'w') as f: json.dump({'best_model_epoch': Path(best_model_path).name}, f) @classmethod - def _get_best_checkpoint_path(cls, path_checkpoint_dir, version=0, **kwargs): - """Gets the best model checkpoint path. - - Args: - path_checkpoint_dir (str): Directory containing the checkpoint. - version (int, optional): Version of the checkpoint. Defaults to 0. - - Returns: - Path: Path to the best checkpoint. - """ - path_version = 'lightning_logs/version_' + str(version) - with open(Path(path_checkpoint_dir) / path_version / 'best_checkpoint.json', 'r') as f: + def _get_best_checkpoint_path(cls, path_checkpoint_dir, **kwargs): + with open(Path(path_checkpoint_dir) / 'best_checkpoint.json', 'r') as f: path_rel_best_checkpoint = Path(json.load(f)['best_model_epoch']) return Path(path_checkpoint_dir) / path_rel_best_checkpoint @classmethod - def load_best_checkpoint(cls, path_checkpoint_dir, version=0, **kwargs): - """Loads the best model checkpoint. - - Args: - path_checkpoint_dir (str): Directory containing the checkpoint. - version (int, optional): Version of the checkpoint. Defaults to 0. - - Returns: - LightningModule: The loaded model. - """ - path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir, version) + def load_best_checkpoint(cls, path_checkpoint_dir, **kwargs): + path_best_checkpoint = cls._get_best_checkpoint_path(path_checkpoint_dir) return cls.load_from_checkpoint(path_best_checkpoint, **kwargs) def load_pretrained(self, checkpoint_path, map_location=None, **kwargs): - """Loads pretrained weights from a checkpoint. - - Args: - checkpoint_path (str): Path to the checkpoint. - map_location (str, optional): Device to map the checkpoint. Defaults to None. - - Returns: - LightningModule: The model with loaded weights. - """ if checkpoint_path.is_dir(): checkpoint_path = self._get_best_checkpoint_path(checkpoint_path, **kwargs) - - with pl_legacy_patch(): - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + checkpoint = torch.load(checkpoint_path, map_location=map_location) return self.load_weights(checkpoint["state_dict"], **kwargs) def load_weights(self, pretrained_weights, strict=True, **kwargs): - """Loads weights into the model. - - Args: - pretrained_weights (dict): Pretrained weights. - strict (bool, optional): Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module’s `state_dict` function. Defaults to True. - - Returns: - LightningModule: The model with loaded weights. - """ filter_fn = kwargs.get('filter', lambda key: key in pretrained_weights) init_weights = self.state_dict() pretrained_weights = {key: value for key, value in pretrained_weights.items() if filter_fn(key)} @@ -143,126 +80,75 @@ def load_weights(self, pretrained_weights, strict=True, **kwargs): class BasicModel(VeryBasicModel): - """ - A basic model class with optimizer and learning rate scheduler configurations. - - Attributes: - optimizer (Optimizer): The optimizer to use. - optimizer_kwargs (dict): Keyword arguments for the optimizer. - lr_scheduler (Scheduler): The learning rate scheduler to use. - lr_scheduler_kwargs (dict): Keyword arguments for the learning rate scheduler. - """ - def __init__( self, optimizer=torch.optim.AdamW, - optimizer_kwargs={'lr': 1e-3, 'weight_decay': 1e-2}, + optimizer_kwargs=None, lr_scheduler=None, - lr_scheduler_kwargs={}, + lr_scheduler_kwargs=None, ): super().__init__() - self.save_hyperparameters() self.optimizer = optimizer - self.optimizer_kwargs = optimizer_kwargs + self.optimizer_kwargs = optimizer_kwargs or {'lr': 1e-3, 'weight_decay': 1e-2} self.lr_scheduler = lr_scheduler - self.lr_scheduler_kwargs = lr_scheduler_kwargs + self.lr_scheduler_kwargs = lr_scheduler_kwargs or {} + self.save_hyperparameters() def configure_optimizers(self): - """Configures the optimizers and learning rate schedulers. - - Returns: - list: List containing the optimizer and optionally the learning rate scheduler. - """ optimizer = self.optimizer(self.parameters(), **self.optimizer_kwargs) if self.lr_scheduler is not None: lr_scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) - return [optimizer], [lr_scheduler] - else: - return [optimizer] + return [optimizer], [{"scheduler": lr_scheduler, "interval": "epoch", "frequency": 1}] + return [optimizer] class BasicClassifier(BasicModel): - """ - A basic classifier model with loss function and metrics. - - Attributes: - in_ch (int): Number of input channels. - out_ch (int): Number of output channels. - spatial_dims (int): Number of spatial dimensions. - loss (Loss): The loss function. - loss_kwargs (dict): Keyword arguments for the loss function. - auc_roc (ModuleDict): Dictionary of AUROC metrics. - acc (ModuleDict): Dictionary of Accuracy metrics. - """ - def __init__( self, in_ch: int, out_ch: int, spatial_dims: int, loss=torch.nn.CrossEntropyLoss, - loss_kwargs={}, + loss_kwargs=None, optimizer=torch.optim.AdamW, - optimizer_kwargs={'lr': 1e-3, 'weight_decay': 1e-2}, + optimizer_kwargs=None, lr_scheduler=None, - lr_scheduler_kwargs={}, - aucroc_kwargs={"task": "binary"}, - acc_kwargs={"task": "binary"} + lr_scheduler_kwargs=None, + aucroc_kwargs=None, + acc_kwargs=None, ): super().__init__(optimizer, optimizer_kwargs, lr_scheduler, lr_scheduler_kwargs) + self.in_ch = in_ch self.out_ch = out_ch self.spatial_dims = spatial_dims - self.loss = loss(**loss_kwargs) - self.loss_kwargs = loss_kwargs + self.loss_kwargs = loss_kwargs or {} + self.loss = loss(**self.loss_kwargs) + + aucroc_kwargs = aucroc_kwargs or {"task": "binary"} + acc_kwargs = acc_kwargs or {"task": "binary"} self.auc_roc = nn.ModuleDict({state: AUROC(**aucroc_kwargs) for state in ["train_", "val_", "test_"]}) self.acc = nn.ModuleDict({state: Accuracy(**acc_kwargs) for state in ["train_", "val_", "test_"]}) def _step(self, batch: dict, batch_idx: int, state: str, step: int, optimizer_idx: int): - """Step function for training, validation, and testing. - - Args: - batch (dict): Input batch. - batch_idx (int): Batch index. - state (str): State of the model ('train', 'val', 'test'). - step (int): Current step. - optimizer_idx (int): Index of the optimizer. - - Returns: - Tensor: Loss value. - """ source, target = batch['source'], batch['target'] target = target[:, None].float() + target_int = target.int() batch_size = source.shape[0] - # Run Model pred = self(source) + loss_val = self.loss(pred, target) - # Compute Loss - logging_dict = {} - logging_dict['loss'] = self.loss(pred, target) - - # Compute Metrics with torch.no_grad(): - self.acc[state + "_"].update(pred, target) - self.auc_roc[state + "_"].update(pred, target) - - # Log Scalars - for metric_name, metric_val in logging_dict.items(): - self.log(f"{state}/{metric_name}", metric_val.cpu() if hasattr(metric_val, 'cpu') else metric_val, - batch_size=batch_size, on_step=True, on_epoch=True) + prob = torch.sigmoid(pred) # logits -> probability + self.acc[state + "_"].update(prob, target_int) + self.auc_roc[state + "_"].update(prob, target_int) - return logging_dict['loss'] + self.log(f"{state}/loss", loss_val, batch_size=batch_size, on_step=True, on_epoch=True) + return loss_val def _epoch_end(self, outputs, state): - """Epoch end function. - - Args: - outputs (Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]): Outputs of the epoch. - state (str): State of the model ('train', 'val', 'test'). - """ - batch_size = len(outputs) - for name, value in [("ACC", self.acc[state + "_"]), ("AUC_ROC", self.auc_roc[state + "_"])]: - self.log(f"{state}/{name}", value.compute().cpu(), batch_size=batch_size, on_step=False, on_epoch=True) - value.reset() + for name, metric in [("ACC", self.acc[state + "_"]), ("AUC_ROC", self.auc_roc[state + "_"])]: + self.log(f"{state}/{name}", metric.compute(), on_epoch=True) + metric.reset() diff --git a/application/provision/project_HA.yml b/application/provision/project_HA.yml index 99cb5150..39e1deee 100644 --- a/application/provision/project_HA.yml +++ b/application/provision/project_HA.yml @@ -1,5 +1,5 @@ api_version: 3 -name: 3dcnn_ptl_HA +name: ODELIA_ternary_classification_HA description: > NVIDIA FLARE project YAML file for configuring a federated learning environment with High Availability (HA). focused on 3D convolutional neural networks (3D CNNs) using PyTorch Lightning (PTL). diff --git a/application/provision/project_MEVIS_test.yml b/application/provision/project_MEVIS_test.yml index 787a93d6..b4e962c1 100644 --- a/application/provision/project_MEVIS_test.yml +++ b/application/provision/project_MEVIS_test.yml @@ -4,15 +4,27 @@ description: > Test setup. participants: - - name: odeliatempvm.local + - name: odeliaswarmvm.local type: server org: MEVIS_Test fed_learn_port: 8002 admin_port: 8003 - - name: temporary_vm + - name: CAM type: client org: MEVIS_Test - - name: permanent_vm + - name: MHA + type: client + org: MEVIS_Test + - name: RUMC + type: client + org: MEVIS_Test + - name: UKA + type: client + org: MEVIS_Test + - name: UMCU + type: client + org: MEVIS_Test + - name: Centralized type: client org: MEVIS_Test - name: admin@mevis.odelia @@ -32,7 +44,7 @@ builders: config_folder: config # scheme for communication driver (currently supporting the default, grpc, only). - scheme: grpc + scheme: http # app_validator is used to verify if uploaded app has proper structures # if not set, no app_validator is included in fed_server.json @@ -54,7 +66,7 @@ builders: # overseer_exists: false args: - sp_end_point: odeliatempvm.local:8002:8003 + sp_end_point: odeliaswarmvm.local:8002:8003 - path: nvflare.lighter.impl.cert.CertBuilder - path: nvflare.lighter.impl.signature.SignatureBuilder diff --git a/application/provision/project_Odelia_allsites.yml b/application/provision/project_Odelia_allsites.yml new file mode 100644 index 00000000..530b356d --- /dev/null +++ b/application/provision/project_Odelia_allsites.yml @@ -0,0 +1,98 @@ +api_version: 3 +name: odelia___REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS___allsites_test +description: Odelia TUD server all collaborators clients on Odelia challenge dataset provision http based yaml file + +participants: + # change example.com to the FQDN of the server + - name: dl3.tud.de + type: server + org: TUD + fed_learn_port: 8002 + admin_port: 8003 + - name: TUD_1 + type: client + org: TUD + - name: TUD_2 + type: client + org: TUD + # Specifying listening_host will enable the creation of one pair of + # certificate/private key for this client, allowing the client to function + # as a server for 3rd-party integration. + # The value must be a hostname that the external trainer can reach via the network. + # listening_host: site-1-lh + - name: MEVIS_1 + type: client + org: MEVIS + - name: MEVIS_2 + type: client + org: MEVIS + - name: MEVIS_3 + type: client + org: MEVIS + - name: UKA_1 + type: client + org: UKA + - name: CAM_1 + type: client + org: Cambridge + - name: VHIO_1 + type: client + org: VHIO + - name: MHA_1 + type: client + org: MHA + - name: RSH_1 + type: client + org: RSH + - name: USZ_1 + type: client + org: USZ + - name: UMCU_1 + type: client + org: UMCU + - name: RUMC_1 + type: client + org: RUMC + - name: jiefu.zhu@tu-dresden.de + type: admin + org: TUD + role: project_admin + +# The same methods in all builders are called in their order defined in builders section +builders: + - path: nvflare.lighter.impl.workspace.WorkspaceBuilder + args: + template_file: master_template.yml + - path: nvflare.lighter.impl.template.TemplateBuilder + - path: nvflare.lighter.impl.static_file.StaticFileBuilder + args: + # config_folder can be set to inform NVIDIA FLARE where to get configuration + config_folder: config + + # scheme for communication driver (currently supporting the default, grpc, only). + scheme: http + + # app_validator is used to verify if uploaded app has proper structures + # if not set, no app_validator is included in fed_server.json + # app_validator: PATH_TO_YOUR_OWN_APP_VALIDATOR + + # when docker_image is set to a docker image name, docker.sh will be generated on server/client/admin + docker_image: jefftud/odelia:__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__ + + # download_job_url is set to http://download.server.com/ as default in fed_server.json. You can override this + # to different url. + # download_job_url: http://download.server.com/ + + overseer_agent: + path: nvflare.ha.dummy_overseer_agent.DummyOverseerAgent + # if overseer_exists is true, args here are ignored. Provisioning + # tool will fill role, name and other local parameters automatically. + # if overseer_exists is false, args in this section will be used and the sp_end_point + # must match the server defined above in the format of SERVER_NAME:FL_PORT:ADMIN_PORT + # + overseer_exists: false + args: + sp_end_point: dl3.tud.de:8002:8003 + + - path: nvflare.lighter.impl.cert.CertBuilder + - path: nvflare.lighter.impl.signature.SignatureBuilder diff --git a/application/provision/project_nonHA.yml b/application/provision/project_nonHA.yml index 811c5f3d..e00297c7 100644 --- a/application/provision/project_nonHA.yml +++ b/application/provision/project_nonHA.yml @@ -1,5 +1,5 @@ api_version: 3 -name: 3dcnn_ptl_nonHA +name: ODELIA_ternary_classification_nonHA description: > NVIDIA FLARE project YAML file for configuring a federated learning environment without High Availability (HA). focused on 3D convolutional neural networks (3D CNNs) using PyTorch Lightning (PTL). diff --git a/assets/VPN setup guide(CLI).md b/assets/VPN setup guide(CLI).md index d6e9bc56..e70fe26c 100644 --- a/assets/VPN setup guide(CLI).md +++ b/assets/VPN setup guide(CLI).md @@ -72,8 +72,20 @@ sh envsetup_scripts/setup_vpntunnel.sh The `.ovpn` file assigned to you by TUD is required for re-establishing the connection. -For further troubleshooting, refer to the **VPN Connect Guide**. +For further troubleshooting, refer to the VPN Connect Guide on the GoodAccess support page: +[GoodAccess VPN Connect Guide](https://support.goodaccess.com/configuration-guides/linux) ---- -This guide ensures a smooth setup and reconnection process for GoodAccess VPN via CLI. \ No newline at end of file + +## Step 6: Troubleshooting — Disconnecting Existing VPN Connections + +Some users have experienced that connecting to GoodAccess **disconnects an existing VPN or ssh connection**. +This may happen because OpenVPN is configured to redirect all network traffic through the GoodAccess tunnel, which overrides your local or other VPN routes and may make the machine inaccessible in its local network. + +If this occurs, you can prevent the redirection by starting OpenVPN with: +```sh +openvpn --config .ovpn --pull-filter ignore redirect-gateway +``` +This tells the OpenVPN client **not** to override your default gateway, allowing your other VPN or ssh connection to remain active. + +> **Note:** This behavior was observed by Aitor and Ole after certain OpenVPN updates. The above command has been effective in resolving the issue. \ No newline at end of file diff --git a/assets/openvpn_configs/good_access/.gitignore b/assets/openvpn_configs/good_access/.gitignore new file mode 100644 index 00000000..2e66e21c --- /dev/null +++ b/assets/openvpn_configs/good_access/.gitignore @@ -0,0 +1 @@ +*.ovpn \ No newline at end of file diff --git a/assets/readme/README.developer.md b/assets/readme/README.developer.md new file mode 100644 index 00000000..1215e574 --- /dev/null +++ b/assets/readme/README.developer.md @@ -0,0 +1,101 @@ +# Usage for MediSwarm and Application Code Developers + +## Versioning of ODELIA Docker Images + +If needed, update the version number in file [odelia_image.version](../../odelia_image.version). It will be used +automatically for the Docker image and startup kits. + +## Build the Docker Image and Startup Kits + +The Docker image contains all dependencies for administrative purposes (dashboard, command-line provisioning, admin +console, server) as well as for running the 3DCNN pipeline under the pytorch-lightning framework. +The project description specifies the swarm nodes etc. to be used for a swarm training. + + ```bash + cd MediSwarm + ./buildDockerImageAndStartupKits.sh -p application/provision/ + ``` + +1. Make sure you have no uncommitted changes. +2. If package versions are still not available, you may have to check what the current version is and update the + `Dockerfile` accordingly. Version numbers are hard-coded to avoid issues due to silently different versions being + installed. +3. After successful build (and after verifying that everything works as expected, i.e., local tests, building startup + kits, running local trainings in the startup kit), you can manually push the image to DockerHub, provided you have + the necessary rights. Make sure you are not re-using a version number for this purpose. + +## Running Tests + + ```bash + ./runIntegrationTests.sh + ``` + +You should see + +1. several expected errors and warnings printed from unit tests that should succeed overall, and a coverage report +2. output of a successful simulation run of a dummy training with two nodes +3. output of a successful proof-of-concept run of a dummy training with two nodes +4. output of a successful simulation run of a 3D CNN training using synthetic data with two nodes +5. output of a set of startup kits being generated +6. output of a Docker/GPU preflight check using one of the startup kits +7. output of a data access preflight check using one of the startup kits +8. output of a dummy training run in a swarm consisting of one server and two client nodes + +Optionally, uncomment running NVFlare unit tests. + +## Distributing Startup Kits + +Distribute the startup kits to the clients. + +## Running the Startup Kits + +See [README.participant.md](./README.participant.md). + +### Configurable Parameters for docker.sh + +* The `docker.sh` script run by the swarm participants passes the following environment variables into the container automatically. +* You can override them to customize training behavior. +* Only do this for testing and debugging purposes! The startup kits are designed to ensure that all sites run the same training code, manipulating `docker.sh` might break this. + +| Environment Variable | Default | Description | +|----------------------|-----------------|----------------------------------------------------------------------| +| `SITE_NAME` | *from flag* | Name of your local site, e.g. `TUD_1`, passed via `--start_client` | +| `DATA_DIR` | *from flag* | Path to the host folder that contains your local data | +| `SCRATCH_DIR` | *from flag* | Path for saving training outputs and temporary files | +| `GPU_DEVICE` | `device=0` | GPU identifier to use inside the container (or `all`) | +| `MODEL` | `MST` | Model architecture, choices: `MST`, `ResNet` | +| `INSTITUTION` | `ODELIA` | Institution name, used to group experiment logs | +| `CONFIG` | `unilateral` | Configuration schema for dataset (e.g. label scheme) | +| `NUM_EPOCHS` | `1` (test mode) | Number of training epochs (used in preflight/local training) | +| `TRAINING_MODE` | derived | Internal use. Automatically set based on flags like `--start_client` | + +These are injected into the container as `--env` variables. You can modify their defaults by editing `docker.sh` or exporting before run: + +```bash +export MODEL=ResNet +export CONFIG=original +./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=1 --start_client +``` + + +## Running the Application + +1. **CIFAR-10 example:** + See [README.md](../../application/jobs/cifar10/README.md) +2. **Minimal PyTorch CNN example:** + See [README.md](../../application/jobs/minimal_training_pytorch_cnn/README.md) +3. **3D CNN for classifying breast tumors:** + See [README.md](../../application/jobs/ODELIA_ternary_classification/README.md) + +## Contributing Application Code + +1. Take a look at application/jobs/minimal_training_pytorch_cnn for a minimal example how pytorch code can be adapted to + work with NVFlare +2. Take a look at application/jobs/ODELIA_ternary_classification for a more realistic example of pytorch code that can + run in the swarm +3. Use the local tests to check if the code is swarm-ready +4. TODO more detailed instructions + +## Continuous Integration + +Tests to be executed after pushing to github are defined in `.github/workflows/pr-test.yaml`. diff --git a/assets/readme/README.operator.md b/assets/readme/README.operator.md new file mode 100644 index 00000000..3c88b6a1 --- /dev/null +++ b/assets/readme/README.operator.md @@ -0,0 +1,93 @@ +# Usage for Swarm Operators + +## Setting up a Swarm + +Production mode is designed for secure, real-world deployments. It supports both local and remote setups, whether +on-premise or in the cloud. For more details, refer to +the [NVFLARE Production Mode](https://nvflare.readthedocs.io/en/2.4.1/real_world_fl.html). + +To set up production mode, follow these steps: + +## Edit `/etc/hosts` + +Ensure that your `/etc/hosts` file includes the correct host mappings. All hosts need to be able to communicate to the +server node. + +For example, add the following line (replace `` with the server's actual IP address): + +```plaintext + dl3.tud.de dl3 +``` + +TODO describe this in participant REAME if needed + +## Create Startup Kits + +### Via Script (recommended) + +1. Use, e.g., the file `application/provision/project_MEVIS_test.yml`, adapt as needed (network protocol etc.) +2. Call `buildDockerImageAndStartupKits.sh -p /path/to/project_configuration.yml -c /path/to/directory/with/VPN/credentials` to build the Docker image and the startup kits + - swarm nodes (admin, server, clients) are configured in `project_configuration.yml` + - the directory with VPN credentials should contain one `.ovpn` file per node + - use `-c tests/local_vpn/client_configs/` to build startup kits for the integration tests +3. Startup kits are generated to `workspace//prod_00/` +4. Deploy startup kits to the respective server/client operators +5. Push the Docker image to the registry + +### Via the Dashboard (not recommended) + +Build the Docker image as described above. + +```bash +docker run -d --rm \ + --ipc=host -p 8443:8443 \ + --name=odelia_swarm_admin \ + -v /var/run/docker.sock:/var/run/docker.sock \ + \ + /bin/bash -c "nvflare dashboard --start --local --cred :" +``` + +using some credentials chosen for the swarm admin account. + +Access the dashboard in a web browser at `https://localhost:8443` log in with these credentials, and configure the +project: + +1. enter project short name, name, description +2. enter docker download link: jefftud/odelia: +3. if needed, enter dates +4. click save +5. Server Configuration > Server (DNS name): +6. click make project public + +#### Register client per site + +Access the dashboard at `https://:8443`. + +1. register a user +2. enter organziation (corresponding to the site) +3. enter role (e.g., org admin) +4. add a site (note: must not contain spaces, best use alphanumerical name) +5. specify number of GPUs and their memory + +#### Approve clients and finish configuration + +Access the dashboard at `https://localhost:8443` log in with the admin credentials. + +1. Users Dashboard > approve client user +2. Client Sites > approve client sites +3. Project Home > freeze project + +#### Download startup kits + +After setting up the project admin configuration, server and clients can download their startup kits. Store the +passwords somewhere, they are only displayed once (or you can download them again). + +## Starting a Swarm Training + +1. Connect the *server* host to the VPN as described above. (TODO update documentation, this step is not needed if the Docker container handles the VPN connection.) +2. Start the *server* startup kit using the respective `startup/docker.sh` script with the option to start the server +3. Provide the *client* startup kits to the swarm participants (be aware that email providers or other channels may + prevent encrypted archives) +4. Make sure the participants have started their clients via the respective startup kits, see below +5. Start the *admin* startup kit using the respective `startup/docker.sh` script to start the admin console +6. Deploy a job by `submit_job ` diff --git a/assets/readme/README.participant.md b/assets/readme/README.participant.md new file mode 100644 index 00000000..be8d4759 --- /dev/null +++ b/assets/readme/README.participant.md @@ -0,0 +1,168 @@ +# MediSwarm Participant Guide + +This guide is for data scientists and medical research sites participating in a Swarm Learning project. + +## Prerequisites + +- Hardware: Min. 32GB RAM, 8 cores, NVIDIA GPU with 24GB VRAM, 4TB storage +- OS: Ubuntu 20.04 LTS +- Software: Docker, OpenVPN, Git + +## Setup +0. Add this line to your `/etc/hosts`: `172.24.4.65 dl3.tud.de dl3` +1. Make sure your compute node satisfies the specification and has the necessary software installed. +2. Set up the VPN. A VPN is necessary so that the swarm nodes can communicate with each other securely across firewalls. For that purpose, + 1. Install OpenVPN + ```bash + sudo apt-get install openvpn + ``` + 2. If you have a graphical user interface(GUI), follow this guide to connect to the + VPN: [VPN setup guide(GUI).pdf](../VPN%20setup%20guide%28GUI%29.pdf) + 3. If you have a command line interface(CLI), follow this guide to connect to the + VPN: [VPN setup guide(CLI).md](../VPN%20setup%20guide%28CLI%29.md) + 4. You may want to clone this repository or selectively download VPN-related scripts for this purpose. + +## Prepare Dataset + +The dataset must be in the following format. + +### Folder Structure + + ```bash + + ├── data_unilateral + │ ├── ID_001_left + │ │ └── Sub_1.nii.gz + │ ├── ID_001_right + │ │ └── Sub_1.nii.gz + │ ├── ID_002_left + │ │ └── Sub_1.nii.gz + │ ├── ID_002_right + │ │ └── Sub_1.nii.gz + │ └── ... + └── metadata_unilateral + ├── annotation.csv + └── split.csv + ``` + +* The name of your site should usually end in `_1`, e.g., `UKA_1`, unless you participate with multiple nodes. +* `ID_001`, `ID_002` need to be unique identifiers in your dataset, not specifically of this format +* You might have additional images in the folder like `Pre.nii.gz`, `Post_1.nii.gz`, `Post_2.nii.gz`, `T2.nii.gz`, and you might have additional folders like `data_raw`, `data`, `metadata` etc. These will be ignored and should not cause problems. +* If you clone the repository, you will find a script that generates a synthetic dataset as an example. + +### Table Format + +#### Annotation + +* `annotation.csv` defines the class labels +* The file contains the columns `UID`, `PatientID`, `Age`, `Lesion` + * `UID` is the identifier used in the folder name, e.g., `ID_001_left`. + * `PatientID` is the identifier of the patient, in this case, `ID_001`. + * `Age` is the age of the patient at the time of the scan in days. + This columns is ignored for our current technical tests and exists only for compatibility with the ODELIA challenge data format. Please ignore discrepancies if age is listed in other units than days. + * `Lesion` is 0 for no lesion, 1 for benign lesion, and 2 for malicious lesion. + +#### Split + +* `split.csv` defines the training/validation/test split. +* These splits are hard-coded rather than randomized during training in order to have consistent and documented splits. +* The file contains the columns `UID`, `Split`, and `Fold`. + * `UID` is the identifier used in the folder name, e.g., `ID_001_left`. + * `Split` is either `train`, `val`, or `test`. The test set is currently ignored. + * `Fold` is the 0-based index of the fold (for a potential cross-validation). + +## Prepare Training Participation + +1. Extract the startup kit provided by swarm operator for the current experiment. + +### Local Testing on Your Data + +1. Directories + ```bash + export SITE_NAME= + export DATADIR= + export SCRATCHDIR= + ``` +2. From the directory where you unpacked the startup kit, + ```bash + cd $SITE_NAME/startup + ``` +3. Verify that your Docker/GPU setup is working + ```bash + ./docker.sh --scratch_dir $SCRATCHDIR --GPU device=0 --dummy_training 2>&1 | tee dummy_training_console_output.txt + ``` + * This will pull the Docker image, which might take a while. + * If you have multiple GPUs and 0 is busy, use a different one. + * The “training” itself should take less than minute and does not yield a meaningful classification performance. +4. Verify that your local data can be accessed and the model can be trained locally + ```bash + ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --preflight_check 2>&1 | tee preflight_check_console_output.txt + ``` + * Training time depends on the size of the local dataset. + +### Run Local Training + +To have a baseline for swarm training, train the same model in a comparable way on the local data only. + +1. From the directory where you unpacked the startup kit (unless you just ran the pre-flight check) + ```bash + cd $SITE_NAME/startup + ``` +2. Start local training + ```bash + ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --local_training 2>&1 | tee local_training_console_output.txt + ``` + * This currently runs 100 epochs (somewhat comparable to 20 rounds with 5 epochs each in the swarm case). +3. Output files + * Same as for the swarm training (see below). + +### Start Swarm Node + +#### VPN + +1. Connect to VPN as described in [VPN setup guide(GUI).pdf](../VPN%20setup%20guide%28GUI%29.pdf) (GUI) or [VPN setup guide(CLI).md](../VPN%20setup%20guide%28CLI%29.md) (command line). + +#### Start the Client + +1. From the directory where you unpacked the startup kit: + ```bash + cd $SITE_NAME/startup # Skip this if you just ran the pre-flight check + ``` + +2. Start the client: + ```bash + ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --start_client + ``` + If you have multiple GPUs and 0 is busy, use a different one. + +3. Console output is captured in `nohup.out`, which may have been created with limited permissions in the container, so + make it readable if necessary: + ```bash + sudo chmod a+r nohup.out + ``` + +4. Output files: + - **Training logs and checkpoints** are saved under: + ``` + $SCRATCHDIR/runs/$SITE_NAME// + ``` + - **Best checkpoint** usually saved as `best.ckpt` or `last.ckpt` + - TODO describe prediction results once implemented + - **TensorBoard logs** are stored in their respective folders inside the run directory + +5. (Optional) You can verify that the container is running properly: + ```bash + docker ps # Check if odelia_swarm_client_$SITE_NAME is listed + nvidia-smi # Check if the GPU is busy training (it will be idling while waiting for model transfer) + tail -f nohup.out # Follow training log + ``` +For any issues, check if the commands above point to problems and contact your Swarm Operator. + +## Troubleshooting + +* Folders where files are located need to have the correct name +* Image files need to have the correct file name including capitalization +* The directories listed as identifiers in the tables `annotation.csv` and `split.csv` should all be present and named correctly (including capitalization), only those directories should be present +* The tables should not have additional or duplicate columns, entries need to have the correct captitalization +* Image and table folders and files need to be present in the folders specified via `--data_dir`. Symlinks to other locations do not work, they are not available in the Docker mount. +* The correct startup kit needs to be used. `SSLCertVerificationError` or `authentication failed` may indicate an incorrect startup kit incompatible with the current experiment. diff --git a/assets/readme/README_old.md b/assets/readme/README_old.md new file mode 100644 index 00000000..516d4d0b --- /dev/null +++ b/assets/readme/README_old.md @@ -0,0 +1,362 @@ +# Introduction + +MediSwarm is an open-source project dedicated to advancing medical deep learning through swarm intelligence, leveraging +the NVFlare platform. Developed in collaboration with the Odelia consortium, this repository aims to create a +decentralized and collaborative framework for medical research and applications. + +## Key Features + +- **Swarm Learning:** Utilizes swarm intelligence principles to improve model performance and adaptability. +- **NVFlare Integration:** Built on NVFlare, providing robust and scalable federated learning capabilities. +- **Data Privacy:** Ensures data security and compliance with privacy regulations by keeping data local to each + institution. +- **Collaborative Research:** Facilitates collaboration among medical researchers and institutions for enhanced + outcomes. +- **Extensible Framework:** Designed to support various medical applications and easily integrate with existing + workflows. + +## Prerequisites + +### Hardware recommendations + +* 64 GB of RAM (32 GB is the absolute minimum) +* 16 CPU cores (8 is the absolute minimum) +* an NVIDIA GPU with 48 GB of RAM (24 GB is the minimum) +* 8 TB of Storage (4 TB is the absolute minimum) + +We demonstrate that the system can run on lightweight hardware like this. For less than 10k EUR, you can configure +systems from suppliers like Lambda, Dell Precision, and Dell Alienware. + +### Operating System + +* Ubuntu 20.04 LTS + +### Software + +* Docker +* openvpn +* git + +### Cloning the repository + + ```bash + git clone https://github.com/KatherLab/MediSwarm.git --recurse-submodules + ``` + +* The last argument is necessary because we are using a git submodule for the (ODELIA fork of + NVFlare)[https://github.com/KatherLab/NVFlare_MediSwarm] +* If you have cloned it without this argument, use `git submodule update --init --recursive` + +### VPN + +A VPN is necessary so that the swarm nodes can communicate with each other securely across firewalls. For that purpose, + +1. Install OpenVPN + ```bash + sudo apt-get install openvpn + ``` +2. If you have a graphical user interface(GUI), follow this guide to connect to the + VPN: [VPN setup guide(GUI).pdf](assets/VPN%20setup%20guide%28GUI%29.pdf) +3. If you have a command line interface(CLI), follow this guide to connect to the + VPN: [VPN setup guide(CLI).md](assets/VPN%20setup%20guide%28CLI%29.md) + +# Usage for Swarm Participants + +## Setup + +1. Make sure your compute node satisfies the specification and has the necessary software installed. +2. Clone the repository and connect the client node to the VPN as described above. TODO is cloning the repository + necessary for swarm participants? +3. TODO anything else? + +## Prepare Dataset + +1. see Step 3: Prepare Data in (this document)[application/jobs/ODELIA_ternary_classification/app/scripts/README.md] + +## Prepare Training Participation + +1. Extract startup kit provided by swarm operator + +## Run Pre-Flight Check + +1. Directories + ```bash + export SITE_NAME= # TODO should be defined above, also needed for dataset location + export DATADIR= + export SCRATCHDIR= + ``` +2. From the directory where you unpacked the startup kit, + ```bash + cd $SITE_NAME/startup + ``` +3. Verify that your Docker/GPU setup is working + ```bash + ./docker.sh --scratch_dir $SCRATCHDIR --GPU device=0 --dummy_training + ``` + * This will pull the Docker image, which might take a while. + * If you have multiple GPUs and 0 is busy, use a different one. + * The “training” itself should take less than minute and does not yield a meaningful classification performance. +4. Verify that your local data can be accessed and the model can be trained locally + ```bash + ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --preflight_check + ``` + * Training time depends on the size of the local dataset. + +## Configurable Parameters for docker.sh + +TODO consider what should be described and recommended as configurable here, given that the goal of the startup kits is +to ensure everyone runs the same training + +When launching the client using `./docker.sh`, the following environment variables are automatically passed into the +container. You can override them to customize training behavior: + +| Environment Variable | Default | Description | +|----------------------|-----------------|----------------------------------------------------------------------| +| `SITE_NAME` | *from flag* | Name of your local site, e.g. `TUD_1`, passed via `--start_client` | +| `DATA_DIR` | *from flag* | Path to the host folder that contains your local data | +| `SCRATCH_DIR` | *from flag* | Path for saving training outputs and temporary files | +| `GPU_DEVICE` | `device=0` | GPU identifier to use inside the container (or `all`) | +| `MODEL` | `MST` | Model architecture, choices: `MST`, `ResNet` | +| `INSTITUTION` | `ODELIA` | Institution name, used to group experiment logs | +| `CONFIG` | `unilateral` | Configuration schema for dataset (e.g. label scheme) | +| `NUM_EPOCHS` | `1` (test mode) | Number of training epochs (used in preflight/local training) | +| `TRAINING_MODE` | derived | Internal use. Automatically set based on flags like `--start_client` | + +These are injected into the container as `--env` variables. You can modify their defaults by editing `docker.sh` or +exporting before run: + +```bash +export MODEL=ResNet +export CONFIG=original +./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=1 --start_client +``` + +## Start Swarm Node + +1. From the directory where you unpacked the startup kit: + ```bash + cd $SITE_NAME/startup # Skip this if you just ran the pre-flight check + ``` + +2. Start the client: + ```bash + ./docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU device=0 --start_client + ``` + If you have multiple GPUs and 0 is busy, use a different one. + +3. Console output is captured in `nohup.out`, which may have been created with limited permissions in the container, so + make it readable if necessary: + ```bash + sudo chmod a+r nohup.out + ``` + +4. Output files: + - **Training logs and checkpoints** are saved under: + ``` + $SCRATCHDIR/runs/$SITE_NAME// + ``` + - **Best checkpoint** usually saved as `best.ckpt` or `last.ckpt` + - **Prediction results**, if enabled, will appear in subfolders of the same directory + - **TensorBoard logs**, if activated, are stored in their respective folders inside the run directory + - TODO what is enabled/activated should be hard-coded, adapt accordingly + +5. (Optional) You can verify that the container is running properly: + ```bash + docker ps # Check if odelia_swarm_client_$SITE_NAME is listed + nvidia-smi # Check if the GPU is busy training (it will be idling while waiting for model transfer) + tail -f nohup.out # Follow training log + ``` + +## Run Local Training + +1. From the directory where you unpacked the startup kit + ```bash + cd $SITE_NAME/startup + ``` +2. Start local training + ```bash + /docker.sh --data_dir $DATADIR --scratch_dir $SCRATCHDIR --GPU all --local_training + ``` + * TODO update when handling of the number of epochs has been implemented +3. Output files + * TODO describe + +# Usage for MediSwarm and Application Code Developers + +## Versioning of ODELIA Docker Images + +If needed, update the version number in file (odelia_image.version)[odelia_image.version]. It will be used automatically +for the Docker image and startup kits. + +## Build the Docker Image and Startup Kits + +The Docker image contains all dependencies for administrative purposes (dashboard, command-line provisioning, admin +console, server) as well as for running the 3DCNN pipeline under the pytorch-lightning framework. +The project description specifies the swarm nodes etc. to be used for a swarm training. + +```bash +cd MediSwarm +./buildDockerImageAndStartupKits.sh -p application/provision/ +``` + +1. Make sure you have no uncommitted changes. +2. If package versions are still not available, you may have to check what the current version is and update the + `Dockerfile` accordingly. Version numbers are hard-coded to avoid issues due to silently different versions being + installed. +3. After successful build (and after verifying that everything works as expected, i.e., local tests, building startup + kits, running local trainings in the startup kit), you can manually push the image to DockerHub, provided you have + the necessary rights. Make sure you are not re-using a version number for this purpose. + +## Running Local Tests + + ```bash + ./runTestsInDocker.sh + ``` + +You should see + +1. several expected errors and warnings printed from unit tests that should succeed overall, and a coverage report +2. output of a successful simulation run with two nodes +3. output of a successful proof-of-concept run run with two nodes +4. output of a set of startup kits being generated +5. output of a dummy training run using one of the startup kits +6. TODO update this to what the tests output now + +Optionally, uncomment running NVFlare unit tests in `_runTestsInsideDocker.sh`. + +## Distributing Startup Kits + +Distribute the startup kits to the clients. + +## Running the Application + +1. **CIFAR-10 example:** + See [cifar10/README.md](application/jobs/cifar10/README.md) +2. **Minimal PyTorch CNN example:** + See [application/jobs/minimal_training_pytorch_cnn/README.md](application/jobs/minimal_training_pytorch_cnn/README.md) +3. **3D CNN for classifying breast tumors:** + See [ODELIA_ternary_classification/README.md](application/jobs/ODELIA_ternary_classification/README.md) + +## Contributing Application Code + +1. Take a look at application/jobs/minimal_training_pytorch_cnn for a minimal example how pytorch code can be adapted to + work with NVFlare +2. Take a look at application/jobs/ODELIA_ternary_classification for a more relastic example of pytorch code that can + run in the swarm +3. Use the local tests to check if the code is swarm-ready +4. TODO more detailed instructions + +# Usage for Swarm Operators + +## Setting up a Swarm + +Production mode is designed for secure, real-world deployments. It supports both local and remote setups, whether +on-premise or in the cloud. For more details, refer to +the [NVFLARE Production Mode](https://nvflare.readthedocs.io/en/2.4.1/real_world_fl.html). + +To set up production mode, follow these steps: + +## Edit `/etc/hosts` + +Ensure that your `/etc/hosts` file includes the correct host mappings. All hosts need to be able to communicate to the +server node. + +For example, add the following line (replace `` with the server's actual IP address): + +```plaintext + dl3.tud.de dl3 +``` + +## Create Startup Kits + +### Via Script (recommended) + +1. Use, e.g., the file `application/provision/project_MEVIS_test.yml`, adapt as needed (network protocol etc.) +2. Call `buildStartupKits.sh /path/to/project_configuration.yml` to build the startup kits +3. Startup kits are generated to `workspace//prod_00/` +4. Deploy startup kits to the respective server/clients + +### Via the Dashboard (not recommended) + +```bash +docker run -d --rm \ + --ipc=host -p 8443:8443 \ + --name=odelia_swarm_admin \ + -v /var/run/docker.sock:/var/run/docker.sock \ + \ + /bin/bash -c "nvflare dashboard --start --local --cred :" +``` + +using some credentials chosen for the swarm admin account. + +Access the dashboard in a web browser at `https://localhost:8443` log in with these credentials, and configure the +project: + +1. enter project short name, name, description +2. enter docker download link: jefftud/odelia: +3. if needed, enter dates +4. click save +5. Server Configuration > Server (DNS name): +6. click make project public + +#### Register client per site + +Access the dashboard at `https://:8443`. + +1. register a user +2. enter organziation (corresponding to the site) +3. enter role (e.g., org admin) +4. add a site (note: must not contain spaces, best use alphanumerical name) +5. specify number of GPUs and their memory + +#### Approve clients and finish configuration + +Access the dashboard at `https://localhost:8443` log in with the admin credentials. + +1. Users Dashboard > approve client user +2. Client Sites > approve client sites +3. Project Home > freeze project + +## Download startup kits + +After setting up the project admin configuration, server and clients can download their startup kits. Store the +passwords somewhere, they are only displayed once (or you can download them again). + +## Starting a Swarm Training + +1. Connect the *server* host to the VPN as described above. +2. Start the *server* startup kit using the respective `startup/docker.sh` script with the option to start the server +3. Provide the *client* startup kits to the swarm participants (be aware that email providers or other channels may + prevent encrypted archives) +4. Make sure the participants have started their clients via the respective startup kits, see below +5. Start the *admin* startup kit using the respective `startup/docker.sh` script to start the admin console +6. Deploy a job by `submit_job ` + +# License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. + +# Maintainers + +[Jeff](https://github.com/Ultimate-Storm) +[Ole Schwen](mailto:ole.schwen@mevis.fraunhofer.de) +[Steffen Renisch](mailto:steffen.renisch@mevis.fraunhofer.de) + +# Contributing + +Feel free to dive in! [Open an issue](https://github.com/KatherLab/MediSwarm/issues) or submit pull requests. + +# Credits + +This project utilizes platforms and resources from the following repositories: + +- **[NVFLARE](https://github.com/NVIDIA/NVFlare)**: NVFLARE (NVIDIA Federated Learning Application Runtime Environment) + is an open-source framework that provides a robust and scalable platform for federated learning applications. We have + integrated NVFLARE to efficiently handle the federated learning aspects of our project. + +Special thanks to the contributors and maintainers of these repositories for their valuable work and support. + +--- + +For more details about NVFLARE and its features, please visit +the [NVFLARE GitHub repository](https://github.com/NVIDIA/NVFlare). diff --git a/buildDockerImageAndStartupKits.sh b/buildDockerImageAndStartupKits.sh index 30f330c8..3786b95b 100755 --- a/buildDockerImageAndStartupKits.sh +++ b/buildDockerImageAndStartupKits.sh @@ -13,37 +13,75 @@ DOCKER_BUILD_ARGS="--no-cache --progress=plain"; while [[ "$#" -gt 0 ]]; do case $1 in -p) PROJECT_FILE="$2"; shift ;; + -c) VPN_CREDENTIALS_DIR="$2"; shift ;; --use-docker-cache) DOCKER_BUILD_ARGS="";; *) echo "Unknown parameter passed: $1"; exit 1 ;; esac shift done -if [ -z "$PROJECT_FILE" ]; then - echo "Usage: buildDockerImageAndStartupKits.sh -p [--use-docker-cache]" +if [[ -z "$PROJECT_FILE" || -z "$VPN_CREDENTIALS_DIR" ]]; then + echo "Usage: buildDockerImageAndStartupKits.sh -p -c [--use-docker-cache]" exit 1 fi VERSION=`./getVersionNumber.sh` -DOCKER_IMAGE=jefftud/odelia:$VERSION +CONTAINER_VERSION_ID=`git rev-parse --short HEAD` # prepare clean version of source code repository clone for building Docker image + CWD=`pwd` CLEAN_SOURCE_DIR=`mktemp -d` -cp -r . $CLEAN_SOURCE_DIR/ -cd $CLEAN_SOURCE_DIR +mkdir $CLEAN_SOURCE_DIR/MediSwarm +rsync -ax --exclude workspace . $CLEAN_SOURCE_DIR/MediSwarm/ +cd $CLEAN_SOURCE_DIR/MediSwarm git clean -x -q -f . cd docker_config/NVFlare git clean -x -q -f . cd ../.. rm .git -rf chmod a+rX . -R + +# replacements in copy of source code +sed -i 's#__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_DOCKER_IMAGE__#'$VERSION'#' docker_config/master_template.yml +sed -i 's#__REPLACED_BY_CONTAINER_VERSION_IDENTIFIER_WHEN_BUILDING_DOCKER_IMAGE__#'$CONTAINER_VERSION_ID'#' docker_config/master_template.yml + +# prepare pre-trained model weights for being included in Docker image + +MODEL_WEIGHTS_FILE=$CWD'/docker_config/torch_home_cache/hub/checkpoints/dinov2_vits14_pretrain.pth' +MODEL_LICENSE_FILE=$CWD'/docker_config/torch_home_cache/hub/facebookresearch_dinov2_main/LICENSE' +if [[ ! -f $MODEL_WEIGHTS_FILE || ! -f $MODEL_LICENSE_FILE ]]; then + echo "Pre-trained model not available. Attempting download" + HUBDIR=$(dirname $(dirname $MODEL_LICENSE_FILE)) + mkdir -p $(dirname $MODEL_WEIGHTS_FILE) + wget https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth -O $MODEL_WEIGHTS_FILE + wget https://github.com/facebookresearch/dinov2/archive/refs/heads/main.zip -O /tmp/dinov2.zip + unzip /tmp/dinov2.zip -d $HUBDIR + mv $HUBDIR/dinov2-main $HUBDIR/$(basename $(dirname $MODEL_LICENSE_FILE)) + touch $HUBDIR/trusted_list +fi + +if echo 2e405cee1bad14912278296d4f42e993 $MODEL_WEIGHTS_FILE | md5sum --check - && echo 153d2db1c329326a2d9f881317ea942e $MODEL_LICENSE_FILE | md5sum --check -; then + cp -r $CWD/docker_config/torch_home_cache $CLEAN_SOURCE_DIR/torch_home_cache +else + exit 1 +fi +chmod a+rX $CLEAN_SOURCE_DIR/torch_home_cache -R + cd $CWD -docker build $DOCKER_BUILD_ARGS -t $DOCKER_IMAGE $CLEAN_SOURCE_DIR -f docker_config/Dockerfile_ODELIA +# build and print follow-up steps +CONTAINER_NAME=`grep " docker_image: " $PROJECT_FILE | sed 's/ docker_image: //' | sed 's#__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__#'$VERSION'#'` +echo $CONTAINER_NAME -rm -rf $CLEAN_SOURCE_DIR +docker build $DOCKER_BUILD_ARGS -t $CONTAINER_NAME $CLEAN_SOURCE_DIR -f docker_config/Dockerfile_ODELIA + +echo "Docker image $CONTAINER_NAME built successfully" +echo "./_buildStartupKits.sh $PROJECT_FILE $VERSION $CONTAINER_NAME" +VPN_CREDENTIALS_DIR=$(realpath $VPN_CREDENTIALS_DIR) +./_buildStartupKits.sh $PROJECT_FILE $VERSION $CONTAINER_NAME $VPN_CREDENTIALS_DIR +echo "Startup kits built successfully" -./_buildStartupKits.sh $PROJECT_FILE $VERSION +rm -rf $CLEAN_SOURCE_DIR -echo "If you wish, manually push $DOCKER_IMAGE now" +echo "If you wish, manually push $CONTAINER_NAME now" diff --git a/docker.sh b/docker.sh deleted file mode 100755 index 50ff1bb5..00000000 --- a/docker.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -# docker run script for FL client -# local data directory -: ${MY_DATA_DIR:="/mnt/swarm_alpha/odelia_dataset_divided"} -# The syntax above is to set MY_DATA_DIR to /home/flcient/data if this -# environment variable is not set previously. -# Therefore, users can set their own MY_DATA_DIR with -# export MY_DATA_DIR=$SOME_DIRECTORY -# before running docker.sh - -# for all gpus use line below -GPU2USE='--gpus=all' -# for 2 gpus use line below -#GPU2USE='--gpus=2' -# for specific gpus as gpu#0 and gpu#2 use line below -#GPU2USE='--gpus="device=0,2"' -# to use host network, use line below -NETARG="--net=host" -# FL clients do not need to open ports, so the following line is not needed. -#NETARG="-p 443:443 -p 8003:8003" -DOCKER_IMAGE=jefftud/nvflare-pt-dev:3dcnn -echo "Starting docker with $DOCKER_IMAGE" -mode="${1:--r}" -if [ $mode = "-d" ] -then - docker run -d --rm --name=mediswarm_root $GPU2USE -u $(id -u):$(id -g) \ - -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v $DIR/..:/workspace/ \ - -v $MY_DATA_DIR:/data/:ro -w /workspace/ --ipc=host $NETARG $DOCKER_IMAGE \ - /bin/bash -c "python -u -m nvflare.private.fed.app.client.client_train -m /workspace -s fed_client.json --set uid=mediswarm_root secure_train=true config_folder=config org=tud" -else - docker run --rm -it --name=mediswarm_root $GPU2USE -u $(id -u):$(id -g) \ - -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v $DIR/..:/workspace/ \ - -v $MY_DATA_DIR:/data/:ro -w /workspace/ --ipc=host $NETARG $DOCKER_IMAGE /bin/bash -fi diff --git a/docker_config/Dockerfile_ODELIA b/docker_config/Dockerfile_ODELIA index fd6023e1..269889b7 100644 --- a/docker_config/Dockerfile_ODELIA +++ b/docker_config/Dockerfile_ODELIA @@ -1,5 +1,5 @@ # Use the specified PyTorch image as the base -ARG PYTORCH_IMAGE=pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime +ARG PYTORCH_IMAGE=pytorch/pytorch:2.2.2-cuda12.1-cudnn8-runtime FROM ${PYTORCH_IMAGE} # Specify the NVFlare version @@ -12,13 +12,100 @@ ENV PYTHON_VERSION=3.10.14 # Install updates of installed packages RUN apt update -RUN apt install -y apt=2.4.14 apt-utils=2.4.14 libapt-pkg6.0=2.4.14 +RUN apt install -y \ + apt=2.4.14 \ + apt-utils=2.4.14 \ + libapt-pkg6.0=2.4.14 # Update versions of installed packages -RUN apt install -y base-files=12ubuntu4.7 bash=5.1-6ubuntu1.1 bsdutils=1:2.37.2-4ubuntu3.4 ca-certificates=20240203~22.04.1 coreutils=8.32-4.1ubuntu1.2 dpkg=1.21.1ubuntu2.3 e2fsprogs=1.46.5-2ubuntu1.2 gpgv=2.2.27-3ubuntu2.3 libblkid1=2.37.2-4ubuntu3.4 libc-bin=2.35-0ubuntu3.10 libc-dev-bin=2.35-0ubuntu3.10 libc6-dev=2.35-0ubuntu3.10 libc6=2.35-0ubuntu3.10 libcap2=1:2.44-1ubuntu0.22.04.2 libcom-err2=1.46.5-2ubuntu1.2 libext2fs2=1.46.5-2ubuntu1.2 libgnutls30=3.7.3-4ubuntu1.6 libgssapi-krb5-2=1.19.2-2ubuntu0.7 libk5crypto3=1.19.2-2ubuntu0.7 libkrb5-3=1.19.2-2ubuntu0.7 libkrb5support0=1.19.2-2ubuntu0.7 libmount1=2.37.2-4ubuntu3.4 libpam-modules-bin=1.4.0-11ubuntu2.5 libpam-modules=1.4.0-11ubuntu2.5 libpam-runtime=1.4.0-11ubuntu2.5 libpam0g=1.4.0-11ubuntu2.5 libseccomp2=2.5.3-2ubuntu3~22.04.1 libsmartcols1=2.37.2-4ubuntu3.4 libss2=1.46.5-2ubuntu1.2 libssl3=3.0.2-0ubuntu1.19 libsystemd0=249.11-0ubuntu3.16 libtasn1-6=4.18.0-4ubuntu0.1 libudev1=249.11-0ubuntu3.16 libuuid1=2.37.2-4ubuntu3.4 linux-libc-dev=5.15.0-141.151 logsave=1.46.5-2ubuntu1.2 mount=2.37.2-4ubuntu3.4 openssl=3.0.2-0ubuntu1.19 util-linux=2.37.2-4ubuntu3.4 +RUN apt install -y \ + base-files=12ubuntu4.7 \ + bash=5.1-6ubuntu1.1 \ + bsdutils=1:2.37.2-4ubuntu3.4 \ + ca-certificates=20240203~22.04.1 \ + coreutils=8.32-4.1ubuntu1.2 \ + dpkg=1.21.1ubuntu2.6 \ + e2fsprogs=1.46.5-2ubuntu1.2 \ + gpgv=2.2.27-3ubuntu2.4 \ + libblkid1=2.37.2-4ubuntu3.4 \ + libc-bin=2.35-0ubuntu3.11 \ + libc-dev-bin=2.35-0ubuntu3.11 \ + libc6-dev=2.35-0ubuntu3.11 \ + libc6=2.35-0ubuntu3.11 \ + libcap2=1:2.44-1ubuntu0.22.04.2 \ + libcom-err2=1.46.5-2ubuntu1.2 \ + libext2fs2=1.46.5-2ubuntu1.2 \ + libgnutls30=3.7.3-4ubuntu1.7 \ + libgssapi-krb5-2=1.19.2-2ubuntu0.7 \ + libk5crypto3=1.19.2-2ubuntu0.7 \ + libkrb5-3=1.19.2-2ubuntu0.7 \ + libkrb5support0=1.19.2-2ubuntu0.7 \ + libmount1=2.37.2-4ubuntu3.4 \ + libpam-modules-bin=1.4.0-11ubuntu2.6 \ + libpam-modules=1.4.0-11ubuntu2.6 \ + libpam-runtime=1.4.0-11ubuntu2.6 \ + libpam0g=1.4.0-11ubuntu2.6 \ + libseccomp2=2.5.3-2ubuntu3~22.04.1 \ + libsmartcols1=2.37.2-4ubuntu3.4 \ + libss2=1.46.5-2ubuntu1.2 \ + libssl3=3.0.2-0ubuntu1.20 \ + libsystemd0=249.11-0ubuntu3.16 \ + libtasn1-6=4.18.0-4ubuntu0.1 \ + libudev1=249.11-0ubuntu3.16 \ + libuuid1=2.37.2-4ubuntu3.4 \ + linux-libc-dev=5.15.0-157.167 \ + logsave=1.46.5-2ubuntu1.2 \ + mount=2.37.2-4ubuntu3.4 \ + openssl=3.0.2-0ubuntu1.20 \ + util-linux=2.37.2-4ubuntu3.4 # Install apt-transport-https curl gnupg lsb-release zip and dependencies at defined versions -RUN apt install -y apt-transport-https=2.4.14 curl=7.81.0-1ubuntu1.20 dirmngr=2.2.27-3ubuntu2.3 distro-info-data=0.52ubuntu0.9 gnupg-l10n=2.2.27-3ubuntu2.3 gnupg-utils=2.2.27-3ubuntu2.3 gnupg=2.2.27-3ubuntu2.3 gpg-agent=2.2.27-3ubuntu2.3 gpg-wks-client=2.2.27-3ubuntu2.3 gpg-wks-server=2.2.27-3ubuntu2.3 gpg=2.2.27-3ubuntu2.3 gpgconf=2.2.27-3ubuntu2.3 gpgsm=2.2.27-3ubuntu2.3 libassuan0=2.5.5-1build1 libbrotli1=1.0.9-2build6 libcurl4=7.81.0-1ubuntu1.20 libexpat1=2.4.7-1ubuntu0.6 libksba8=1.6.0-2ubuntu0.2 libldap-2.5-0=2.5.19+dfsg-0ubuntu0.22.04.1 libldap-common=2.5.19+dfsg-0ubuntu0.22.04.1 libmpdec3=2.5.1-2build2 libnghttp2-14=1.43.0-1ubuntu0.2 libnpth0=1.6-3build2 libpsl5=0.21.0-1.2build2 libpython3-stdlib=3.10.6-1~22.04.1 libpython3.10-minimal=3.10.12-1~22.04.9 libpython3.10-stdlib=3.10.12-1~22.04.9 libreadline8=8.1.2-1 librtmp1=2.4+20151223.gitfa8646d.1-2build4 libsasl2-2=2.1.27+dfsg2-3ubuntu1.2 libsasl2-modules-db=2.1.27+dfsg2-3ubuntu1.2 libsasl2-modules=2.1.27+dfsg2-3ubuntu1.2 libsqlite3-0=3.37.2-2ubuntu0.4 libssh-4=0.9.6-2ubuntu0.22.04.3 lsb-release=11.1.0ubuntu4 media-types=7.0.0 pinentry-curses=1.1.1-1build2 publicsuffix=20211207.1025-1 python3-minimal=3.10.6-1~22.04.1 python3.10-minimal=3.10.12-1~22.04.9 python3.10=3.10.12-1~22.04.9 python3=3.10.6-1~22.04.1 readline-common=8.1.2-1 unzip=6.0-26ubuntu3.2 zip=3.0-12build2 +RUN apt install -y \ + apt-transport-https=2.4.14 \ + curl=7.81.0-1ubuntu1.21 \ + dirmngr=2.2.27-3ubuntu2.4 \ + distro-info-data=0.52ubuntu0.9 \ + gnupg-l10n=2.2.27-3ubuntu2.4 \ + gnupg-utils=2.2.27-3ubuntu2.4 \ + gnupg=2.2.27-3ubuntu2.4 \ + gpg-agent=2.2.27-3ubuntu2.4 \ + gpg-wks-client=2.2.27-3ubuntu2.4 \ + gpg-wks-server=2.2.27-3ubuntu2.4 \ + gpg=2.2.27-3ubuntu2.4 \ + gpgconf=2.2.27-3ubuntu2.4 \ + gpgsm=2.2.27-3ubuntu2.4 \ + libassuan0=2.5.5-1build1 \ + libbrotli1=1.0.9-2build6 \ + libcurl4=7.81.0-1ubuntu1.21 \ + libexpat1=2.4.7-1ubuntu0.6 \ + libksba8=1.6.0-2ubuntu0.2 \ + libldap-2.5-0=2.5.19+dfsg-0ubuntu0.22.04.1 \ + libldap-common=2.5.19+dfsg-0ubuntu0.22.04.1 \ + libmpdec3=2.5.1-2build2 \ + libnghttp2-14=1.43.0-1ubuntu0.2 \ + libnpth0=1.6-3build2 \ + libpsl5=0.21.0-1.2build2 \ + libpython3-stdlib=3.10.6-1~22.04.1 \ + libpython3.10-minimal=3.10.12-1~22.04.11 \ + libpython3.10-stdlib=3.10.12-1~22.04.11 \ + libreadline8=8.1.2-1 \ + librtmp1=2.4+20151223.gitfa8646d.1-2build4 \ + libsasl2-2=2.1.27+dfsg2-3ubuntu1.2 \ + libsasl2-modules-db=2.1.27+dfsg2-3ubuntu1.2 \ + libsasl2-modules=2.1.27+dfsg2-3ubuntu1.2 \ + libsqlite3-0=3.37.2-2ubuntu0.5 \ + libssh-4=0.9.6-2ubuntu0.22.04.4 \ + lsb-release=11.1.0ubuntu4 \ + media-types=7.0.0 \ + pinentry-curses=1.1.1-1build2 \ + publicsuffix=20211207.1025-1 \ + python3-minimal=3.10.6-1~22.04.1 \ + python3.10-minimal=3.10.12-1~22.04.11 \ + python3.10=3.10.12-1~22.04.11 \ + python3=3.10.6-1~22.04.1 \ + readline-common=8.1.2-1 \ + unzip=6.0-26ubuntu3.2 \ + zip=3.0-12build2 # Prepare Docker installation RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc \ @@ -27,7 +114,95 @@ RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings && apt update # Install docker-ce docker-ce-cli containerd.io and dependencies at fixed versions -RUN apt install -y apparmor=3.0.4-2ubuntu2.4 containerd.io=1.7.27-1 dbus-user-session=1.12.20-2ubuntu4.1 dbus=1.12.20-2ubuntu4.1 dmsetup=2:1.02.175-2.1ubuntu5 docker-buildx-plugin docker-ce-cli docker-ce-rootless-extras docker-ce docker-compose-plugin gir1.2-glib-2.0=1.72.0-1 git-man=1:2.34.1-1ubuntu1.12 git=1:2.34.1-1ubuntu1.12 iptables=1.8.7-1ubuntu5.2 less=590-1ubuntu0.22.04.3 libapparmor1=3.0.4-2ubuntu2.4 libargon2-1=0~20171227-0.3 libbsd0=0.11.5-1 libcbor0.8=0.8.0-2ubuntu1 libcryptsetup12=2:2.4.3-1ubuntu1.3 libcurl3-gnutls=7.81.0-1ubuntu1.20 libdbus-1-3=1.12.20-2ubuntu4.1 libdevmapper1.02.1=2:1.02.175-2.1ubuntu5 libedit2=3.1-20210910-1build1 liberror-perl=0.17029-1 libfido2-1=1.10.0-1 libgdbm-compat4=1.23-1 libgdbm6=1.23-1 libgirepository-1.0-1=1.72.0-1 libglib2.0-0=2.72.4-0ubuntu2.5 libglib2.0-data=2.72.4-0ubuntu2.5 libicu70=70.1-2 libip4tc2=1.8.7-1ubuntu5.2 libip6tc2=1.8.7-1ubuntu5.2 libjson-c5=0.15-3~ubuntu1.22.04.2 libkmod2=29-1ubuntu1 libltdl7=2.4.6-15build2 libmd0=1.0.4-1build1 libmnl0=1.0.4-3build2 libnetfilter-conntrack3=1.0.9-1 libnfnetlink0=1.0.1-3build3 libnftnl11=1.2.1-1build1 libnss-systemd=249.11-0ubuntu3.16 libpam-systemd=249.11-0ubuntu3.16 libperl5.34=5.34.0-3ubuntu1.4 libslirp0=4.6.1-1build1 libx11-6=2:1.7.5-1ubuntu0.3 libx11-data=2:1.7.5-1ubuntu0.3 libxau6=1:1.0.9-1build5 libxcb1=1.14-3ubuntu3 libxdmcp6=1:1.1.3-0ubuntu5 libxext6=2:1.3.4-1build1 libxml2=2.9.13+dfsg-1ubuntu0.7 libxmuu1=2:1.1.3-3 libxtables12=1.8.7-1ubuntu5.2 netbase=6.3 networkd-dispatcher=2.1-2ubuntu0.22.04.2 openssh-client=1:8.9p1-3ubuntu0.13 patch=2.7.6-7build2 perl-base=5.34.0-3ubuntu1.4 perl-modules-5.34=5.34.0-3ubuntu1.4 perl=5.34.0-3ubuntu1.4 pigz=2.6-1 python3-dbus=1.2.18-3build1 python3-gi=3.42.1-0ubuntu1 shared-mime-info=2.1-2 slirp4netns=1.0.1-2 systemd-sysv=249.11-0ubuntu3.16 systemd-timesyncd=249.11-0ubuntu3.16 systemd=249.11-0ubuntu3.16 xauth=1:1.1-1build2 xdg-user-dirs=0.17-2ubuntu4 xz-utils=5.2.5-2ubuntu1 +RUN apt install -y \ + apparmor=3.0.4-2ubuntu2.4 \ + containerd.io=1.7.28-0~ubuntu.22.04~jammy \ + dbus-user-session=1.12.20-2ubuntu4.1 \ + dbus=1.12.20-2ubuntu4.1 \ + dmsetup=2:1.02.175-2.1ubuntu5 \ + docker-buildx-plugin=0.29.0-0~ubuntu.22.04~jammy \ + docker-ce-cli=5:28.4.0-1~ubuntu.22.04~jammy \ + docker-ce-rootless-extras=5:28.4.0-1~ubuntu.22.04~jammy \ + docker-ce=5:28.4.0-1~ubuntu.22.04~jammy \ + docker-compose-plugin=2.39.4-0~ubuntu.22.04~jammy \ + gir1.2-glib-2.0=1.72.0-1 \ + git-man=1:2.34.1-1ubuntu1.15 \ + git=1:2.34.1-1ubuntu1.15 \ + iptables=1.8.7-1ubuntu5.2 \ + less=590-1ubuntu0.22.04.3 \ + libapparmor1=3.0.4-2ubuntu2.4 \ + libargon2-1=0~20171227-0.3 \ + libbsd0=0.11.5-1 \ + libcbor0.8=0.8.0-2ubuntu1 \ + libcryptsetup12=2:2.4.3-1ubuntu1.3 \ + libcurl3-gnutls=7.81.0-1ubuntu1.21 \ + libdbus-1-3=1.12.20-2ubuntu4.1 \ + libdevmapper1.02.1=2:1.02.175-2.1ubuntu5 \ + libedit2=3.1-20210910-1build1 \ + liberror-perl=0.17029-1 \ + libfido2-1=1.10.0-1 \ + libgdbm-compat4=1.23-1 \ + libgdbm6=1.23-1 \ + libgirepository-1.0-1=1.72.0-1 \ + libglib2.0-0=2.72.4-0ubuntu2.6 \ + libglib2.0-data=2.72.4-0ubuntu2.6 \ + libicu70=70.1-2 \ + libip4tc2=1.8.7-1ubuntu5.2 \ + libip6tc2=1.8.7-1ubuntu5.2 \ + libjson-c5=0.15-3~ubuntu1.22.04.2 \ + libkmod2=29-1ubuntu1 \ + libltdl7=2.4.6-15build2 \ + libmd0=1.0.4-1build1 \ + libmnl0=1.0.4-3build2 \ + libnetfilter-conntrack3=1.0.9-1 \ + libnfnetlink0=1.0.1-3build3 \ + libnftnl11=1.2.1-1build1 \ + libnss-systemd=249.11-0ubuntu3.16 \ + libpam-systemd=249.11-0ubuntu3.16 \ + libperl5.34=5.34.0-3ubuntu1.5 \ + libslirp0=4.6.1-1build1 \ + libx11-6=2:1.7.5-1ubuntu0.3 \ + libx11-data=2:1.7.5-1ubuntu0.3 \ + libxau6=1:1.0.9-1build5 \ + libxcb1=1.14-3ubuntu3 \ + libxdmcp6=1:1.1.3-0ubuntu5 \ + libxext6=2:1.3.4-1build1 \ + libxml2=2.9.13+dfsg-1ubuntu0.9 \ + libxmuu1=2:1.1.3-3 \ + libxtables12=1.8.7-1ubuntu5.2 \ + netbase=6.3 \ + networkd-dispatcher=2.1-2ubuntu0.22.04.2 \ + openssh-client=1:8.9p1-3ubuntu0.13 \ + patch=2.7.6-7build2 \ + perl-base=5.34.0-3ubuntu1.5 \ + perl-modules-5.34=5.34.0-3ubuntu1.5 \ + perl=5.34.0-3ubuntu1.5 \ + pigz=2.6-1 \ + python3-dbus=1.2.18-3build1 \ + python3-gi=3.42.1-0ubuntu1 \ + shared-mime-info=2.1-2 \ + slirp4netns=1.0.1-2 \ + systemd-sysv=249.11-0ubuntu3.16 \ + systemd-timesyncd=249.11-0ubuntu3.16 \ + systemd=249.11-0ubuntu3.16 \ + xauth=1:1.1-1build2 \ + xdg-user-dirs=0.17-2ubuntu4 \ + xz-utils=5.2.5-2ubuntu1 + +# openvpn iputils-ping net-tools sudo and dependencies at fixed versions +# TODO remove tools only needed for debugging +RUN apt install -y \ + iproute2=5.15.0-1ubuntu2 \ + iputils-ping=3:20211215-1ubuntu0.1 \ + libatm1=1:2.5.1-4build2 \ + libbpf0=1:0.5.0-1ubuntu22.04.1 \ + libcap2-bin=1:2.44-1ubuntu0.22.04.2 \ + libelf1=0.186-1ubuntu0.1 \ + liblzo2-2=2.10-2build3 \ + libpam-cap=1:2.44-1ubuntu0.22.04.2 \ + libpkcs11-helper1=1.28-1ubuntu0.22.04.1 \ + net-tools=1.60+git20181103.0eebece-1ubuntu5.4 \ + openvpn=2.5.11-0ubuntu0.22.04.1 # Clean up apt cache RUN rm -rf /var/lib/apt/lists/* @@ -36,36 +211,156 @@ RUN rm -rf /var/lib/apt/lists/* RUN python3 -m pip uninstall -y conda conda-package-handling conda_index # Install specific versions of pip and setuptools -RUN python3 -m pip install -U pip==23.3.1 setuptools==75.8.2 +RUN python3 -m pip install \ + -U \ + pip==25.1.1 \ + setuptools==80.8.0 # Install dependencies of NVFlare at fixed versions -RUN python3 -m pip install --upgrade psutil==7.0.0 -RUN python3 -m pip install Flask==3.0.2 Flask-JWT-Extended==4.6.0 Flask-SQLAlchemy==3.1.1 PyJWT==2.10.1 SQLAlchemy==2.0.16 Werkzeug==3.0.1 blinker==1.9.0 docker==7.1.0 greenlet==3.1.1 grpcio==1.62.1 gunicorn==23.0.0 itsdangerous==2.2.0 msgpack==1.1.0 protobuf==4.24.4 pyhocon==0.3.61 pyparsing==3.0.9 websockets==15.0 +RUN python3 -m pip install \ + --upgrade \ + psutil==7.0.0 +RUN python3 -m pip install \ + Flask==3.0.2 \ + Flask-JWT-Extended==4.6.0 \ + Flask-SQLAlchemy==3.1.1 \ + PyJWT==2.10.1 \ + SQLAlchemy==2.0.16 \ + Werkzeug==3.0.1 \ + blinker==1.9.0 \ + docker==7.1.0 \ + greenlet==3.2.2 \ + grpcio==1.62.1 \ + gunicorn==23.0.0 \ + itsdangerous==2.2.0 \ + msgpack==1.1.0 \ + protobuf==4.24.4 \ + pyhocon==0.3.61 \ + pyparsing==3.2.3 \ + websockets==15.0.1 -# Install additional Python packages for swarm training at defined versions -RUN python3 -m pip install Deprecated==1.2.14 SimpleITK==2.2.1 absl-py==2.1.0 aiohttp==3.9.5 aiosignal==1.3.1 async-timeout==4.0.3 cachetools==5.3.3 contourpy==1.2.1 cycler==0.12.1 et-xmlfile==1.1.0 fonttools==4.53.1 frozenlist==1.4.1 google-auth-oauthlib==1.0.0 google-auth==2.31.0 huggingface_hub==0.23.4 humanize==4.9.0 joblib==1.4.2 kiwisolver==1.4.5 lightning-utilities==0.11.3.post0 markdown-it-py==3.0.0 markdown==3.6 matplotlib==3.7.2 mdurl==0.1.2 monai==1.3.0 multidict==6.0.5 nibabel==5.2.1 oauthlib==3.2.2 openpyxl==3.1.0 pandas==2.2.2 pyasn1-modules==0.4.0 pyasn1==0.6.0 pydicom==2.4.4 python-dateutil==2.9.0.post0 pytorch-lightning==1.9.0 requests-oauthlib==2.0.0 rich==13.7.1 rsa==4.9 safetensors==0.4.3 scikit-learn==1.3.0 scipy==1.14.0 seaborn==0.12.2 shellingham==1.5.4 tensorboard-data-server==0.7.2 tensorboard-plugin-wit==1.8.1 tensorboard==2.12.1 threadpoolctl==3.5.0 timm==0.9.16 torchio==0.19.6 torchmetrics==1.4.0.post0 torchvision==0.17.0 tqdm==4.65.0 typer==0.12.3 tzdata==2024.1 wrapt==1.16.0 yarl==1.9.4 +# Install additional Python packages for application code at defined versions +RUN python3 -m pip install \ + Deprecated==1.2.18 \ + SimpleITK==2.5.0 \ + absl-py==2.2.2 \ + aiohttp==3.11.18 \ + aiosignal==1.3.2 \ + async-timeout==5.0.1 \ + cachetools==5.5.2 \ + contourpy==1.3.2 \ + cycler==0.12.1 \ + et-xmlfile==2.0.0 \ + fonttools==4.58.0 \ + frozenlist==1.6.0 \ + google-auth-oauthlib==1.2.2 \ + google-auth==2.40.2 \ + huggingface_hub==0.29.3 \ + datasets==3.4.1 \ + coral_pytorch==1.4.0 \ + humanize==4.12.3 \ + joblib==1.5.1 \ + kiwisolver==1.4.8 \ + lightning-utilities==0.14.3 \ + markdown-it-py==3.0.0 \ + markdown==3.8 \ + matplotlib==3.9.2 \ + mdurl==0.1.2 \ + monai==1.4.0 \ + multidict==6.4.4 \ + nibabel==5.3.2 \ + oauthlib==3.2.2 \ + openpyxl==3.1.5 \ + pandas==2.2.3 \ + numpy==1.26.4 \ + pyasn1-modules==0.4.2 \ + pyasn1==0.6.1 \ + pydicom==3.0.1 \ + python-dateutil==2.9.0.post0 \ + x-transformers==2.3.5 \ + pytorch-lightning==2.4.0 \ + requests==2.32.3 \ + requests-oauthlib==2.0.0 \ + rich==14.0.0 \ + rsa==4.9.1 \ + safetensors==0.5.3 \ + scikit-learn==1.5.2 \ + scipy==1.15.3 \ + seaborn==0.13.2 \ + wandb==0.18.6 \ + einops==0.8.0 \ + shellingham==1.5.4 \ + tensorboard-data-server==0.7.2 \ + tensorboard-plugin-wit==1.8.1 \ + tensorboard==2.19.0 \ + threadpoolctl==3.6.0 \ + timm==1.0.15 \ + torchio==0.20.1 \ + torchmetrics==1.7.1 \ + torchvision==0.17.2 \ + torchaudio==2.2.2 \ + tqdm==4.67.0 \ + typer==0.15.4 \ + tzdata==2025.2 \ + wrapt==1.17.2 \ + yarl==1.20.0 \ + aiohappyeyeballs==2.6.1 \ + annotated-types==0.7.0 \ + dill==0.3.8 \ + docker-pycreds==0.4.0 \ + einx==0.3.0 \ + frozendict==2.4.6 \ + gitdb==4.0.12 \ + gitpython==3.1.44 \ + hf-xet==1.1.2 \ + importlib-resources==6.5.2 \ + loguru==0.7.3 \ + multiprocess==0.70.16 \ + propcache==0.3.1 \ + pyarrow==20.0.0 \ + pydantic==2.11.5 \ + pydantic-core==2.33.2 \ + sentry-sdk==2.29.1 \ + setproctitle==1.3.6 \ + smmap==5.0.2 \ + typing-extensions==4.13.2 \ + typing-inspection==0.4.1 \ + xxhash==3.5.0 # Install packages needed for testing and for listing licenses of installed packages -RUN python3 -m pip install coverage==7.5.4 mock==5.1.0 -RUN python3 -m pip install pip-licenses==5.0.0 prettytable==3.14.0 +RUN python3 -m pip install \ + coverage==7.8.2 \ + mock==5.2.0 +RUN python3 -m pip install \ + pip-licenses==5.0.0 \ + prettytable==3.16.0 # Clean up pip cache RUN python3 -m pip cache purge # install ODELIA fork of NVFlare from local source WORKDIR /workspace/ -COPY ./docker_config/NVFlare /workspace/nvflare +COPY ./MediSwarm/docker_config/NVFlare /workspace/nvflare ## use startup kit template in the dashboard -COPY ./docker_config/master_template.yml /workspace/nvflare/nvflare/lighter/impl/ +COPY ./MediSwarm/docker_config/master_template.yml /workspace/nvflare/nvflare/lighter/impl/ RUN python3 -m pip install /workspace/nvflare RUN rm -rf /workspace/nvflare # Install the ODELIA controller package from local source -COPY ./controller /workspace/controller -RUN python3 -m pip install /workspace/controller +COPY ./MediSwarm/controller /workspace/controller +RUN python3 -m pip install /workspace/controller RUN rm -rf /workspace/controller # Copy the source code for local training and deploying to the swarm -COPY . /MediSwarm +COPY ./MediSwarm /MediSwarm RUN mkdir -p /fl_admin/transfer RUN ln -s /MediSwarm /fl_admin/transfer/MediSwarm + +# Copy pre-trained model weights to image +COPY ./torch_home_cache /torch_home + +# allow creating home directory for local user inside container if needed +RUN chmod a+rwx /home + +# allow starting VPN connection by non-root users +RUN chmod gu+s /usr/sbin/openvpn diff --git a/docker_config/master_template.yml b/docker_config/master_template.yml index 2c4b1170..c3f3d9d2 100644 --- a/docker_config/master_template.yml +++ b/docker_config/master_template.yml @@ -334,6 +334,9 @@ authz_def: | fl_admin_sh: | #!/usr/bin/env bash + + openvpn ./vpn_client.ovpn >> nohup_vpn.out 2>&1 & + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" python3 -m nvflare.fuel.hci.tools.admin -m $DIR/.. -s fed_admin.json @@ -367,6 +370,9 @@ start_ovsr_sh: | start_cln_sh: | #!/usr/bin/env bash + + openvpn ./vpn_client.ovpn >> nohup_vpn.out 2>&1 & + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" all_arguments="${@}" doCloud=false @@ -392,6 +398,9 @@ start_cln_sh: | start_svr_sh: | #!/usr/bin/env bash + + openvpn ./vpn_client.ovpn >> nohup_vpn.out 2>&1 & + DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" all_arguments="${@}" doCloud=false @@ -620,25 +629,34 @@ sub_start_svr_sh: | docker_cln_sh: | #!/usr/bin/env bash - # docker run script for FL client + # docker run script for FL client with proper env variable forwarding + # Auto disable TTY in non-interactive CI environments + if [ -t 1 ]; then + TTY_OPT="-it" + else + echo "[INFO] No interactive terminal detected, disabling TTY." + TTY_OPT="" + fi + # Parse command-line arguments while [[ "$#" -gt 0 ]]; do case $1 in --data_dir) MY_DATA_DIR="$2"; shift ;; --scratch_dir) MY_SCRATCH_DIR="$2"; shift ;; --GPU) GPU2USE="$2"; shift ;; - --no_pull) NOPULL="1";; - --dummy_training) DUMMY_TRAINING="1";; - --preflight_check) PREFLIGHT_CHECK="1";; - --local_training) LOCAL_TRAINING="1";; - --start_client) START_CLIENT="1";; - --interactive) INTERACTIVE="1";; + --no_pull) NOPULL="1" ;; + --dummy_training) DUMMY_TRAINING="1" ;; + --preflight_check) PREFLIGHT_CHECK="1" ;; + --local_training) LOCAL_TRAINING="1" ;; + --start_client) START_CLIENT="1" ;; + --interactive) INTERACTIVE="1" ;; + --run_script) SCRIPT_TO_RUN="$2"; shift ;; *) echo "Unknown parameter passed: $1"; exit 1 ;; esac shift done - # Ask user for required parameters not passed as command line arguments + # Prompt for parameters if missing if [[ -z "$DUMMY_TRAINING" && -z "$MY_DATA_DIR" ]]; then read -p "Enter the path to your data directory (default: /home/flclient/data): " user_data_dir : ${MY_DATA_DIR:="${user_data_dir:-/home/flclient/data}"} @@ -654,27 +672,29 @@ docker_cln_sh: | : ${GPU2USE:="${user_gpu:-device=0}"} fi - # Get the directory of the current script + # Resolve script directory DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" - sudo mkdir -p $MY_SCRATCH_DIR - sudo chown -R $(id -u):$(id -g) $MY_SCRATCH_DIR - sudo chmod -R 777 $MY_SCRATCH_DIR + mkdir -p "$MY_SCRATCH_DIR" + chmod -R 777 "$MY_SCRATCH_DIR" - # To use host network - NETARG="--net=host" + # Networking & Cleanup + NETARG="--cap-add=NET_ADMIN --device /dev/net/tun" + rm -rf ../pid.fl ../daemon_pid.fl - rm -rf ../pid.fl ../daemon_pid.fl # clean up potential leftovers from previous run - - # Docker image to use + # Docker image and container name DOCKER_IMAGE={~~docker_image~~} if [ -z "$NOPULL" ]; then echo "Updating docker image" - docker pull $DOCKER_IMAGE + docker pull "$DOCKER_IMAGE" fi - CONTAINER_NAME=odelia_swarm_client_{~~client_name~~} + CONTAINER_NAME=odelia_swarm_client_{~~client_name~~}___REPLACED_BY_CONTAINER_VERSION_IDENTIFIER_WHEN_BUILDING_DOCKER_IMAGE__ DOCKER_OPTIONS_A="--name=$CONTAINER_NAME --gpus=$GPU2USE -u $(id -u):$(id -g)" + DOCKER_OPTIONS_A+=" --add-host dl3.tud.de:72.24.4.65 --add-host dl3:72.24.4.65" + if [[ ! -z "$ODELIA_ADDITIONAL_DOCKER_OPTIONS" ]]; then + DOCKER_OPTIONS_A+=" ${ODELIA_ADDITIONAL_DOCKER_OPTIONS}" + fi DOCKER_MOUNTS="-v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v $DIR/..:/startupkit/ -v $MY_SCRATCH_DIR:/scratch/" if [[ ! -z "$MY_DATA_DIR" ]]; then DOCKER_MOUNTS+=" -v $MY_DATA_DIR:/data/:ro" @@ -682,44 +702,53 @@ docker_cln_sh: | DOCKER_OPTIONS_B="-w /startupkit/startup/ --ipc=host $NETARG" DOCKER_OPTIONS="${DOCKER_OPTIONS_A} ${DOCKER_MOUNTS} ${DOCKER_OPTIONS_B}" - echo "Starting docker with $DOCKER_IMAGE as $CONTAINER_NAME" - # Run docker with appropriate parameters + # Common ENV vars + ENV_VARS="--env SITE_NAME={~~client_name~~} \ + --env DATA_DIR=/data \ + --env SCRATCH_DIR=/scratch \ + --env TORCH_HOME=/torch_home \ + --env GPU_DEVICE=$GPU2USE \ + --env MODEL_NAME=MST \ + --env CONFIG=unilateral \ + --env MEDISWARM_VERSION=__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_DOCKER_IMAGE__" + + # Execution modes if [[ ! -z "$DUMMY_TRAINING" ]]; then - DOCKER_ENV_VAR="--env TRAINING_MODE=local_training" - docker run --rm -it \ - $DOCKER_OPTIONS $DOCKER_ENV_VAR $DOCKER_IMAGE \ + docker run --rm $TTY_OPT $DOCKER_OPTIONS $ENV_VARS --env TRAINING_MODE=local_training $DOCKER_IMAGE \ /bin/bash -c "/MediSwarm/application/jobs/minimal_training_pytorch_cnn/app/custom/main.py" + elif [[ ! -z "$PREFLIGHT_CHECK" ]]; then - DOCKER_ENV_VAR="--env TRAINING_MODE=preflight_check --env SITE_NAME={~~client_name~~} --env NUM_EPOCHS=1" - docker run --rm -it \ - $DOCKER_OPTIONS $DOCKER_ENV_VAR $DOCKER_IMAGE \ - /bin/bash -c "/MediSwarm/application/jobs/3dcnn_ptl/app/custom/main.py" + docker run --rm $TTY_OPT $DOCKER_OPTIONS $ENV_VARS --env TRAINING_MODE=preflight_check --env NUM_EPOCHS=1 $DOCKER_IMAGE \ + /bin/bash -c "/MediSwarm/application/jobs/ODELIA_ternary_classification/app/custom/main.py" + elif [[ ! -z "$LOCAL_TRAINING" ]]; then - # TODO how to set number of epochs - DOCKER_ENV_VAR="--env TRAINING_MODE=local_training --env SITE_NAME={~~client_name~~} --env NUM_EPOCHS=1" - docker run --rm -it \ - $DOCKER_OPTIONS $DOCKER_ENV_VAR $DOCKER_IMAGE \ - /bin/bash -c "/MediSwarm/application/jobs/3dcnn_ptl/app/custom/main.py" + docker run --rm $TTY_OPT $DOCKER_OPTIONS $ENV_VARS --env TRAINING_MODE=local_training --env NUM_EPOCHS=100 $DOCKER_IMAGE \ + /bin/bash -c "/MediSwarm/application/jobs/ODELIA_ternary_classification/app/custom/main.py" + elif [[ ! -z "$START_CLIENT" ]]; then - DOCKER_ENV_VAR="--env TRAINING_MODE=swarm" - docker run -d -t --rm \ - $DOCKER_OPTIONS $DOCKER_ENV_VAR $DOCKER_IMAGE \ + docker run -d -t --rm $DOCKER_OPTIONS $ENV_VARS --env TRAINING_MODE=swarm $DOCKER_IMAGE \ /bin/bash -c "nohup ./start.sh >> nohup.out 2>&1 && /bin/bash" + elif [[ ! -z "$INTERACTIVE" ]]; then - # start interactive container - DOCKER_ENV_VAR="" - docker run --rm -it --detach-keys="ctrl-x" \ - $DOCKER_OPTIONS $DOCKER_ENV_VAR $DOCKER_IMAGE \ - /bin/bash -c "/bin/bash" + docker run --rm $TTY_OPT --detach-keys="ctrl-x" $DOCKER_OPTIONS $DOCKER_IMAGE /bin/bash + + elif [[ ! -z "$SCRIPT_TO_RUN" ]]; then + docker run --rm $TTY_OPT $DOCKER_OPTIONS $ENV_VARS $DOCKER_IMAGE \ + /bin/bash -c "$SCRIPT_TO_RUN" + else - echo "One of the following options must be passed:" - echo "--dummy_training locally train a minimum example (to check if the Docker/GPU setup is working)" - echo "--preflight_check run a single epoch of local training (to check if your data can be accessed properly and if you are ready for swarm training)" - echo "--local_training run a local training (to train a local model on your data only)" - echo "--start_client start the swarm learning client" - echo "--interactive start the container with an interactive shell (for debugging purposes)" + echo "❗ One of the following options must be passed:" + echo "--dummy_training minimal sanity check for Docker/GPU" + echo "--preflight_check verify data access & local training" + echo "--local_training train a local model" + echo "--start_client launch FL client in swarm mode" + echo "--interactive drop into interactive container (for debugging)" + echo "--run_script execute script in container (for testing)" + exit 1 fi + + docker_svr_sh: | #!/usr/bin/env bash # docker run script for FL server @@ -738,33 +767,42 @@ docker_svr_sh: | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # to use host network, use line below - NETARG="--net=host" + NETARG="--cap-add=NET_ADMIN --device /dev/net/tun" # or to expose specific ports, use line below #NETARG="-p {~~admin_port~~}:{~~admin_port~~} -p {~~fed_learn_port~~}:{~~fed_learn_port~~}" + # TODO check if admin rights are needed and make sure output files are readable and deletable by non-root users on the host + DOCKER_IMAGE={~~docker_image~~} if [ -z "$NOPULL" ]; then echo "Updating docker image" docker pull $DOCKER_IMAGE fi svr_name="${SVR_NAME:-flserver}" - CONTAINER_NAME=odelia_swarm_server_$svr_name + CONTAINER_NAME=odelia_swarm_server_${svr_name}___REPLACED_BY_CONTAINER_VERSION_IDENTIFIER_WHEN_BUILDING_DOCKER_IMAGE__ rm -rf ../pid.fl ../daemon_pid.fl # clean up potential leftovers from previous run + ADDITIONAL_DOCKER_OPTIONS=" --add-host dl3.tud.de:72.24.4.65 --add-host dl3:72.24.4.65" + if [[ ! -z "$ODELIA_ADDITIONAL_DOCKER_OPTIONS" ]]; then + ADDITIONAL_DOCKER_OPTIONS+=" ${ODELIA_ADDITIONAL_DOCKER_OPTIONS}" + fi + echo "Starting docker with $DOCKER_IMAGE as $CONTAINER_NAME" # Run docker with appropriate parameters if [ ! -z "$START_SERVER" ]; then docker run -d -t --rm --name=$CONTAINER_NAME \ + ${ADDITIONAL_DOCKER_OPTIONS} \ -v $DIR/..:/startupkit/ -w /startupkit/startup/ \ --ipc=host $NETARG $DOCKER_IMAGE \ /bin/bash -c "nohup ./start.sh >> nohup.out 2>&1 && chmod a+r nohup.out && /bin/bash" elif [ ! -z "$LIST_LICENSES" ]; then - docker run -it --rm --name=$CONTAINER_NAME \ + docker run --rm --name=$CONTAINER_NAME \ $DOCKER_IMAGE \ /bin/bash -c "pip-licenses -s -u --order=license" elif [ ! -z "$INTERACTIVE" ]; then docker run --rm -it --detach-keys="ctrl-x" --name=$CONTAINER_NAME \ + ${ADDITIONAL_DOCKER_OPTIONS} \ -v $DIR/..:/startupkit/ -w /startupkit/startup/ \ --ipc=host $NETARG $DOCKER_IMAGE \ /bin/bash -c "/bin/bash" @@ -788,17 +826,28 @@ docker_adm_sh: | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # To use host network - NETARG="--net=host" + NETARG="--cap-add=NET_ADMIN --device /dev/net/tun" + + # TODO check if admin rights are needed and make sure output files are readable and deletable by non-root users on the host DOCKER_IMAGE={~~docker_image~~} if [ -z "$NOPULL" ]; then echo "Updating docker image" docker pull $DOCKER_IMAGE fi - CONTAINER_NAME=odelia_swarm_admin + CONTAINER_NAME=odelia_swarm_admin___REPLACED_BY_CONTAINER_VERSION_IDENTIFIER_WHEN_BUILDING_DOCKER_IMAGE__ + + ADDITIONAL_DOCKER_OPTIONS=" --add-host dl3.tud.de:72.24.4.65 --add-host dl3:72.24.4.65" + if [[ ! -z "$ODELIA_ADDITIONAL_DOCKER_OPTIONS" ]]; then + ADDITIONAL_DOCKER_OPTIONS+=" ${ODELIA_ADDITIONAL_DOCKER_OPTIONS}" + fi echo "Starting docker with $DOCKER_IMAGE as $CONTAINER_NAME" - docker run --rm -it --name=fladmin -v $DIR/../local/:/fl_admin/local/ -v $DIR/../startup/:/fl_admin/startup/ -w /fl_admin/startup/ $NETARG $DOCKER_IMAGE /bin/bash -c "./fl_admin.sh" + docker run --rm -it --name=$CONTAINER_NAME \ + ${ADDITIONAL_DOCKER_OPTIONS} \ + -v $DIR/../local/:/fl_admin/local/ -v $DIR/../startup/:/fl_admin/startup/ \ + -w /fl_admin/startup/ $NETARG $DOCKER_IMAGE \ + /bin/bash -c "./fl_admin.sh" compose_yaml: | services: diff --git a/odelia_image.version b/odelia_image.version index ecc84fdd..c9a1f1c3 100644 --- a/odelia_image.version +++ b/odelia_image.version @@ -1,2 +1,2 @@ # version of the ODELIA Docker image, read by different scripts -0.9 \ No newline at end of file +1.0.1 diff --git a/runIntegrationTests.sh b/runIntegrationTests.sh new file mode 100755 index 00000000..3433195b --- /dev/null +++ b/runIntegrationTests.sh @@ -0,0 +1,543 @@ +#!/usr/bin/env bash + +set -e + +VERSION=$(./getVersionNumber.sh) +CONTAINER_VERSION_SUFFIX=$(git rev-parse --short HEAD) +DOCKER_IMAGE=localhost:5000/odelia:$VERSION +PROJECT_DIR="workspace/odelia_${VERSION}_dummy_project_for_testing" +SYNTHETIC_DATA_DIR=$(mktemp -d) +SCRATCH_DIR=$(mktemp -d) +CWD=$(pwd) +PROJECT_FILE="tests/provision/dummy_project_for_testing.yml" +if [ -z "$GPU_FOR_TESTING" ]; then + export GPU_FOR_TESTING="all" +fi + + +check_files_on_github () { + echo "[Run] Test whether expected content is available on github" + + LICENSE_ON_GITHUB=$(curl -L https://github.com/KatherLab/MediSwarm/raw/refs/heads/main/LICENSE) + if echo "$LICENSE_ON_GITHUB" | grep -q "MIT License" ; then + echo "Downloaded and verified license from github" + else + echo "Could not download and verify license" + exit 1 + fi + + MAIN_README=$(curl -L https://github.com/KatherLab/MediSwarm/raw/refs/heads/main/README.md) + for ROLE in 'Swarm Participant' 'Developer' 'Swarm Operator'; + do + if echo "$MAIN_README" | grep -qie "$ROLE" ; then + echo "Instructions for $ROLE found" + else + echo "Instructions for role $ROLE missing" + exit 1 + fi + done + + PARTICIPANT_README=$(curl -L https://github.com/KatherLab/MediSwarm/raw/refs/heads/main/assets/readme/README.participant.md) + for EXPECTED_KEYWORDS in 'Prerequisites' 'RAM' 'Ubuntu' 'VPN' 'Prepare Dataset' './docker.sh' 'Local Training' 'Start Swarm Node'; + do + if echo "$PARTICIPANT_README" | grep -qie "$EXPECTED_KEYWORDS" ; then + echo "Instructions on $EXPECTED_KEYWORDS found" + else + echo "Instructions on $EXPECTED_KEYWORDS missing" + exit 1 + fi + done + + SWARM_OPERATOR_README=$(curl -L https://github.com/KatherLab/MediSwarm/raw/refs/heads/main/assets/readme/README.operator.md) + for EXPECTED_KEYWORDS in 'Create Startup Kits' 'Starting a Swarm Training'; + do + if echo "$SWARM_OPERATOR_README" | grep -qie "$EXPECTED_KEYWORDS" ; then + echo "Instructions on $EXPECTED_KEYWORDS found" + else + echo "Instructions on $EXPECTED_KEYWORDS missing" + exit 1 + fi + done + + APC_DEVELOPER_README=$(curl -L https://github.com/KatherLab/MediSwarm/raw/refs/heads/main/assets/readme/README.developer.md) + for EXPECTED_KEYWORDS in 'Contributing Application Code'; + do + if echo "$APC_DEVELOPER_README" | grep -qie "$EXPECTED_KEYWORDS" ; then + echo "Instructions on $EXPECTED_KEYWORDS found" + else + echo "Instructions on $EXPECTED_KEYWORDS missing" + exit 1 + fi + done + + DUMMY_TRAINING_APC=$(curl -L https://raw.githubusercontent.com/KatherLab/MediSwarm/refs/heads/main/application/jobs/minimal_training_pytorch_cnn/app/custom/main.py) + for EXPECTED_KEYWORDS in 'python3'; + do + if echo "$DUMMY_TRAINING_APC" | grep -qie "$EXPECTED_KEYWORDS" ; then + echo "Dummy Training ApC: $EXPECTED_KEYWORDS found" + else + echo "Dummy Training ApC: $EXPECTED_KEYWORDS missing" + exit 1 + fi + done +} + + +_run_test_in_docker() { + echo "[Run]" $1 "inside Docker ..." + docker run --rm \ + --shm-size=16g \ + --ipc=host \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + -u $(id -u):$(id -g) \ + -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group \ + -v "$SYNTHETIC_DATA_DIR":/data \ + -v "$SCRATCH_DIR":/scratch \ + --gpus="$GPU_FOR_TESTING" \ + --entrypoint=/MediSwarm/$1 \ + "$DOCKER_IMAGE" +} + + +run_unit_tests_controller(){ + echo "[Run] Controller unit tests" + _run_test_in_docker tests/integration_tests/_run_controller_unit_tests_with_coverage.sh +} + +run_dummy_training_standalone(){ + echo "[Run] Minimal example, standalone" + _run_test_in_docker tests/integration_tests/_run_minimal_example_standalone.sh +} + +run_dummy_training_simulation_mode(){ + echo "[Run] Minimal example, simulation mode" + _run_test_in_docker tests/integration_tests/_run_minimal_example_simulation_mode.sh +} + +run_dummy_training_poc_mode(){ + echo "[Run] Minimal example, proof-of-concept mode" + _run_test_in_docker tests/integration_tests/_run_minimal_example_proof_of_concept_mode.sh +} + +run_nvflare_unit_tests(){ + echo "[Run] NVFlare unit tests" + _run_test_in_docker tests/unit_tests/_run_nvflare_unit_tests.sh +} + + +create_startup_kits_and_check_contained_files () { + echo "[Prepare] Startup kits for test project ..." + + if [ ! -d "$PROJECT_DIR"/prod_00 ]; then + ./_buildStartupKits.sh $PROJECT_FILE $VERSION $DOCKER_IMAGE + fi + if [ -d "$PROJECT_DIR"/prod_01 ]; then + echo '$PROJECT_DIR/prod_01 exists, please remove/rename it' + exit 1 + fi + ./_buildStartupKits.sh $PROJECT_FILE $VERSION $DOCKER_IMAGE + + for FILE in 'client.crt' 'client.key' 'docker.sh' 'rootCA.pem'; + do + if [ -f "$PROJECT_DIR/prod_01/client_A/startup/$FILE" ] ; then + echo "$FILE found" + else + echo "$FILE missing" + exit 1 + fi + done + + ZIP_CONTENT=$(unzip -tv "$PROJECT_DIR/prod_01/client_B_${VERSION}.zip") + for FILE in 'client.crt' 'client.key' 'docker.sh' 'rootCA.pem'; + do + if echo "$ZIP_CONTENT" | grep -q "$FILE" ; then + echo "$FILE found in zip" + else + echo "$FILE missing in zip" + exit 1 + fi + done +} + + +create_synthetic_data () { + echo "[Prepare] Synthetic data ..." + docker run --rm \ + -u $(id -u):$(id -g) \ + -v /etc/passwd:/etc/passwd -v /etc/group:/etc/group \ + -v "$SYNTHETIC_DATA_DIR":/synthetic_data \ + -w /MediSwarm \ + $DOCKER_IMAGE \ + /bin/bash -c "python3 application/jobs/ODELIA_ternary_classification/app/scripts/create_synthetic_dataset/create_synthetic_dataset.py /synthetic_data" +} + + +run_list_licenses () { + cd "$PROJECT_DIR"/prod_00 + cd testserver.local/startup + LICENSES_LISTED=$(./docker.sh --list_licenses --no_pull) + + for EXPECTED_KEYWORDS in 'scikit-learn' 'torch' 'nvflare_mediswarm' 'BSD License' 'MIT License'; + do + if echo "$LICENSES_LISTED" | grep -qie "$EXPECTED_KEYWORDS" ; then + echo "Instructions on $EXPECTED_KEYWORDS found" + else + echo "Instructions on $EXPECTED_KEYWORDS missing" + exit 1 + fi + done + + cd "$CWD" +} + + +run_docker_gpu_preflight_check () { + # requires having built a startup kit + echo "[Run] Docker/GPU preflight check (local dummy training via startup kit) ..." + cd "$PROJECT_DIR/prod_00/client_A/startup/" + CONSOLE_OUTPUT=docker_gpu_preflight_check_console_output.txt + ./docker.sh --scratch_dir "$SCRATCH_DIR"/client_A --GPU device=$GPU_FOR_TESTING --dummy_training --no_pull 2>&1 | tee "$CONSOLE_OUTPUT" + + if grep -q "Epoch 1: 100%" "$CONSOLE_OUTPUT" && grep -q "Training completed successfully" "$CONSOLE_OUTPUT"; then + echo "Expected output of Docker/GPU preflight check found" + else + echo "Missing expected output of Docker/GPU preflight check" + exit 1 + fi + + cd "$CWD" +} + + +run_data_access_preflight_check () { + # requires having built a startup kit and synthetic dataset + echo "[Run] Data access preflight check..." + cd "$PROJECT_DIR"/prod_00 + cd client_A/startup + CONSOLE_OUTPUT=data_access_preflight_check_console_output.txt + ./docker.sh --data_dir "$SYNTHETIC_DATA_DIR" --scratch_dir "$SCRATCH_DIR"/client_A --GPU device=$GPU_FOR_TESTING --preflight_check --no_pull 2>&1 | tee $CONSOLE_OUTPUT + + if grep -q "Train set: 18, Val set: 6" "$CONSOLE_OUTPUT" && grep -q "Epoch 0: 100%" "$CONSOLE_OUTPUT"; then + echo "Expected output of Docker/GPU preflight check found" + else + echo "Missing expected output of Docker/GPU preflight check" + exit 1 + fi + + cd "$CWD" +} + + +run_3dcnn_simulation_mode () { + # requires having built a startup kit and synthetic dataset + echo "[Run] Simulation mode of 3DCNN training in Docker" + _run_test_in_docker tests/integration_tests/_run_3dcnn_simulation_mode.sh +} + + +start_testing_vpn () { + echo "[Prepare] Start local VPN server for testing ..." + + cp -r tests/local_vpn "$PROJECT_DIR"/prod_00/ + chmod a+rX "$PROJECT_DIR"/prod_00/local_vpn -R + cd "$PROJECT_DIR"/prod_00/local_vpn + ./run_docker_openvpnserver.sh + cd "$CWD" +} + + +kill_testing_vpn () { + echo "[Cleanup] Kill local VPN server Docker container ..." + docker kill odelia_testing_openvpnserver +} + + +start_server_and_clients () { + echo "[Run] Start server and client Docker containers ..." + export ODELIA_ADDITIONAL_DOCKER_OPTIONS="--add-host testserver.local:10.8.0.4" + cd "$PROJECT_DIR"/prod_00 + cd testserver.local/startup + ./docker.sh --no_pull --start_server + cd ../.. + sleep 10 + + cd client_A/startup + ./docker.sh --no_pull --data_dir "$SYNTHETIC_DATA_DIR" --scratch_dir "$SCRATCH_DIR"/client_A --GPU device=$GPU_FOR_TESTING --start_client + cd ../.. + cd client_B/startup + ./docker.sh --no_pull --data_dir "$SYNTHETIC_DATA_DIR" --scratch_dir "$SCRATCH_DIR"/client_B --GPU device=$GPU_FOR_TESTING --start_client + sleep 8 + + cd "$CWD" +} + + +start_registry_docker_and_push () { + docker run -d --rm -p 5000:5000 --name local_test_registry_$CONTAINER_VERSION_SUFFIX registry:3 + sleep 3 + docker push localhost:5000/odelia:$VERSION +} + + +run_container_with_pulling () { + docker rmi localhost:5000/odelia:$VERSION + cd "$PROJECT_DIR"/prod_00 + cd testserver.local/startup + OUTPUT=$(./docker.sh --list_licenses) + + if echo "$OUTPUT" | grep -qie "Status: Downloaded newer image for localhost:5000/odelia:$VERSION" ; then + echo "Image pulled successfully" + else + echo "Instructions on $EXPECTED_KEYWORDS missing" + exit 1 + fi + + cd "$CWD" +} + + +kill_registry_docker () { + docker kill local_test_registry_$CONTAINER_VERSION_SUFFIX +} + + +verify_wrong_client_does_not_connect () { + echo "[Run] Verify that client with outdated startup kit does not connect ..." + + cp -r "$PROJECT_DIR"/prod_01 "$PROJECT_DIR"/prod_wrong_client + cd "$PROJECT_DIR"/prod_wrong_client + cd testserver.local/startup + ./docker.sh --no_pull --start_server + cd ../.. + sleep 10 + + rm client_A -rf + tar xvf "$CWD"/tests/integration_tests/outdated_startup_kit.tar.gz + sed -i 's#DOCKER_IMAGE=localhost:5000/odelia:1.0.1-dev.250919.095c1b7#DOCKER_IMAGE='$DOCKER_IMAGE'#' client_A/startup/docker.sh + sed -i 's#CONTAINER_NAME=odelia_swarm_client_client_A_095c1b7#CONTAINER_NAME=odelia_swarm_client_client_A_'$CONTAINER_VERSION_SUFFIX'#' client_A/startup/docker.sh + + cd client_A/startup + ./docker.sh --no_pull --data_dir "$SYNTHETIC_DATA_DIR" --scratch_dir "$SCRATCH_DIR"/client_A --GPU device=$GPU_FOR_TESTING --start_client + cd ../.. + + sleep 20 + + CONSOLE_OUTPUT_SERVER=testserver.local/startup/nohup.out + CONSOLE_OUTPUT_CLIENT=client_A/startup/nohup.out + + if grep -q "Total clients: 1" $CONSOLE_OUTPUT_SERVER; then + echo "Connection with non-authorized client" + exit 1 + else + echo "Connection rejected successfully by server" + fi + + if grep -q "SSLCertVerificationError" $CONSOLE_OUTPUT_CLIENT; then + echo "Connection rejected successfully by client" + else + echo "Could not verify that connection was rejected" + exit 1 + fi + + docker kill odelia_swarm_server_flserver_$CONTAINER_VERSION_SUFFIX odelia_swarm_client_client_A_$CONTAINER_VERSION_SUFFIX + rm -rf "$PROJECT_DIR"/prod_wrong_client + + cd "$CWD" +} + + +run_dummy_training_in_swarm () { + echo "[Run] Dummy training in swarm ..." + + cd "$PROJECT_DIR"/prod_00 + cd admin@test.odelia/startup + "$CWD"/tests/integration_tests/_submitDummyTraining.exp + docker kill odelia_swarm_admin_$CONTAINER_VERSION_SUFFIX + sleep 60 + cd "$CWD" + + cd "$PROJECT_DIR"/prod_00/testserver.local/startup + CONSOLE_OUTPUT=nohup.out + for EXPECTED_OUTPUT in 'Total clients: 2' 'updated status of client client_A on round 4' 'updated status of client client_B on round 4' 'all_done=True' 'Server runner finished.' \ + 'Start to the run Job: [0-9a-f]\{8\}-[0-9a-f]\{4\}-[0-9a-f]\{4\}-[0-9a-f]\{4\}-[0-9a-f]\{12\}' 'updated status of client client_B on round 4'; + do + if grep -q --regexp="$EXPECTED_OUTPUT" "$CONSOLE_OUTPUT"; then + echo "Expected output $EXPECTED_OUTPUT found" + else + echo "Expected output $EXPECTED_OUTPUT missing" + exit 1 + fi + done + cd "$CWD" + + cd "$PROJECT_DIR"/prod_00/client_A/startup + CONSOLE_OUTPUT=nohup.out + for EXPECTED_OUTPUT in 'Sending training result to aggregation client' 'Epoch 9: 100%' 'val/AUC_ROC'; + do + if grep -q "$EXPECTED_OUTPUT" "$CONSOLE_OUTPUT"; then + echo "Expected output $EXPECTED_OUTPUT found" + else + echo "Expected output $EXPECTED_OUTPUT missing" + exit 1 + fi + done + cd "$CWD" + + for EXPECTED_OUTPUT in 'validation metric .* from client' 'aggregating [0-9]* update(s) at round [0-9]*'; + do + if grep -q --regexp="$EXPECTED_OUTPUT" "$PROJECT_DIR"/prod_00/client_?/startup/nohup.out; then + echo "Expected output $EXPECTED_OUTPUT found" + else + echo "Expected output $EXPECTED_OUTPUT missing" + exit 1 + fi + done + + cd "$PROJECT_DIR"/prod_00/client_A/ + FILES_PRESENT=$(find . -type f -name "*.*") + for EXPECTED_FILE in 'custom/minimal_training.py' 'best_FL_global_model.pt' 'FL_global_model.pt' ; + do + if echo "$FILES_PRESENT" | grep -q "$EXPECTED_FILE" ; then + echo "Expected file $EXPECTED_FILE found" + else + echo "Expected file $EXPECTED_FILE missing" + exit 1 + fi + done + + actualsize=$(wc -c <*/app_client_A/best_FL_global_model.pt) + if [ $actualsize -le 1048576 ]; then + echo "Checkpoint file size OK" + else + echo "Checkpoint too large: " $actualsize + exit 1 + fi + + cd "$CWD" +} + + +kill_server_and_clients () { + echo "[Cleanup] Kill server and client Docker containers ..." + docker kill odelia_swarm_server_flserver_$CONTAINER_VERSION_SUFFIX odelia_swarm_client_client_A_$CONTAINER_VERSION_SUFFIX odelia_swarm_client_client_B_$CONTAINER_VERSION_SUFFIX +} + + +cleanup_temporary_data () { + echo "[Cleanup] Removing synthetic data, scratch directory, dummy workspace ..." + rm -rf "$SYNTHETIC_DATA_DIR" + rm -rf "$SCRATCH_DIR" + rm -rf "$PROJECT_DIR" +} + + +case "$1" in + check_files_on_github) + check_files_on_github + ;; + + run_unit_tests_controller) + run_unit_tests_controller + cleanup_temporary_data + ;; + + run_dummy_training_standalone) + run_dummy_training_standalone + cleanup_temporary_data + ;; + + run_dummy_training_simulation_mode) + run_dummy_training_simulation_mode + cleanup_temporary_data + ;; + + run_dummy_training_poc_mode) + run_dummy_training_poc_mode + cleanup_temporary_data + ;; + + run_3dcnn_simulation_mode) + create_synthetic_data + run_3dcnn_simulation_mode + cleanup_temporary_data + ;; + + create_startup_kits) + create_startup_kits_and_check_contained_files + cleanup_temporary_data + ;; + + run_list_licenses) + create_startup_kits_and_check_contained_files + run_list_licenses + cleanup_temporary_data + ;; + + run_docker_gpu_preflight_check) + create_startup_kits_and_check_contained_files + run_docker_gpu_preflight_check + cleanup_temporary_data + ;; + + run_data_access_preflight_check) + create_startup_kits_and_check_contained_files + create_synthetic_data + run_data_access_preflight_check + cleanup_temporary_data + ;; + + push_pull_image) + create_startup_kits_and_check_contained_files + start_registry_docker_and_push + run_container_with_pulling + kill_registry_docker + # TODO add to CI if we want this (takes several minutes) + ;; + + check_wrong_startup_kit) + create_startup_kits_and_check_contained_files + create_synthetic_data + verify_wrong_client_does_not_connect + cleanup_temporary_data + # TODO add to CI if we want this + ;; + + run_dummy_training_in_swarm) + create_startup_kits_and_check_contained_files + create_synthetic_data + start_testing_vpn + start_server_and_clients + run_dummy_training_in_swarm + kill_server_and_clients + kill_testing_vpn + cleanup_temporary_data + # TODO add to CI if we want this (currently not working) + ;; + + all | "") + check_files_on_github + run_unit_tests_controller + run_dummy_training_standalone + run_dummy_training_simulation_mode + run_dummy_training_poc_mode + # run_nvflare_unit_tests # uncomment to enable NVFlare unit tests + create_synthetic_data + run_3dcnn_simulation_mode + create_startup_kits_and_check_contained_files + start_registry_docker_and_push + run_container_with_pulling + kill_registry_docker + run_docker_gpu_preflight_check + run_data_access_preflight_check + start_testing_vpn + start_server_and_clients + kill_testing_vpn + verify_wrong_client_does_not_connect + run_dummy_training_in_swarm + kill_server_and_clients + cleanup_temporary_data + ;; + + *) echo "Unknown argument: $1"; exit 1 ;; +esac diff --git a/runTestsInDocker.sh b/runTestsInDocker.sh deleted file mode 100755 index 4e5a5f5e..00000000 --- a/runTestsInDocker.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash - -VERSION=`./getVersionNumber.sh` -DOCKER_IMAGE=jefftud/odelia:$VERSION - -docker run -it --rm \ - --shm-size=16g \ - --ipc=host \ - --ulimit memlock=-1 \ - --ulimit stack=67108864 \ - -v /tmp:/scratch \ - --gpus=all \ - --entrypoint=/MediSwarm/_runTestsInsideDocker.sh \ - $DOCKER_IMAGE - -./_buildStartupKits.sh tests/provision/dummy_project_for_testing.yml $VERSION - -PROJECT_DIR=workspace/odelia_${VERSION}_dummy_project_for_testing -cd $PROJECT_DIR/prod_00/client_A/startup/ -./docker.sh --data_dir /tmp/ --scratch_dir /tmp/scratch --GPU all --no_pull --dummy_training -cd ../../../../../ -rm -rf $PROJECT_DIR diff --git a/scripts/ci/update_apt_versions.sh b/scripts/ci/update_apt_versions.sh index c3041f83..0d1d00ec 100755 --- a/scripts/ci/update_apt_versions.sh +++ b/scripts/ci/update_apt_versions.sh @@ -1,10 +1,10 @@ #!/usr/bin/env bash - set -e DOCKERFILE_PATH="docker_config/Dockerfile_ODELIA" LOG_PATH=$(mktemp) PROJECT_YML="tests/provision/dummy_project_for_testing.yml" +VPN_TEST_CREDENTIALS="tests/local_vpn/client_configs/" echo "[INFO] Removing APT version pins from Dockerfile..." scripts/dev_utils/dockerfile_update_removeVersionApt.py "$DOCKERFILE_PATH" @@ -15,34 +15,25 @@ git config user.name "GitHub CI" git commit "$DOCKERFILE_PATH" -m "WIP: remove apt versions for rebuild" || echo "[INFO] No version pin removal change to commit." echo "[INFO] Rebuilding Docker image and capturing logs..." -if ! ./buildDockerImageAndStartupKits.sh -p "$PROJECT_YML" 2>&1 | tee "$LOG_PATH"; then - echo "[WARNING] Docker build failed. Proceeding to clean invalid versions..." +if ! ./buildDockerImageAndStartupKits.sh -p "$PROJECT_YML" -c "$VPN_TEST_CREDENTIALS" > "$LOG_PATH" 2>&1; then + echo "Build failed. Output:" + cat "$LOG_PATH" + exit 1 fi +echo "[DEBUG] First 20 lines of build log:" +head -n 20 "$LOG_PATH" + +echo "[DEBUG] Checking for apt install commands:" +grep "apt install" "$LOG_PATH" || echo "[WARN] No apt install command found in log!" + echo "[INFO] Re-adding updated APT version pins to Dockerfile..." scripts/dev_utils/dockerfile_update_addAptVersionNumbers.py "$DOCKERFILE_PATH" "$LOG_PATH" rm "$LOG_PATH" -echo "[INFO] Validating all pinned versions, removing invalid ones..." -has_invalid_versions=0 -while IFS= read -r match; do - pkg="$(echo "$match" | cut -d= -f1)" - ver="$(echo "$match" | cut -d= -f2)" - echo -n "Checking $pkg=$ver... " - if ! apt-cache madison "$pkg" | grep -q "$ver"; then - echo "NOT FOUND – removing pin" - sed -i "s|\b$pkg=$ver\b|$pkg|" "$DOCKERFILE_PATH" - has_invalid_versions=1 - else - echo "OK" - fi -done < <(grep -oP '\b[a-z0-9\.\-]+=[a-zA-Z0-9:~.+-]+\b' "$DOCKERFILE_PATH") - -if git diff --quiet; then - echo "[INFO] No changes to apt versions found. Skipping commit." +git fetch origin main +if git diff --quiet origin/main..HEAD; then echo "NO_CHANGES=true" >> "$GITHUB_ENV" else - echo "[INFO] Committing updated apt versions..." - git commit "$DOCKERFILE_PATH" -m "chore: update apt versions based on rebuild" echo "NO_CHANGES=false" >> "$GITHUB_ENV" fi diff --git a/scripts/dev_utils/dockerfile_update_addAptVersionNumbers.py b/scripts/dev_utils/dockerfile_update_addAptVersionNumbers.py index cca37ddd..cd9c94c7 100755 --- a/scripts/dev_utils/dockerfile_update_addAptVersionNumbers.py +++ b/scripts/dev_utils/dockerfile_update_addAptVersionNumbers.py @@ -3,16 +3,12 @@ import re import sys -def load_file(filename: str) -> str: - with open(filename, 'r') as infile: - return infile.read() +from dockerfile_update_removeVersionApt import LINE_BREAK_IN_COMMAND, LINE_BREAK_REPLACEMENT, load_file, save_file -def save_file(contents: str, filename: str) -> None: - with open(filename, 'w') as outfile: - outfile.write(contents) +APT_INSTALL_COMMAND = 'RUN apt install -y' +APT_INSTALL_REPLACEMENT = 'ΡΥΝ απτ ινσταλλ -υ' - -def parse_apt_versions(installlog: str) -> str: +def parse_apt_versions(installlog: str) -> dict: versions = {} for line in installlog.splitlines(): if re.match('.*Get:[0-9]* http.*', line): @@ -27,10 +23,11 @@ def parse_apt_versions(installlog: str) -> str: def add_apt_versions(dockerfile: str, versions: dict) -> str: - dockerfile = dockerfile.replace('RUN apt install', 'RUN_apt_install') + dockerfile = dockerfile.replace(LINE_BREAK_IN_COMMAND, LINE_BREAK_REPLACEMENT) + dockerfile = dockerfile.replace(APT_INSTALL_COMMAND, APT_INSTALL_REPLACEMENT) outlines = [] for line in dockerfile.splitlines(): - if line.startswith('RUN_apt_install'): + if line.startswith(APT_INSTALL_REPLACEMENT): outline = '' + line for package, version in versions.items(): outline = outline.replace(f' {package} ', f' {package}={version} ') @@ -39,7 +36,8 @@ def add_apt_versions(dockerfile: str, versions: dict) -> str: else: outlines.append(line) dockerfile = '\n'.join(outlines) + '\n' - dockerfile = dockerfile.replace('RUN_apt_install', 'RUN apt install') + dockerfile = dockerfile.replace(APT_INSTALL_REPLACEMENT, APT_INSTALL_COMMAND) + dockerfile = dockerfile.replace(LINE_BREAK_REPLACEMENT, LINE_BREAK_IN_COMMAND) return dockerfile @@ -52,6 +50,9 @@ def report_non_fixed_versions(dockerfile: str, versions: dict) -> None: if __name__ == '__main__': dockerfile = load_file(sys.argv[1]) installlog = load_file(sys.argv[2]) + if LINE_BREAK_REPLACEMENT in dockerfile or APT_INSTALL_REPLACEMENT in dockerfile: + raise Exception('Line break replacement {LINE_BREAK_REPLACEMENT} or apt command replacement {APT_INSTALL_REPLACEMENT} in Dockerfile, cannot process it.') + versions = parse_apt_versions(installlog) report_non_fixed_versions(dockerfile, versions) dockerfile = add_apt_versions(dockerfile, versions) diff --git a/scripts/dev_utils/dockerfile_update_removeVersionApt.py b/scripts/dev_utils/dockerfile_update_removeVersionApt.py index 15055b7f..7f87aa21 100755 --- a/scripts/dev_utils/dockerfile_update_removeVersionApt.py +++ b/scripts/dev_utils/dockerfile_update_removeVersionApt.py @@ -3,6 +3,9 @@ import re import sys +LINE_BREAK_IN_COMMAND = ' \\\n ' +LINE_BREAK_REPLACEMENT = ' λινε βρεακ ρεπλαζεμεντ ' + def load_file(filename: str) -> str: with open(filename, 'r') as infile: return infile.read() @@ -12,18 +15,22 @@ def save_file(contents: str, filename: str) -> None: outfile.write(contents) -def remove_apt_versions(dockerfile: str) -> str: +def remove_apt_versions(contents: str) -> str: + contents = contents.replace(LINE_BREAK_IN_COMMAND, LINE_BREAK_REPLACEMENT) output = [] - for line in dockerfile.splitlines(): - if line.startswith('RUN apt install'): + for line in contents.splitlines(): + if line.startswith('RUN apt install -y'): out_line = re.sub('=[^ ]*', '', line) output.append(out_line) else: output.append(line) - return '\n'.join(output) - + output = '\n'.join(output) + '\n' + output = output.replace(LINE_BREAK_REPLACEMENT, LINE_BREAK_IN_COMMAND) + return output if __name__ == '__main__': - dockerfile = load_file(sys.argv[1]) - dockerfile = remove_apt_versions(dockerfile) - save_file(dockerfile, sys.argv[1]) + contents = load_file(sys.argv[1]) + if LINE_BREAK_REPLACEMENT in contents: + raise Exception('Line break replacement {LINE_BREAK_REPLACEMENT} in Dockerfile, cannot process it.') + contents = remove_apt_versions(contents) + save_file(contents, sys.argv[1]) diff --git a/scripts/dev_utils/remove_old_odelia_docker_images.sh b/scripts/dev_utils/remove_old_odelia_docker_images.sh new file mode 100755 index 00000000..5f25f6d3 --- /dev/null +++ b/scripts/dev_utils/remove_old_odelia_docker_images.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +export OLD_ODELIA_DOCKER_IMAGES=$(docker image list | grep jefftud/odelia | sed 's|jefftud/odelia *[0-9a-z.-]* *||' | sed 's| *.*||' | tail -n +2) +export OLD_ODELIA_DOCKER_IMAGES_LOCAL=$(docker image list | grep localhost:5000/odelia | sed 's|localhost:5000/odelia *[0-9a-z.-]* *||' | sed 's| *.*||' | tail -n +2) + +echo "All docker images:" + +docker image list + +echo "The following Docker images are old ODELIA docker images:" + +echo "$OLD_ODELIA_DOCKER_IMAGES" "$OLD_ODELIA_DOCKER_IMAGES_LOCAL" + +read -p "Delete these Docker images, unless they have additional tags? (y/n): " answer + +if [[ "$answer" == "y" ]]; then + for image in $OLD_ODELIA_DOCKER_IMAGES $OLD_ODELIA_DOCKER_IMAGES_LOCAL; do + docker rmi $image + done +fi diff --git a/scripts/pr_validation.py b/scripts/pr_validation.py new file mode 100644 index 00000000..79a56346 --- /dev/null +++ b/scripts/pr_validation.py @@ -0,0 +1,51 @@ +# scripts/pr_validation.py + +import os +import subprocess +from pathlib import Path +import logging + +logging.basicConfig(level=logging.INFO) +print("Script is running") + +import os + +print("PWD:", os.getcwd()) +print("Files in current dir:", os.listdir()) + + +def get_latest_workspace(): + root = Path.cwd() + candidates = list(root.rglob("odelia_0.9-dev.*_MEVIS_test")) + if not candidates: + raise RuntimeError("No workspace found matching pattern 'odelia_0.9-dev.*_MEVIS_test'") + return sorted(candidates, reverse=True)[0] + + +def run_command(cmd, cwd=None): + print(f"\n>>> Running: {' '.join(cmd)} in {cwd}") + subprocess.run(cmd, cwd=cwd, check=True) + + +def main(): + site = os.environ.get("SITE_NAME", "UKA") + datadir = os.environ["DATADIR"] + scratchdir = os.environ["SCRATCHDIR"] + + workspace_version = get_latest_workspace() + startup_dir = workspace_version / "prod_00" / site / "startup" + + print(f"Using workspace: {workspace_version}") + print(f"Startup directory: {startup_dir}") + + # Run dummy training + run_command(["./docker.sh", "--scratch_dir", scratchdir, "--GPU", "device=0", "--dummy_training"], cwd=startup_dir) + + # Run preflight check + run_command( + ["./docker.sh", "--data_dir", datadir, "--scratch_dir", scratchdir, "--GPU", "device=0", "--preflight_check"], + cwd=startup_dir) + + +if __name__ == "__main__": + main() diff --git a/tests/integration_tests/_run_3dcnn_simulation_mode.sh b/tests/integration_tests/_run_3dcnn_simulation_mode.sh new file mode 100755 index 00000000..a39da49d --- /dev/null +++ b/tests/integration_tests/_run_3dcnn_simulation_mode.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -e + +run_3dcnn_simulation_mode () { + # both clients use the same data according to SITE_NAME, there are no separate env variables from which the code could read which client it is + # change training configuration to run 2 rounds + cd /MediSwarm + export TMPDIR=$(mktemp -d) + cp -R application/jobs/ODELIA_ternary_classification ${TMPDIR}/ODELIA_ternary_classification + sed -i 's/num_rounds = .*/num_rounds = 2/' ${TMPDIR}/ODELIA_ternary_classification/app/config/config_fed_server.conf + export TRAINING_MODE="swarm" + export SITE_NAME="client_A" + export DATA_DIR=/data + export SCRATCH_DIR=/scratch + export TORCH_HOME=/torch_home + export MODEL_NAME=MST + export CONFIG=unilateral + nvflare simulator -w /tmp/ODELIA_ternary_classification -n 2 -t 2 ${TMPDIR}/ODELIA_ternary_classification -c client_A,client_B + rm -rf ${TMPDIR} +} + +run_3dcnn_simulation_mode diff --git a/tests/integration_tests/_run_controller_unit_tests_with_coverage.sh b/tests/integration_tests/_run_controller_unit_tests_with_coverage.sh new file mode 100755 index 00000000..46e6e11c --- /dev/null +++ b/tests/integration_tests/_run_controller_unit_tests_with_coverage.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +set -e + +run_controller_unit_tests_with_coverage () { + # run unit tests of ODELIA swarm learning and report coverage + export MPLCONFIGDIR=/tmp + export COVERAGE_FILE=/tmp/.MediSwarm_coverage + cd /MediSwarm/tests/unit_tests/controller + PYTHONPATH=/MediSwarm/controller/controller python3 -m coverage run --source=/MediSwarm/controller/controller -m unittest discover + coverage report -m + rm "$COVERAGE_FILE" +} + +run_controller_unit_tests_with_coverage diff --git a/tests/integration_tests/_run_minimal_example_proof_of_concept_mode.sh b/tests/integration_tests/_run_minimal_example_proof_of_concept_mode.sh new file mode 100755 index 00000000..ee26a4d0 --- /dev/null +++ b/tests/integration_tests/_run_minimal_example_proof_of_concept_mode.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash + +set -e + +run_minimal_example_proof_of_concept_mode () { + # run proof-of-concept mode for minimal example + mkdir -p ~/.nvflare + cd /MediSwarm + export TRAINING_MODE="swarm" + nvflare poc prepare -c poc_client_0 poc_client_1 + nvflare poc prepare-jobs-dir -j application/jobs/ + nvflare poc start -ex admin@nvidia.com + sleep 15 + echo "Will submit job now after sleeping 15 seconds to allow the background process to complete" + nvflare job submit -j application/jobs/minimal_training_pytorch_cnn + sleep 60 + echo "Will shut down now after sleeping 60 seconds to allow the background process to complete" + sleep 2 + nvflare poc stop +} + +run_minimal_example_proof_of_concept_mode diff --git a/tests/integration_tests/_run_minimal_example_simulation_mode.sh b/tests/integration_tests/_run_minimal_example_simulation_mode.sh new file mode 100755 index 00000000..e1fd931f --- /dev/null +++ b/tests/integration_tests/_run_minimal_example_simulation_mode.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -e + +run_minimal_example_simulation_mode () { + # run simulation mode for minimal example + cd /MediSwarm + export TRAINING_MODE="swarm" + nvflare simulator -w /tmp/minimal_training_pytorch_cnn -n 2 -t 2 application/jobs/minimal_training_pytorch_cnn -c simulated_node_0,simulated_node_1 +} + +run_minimal_example_simulation_mode diff --git a/tests/integration_tests/_run_minimal_example_standalone.sh b/tests/integration_tests/_run_minimal_example_standalone.sh new file mode 100755 index 00000000..f0106342 --- /dev/null +++ b/tests/integration_tests/_run_minimal_example_standalone.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -e + +run_minimal_example_standalone () { + # run standalone version of minimal example + cd /MediSwarm/application/jobs/minimal_training_pytorch_cnn/app/custom/ + export TRAINING_MODE="local_training" + ./main.py +} + +run_minimal_example_standalone diff --git a/tests/integration_tests/_submitDummyTraining.exp b/tests/integration_tests/_submitDummyTraining.exp new file mode 100755 index 00000000..7d69997c --- /dev/null +++ b/tests/integration_tests/_submitDummyTraining.exp @@ -0,0 +1,15 @@ +#!/usr/bin/env expect + +spawn ./docker.sh --no_pull +expect "User Name: " +send "admin@test.odelia\r" +expect "> " +send "submit_job MediSwarm/application/jobs/minimal_training_pytorch_cnn\r" +expect "> " +send "sys_info client\r" +expect "> " +send "sys_info server\r" +expect "> " +send "list_jobs\r" +expect "> " +send "list_jobs\r" diff --git a/tests/integration_tests/outdated_startup_kit.tar.gz b/tests/integration_tests/outdated_startup_kit.tar.gz new file mode 100644 index 00000000..ba3a984e Binary files /dev/null and b/tests/integration_tests/outdated_startup_kit.tar.gz differ diff --git a/tests/local_vpn/Dockerfile_openvpnserver b/tests/local_vpn/Dockerfile_openvpnserver new file mode 100644 index 00000000..8270f8fa --- /dev/null +++ b/tests/local_vpn/Dockerfile_openvpnserver @@ -0,0 +1,11 @@ +FROM ubuntu:22.04 + +RUN apt update +RUN apt install -y easy-rsa openvpn openssl ufw joe patch +RUN apt install -y openssh-server net-tools + +RUN useradd ca_user + +COPY _openvpn_certificate_creation.sh / +COPY _openvpn_start.sh / +RUN chmod u+x /*.sh diff --git a/tests/local_vpn/README.txt b/tests/local_vpn/README.txt new file mode 100644 index 00000000..5cc3e826 --- /dev/null +++ b/tests/local_vpn/README.txt @@ -0,0 +1,17 @@ +# Following https://www.digitalocean.com/community/tutorials/how-to-set-up-and-configure-an-openvpn-server-on-ubuntu-20-04 +# but on 22.04 + +Setup +----- +./create_openvpn_certificates.sh builds a docker image and creates certificates and .ovpn config files for the clients specified in _openvpn_certificate_creation.sh +Modify server_config/server.conf and client_configs/client.conf to modify network configuration. +Files to use on the server and client are created in server_config/ and client_configs/ + +Usage +----- +./openvpn_start.sh builds a docker image and starts OpenVPN server in the docker container. +Modify _openvpn_start.sh for further firewall etc. configuration. + +Disclaimer +---------- +This configuration is not necessarily secure and should not be re-used unless you know what you are doing. diff --git a/tests/local_vpn/_build_docker.sh b/tests/local_vpn/_build_docker.sh new file mode 100755 index 00000000..0df1ce0f --- /dev/null +++ b/tests/local_vpn/_build_docker.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +# TODO should this be named "latest"? Do we need to pin versions? +# TODO think about splitting building certificates from running the VPN container + +docker build -t odelia_testing_openvpnserver:latest . -f Dockerfile_openvpnserver diff --git a/tests/local_vpn/_openvpn_certificate_creation.sh b/tests/local_vpn/_openvpn_certificate_creation.sh new file mode 100644 index 00000000..a815f001 --- /dev/null +++ b/tests/local_vpn/_openvpn_certificate_creation.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# Roughly following https://www.digitalocean.com/community/tutorials/how-to-set-up-and-configure-an-openvpn-server-on-ubuntu-20-04 +# but on 22.04 + +chown ca_user:ca_user /home/ca_user/ -R +chmod a+rwX /home/ca_user/ -R +/bin/su - -c '/home/ca_user/ca_setup.sh' ca_user + +mkdir ~/easy-rsa +ln -s /usr/share/easy-rsa/* ~/easy-rsa/ +cd ~/easy-rsa + +echo 'set_var EASYRSA_ALGO "ec"' > vars +echo 'set_var EASYRSA_DIGEST "sha512"' >> vars + +./easyrsa init-pki + +rm /server_config/ca.crt \ + /server_config/server.crt \ + /server_config/server.key \ + /server_config/ta.key -f + +rm -rf /client_configs/keys +mkdir -p /client_configs/keys/ + +export EASYRSA_BATCH=1 +./easyrsa gen-req server nopass + +cp ~/easy-rsa/pki/reqs/server.req /tmp/ +chmod a+r /tmp/server.req +/bin/su - -c "export EASYRSA_BATCH=1 && cd ~/easy-rsa/ && ./easyrsa import-req /tmp/server.req server && ./easyrsa sign-req server server" ca_user + +cd ~/easy-rsa +openvpn --genkey secret ta.key +cp ta.key /client_configs/keys/ +cp /home/ca_user/easy-rsa/pki/ca.crt /client_configs/keys/ + +# copy/create files to where they are needed +cp /home/ca_user/easy-rsa/pki/ca.crt /server_config/ +cp /home/ca_user/easy-rsa/pki/issued/server.crt /server_config/ +cp ~/easy-rsa/pki/private/server.key /server_config/ +cp ~/easy-rsa/ta.key /server_config/ + +mkdir /server_config/ccd + +i=4 +for client in testserver.local admin@test.odelia client_A client_B; do + cd ~/easy-rsa + EASYRSA_BATCH=1 EASYRSA_REQ_CN=$client ./easyrsa gen-req $client nopass + cp pki/private/$client.key /client_configs/keys/ + + cp ~/easy-rsa/pki/reqs/$client.req /tmp/ + chmod a+r /tmp/$client.req + /bin/su - -c "export EASYRSA_BATCH=1 && cd ~/easy-rsa/ && ./easyrsa import-req /tmp/$client.req $client && ./easyrsa sign-req client $client" ca_user + cp /home/ca_user/easy-rsa/pki/issued/$client.crt /client_configs/keys/ + + cd /client_configs + ./make_ovpn.sh $client + + echo "ifconfig-push 10.8.0."$i" 255.0.0.0" > /server_config/ccd/$client + i=$((i+1)) +done + +chmod a+rwX /client_configs -R +chmod a+rwX /server_config -R +chmod a+rwX /home/ca_user -R diff --git a/tests/local_vpn/_openvpn_start.sh b/tests/local_vpn/_openvpn_start.sh new file mode 100644 index 00000000..62d1a864 --- /dev/null +++ b/tests/local_vpn/_openvpn_start.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +echo "net.ipv4.ip_forward = 1" >> /etc/sysctl.conf +sysctl -p + +echo "MTBhMTEsMTkKPiAjIFNUQVJUIE9QRU5WUE4gUlVMRVMKPiAjIE5BVCB0YWJsZSBydWxlcwo+ICpuYXQKPiA6UE9TVFJPVVRJTkcgQUNDRVBUIFswOjBdCj4gIyBBbGxvdyB0cmFmZmljIGZyb20gT3BlblZQTiBjbGllbnQgdG8gZXRoMCAoY2hhbmdlIHRvIHRoZSBpbnRlcmZhY2UgeW91IGRpc2NvdmVyZWQhKQo+IC1BIFBPU1RST1VUSU5HIC1zIDEwLjguMC4wLzggLW8gZXRoMCAtaiBNQVNRVUVSQURFCj4gQ09NTUlUCj4gIyBFTkQgT1BFTlZQTiBSVUxFUwo+IAo=" | base64 -d > before.rules.patch +patch /etc/ufw/before.rules before.rules.patch +rm before.rules.patch + +echo "MTljMTkKPCBERUZBVUxUX0ZPUldBUkRfUE9MSUNZPSJEUk9QIgotLS0KPiBERUZBVUxUX0ZPUldBUkRfUE9MSUNZPSJBQ0NFUFQiCg==" | base64 -d > ufw.patch +patch /etc/default/ufw ufw.patch +rm ufw.patch + +ufw allow 9194/udp +ufw allow OpenSSH +ufw disable +ufw enable + +cp /server_config/ca.crt /etc/openvpn/server/ +cp /server_config/server.conf /etc/openvpn/server/ +cp /server_config/server.crt /etc/openvpn/server/ +cp /server_config/server.key /etc/openvpn/server/ +cp /server_config/ta.key /etc/openvpn/server/ +cp /server_config/ccd /etc/openvpn/ccd -r + +# write log to folder on host +cd server_config + +nohup openvpn --duplicate-cn --client-to-client --config /etc/openvpn/server/server.conf & +sleep 2 +chmod a+r /server_config/nohup.out diff --git a/tests/local_vpn/client_configs/.gitignore b/tests/local_vpn/client_configs/.gitignore new file mode 100644 index 00000000..38156aad --- /dev/null +++ b/tests/local_vpn/client_configs/.gitignore @@ -0,0 +1 @@ +keys \ No newline at end of file diff --git a/tests/local_vpn/client_configs/admin@test.odelia_client.ovpn b/tests/local_vpn/client_configs/admin@test.odelia_client.ovpn new file mode 100644 index 00000000..8b9a87ee --- /dev/null +++ b/tests/local_vpn/client_configs/admin@test.odelia_client.ovpn @@ -0,0 +1,299 @@ +############################################## +# Sample client-side OpenVPN 2.0 config file # +# for connecting to multi-client server. # +# # +# This configuration can be used by multiple # +# clients, however each client should have # +# its own cert and key files. # +# # +# On Windows, you might want to rename this # +# file so it has a .ovpn extension # +############################################## + +# Specify that we are a client and that we +# will be pulling certain config file directives +# from the server. +client + +# Use the same setting as you are using on +# the server. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel +# if you have more than one. On XP SP2, +# you may need to disable the firewall +# for the TAP adapter. +;dev-node MyTap + +# Are we connecting to a TCP or +# UDP server? Use the same setting as +# on the server. +;proto tcp +proto udp + +# The hostname/IP and port of the server. +# You can have multiple remote entries +# to load balance between the servers. +remote 172.17.0.1 9194 + +# Choose a random host from the remote +# list for load-balancing. Otherwise +# try hosts in the order specified. +;remote-random + +# Keep trying indefinitely to resolve the +# host name of the OpenVPN server. Very useful +# on machines which are not permanently connected +# to the internet such as laptops. +resolv-retry infinite + +# Most clients don't need to bind to +# a specific local port number. +nobind + +# Downgrade privileges after initialization (non-Windows only) +user nobody +group nogroup + +# Try to preserve some state across restarts. +persist-key +persist-tun + +# If you are connecting through an +# HTTP proxy to reach the actual OpenVPN +# server, put the proxy server/IP and +# port number here. See the man page +# if your proxy server requires +# authentication. +;http-proxy-retry # retry on connection failures +;http-proxy [proxy server] [proxy port #] + +# Wireless networks often produce a lot +# of duplicate packets. Set this flag +# to silence duplicate packet warnings. +;mute-replay-warnings + +# SSL/TLS parms. +# See the server config file for more +# description. It's best to use +# a separate .crt/.key file pair +# for each client. A single ca +# file can be used for all clients. + +# Verify server certificate by checking that the +# certificate has the correct key usage set. +# This is an important precaution to protect against +# a potential attack discussed here: +# http://openvpn.net/howto.html#mitm +# +# To use this feature, you will need to generate +# your server certificates with the keyUsage set to +# digitalSignature, keyEncipherment +# and the extendedKeyUsage to +# serverAuth +# EasyRSA can do this for you. +remote-cert-tls server + +# If a tls-auth key is used on the server +# then every client must also have the key. +;tls-auth ta.key 1 + +# Select a cryptographic cipher. +# If the cipher option is used on the server +# then you must also specify it here. +# Note that v2.4 client/server will automatically +# negotiate AES-256-GCM in TLS mode. +# See also the data-ciphers option in the manpage +;cipher AES-256-CBC +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# Don't enable this unless it is also +# enabled in the server config file. +#comp-lzo + +# Set log file verbosity. +verb 3 + +# Silence repeating messages +;mute 20 + +key-direction 1 + +; script-security 2 +; up /etc/openvpn/update-resolv-conf +; down /etc/openvpn/update-resolv-conf + +; script-security 2 +; up /etc/openvpn/update-systemd-resolved +; down /etc/openvpn/update-systemd-resolved +; down-pre +; dhcp-option DOMAIN-ROUTE . + +-----BEGIN CERTIFICATE----- +MIIDQjCCAiqgAwIBAgIUBwqUYD1oxBKeImaMZfm44TsTAF0wDQYJKoZIhvcNAQEL +BQAwEzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQyWhcNMzUwOTIx +MTI0NjQyWjATMREwDwYDVQQDDAhDaGFuZ2VNZTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAKGt+8oRY7cWPg1SahfIV3XAeeH1SQEFq4f2q+E9ZbWVnCg9 +b59hMzwYr84/j4V73Hlv2udLrkguvnT9KqqJY/0wo3Bd1swH2WLej1fo0+rVo24w +hzeLfeH1e4erZbzQk8XG68U7yNDHKYo+LIz9syBzZA4Bq12bHxDsZbJF7HUANzFR +j9Xg3dR7utPtG8ktmD83rV9/E97whblMpLmjmf2sbCqdLOKTkZnwp5mI47TTkhMj +9K0q7irHmbtZcPZQH5Z59GtqaCaRt8DKfeYniyoPnGVfzFberHHQ4C11pcRrdvgY +n14/W5myh6HESQD6umyCYooyXG7wfqIKujROQCMCAwEAAaOBjTCBijAdBgNVHQ4E +FgQUtMsHbl94qRV7OW5UNNjk2mJ+/U8wTgYDVR0jBEcwRYAUtMsHbl94qRV7OW5U +NNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1lghQHCpRgPWjEEp4iZoxl ++bjhOxMAXTAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIBBjANBgkqhkiG9w0BAQsF +AAOCAQEAGeryP/2JuOp7tzi7Ww9lFUx2DRcgq/FwnU4biotUfuLejHQt/IeIwRYs +dW6AToUYJak8Uy/AFffMootwLcC8z8FATBnxtokWNpxtscpbTSHbeS0HvXnXFaU8 +xxlzp9l5k+46MrrvdzFsjoRfVxs0FUHzWifBnObBziTLfHt+J71509uqRWX6JuTa +PDAT8CMcLKxxS4BcorWtAmc51lW/dQQ41HDJ8a6acltDAprmlnhd8ksWzpTjUDNR +/cfSMcVTpPxPSW/WchR5NlJKQEAf9B/xC+LQgDRSDLaZ8CvzRDgosllzJ+aIS7GK +GPec69LiKqpirZ7enwDM67R4DwIHKA== +-----END CERTIFICATE----- + + +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 45:fc:0b:2c:a3:b7:9c:b6:f1:56:fd:47:cb:b2:12:12 + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=ChangeMe + Validity + Not Before: Sep 23 12:46:43 2025 GMT + Not After : Dec 27 12:46:43 2027 GMT + Subject: CN=admin@test.odelia + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:ca:33:a9:8e:be:5d:00:9e:ff:72:43:e9:e4:8b: + 8f:09:6e:56:38:7e:f8:57:e1:5f:e7:df:af:e1:22: + 69:1e:7a:9c:a3:43:84:8f:f8:cc:61:4e:61:dc:3a: + 56:02:77:13:65:09:4e:25:02:94:a9:94:3f:76:4f: + b8:6c:98:36:0c:52:cc:22:e7:16:97:2b:c2:c1:7c: + 14:db:f8:45:7a:b7:c8:b0:5c:a9:a1:d8:0c:ca:b0: + 4f:b3:a6:f3:05:f2:e7:43:ac:90:2c:32:4b:ae:b8: + d8:67:c0:0f:46:e2:e1:a7:d9:a4:cd:c7:5b:29:4e: + c4:38:aa:6b:43:c5:31:8e:a4:be:68:73:82:72:ca: + a4:df:81:80:c7:13:df:b7:e1:53:07:04:c0:d6:78: + 66:22:9a:fe:ba:95:0e:e5:cc:93:47:1f:f1:e9:86: + 77:3d:c4:54:cd:b8:c9:8a:2b:02:eb:84:0b:68:22: + 50:8f:16:7a:e5:d7:ec:3f:3f:25:f0:79:74:42:3a: + bb:2e:a3:dc:c0:d4:d3:05:8b:4e:01:a7:e8:ff:6d: + 94:1e:4d:de:f7:76:10:cc:62:66:d9:b4:1e:58:0c: + 52:de:46:1c:26:bc:71:ef:82:bb:25:f6:d7:14:19: + e6:3d:a1:e4:cc:0b:94:1f:c6:bb:37:81:4d:5c:76: + 6b:2b + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + X509v3 Subject Key Identifier: + E5:EC:27:21:94:02:06:AC:C6:C7:DB:B6:13:25:9C:C3:60:1E:47:FE + X509v3 Authority Key Identifier: + keyid:B4:CB:07:6E:5F:78:A9:15:7B:39:6E:54:34:D8:E4:DA:62:7E:FD:4F + DirName:/CN=ChangeMe + serial:07:0A:94:60:3D:68:C4:12:9E:22:66:8C:65:F9:B8:E1:3B:13:00:5D + X509v3 Extended Key Usage: + TLS Web Client Authentication + X509v3 Key Usage: + Digital Signature + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 73:cb:0e:63:bf:1d:f5:04:37:d3:cc:9c:c8:d2:21:60:f0:ae: + 23:08:38:0b:77:31:9b:6f:b3:89:5f:5c:69:86:e8:69:47:b8: + da:04:56:8b:a2:f4:25:2b:48:c6:4f:1d:a2:8a:b3:b8:7c:a8: + d2:9e:89:9a:20:71:69:fb:9f:4d:39:8d:cb:9c:f2:58:bc:58: + 19:10:cd:be:1f:bd:6e:e4:af:fd:c6:eb:2f:83:39:e7:4b:2c: + bf:23:e1:9d:9e:81:80:86:41:df:9f:fc:3b:d3:29:7f:dc:fb: + a6:45:5c:38:0b:80:de:27:ef:23:f8:53:80:48:69:37:c9:9b: + aa:24:cc:ff:54:80:77:2b:ab:51:c7:02:4d:e7:49:01:af:f4: + d3:d1:89:09:4a:96:99:44:e2:0d:13:b1:9d:4b:47:73:70:22: + fc:a7:4f:20:90:00:a3:5b:96:c9:59:e7:0e:e1:25:e0:00:3c: + 66:a8:32:62:f1:42:bc:84:32:32:46:b7:ac:b9:ed:e7:45:47: + 3b:26:b7:2b:f2:ce:04:e9:64:9c:52:5d:e4:08:11:32:ff:e0: + ff:a9:d8:e5:1a:e7:f0:cc:21:25:f8:04:40:6a:e3:ed:5f:fc: + b2:15:0a:b7:cf:85:db:82:29:e2:27:ed:e8:94:f4:c3:01:77: + 04:d0:bf:7d +-----BEGIN CERTIFICATE----- +MIIDWTCCAkGgAwIBAgIQRfwLLKO3nLbxVv1Hy7ISEjANBgkqhkiG9w0BAQsFADAT +MREwDwYDVQQDDAhDaGFuZ2VNZTAeFw0yNTA5MjMxMjQ2NDNaFw0yNzEyMjcxMjQ2 +NDNaMBwxGjAYBgNVBAMMEWFkbWluQHRlc3Qub2RlbGlhMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAyjOpjr5dAJ7/ckPp5IuPCW5WOH74V+Ff59+v4SJp +Hnqco0OEj/jMYU5h3DpWAncTZQlOJQKUqZQ/dk+4bJg2DFLMIucWlyvCwXwU2/hF +erfIsFypodgMyrBPs6bzBfLnQ6yQLDJLrrjYZ8APRuLhp9mkzcdbKU7EOKprQ8Ux +jqS+aHOCcsqk34GAxxPft+FTBwTA1nhmIpr+upUO5cyTRx/x6YZ3PcRUzbjJiisC +64QLaCJQjxZ65dfsPz8l8Hl0Qjq7LqPcwNTTBYtOAafo/22UHk3e93YQzGJm2bQe +WAxS3kYcJrxx74K7JfbXFBnmPaHkzAuUH8a7N4FNXHZrKwIDAQABo4GfMIGcMAkG +A1UdEwQCMAAwHQYDVR0OBBYEFOXsJyGUAgasxsfbthMlnMNgHkf+ME4GA1UdIwRH +MEWAFLTLB25feKkVezluVDTY5Npifv1PoRekFTATMREwDwYDVQQDDAhDaGFuZ2VN +ZYIUBwqUYD1oxBKeImaMZfm44TsTAF0wEwYDVR0lBAwwCgYIKwYBBQUHAwIwCwYD +VR0PBAQDAgeAMA0GCSqGSIb3DQEBCwUAA4IBAQBzyw5jvx31BDfTzJzI0iFg8K4j +CDgLdzGbb7OJX1xphuhpR7jaBFaLovQlK0jGTx2iirO4fKjSnomaIHFp+59NOY3L +nPJYvFgZEM2+H71u5K/9xusvgznnSyy/I+GdnoGAhkHfn/w70yl/3PumRVw4C4De +J+8j+FOASGk3yZuqJMz/VIB3K6tRxwJN50kBr/TT0YkJSpaZROINE7GdS0dzcCL8 +p08gkACjW5bJWecO4SXgADxmqDJi8UK8hDIyRresue3nRUc7Jrcr8s4E6WScUl3k +CBEy/+D/qdjlGufwzCEl+ARAauPtX/yyFQq3z4XbginiJ+3olPTDAXcE0L99 +-----END CERTIFICATE----- + + +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDKM6mOvl0Anv9y +Q+nki48JblY4fvhX4V/n36/hImkeepyjQ4SP+MxhTmHcOlYCdxNlCU4lApSplD92 +T7hsmDYMUswi5xaXK8LBfBTb+EV6t8iwXKmh2AzKsE+zpvMF8udDrJAsMkuuuNhn +wA9G4uGn2aTNx1spTsQ4qmtDxTGOpL5oc4JyyqTfgYDHE9+34VMHBMDWeGYimv66 +lQ7lzJNHH/Hphnc9xFTNuMmKKwLrhAtoIlCPFnrl1+w/PyXweXRCOrsuo9zA1NMF +i04Bp+j/bZQeTd73dhDMYmbZtB5YDFLeRhwmvHHvgrsl9tcUGeY9oeTMC5Qfxrs3 +gU1cdmsrAgMBAAECggEARJO+9vWjLzm9ozBbXDLk4Sr1vRV6/rsmPssDqJR2GGs3 +Qrv8cqPMFVhzIjD6yL0/6617PlwgtV7dIzAoeVQqSIWwTEoZxE2IMPz3Sh9q2RMJ +0p6qvYQ72mZvsZt2otbeDnqxLvwj5O82HgHqbH04USgkl9H8Hgdjz2AlHwf7Jcgo +pwD48OtE8YFMof3/SFzKUJDPdsCsjGlWnDDJMjXrIR6BZdE7xxEX7L2VRcmVBQXR +lhAeNwYJNQ1qWGuXaSPx2BNa9BuTd66PwypsyPKwI63CJ6WkUh1bWsAviTBzr5Oz +u27eym4uK6mfXS6Pgv4VcM0kNUjnPd8p/XSaGQCfUQKBgQDLLuV9bhUyHPRbJbHC +WcXxNMiwUOpoyQY+KTj+p8mgXc9tB1TvL+dNi+vhc1mgHNactFnmyzx1S8eM2Wn9 +Aw1fxw42APUTw4rJh+l3UsuTBMZwQ3s6CeNFgX+PNHqHK/47xyXdmgipMB4Uz0JI +EPEe1avbLTymDCmbfJ3mFqLMuQKBgQD+w3VZhhI5OJSEQUkJs/sURQdBoD05rX/B +afz4ZqfRfLJscI3oG1oV8ZwMkoeA1ou5ovtxr/4XlGPhsopDa7sUMWbhEkm5Rssw +gPVmGE9HnM8tSG/8So7fbIXHCGcKRDD5JKnjPyqpP78wTzJeOrCfhbFP+dpiRJrJ +mEhn3V2tAwKBgQDDXHYgIkaTBrAVK6s9eeAPSndkwIiC9DbicfRxNpdxcIHPDWun +B+JY9554Cdc1UkUwK2D9vpCFH7XhQfLc6aBkZRrO5iC/PhcmK15Z8uv2knLS4q+L +YJJ79EXYRddCPRSYGaXY6xBEzRU/YQEUFeYhhcVWWqqj5bHj5PBVmZIzUQKBgQC3 +AOa6ETHkAr3EpzT1EGFqtQ86WAXC+duMrzr1oKAqPl3YwZ1ePs+edbk32sYViYhD +KE1g5CAtBf4dsWfaeHehUL9rK/zjZ3Qr+mbNGOdSNNUp3R/8Zf5thgIu7908pbFc +NrcGs2hMvarz49/1ikk3vgyZu4vhDRD3gTl5yq0wywKBgQCddStrC6gtOOaZofSL +bU6Le9TXyDbBfGiVpDxaD1rxdWylN3jQY9JSznmq7RGTR2TUVlqeFoCunaLc4VJi +N+np08niR1T/Mm+8HqLzYRROmLIozETrPomdgj1Ewa83lmI4/JSiNZbkFs+Jh6J1 +sGSrPFifkIAVP/C6PbVqj1Nn7A== +-----END PRIVATE KEY----- + + +# +# 2048 bit OpenVPN static key +# +-----BEGIN OpenVPN Static key V1----- +488b61084812969fe8ad0f9dd40f56a2 +6cdadddfe345daef6b5c6d3c3e779fc5 +1f7d236966953482d2af085e3f8581b7 +d216f2d891972a463bbb22ca6c104b9d +f99dcb19d7d575a1d46e7918bb2556c6 +db9f51cd792c5e89e011586214692b95 +2a32a7fe85e4538c40e1d0aa2a9f8e15 +fcc0ce5d31974e3c2041b127776f7658 +878cb8245ed235ec996c2370c0fc0023 +699bc028b3412bc40209cba8233bc111 +fa1438095f99052d799fa718f3b04499 +472254d0286b4b2ce99db49e98a4cc25 +fd948bddcdcf08006a6d7bff40354e7b +5e93ea753a8ecc05de41ae34d280e7eb +99220e436bf8b7693a00667485631e28 +edba3e33b6f558dfa50b92eec6ac8b44 +-----END OpenVPN Static key V1----- + diff --git a/tests/local_vpn/client_configs/client.conf b/tests/local_vpn/client_configs/client.conf new file mode 100755 index 00000000..49669b71 --- /dev/null +++ b/tests/local_vpn/client_configs/client.conf @@ -0,0 +1,138 @@ +############################################## +# Sample client-side OpenVPN 2.0 config file # +# for connecting to multi-client server. # +# # +# This configuration can be used by multiple # +# clients, however each client should have # +# its own cert and key files. # +# # +# On Windows, you might want to rename this # +# file so it has a .ovpn extension # +############################################## + +# Specify that we are a client and that we +# will be pulling certain config file directives +# from the server. +client + +# Use the same setting as you are using on +# the server. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel +# if you have more than one. On XP SP2, +# you may need to disable the firewall +# for the TAP adapter. +;dev-node MyTap + +# Are we connecting to a TCP or +# UDP server? Use the same setting as +# on the server. +;proto tcp +proto udp + +# The hostname/IP and port of the server. +# You can have multiple remote entries +# to load balance between the servers. +remote 172.17.0.1 9194 + +# Choose a random host from the remote +# list for load-balancing. Otherwise +# try hosts in the order specified. +;remote-random + +# Keep trying indefinitely to resolve the +# host name of the OpenVPN server. Very useful +# on machines which are not permanently connected +# to the internet such as laptops. +resolv-retry infinite + +# Most clients don't need to bind to +# a specific local port number. +nobind + +# Downgrade privileges after initialization (non-Windows only) +user nobody +group nogroup + +# Try to preserve some state across restarts. +persist-key +persist-tun + +# If you are connecting through an +# HTTP proxy to reach the actual OpenVPN +# server, put the proxy server/IP and +# port number here. See the man page +# if your proxy server requires +# authentication. +;http-proxy-retry # retry on connection failures +;http-proxy [proxy server] [proxy port #] + +# Wireless networks often produce a lot +# of duplicate packets. Set this flag +# to silence duplicate packet warnings. +;mute-replay-warnings + +# SSL/TLS parms. +# See the server config file for more +# description. It's best to use +# a separate .crt/.key file pair +# for each client. A single ca +# file can be used for all clients. + +# Verify server certificate by checking that the +# certificate has the correct key usage set. +# This is an important precaution to protect against +# a potential attack discussed here: +# http://openvpn.net/howto.html#mitm +# +# To use this feature, you will need to generate +# your server certificates with the keyUsage set to +# digitalSignature, keyEncipherment +# and the extendedKeyUsage to +# serverAuth +# EasyRSA can do this for you. +remote-cert-tls server + +# If a tls-auth key is used on the server +# then every client must also have the key. +;tls-auth ta.key 1 + +# Select a cryptographic cipher. +# If the cipher option is used on the server +# then you must also specify it here. +# Note that v2.4 client/server will automatically +# negotiate AES-256-GCM in TLS mode. +# See also the data-ciphers option in the manpage +;cipher AES-256-CBC +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# Don't enable this unless it is also +# enabled in the server config file. +#comp-lzo + +# Set log file verbosity. +verb 3 + +# Silence repeating messages +;mute 20 + +key-direction 1 + +; script-security 2 +; up /etc/openvpn/update-resolv-conf +; down /etc/openvpn/update-resolv-conf + +; script-security 2 +; up /etc/openvpn/update-systemd-resolved +; down /etc/openvpn/update-systemd-resolved +; down-pre +; dhcp-option DOMAIN-ROUTE . diff --git a/tests/local_vpn/client_configs/client_A_client.ovpn b/tests/local_vpn/client_configs/client_A_client.ovpn new file mode 100644 index 00000000..1506b75d --- /dev/null +++ b/tests/local_vpn/client_configs/client_A_client.ovpn @@ -0,0 +1,299 @@ +############################################## +# Sample client-side OpenVPN 2.0 config file # +# for connecting to multi-client server. # +# # +# This configuration can be used by multiple # +# clients, however each client should have # +# its own cert and key files. # +# # +# On Windows, you might want to rename this # +# file so it has a .ovpn extension # +############################################## + +# Specify that we are a client and that we +# will be pulling certain config file directives +# from the server. +client + +# Use the same setting as you are using on +# the server. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel +# if you have more than one. On XP SP2, +# you may need to disable the firewall +# for the TAP adapter. +;dev-node MyTap + +# Are we connecting to a TCP or +# UDP server? Use the same setting as +# on the server. +;proto tcp +proto udp + +# The hostname/IP and port of the server. +# You can have multiple remote entries +# to load balance between the servers. +remote 172.17.0.1 9194 + +# Choose a random host from the remote +# list for load-balancing. Otherwise +# try hosts in the order specified. +;remote-random + +# Keep trying indefinitely to resolve the +# host name of the OpenVPN server. Very useful +# on machines which are not permanently connected +# to the internet such as laptops. +resolv-retry infinite + +# Most clients don't need to bind to +# a specific local port number. +nobind + +# Downgrade privileges after initialization (non-Windows only) +user nobody +group nogroup + +# Try to preserve some state across restarts. +persist-key +persist-tun + +# If you are connecting through an +# HTTP proxy to reach the actual OpenVPN +# server, put the proxy server/IP and +# port number here. See the man page +# if your proxy server requires +# authentication. +;http-proxy-retry # retry on connection failures +;http-proxy [proxy server] [proxy port #] + +# Wireless networks often produce a lot +# of duplicate packets. Set this flag +# to silence duplicate packet warnings. +;mute-replay-warnings + +# SSL/TLS parms. +# See the server config file for more +# description. It's best to use +# a separate .crt/.key file pair +# for each client. A single ca +# file can be used for all clients. + +# Verify server certificate by checking that the +# certificate has the correct key usage set. +# This is an important precaution to protect against +# a potential attack discussed here: +# http://openvpn.net/howto.html#mitm +# +# To use this feature, you will need to generate +# your server certificates with the keyUsage set to +# digitalSignature, keyEncipherment +# and the extendedKeyUsage to +# serverAuth +# EasyRSA can do this for you. +remote-cert-tls server + +# If a tls-auth key is used on the server +# then every client must also have the key. +;tls-auth ta.key 1 + +# Select a cryptographic cipher. +# If the cipher option is used on the server +# then you must also specify it here. +# Note that v2.4 client/server will automatically +# negotiate AES-256-GCM in TLS mode. +# See also the data-ciphers option in the manpage +;cipher AES-256-CBC +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# Don't enable this unless it is also +# enabled in the server config file. +#comp-lzo + +# Set log file verbosity. +verb 3 + +# Silence repeating messages +;mute 20 + +key-direction 1 + +; script-security 2 +; up /etc/openvpn/update-resolv-conf +; down /etc/openvpn/update-resolv-conf + +; script-security 2 +; up /etc/openvpn/update-systemd-resolved +; down /etc/openvpn/update-systemd-resolved +; down-pre +; dhcp-option DOMAIN-ROUTE . + +-----BEGIN CERTIFICATE----- +MIIDQjCCAiqgAwIBAgIUBwqUYD1oxBKeImaMZfm44TsTAF0wDQYJKoZIhvcNAQEL +BQAwEzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQyWhcNMzUwOTIx +MTI0NjQyWjATMREwDwYDVQQDDAhDaGFuZ2VNZTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAKGt+8oRY7cWPg1SahfIV3XAeeH1SQEFq4f2q+E9ZbWVnCg9 +b59hMzwYr84/j4V73Hlv2udLrkguvnT9KqqJY/0wo3Bd1swH2WLej1fo0+rVo24w +hzeLfeH1e4erZbzQk8XG68U7yNDHKYo+LIz9syBzZA4Bq12bHxDsZbJF7HUANzFR +j9Xg3dR7utPtG8ktmD83rV9/E97whblMpLmjmf2sbCqdLOKTkZnwp5mI47TTkhMj +9K0q7irHmbtZcPZQH5Z59GtqaCaRt8DKfeYniyoPnGVfzFberHHQ4C11pcRrdvgY +n14/W5myh6HESQD6umyCYooyXG7wfqIKujROQCMCAwEAAaOBjTCBijAdBgNVHQ4E +FgQUtMsHbl94qRV7OW5UNNjk2mJ+/U8wTgYDVR0jBEcwRYAUtMsHbl94qRV7OW5U +NNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1lghQHCpRgPWjEEp4iZoxl ++bjhOxMAXTAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIBBjANBgkqhkiG9w0BAQsF +AAOCAQEAGeryP/2JuOp7tzi7Ww9lFUx2DRcgq/FwnU4biotUfuLejHQt/IeIwRYs +dW6AToUYJak8Uy/AFffMootwLcC8z8FATBnxtokWNpxtscpbTSHbeS0HvXnXFaU8 +xxlzp9l5k+46MrrvdzFsjoRfVxs0FUHzWifBnObBziTLfHt+J71509uqRWX6JuTa +PDAT8CMcLKxxS4BcorWtAmc51lW/dQQ41HDJ8a6acltDAprmlnhd8ksWzpTjUDNR +/cfSMcVTpPxPSW/WchR5NlJKQEAf9B/xC+LQgDRSDLaZ8CvzRDgosllzJ+aIS7GK +GPec69LiKqpirZ7enwDM67R4DwIHKA== +-----END CERTIFICATE----- + + +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 54:bc:c2:64:c6:73:20:54:74:58:b8:6a:6e:10:38:76 + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=ChangeMe + Validity + Not Before: Sep 23 12:46:43 2025 GMT + Not After : Dec 27 12:46:43 2027 GMT + Subject: CN=client_A + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:c8:62:8c:53:02:a0:a3:8b:17:bc:80:97:f8:0f: + 63:35:7d:75:1d:b4:36:bd:75:17:ac:36:35:0b:6a: + ec:38:b3:7f:d6:1f:ef:c2:90:dc:b3:d5:1e:11:65: + 36:5c:63:b8:ef:7c:d2:eb:05:4c:61:54:02:93:8b: + 84:6b:8b:1c:ca:3e:6e:d5:b4:b0:2c:6f:a4:36:db: + fc:d4:a3:8c:23:da:f0:be:cf:d3:16:dd:44:4d:77: + ce:53:1d:5e:14:e2:c3:67:b1:9a:25:44:f9:b3:b1: + f6:13:a6:0d:5e:16:49:cc:cd:52:b8:8c:2c:8e:ac: + 87:17:ff:ff:c1:8a:e3:f5:3c:71:69:9f:14:a2:85: + 37:0e:4b:16:24:83:08:4e:58:b7:60:36:98:c7:2e: + 4b:bb:d7:b2:e0:aa:95:bb:22:7d:a6:bf:da:71:95: + c0:fe:d6:bb:93:06:27:2f:b9:4c:47:85:f5:80:2b: + f1:1b:c8:03:bb:5a:8d:13:e9:0e:1a:23:c1:92:7a: + 7a:41:43:93:f3:3a:ca:36:0b:a2:dc:b8:fc:61:7d: + 7b:af:3e:7a:fc:ad:ac:d4:04:f4:ec:57:18:ae:c8: + 4d:c3:ec:5c:bd:72:c0:b0:8e:24:fe:13:44:93:b0: + c3:78:3c:99:23:74:dd:44:8f:e3:ac:1b:12:8d:d8: + 74:e9 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + X509v3 Subject Key Identifier: + DF:E6:D3:15:9C:F9:C3:F9:4E:C9:60:28:FA:6B:38:CA:1C:72:F7:B2 + X509v3 Authority Key Identifier: + keyid:B4:CB:07:6E:5F:78:A9:15:7B:39:6E:54:34:D8:E4:DA:62:7E:FD:4F + DirName:/CN=ChangeMe + serial:07:0A:94:60:3D:68:C4:12:9E:22:66:8C:65:F9:B8:E1:3B:13:00:5D + X509v3 Extended Key Usage: + TLS Web Client Authentication + X509v3 Key Usage: + Digital Signature + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 6c:de:92:45:ed:e7:01:63:d9:63:29:65:b7:75:e6:ed:31:44: + 8b:a7:7c:06:0c:02:87:15:bd:f2:e3:3e:e0:8b:74:87:44:d3: + 8a:f6:86:6d:3e:2f:1c:e7:b9:1d:b5:42:4d:60:76:1c:4f:8d: + a7:9c:81:a6:57:8b:62:85:76:15:f8:f8:0d:ef:2c:85:27:f5: + 2a:1d:36:84:88:77:72:f7:52:85:93:b8:0f:0b:97:54:e9:23: + 76:d6:1d:44:09:57:3e:ee:33:72:87:02:91:2e:50:fc:a2:88: + 42:88:6d:de:26:21:cc:79:96:61:9f:d9:1e:12:54:7c:96:f7: + 49:4a:08:f9:72:26:d7:40:59:fc:ab:8b:01:3d:b6:e2:4d:19: + fc:ff:1a:39:78:65:e0:13:9a:33:be:99:d6:fb:30:ea:a4:0b: + 41:32:eb:0e:f8:1c:95:e7:16:a0:3f:8e:2c:43:17:10:3c:f7: + b3:98:71:59:2d:17:94:32:a1:9b:85:39:2f:fa:2e:f9:45:dc: + 6e:c9:11:de:94:e1:10:52:87:04:43:e1:9b:4e:39:7b:c6:1e: + 55:a8:82:7c:77:d1:4a:cb:4c:8f:cb:ee:3f:b6:c7:6f:8a:3d: + 1a:a9:9e:9a:16:a4:3e:10:c0:49:95:5a:7c:c0:13:35:15:e8: + 1f:1f:f8:1a +-----BEGIN CERTIFICATE----- +MIIDUDCCAjigAwIBAgIQVLzCZMZzIFR0WLhqbhA4djANBgkqhkiG9w0BAQsFADAT +MREwDwYDVQQDDAhDaGFuZ2VNZTAeFw0yNTA5MjMxMjQ2NDNaFw0yNzEyMjcxMjQ2 +NDNaMBMxETAPBgNVBAMMCGNsaWVudF9BMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAyGKMUwKgo4sXvICX+A9jNX11HbQ2vXUXrDY1C2rsOLN/1h/vwpDc +s9UeEWU2XGO473zS6wVMYVQCk4uEa4scyj5u1bSwLG+kNtv81KOMI9rwvs/TFt1E +TXfOUx1eFOLDZ7GaJUT5s7H2E6YNXhZJzM1SuIwsjqyHF///wYrj9TxxaZ8UooU3 +DksWJIMITli3YDaYxy5Lu9ey4KqVuyJ9pr/acZXA/ta7kwYnL7lMR4X1gCvxG8gD +u1qNE+kOGiPBknp6QUOT8zrKNgui3Lj8YX17rz56/K2s1AT07FcYrshNw+xcvXLA +sI4k/hNEk7DDeDyZI3TdRI/jrBsSjdh06QIDAQABo4GfMIGcMAkGA1UdEwQCMAAw +HQYDVR0OBBYEFN/m0xWc+cP5TslgKPprOMoccveyME4GA1UdIwRHMEWAFLTLB25f +eKkVezluVDTY5Npifv1PoRekFTATMREwDwYDVQQDDAhDaGFuZ2VNZYIUBwqUYD1o +xBKeImaMZfm44TsTAF0wEwYDVR0lBAwwCgYIKwYBBQUHAwIwCwYDVR0PBAQDAgeA +MA0GCSqGSIb3DQEBCwUAA4IBAQBs3pJF7ecBY9ljKWW3debtMUSLp3wGDAKHFb3y +4z7gi3SHRNOK9oZtPi8c57kdtUJNYHYcT42nnIGmV4tihXYV+PgN7yyFJ/UqHTaE +iHdy91KFk7gPC5dU6SN21h1ECVc+7jNyhwKRLlD8oohCiG3eJiHMeZZhn9keElR8 +lvdJSgj5cibXQFn8q4sBPbbiTRn8/xo5eGXgE5ozvpnW+zDqpAtBMusO+ByV5xag +P44sQxcQPPezmHFZLReUMqGbhTkv+i75RdxuyRHelOEQUocEQ+GbTjl7xh5VqIJ8 +d9FKy0yPy+4/tsdvij0aqZ6aFqQ+EMBJlVp8wBM1FegfH/ga +-----END CERTIFICATE----- + + +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDIYoxTAqCjixe8 +gJf4D2M1fXUdtDa9dResNjULauw4s3/WH+/CkNyz1R4RZTZcY7jvfNLrBUxhVAKT +i4RrixzKPm7VtLAsb6Q22/zUo4wj2vC+z9MW3URNd85THV4U4sNnsZolRPmzsfYT +pg1eFknMzVK4jCyOrIcX///BiuP1PHFpnxSihTcOSxYkgwhOWLdgNpjHLku717Lg +qpW7In2mv9pxlcD+1ruTBicvuUxHhfWAK/EbyAO7Wo0T6Q4aI8GSenpBQ5PzOso2 +C6LcuPxhfXuvPnr8razUBPTsVxiuyE3D7Fy9csCwjiT+E0STsMN4PJkjdN1Ej+Os +GxKN2HTpAgMBAAECggEAWxr3KryclY6lTZZ3wZgZZpXyO/2WD8BfcXQ53MWRvdva +iNt/Ukozle2U3JrUQuAyEmyBpsoDZpLgEv4RSCX5AnitQquCl8lwc2LEilcLXbfq +0g5CLniOV9xbKc3F2yAYcJo+d6hrEQid1WQfWsIubpeBfxd4IKwPRdmmCfRgXTv+ +a7TVI9pRmFNg7J9Cs2VEqf7SdMX8U+7bPJfvHZ+aWYO5d9ZWhMSW5EB43QlgcVg2 +Eof1AjvkBY4NOOsb2uWkw7HiKloT95L8PR6I9bSCesJU58oGDPJyQKG58ANk5alh +9qPgzK5RnkMxzO+aEEzZ5x8NYacx51JwcScI6r5/ewKBgQDgMQSg+h+JZmGY1nuY +5OM6OiGoyHq8PAogPzWEO4N5I26kmkiTiLzyr4dzvPNx+1uOCuSQpt/qBB9rli1w +y1PQkrXMtfrHv83AWep1bFgripgwsGTKRTq0t9Obl5zzkV2OaBlGJP+gaBnfEbM4 +htchBFEyTMfoobFz9+Xv8mvHBwKBgQDk0NcZ7xoqx1PY4f8bbpAIOj8VhJopsZBm +Jv/jzJq8JREMXU54y8VkaT3ihY5tq/7DhPvpeVy87UI6urEaxoK5D3xAdoMVsBy0 +SVkfAMTjqU5PpZahPTL1vyvrH9EvJSfW1/qhtdyxZzw5p0T2Ro1ZEEC/GTRAZZuV +LgUHt594jwKBgQCtzJJgEUedhucmSzAqCVc2bpZleHXds1XORfJA/rofkR5XMNwO +s7R3FyiUyuiXdls1tLAYi6WOj3+kMhosFRR23yVc+77cV480DQC74zA/IQR2ymh4 +fk7ShqffORwNnqW+nmjpfglF2y4jRl9/9NiV2fjwW6GmcKNW2dlBuNdgxQKBgEKu +pfEV4D9VRZcwDWNWLj1nlBjWQwMhjx5mAS7G4tUvzC8ZRhQn9keT8AgCugY2GJGs +QKnCx4b7cdChtZlC/relzqUOpJb+cu8LbSB+3eIm5f6KGEK3DhHV+5uS8yhVIK4Y +1R6pXD6LAl8e4xcOaoTpGqVWWAboVZX9ClQ8bAn7AoGADM9OTA+hc/LicOO/oqp+ +lJ3XKBbQMvWZY0fhGvm0DSYLZi7cBOBCBwJXvfq278Cq1u+i2QHW9hV64Dcbt0TQ +l74cqQpoXZ7ZYFUUmYsEh3smL8K1u176Yig9LbVjUBD2eF02J+OXGWJtDQwyI696 +04gCGQhFI98vaM11YlS+skk= +-----END PRIVATE KEY----- + + +# +# 2048 bit OpenVPN static key +# +-----BEGIN OpenVPN Static key V1----- +488b61084812969fe8ad0f9dd40f56a2 +6cdadddfe345daef6b5c6d3c3e779fc5 +1f7d236966953482d2af085e3f8581b7 +d216f2d891972a463bbb22ca6c104b9d +f99dcb19d7d575a1d46e7918bb2556c6 +db9f51cd792c5e89e011586214692b95 +2a32a7fe85e4538c40e1d0aa2a9f8e15 +fcc0ce5d31974e3c2041b127776f7658 +878cb8245ed235ec996c2370c0fc0023 +699bc028b3412bc40209cba8233bc111 +fa1438095f99052d799fa718f3b04499 +472254d0286b4b2ce99db49e98a4cc25 +fd948bddcdcf08006a6d7bff40354e7b +5e93ea753a8ecc05de41ae34d280e7eb +99220e436bf8b7693a00667485631e28 +edba3e33b6f558dfa50b92eec6ac8b44 +-----END OpenVPN Static key V1----- + diff --git a/tests/local_vpn/client_configs/client_B_client.ovpn b/tests/local_vpn/client_configs/client_B_client.ovpn new file mode 100644 index 00000000..6229c033 --- /dev/null +++ b/tests/local_vpn/client_configs/client_B_client.ovpn @@ -0,0 +1,299 @@ +############################################## +# Sample client-side OpenVPN 2.0 config file # +# for connecting to multi-client server. # +# # +# This configuration can be used by multiple # +# clients, however each client should have # +# its own cert and key files. # +# # +# On Windows, you might want to rename this # +# file so it has a .ovpn extension # +############################################## + +# Specify that we are a client and that we +# will be pulling certain config file directives +# from the server. +client + +# Use the same setting as you are using on +# the server. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel +# if you have more than one. On XP SP2, +# you may need to disable the firewall +# for the TAP adapter. +;dev-node MyTap + +# Are we connecting to a TCP or +# UDP server? Use the same setting as +# on the server. +;proto tcp +proto udp + +# The hostname/IP and port of the server. +# You can have multiple remote entries +# to load balance between the servers. +remote 172.17.0.1 9194 + +# Choose a random host from the remote +# list for load-balancing. Otherwise +# try hosts in the order specified. +;remote-random + +# Keep trying indefinitely to resolve the +# host name of the OpenVPN server. Very useful +# on machines which are not permanently connected +# to the internet such as laptops. +resolv-retry infinite + +# Most clients don't need to bind to +# a specific local port number. +nobind + +# Downgrade privileges after initialization (non-Windows only) +user nobody +group nogroup + +# Try to preserve some state across restarts. +persist-key +persist-tun + +# If you are connecting through an +# HTTP proxy to reach the actual OpenVPN +# server, put the proxy server/IP and +# port number here. See the man page +# if your proxy server requires +# authentication. +;http-proxy-retry # retry on connection failures +;http-proxy [proxy server] [proxy port #] + +# Wireless networks often produce a lot +# of duplicate packets. Set this flag +# to silence duplicate packet warnings. +;mute-replay-warnings + +# SSL/TLS parms. +# See the server config file for more +# description. It's best to use +# a separate .crt/.key file pair +# for each client. A single ca +# file can be used for all clients. + +# Verify server certificate by checking that the +# certificate has the correct key usage set. +# This is an important precaution to protect against +# a potential attack discussed here: +# http://openvpn.net/howto.html#mitm +# +# To use this feature, you will need to generate +# your server certificates with the keyUsage set to +# digitalSignature, keyEncipherment +# and the extendedKeyUsage to +# serverAuth +# EasyRSA can do this for you. +remote-cert-tls server + +# If a tls-auth key is used on the server +# then every client must also have the key. +;tls-auth ta.key 1 + +# Select a cryptographic cipher. +# If the cipher option is used on the server +# then you must also specify it here. +# Note that v2.4 client/server will automatically +# negotiate AES-256-GCM in TLS mode. +# See also the data-ciphers option in the manpage +;cipher AES-256-CBC +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# Don't enable this unless it is also +# enabled in the server config file. +#comp-lzo + +# Set log file verbosity. +verb 3 + +# Silence repeating messages +;mute 20 + +key-direction 1 + +; script-security 2 +; up /etc/openvpn/update-resolv-conf +; down /etc/openvpn/update-resolv-conf + +; script-security 2 +; up /etc/openvpn/update-systemd-resolved +; down /etc/openvpn/update-systemd-resolved +; down-pre +; dhcp-option DOMAIN-ROUTE . + +-----BEGIN CERTIFICATE----- +MIIDQjCCAiqgAwIBAgIUBwqUYD1oxBKeImaMZfm44TsTAF0wDQYJKoZIhvcNAQEL +BQAwEzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQyWhcNMzUwOTIx +MTI0NjQyWjATMREwDwYDVQQDDAhDaGFuZ2VNZTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAKGt+8oRY7cWPg1SahfIV3XAeeH1SQEFq4f2q+E9ZbWVnCg9 +b59hMzwYr84/j4V73Hlv2udLrkguvnT9KqqJY/0wo3Bd1swH2WLej1fo0+rVo24w +hzeLfeH1e4erZbzQk8XG68U7yNDHKYo+LIz9syBzZA4Bq12bHxDsZbJF7HUANzFR +j9Xg3dR7utPtG8ktmD83rV9/E97whblMpLmjmf2sbCqdLOKTkZnwp5mI47TTkhMj +9K0q7irHmbtZcPZQH5Z59GtqaCaRt8DKfeYniyoPnGVfzFberHHQ4C11pcRrdvgY +n14/W5myh6HESQD6umyCYooyXG7wfqIKujROQCMCAwEAAaOBjTCBijAdBgNVHQ4E +FgQUtMsHbl94qRV7OW5UNNjk2mJ+/U8wTgYDVR0jBEcwRYAUtMsHbl94qRV7OW5U +NNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1lghQHCpRgPWjEEp4iZoxl ++bjhOxMAXTAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIBBjANBgkqhkiG9w0BAQsF +AAOCAQEAGeryP/2JuOp7tzi7Ww9lFUx2DRcgq/FwnU4biotUfuLejHQt/IeIwRYs +dW6AToUYJak8Uy/AFffMootwLcC8z8FATBnxtokWNpxtscpbTSHbeS0HvXnXFaU8 +xxlzp9l5k+46MrrvdzFsjoRfVxs0FUHzWifBnObBziTLfHt+J71509uqRWX6JuTa +PDAT8CMcLKxxS4BcorWtAmc51lW/dQQ41HDJ8a6acltDAprmlnhd8ksWzpTjUDNR +/cfSMcVTpPxPSW/WchR5NlJKQEAf9B/xC+LQgDRSDLaZ8CvzRDgosllzJ+aIS7GK +GPec69LiKqpirZ7enwDM67R4DwIHKA== +-----END CERTIFICATE----- + + +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + e0:1a:9b:9d:b6:2e:8a:b3:15:ba:a5:92:33:3d:75:01 + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=ChangeMe + Validity + Not Before: Sep 23 12:46:43 2025 GMT + Not After : Dec 27 12:46:43 2027 GMT + Subject: CN=client_B + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:bf:f1:4b:16:3d:95:5e:bd:9f:34:53:d6:a0:80: + 7c:0c:3b:36:65:32:0c:b5:a2:98:12:92:81:66:73: + 68:dd:ec:e3:b4:86:f8:7c:32:c1:1b:01:3b:47:07: + 61:fb:e4:d4:40:cf:9e:b6:1f:b8:10:8d:ac:39:f6: + 76:5d:84:5c:fb:38:f6:5d:cd:fe:60:dd:58:b9:fa: + ee:6b:61:62:53:e1:aa:31:b0:b8:36:8e:6b:b1:7c: + 08:8a:5f:1c:f3:03:29:3b:4f:bc:12:74:60:af:97: + 39:63:c2:77:f1:73:8d:b1:f5:80:f2:a2:e9:6b:4d: + 83:bf:7a:95:ee:30:6b:e1:e0:a4:6c:b4:e6:75:f9: + 92:3c:17:a0:17:1d:37:4b:5f:b3:2d:7a:ab:20:5e: + 27:22:82:31:5d:67:bb:58:3e:53:06:02:d9:17:84: + fa:2a:56:48:10:12:d8:5f:c2:00:f0:8c:d8:29:09: + ed:bf:d1:c2:30:74:2f:33:3f:7e:38:88:3a:fc:13: + f1:ed:5b:90:30:8e:7a:c5:b2:89:0f:21:e6:ad:8d: + a4:ca:30:e3:f8:5f:52:8e:cb:eb:13:6d:ce:cb:7c: + 21:ae:ab:b5:58:cd:85:1f:93:98:7f:ad:3f:1f:b0: + 95:14:74:20:ed:82:be:28:47:77:80:a8:8b:a7:33: + 41:7f + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + X509v3 Subject Key Identifier: + 6B:AB:8A:5C:11:80:8B:38:1F:B9:4B:7E:DC:AF:5A:B4:CF:41:74:4F + X509v3 Authority Key Identifier: + keyid:B4:CB:07:6E:5F:78:A9:15:7B:39:6E:54:34:D8:E4:DA:62:7E:FD:4F + DirName:/CN=ChangeMe + serial:07:0A:94:60:3D:68:C4:12:9E:22:66:8C:65:F9:B8:E1:3B:13:00:5D + X509v3 Extended Key Usage: + TLS Web Client Authentication + X509v3 Key Usage: + Digital Signature + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 49:2f:45:4f:07:f9:cf:26:0a:0c:a6:45:a9:cc:ca:e6:be:1f: + 24:47:b5:a7:5b:f0:00:e3:6d:15:b7:cd:1f:98:33:7a:dd:b4: + 2d:1a:c0:fe:34:84:ec:53:f8:b0:88:7c:30:9f:f3:43:5b:19: + 5b:dc:57:e4:18:fe:d7:cf:eb:50:03:8a:bf:03:d5:9c:79:92: + ad:5f:fe:12:a5:39:74:4e:e1:e0:48:af:31:62:a7:e8:e6:9a: + e9:e2:d7:40:52:d5:ab:22:e3:0b:9c:78:18:83:76:ba:5e:fe: + 6f:aa:96:f4:76:0f:88:ac:56:18:bc:e6:da:b7:55:ab:42:b7: + 74:2b:94:00:c8:e5:a1:66:63:41:b5:a9:48:7d:15:ce:d1:eb: + 14:50:3e:d0:a7:78:f4:92:0f:e3:ee:0d:df:5d:2c:ce:85:bf: + 73:39:32:dc:17:39:d4:39:11:11:f4:0b:ad:4d:af:88:1a:d4: + c4:bf:b9:1c:ed:e8:21:d4:b7:48:01:55:ff:a7:2b:86:b4:dd: + b4:54:fb:1f:0d:96:2b:da:15:c7:13:d2:1d:34:d5:13:dd:f4: + 6a:20:5a:e8:00:b8:60:88:5c:76:7e:77:82:6f:1b:a7:4c:41: + fb:4f:0f:1a:df:46:1f:09:79:a0:1c:16:c1:cd:7a:48:1c:91: + 1f:db:06:92 +-----BEGIN CERTIFICATE----- +MIIDUTCCAjmgAwIBAgIRAOAam522LoqzFbqlkjM9dQEwDQYJKoZIhvcNAQELBQAw +EzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQzWhcNMjcxMjI3MTI0 +NjQzWjATMREwDwYDVQQDDAhjbGllbnRfQjCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBAL/xSxY9lV69nzRT1qCAfAw7NmUyDLWimBKSgWZzaN3s47SG+Hwy +wRsBO0cHYfvk1EDPnrYfuBCNrDn2dl2EXPs49l3N/mDdWLn67mthYlPhqjGwuDaO +a7F8CIpfHPMDKTtPvBJ0YK+XOWPCd/FzjbH1gPKi6WtNg796le4wa+HgpGy05nX5 +kjwXoBcdN0tfsy16qyBeJyKCMV1nu1g+UwYC2ReE+ipWSBAS2F/CAPCM2CkJ7b/R +wjB0LzM/fjiIOvwT8e1bkDCOesWyiQ8h5q2NpMow4/hfUo7L6xNtzst8Ia6rtVjN +hR+TmH+tPx+wlRR0IO2CvihHd4Coi6czQX8CAwEAAaOBnzCBnDAJBgNVHRMEAjAA +MB0GA1UdDgQWBBRrq4pcEYCLOB+5S37cr1q0z0F0TzBOBgNVHSMERzBFgBS0ywdu +X3ipFXs5blQ02OTaYn79T6EXpBUwEzERMA8GA1UEAwwIQ2hhbmdlTWWCFAcKlGA9 +aMQSniJmjGX5uOE7EwBdMBMGA1UdJQQMMAoGCCsGAQUFBwMCMAsGA1UdDwQEAwIH +gDANBgkqhkiG9w0BAQsFAAOCAQEASS9FTwf5zyYKDKZFqczK5r4fJEe1p1vwAONt +FbfNH5gzet20LRrA/jSE7FP4sIh8MJ/zQ1sZW9xX5Bj+18/rUAOKvwPVnHmSrV/+ +EqU5dE7h4EivMWKn6Oaa6eLXQFLVqyLjC5x4GIN2ul7+b6qW9HYPiKxWGLzm2rdV +q0K3dCuUAMjloWZjQbWpSH0VztHrFFA+0Kd49JIP4+4N310szoW/czky3Bc51DkR +EfQLrU2viBrUxL+5HO3oIdS3SAFV/6crhrTdtFT7Hw2WK9oVxxPSHTTVE930aiBa +6AC4YIhcdn53gm8bp0xB+08PGt9GHwl5oBwWwc16SByRH9sGkg== +-----END CERTIFICATE----- + + +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/8UsWPZVevZ80 +U9aggHwMOzZlMgy1opgSkoFmc2jd7OO0hvh8MsEbATtHB2H75NRAz562H7gQjaw5 +9nZdhFz7OPZdzf5g3Vi5+u5rYWJT4aoxsLg2jmuxfAiKXxzzAyk7T7wSdGCvlzlj +wnfxc42x9YDyoulrTYO/epXuMGvh4KRstOZ1+ZI8F6AXHTdLX7MteqsgXicigjFd +Z7tYPlMGAtkXhPoqVkgQEthfwgDwjNgpCe2/0cIwdC8zP344iDr8E/HtW5AwjnrF +sokPIeatjaTKMOP4X1KOy+sTbc7LfCGuq7VYzYUfk5h/rT8fsJUUdCDtgr4oR3eA +qIunM0F/AgMBAAECggEAKNI2d+ptBBMr8sMJ2GS6/RbywJ7eWRrVYM3Lu3A8E0a4 +PsKdwjxBGW8vnjGRwzKteYMua+lfChY3VLR4A/eMltlMfDK9MPiiUBtv7WJuuQw7 +WAoPg3rSqJKKdnM4Au7fLAAPLZWWooF08SSAwdcjgX+HBxNitTFtHaIClP+zUfxI +av/bwDbUj928Lo/WZ/UtS0v+Bq8C+B4c/udYN7k4VDTuKvVv0KqJTn0deQ1fGBxt +a61HcLPOjBO9wnakcZMtmcz9bi9ziKIsOvoontTPTNP9M2p11mJMdndZvBW6sue0 +zb31Kd1QLlk6LkLEbp32SwA265QofOvc2Xf8Gr0G8QKBgQDjIhdhvkZqP3cl4jCw +IlPR7Y7TCWECXh9v76MLKIvmLXo5mO8b/DeQBTMQ5N+PW7/6eB5GDFi+B+foqEbk +NbpawtvDSglhGjyj0X6XqYHMpBSBLQuEvw03BiOgJddkE1BA2HVBvfr5JS7eRoZZ +sjOX+OBmpb1ie6hH7QIFbHbrzQKBgQDYVkFD6tKyUI8QhGqLf6xPvLha9gTwCG8m +uQe+fjVFZ2f/Cru5/sNl/xMiW2y1Sq37L8mLmY1hdxSGfkDzfcKdF6I1//woZmfK +cXWFTpqEBYTbVQGktZb37KasNdp4hREavWc3xKdiJOfbxVk+9cO4zSPyfMIpN/Km +YxwCApXOewKBgQDe8W+R+XqUf4csIEE6Ife0b0Fp1CLseAbTkJyxLzNi0/DM6FiL +V54SN4hQZNcrmBtwdscAas4QeSIhNEuhZTtuKyYbImjibyZmhhOEOlW10Lhvsw9D +VWRbRiNh5sLs8Cgt/knaJeha9Sxz8TWehVQvL5LULoseR9J+Bx2cxUJVYQKBgQCh +nqb5l3g7ESYgf9ydRQ+1LldIVV3Q+WwYsMkRTnZ72EoAZsNiq+rMy2g/FbA8LIOY +EdZvfZL7CpyB8daSUhTPibV8xDZc9Ex8GJFkuxmCoiDkPziQFb2okNrf8we5XCgw +Iun25urpzoqNTH1lJPRInrFJWl0vsAWOuqJU+hty+wKBgBDc7Ym9zMUaB1rLatwd +ECejBcvAPpD7rwEqmdzj9DTfzCOaUwsHwsQAEwg1tFrhuK5W44FtAP8y4eWn3Krt +ExPgrA5JxWnmI297Pa9YDuB6eczSdxKH2AxE0vz552ZPnO5eTZIQZAgIuVGZZxmR +KcXzTlbubo5w1jJpvbczHhA5 +-----END PRIVATE KEY----- + + +# +# 2048 bit OpenVPN static key +# +-----BEGIN OpenVPN Static key V1----- +488b61084812969fe8ad0f9dd40f56a2 +6cdadddfe345daef6b5c6d3c3e779fc5 +1f7d236966953482d2af085e3f8581b7 +d216f2d891972a463bbb22ca6c104b9d +f99dcb19d7d575a1d46e7918bb2556c6 +db9f51cd792c5e89e011586214692b95 +2a32a7fe85e4538c40e1d0aa2a9f8e15 +fcc0ce5d31974e3c2041b127776f7658 +878cb8245ed235ec996c2370c0fc0023 +699bc028b3412bc40209cba8233bc111 +fa1438095f99052d799fa718f3b04499 +472254d0286b4b2ce99db49e98a4cc25 +fd948bddcdcf08006a6d7bff40354e7b +5e93ea753a8ecc05de41ae34d280e7eb +99220e436bf8b7693a00667485631e28 +edba3e33b6f558dfa50b92eec6ac8b44 +-----END OpenVPN Static key V1----- + diff --git a/tests/local_vpn/client_configs/make_ovpn.sh b/tests/local_vpn/client_configs/make_ovpn.sh new file mode 100755 index 00000000..6a73d7f7 --- /dev/null +++ b/tests/local_vpn/client_configs/make_ovpn.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# First argument: Client identifier + +KEY_DIR=./keys +BASE_CONFIG=./client.conf + +cat ${BASE_CONFIG} \ + <(echo -e '') \ + ${KEY_DIR}/ca.crt \ + <(echo -e '\n') \ + ${KEY_DIR}/${1}.crt \ + <(echo -e '\n') \ + ${KEY_DIR}/${1}.key \ + <(echo -e '\n') \ + ${KEY_DIR}/ta.key \ + <(echo -e '') \ + > ${1}_client.ovpn diff --git a/tests/local_vpn/client_configs/testserver.local_client.ovpn b/tests/local_vpn/client_configs/testserver.local_client.ovpn new file mode 100644 index 00000000..4d11e13f --- /dev/null +++ b/tests/local_vpn/client_configs/testserver.local_client.ovpn @@ -0,0 +1,299 @@ +############################################## +# Sample client-side OpenVPN 2.0 config file # +# for connecting to multi-client server. # +# # +# This configuration can be used by multiple # +# clients, however each client should have # +# its own cert and key files. # +# # +# On Windows, you might want to rename this # +# file so it has a .ovpn extension # +############################################## + +# Specify that we are a client and that we +# will be pulling certain config file directives +# from the server. +client + +# Use the same setting as you are using on +# the server. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel +# if you have more than one. On XP SP2, +# you may need to disable the firewall +# for the TAP adapter. +;dev-node MyTap + +# Are we connecting to a TCP or +# UDP server? Use the same setting as +# on the server. +;proto tcp +proto udp + +# The hostname/IP and port of the server. +# You can have multiple remote entries +# to load balance between the servers. +remote 172.17.0.1 9194 + +# Choose a random host from the remote +# list for load-balancing. Otherwise +# try hosts in the order specified. +;remote-random + +# Keep trying indefinitely to resolve the +# host name of the OpenVPN server. Very useful +# on machines which are not permanently connected +# to the internet such as laptops. +resolv-retry infinite + +# Most clients don't need to bind to +# a specific local port number. +nobind + +# Downgrade privileges after initialization (non-Windows only) +user nobody +group nogroup + +# Try to preserve some state across restarts. +persist-key +persist-tun + +# If you are connecting through an +# HTTP proxy to reach the actual OpenVPN +# server, put the proxy server/IP and +# port number here. See the man page +# if your proxy server requires +# authentication. +;http-proxy-retry # retry on connection failures +;http-proxy [proxy server] [proxy port #] + +# Wireless networks often produce a lot +# of duplicate packets. Set this flag +# to silence duplicate packet warnings. +;mute-replay-warnings + +# SSL/TLS parms. +# See the server config file for more +# description. It's best to use +# a separate .crt/.key file pair +# for each client. A single ca +# file can be used for all clients. + +# Verify server certificate by checking that the +# certificate has the correct key usage set. +# This is an important precaution to protect against +# a potential attack discussed here: +# http://openvpn.net/howto.html#mitm +# +# To use this feature, you will need to generate +# your server certificates with the keyUsage set to +# digitalSignature, keyEncipherment +# and the extendedKeyUsage to +# serverAuth +# EasyRSA can do this for you. +remote-cert-tls server + +# If a tls-auth key is used on the server +# then every client must also have the key. +;tls-auth ta.key 1 + +# Select a cryptographic cipher. +# If the cipher option is used on the server +# then you must also specify it here. +# Note that v2.4 client/server will automatically +# negotiate AES-256-GCM in TLS mode. +# See also the data-ciphers option in the manpage +;cipher AES-256-CBC +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# Don't enable this unless it is also +# enabled in the server config file. +#comp-lzo + +# Set log file verbosity. +verb 3 + +# Silence repeating messages +;mute 20 + +key-direction 1 + +; script-security 2 +; up /etc/openvpn/update-resolv-conf +; down /etc/openvpn/update-resolv-conf + +; script-security 2 +; up /etc/openvpn/update-systemd-resolved +; down /etc/openvpn/update-systemd-resolved +; down-pre +; dhcp-option DOMAIN-ROUTE . + +-----BEGIN CERTIFICATE----- +MIIDQjCCAiqgAwIBAgIUBwqUYD1oxBKeImaMZfm44TsTAF0wDQYJKoZIhvcNAQEL +BQAwEzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQyWhcNMzUwOTIx +MTI0NjQyWjATMREwDwYDVQQDDAhDaGFuZ2VNZTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAKGt+8oRY7cWPg1SahfIV3XAeeH1SQEFq4f2q+E9ZbWVnCg9 +b59hMzwYr84/j4V73Hlv2udLrkguvnT9KqqJY/0wo3Bd1swH2WLej1fo0+rVo24w +hzeLfeH1e4erZbzQk8XG68U7yNDHKYo+LIz9syBzZA4Bq12bHxDsZbJF7HUANzFR +j9Xg3dR7utPtG8ktmD83rV9/E97whblMpLmjmf2sbCqdLOKTkZnwp5mI47TTkhMj +9K0q7irHmbtZcPZQH5Z59GtqaCaRt8DKfeYniyoPnGVfzFberHHQ4C11pcRrdvgY +n14/W5myh6HESQD6umyCYooyXG7wfqIKujROQCMCAwEAAaOBjTCBijAdBgNVHQ4E +FgQUtMsHbl94qRV7OW5UNNjk2mJ+/U8wTgYDVR0jBEcwRYAUtMsHbl94qRV7OW5U +NNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1lghQHCpRgPWjEEp4iZoxl ++bjhOxMAXTAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIBBjANBgkqhkiG9w0BAQsF +AAOCAQEAGeryP/2JuOp7tzi7Ww9lFUx2DRcgq/FwnU4biotUfuLejHQt/IeIwRYs +dW6AToUYJak8Uy/AFffMootwLcC8z8FATBnxtokWNpxtscpbTSHbeS0HvXnXFaU8 +xxlzp9l5k+46MrrvdzFsjoRfVxs0FUHzWifBnObBziTLfHt+J71509uqRWX6JuTa +PDAT8CMcLKxxS4BcorWtAmc51lW/dQQ41HDJ8a6acltDAprmlnhd8ksWzpTjUDNR +/cfSMcVTpPxPSW/WchR5NlJKQEAf9B/xC+LQgDRSDLaZ8CvzRDgosllzJ+aIS7GK +GPec69LiKqpirZ7enwDM67R4DwIHKA== +-----END CERTIFICATE----- + + +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 0b:44:23:c7:c0:5f:a4:2c:ee:c7:77:80:f9:48:36:04 + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=ChangeMe + Validity + Not Before: Sep 23 12:46:42 2025 GMT + Not After : Dec 27 12:46:42 2027 GMT + Subject: CN=testserver.local + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:b4:ef:06:2e:ca:f4:e5:1f:b4:1e:d0:ca:d4:a1: + ef:03:4d:14:b6:e8:4e:e9:26:e0:c5:96:d7:0a:36: + a5:4c:6d:92:5b:05:e8:0e:57:14:64:c1:84:1f:7c: + f4:99:3a:c7:4a:41:92:5a:c1:99:c1:0c:33:d6:81: + f2:49:e3:7a:10:d1:2e:24:b8:3e:d1:00:a6:c0:a4: + 56:a5:17:7d:70:df:74:e5:0c:97:5e:67:2f:05:0a: + 81:8b:24:5b:22:b5:87:62:12:4a:92:b2:e2:b7:3b: + d6:39:20:dc:22:76:58:61:5c:a4:6d:d5:33:4b:a6: + 54:00:7f:43:69:ce:0a:d6:3a:21:d2:8c:59:1e:e7: + 66:ad:77:6b:fe:56:d3:12:ca:bd:18:55:c9:71:e4: + 8b:da:67:28:b3:63:6b:6f:31:e2:b5:89:15:af:ea: + 1a:9a:7f:31:b3:f1:ba:32:21:59:96:81:71:9f:69: + 13:86:d2:db:c5:aa:0c:a7:95:3b:68:a3:9d:46:a9: + 61:c9:04:13:53:44:3e:60:81:5e:da:54:43:b2:90: + 75:33:dc:4a:9a:ed:2e:f0:82:ef:1f:e6:72:7f:6b: + 20:64:67:9b:d3:66:e4:99:64:6a:62:7f:47:83:c3: + 50:f3:bc:fe:e2:7a:c8:65:99:82:2c:89:3b:2c:78: + 32:e3 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + X509v3 Subject Key Identifier: + D0:06:2F:3A:D8:9B:F3:9D:7D:B5:8A:F6:5D:CE:8A:83:89:5D:AB:B0 + X509v3 Authority Key Identifier: + keyid:B4:CB:07:6E:5F:78:A9:15:7B:39:6E:54:34:D8:E4:DA:62:7E:FD:4F + DirName:/CN=ChangeMe + serial:07:0A:94:60:3D:68:C4:12:9E:22:66:8C:65:F9:B8:E1:3B:13:00:5D + X509v3 Extended Key Usage: + TLS Web Client Authentication + X509v3 Key Usage: + Digital Signature + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 01:81:34:a1:ad:9d:9f:e7:cf:a1:ae:e5:8b:6f:d8:b3:eb:ac: + f7:8f:09:8c:f5:ad:64:96:a5:45:58:c6:92:6e:f8:e2:21:06: + 2d:2a:89:fb:61:d5:eb:6b:56:78:d7:28:31:f7:58:2c:52:bf: + b2:ed:48:92:0c:49:b1:70:30:78:14:41:76:d4:c4:be:3c:15: + b8:4f:27:6d:a9:87:3b:45:b9:a4:76:3d:23:51:6a:9d:ca:24: + 63:ba:50:ed:4c:b9:ad:8f:c8:57:54:44:16:53:35:0a:c6:c8: + 25:2e:57:7c:32:28:57:bd:e4:6d:98:a8:96:31:d9:42:bb:65: + 25:0e:2a:d9:a5:94:17:2c:6c:bb:f7:c6:d6:e9:b2:df:a2:66: + f6:cb:73:43:97:dc:5c:b5:34:a3:0a:8b:84:ba:71:4e:81:83: + 8d:5e:2c:99:7f:12:89:b3:90:27:1a:0c:e8:c6:d5:51:8f:9f: + ea:49:b9:24:64:68:64:40:98:21:82:eb:52:7c:8b:10:48:61: + b5:01:d4:42:6c:2e:13:f1:07:52:0d:cf:05:cd:06:70:0c:63: + aa:e1:dc:93:2b:bb:8e:eb:11:3e:59:6f:12:90:37:29:d8:45: + fc:d3:52:87:b4:a2:55:54:f2:17:d8:f4:32:52:39:3a:cf:0d: + 2c:a0:d4:e3 +-----BEGIN CERTIFICATE----- +MIIDWDCCAkCgAwIBAgIQC0Qjx8BfpCzux3eA+Ug2BDANBgkqhkiG9w0BAQsFADAT +MREwDwYDVQQDDAhDaGFuZ2VNZTAeFw0yNTA5MjMxMjQ2NDJaFw0yNzEyMjcxMjQ2 +NDJaMBsxGTAXBgNVBAMMEHRlc3RzZXJ2ZXIubG9jYWwwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQC07wYuyvTlH7Qe0MrUoe8DTRS26E7pJuDFltcKNqVM +bZJbBegOVxRkwYQffPSZOsdKQZJawZnBDDPWgfJJ43oQ0S4kuD7RAKbApFalF31w +33TlDJdeZy8FCoGLJFsitYdiEkqSsuK3O9Y5INwidlhhXKRt1TNLplQAf0NpzgrW +OiHSjFke52atd2v+VtMSyr0YVclx5IvaZyizY2tvMeK1iRWv6hqafzGz8boyIVmW +gXGfaROG0tvFqgynlTtoo51GqWHJBBNTRD5ggV7aVEOykHUz3Eqa7S7wgu8f5nJ/ +ayBkZ5vTZuSZZGpif0eDw1DzvP7ieshlmYIsiTsseDLjAgMBAAGjgZ8wgZwwCQYD +VR0TBAIwADAdBgNVHQ4EFgQU0AYvOtib8519tYr2Xc6Kg4ldq7AwTgYDVR0jBEcw +RYAUtMsHbl94qRV7OW5UNNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1l +ghQHCpRgPWjEEp4iZoxl+bjhOxMAXTATBgNVHSUEDDAKBggrBgEFBQcDAjALBgNV +HQ8EBAMCB4AwDQYJKoZIhvcNAQELBQADggEBAAGBNKGtnZ/nz6Gu5Ytv2LPrrPeP +CYz1rWSWpUVYxpJu+OIhBi0qifth1etrVnjXKDH3WCxSv7LtSJIMSbFwMHgUQXbU +xL48FbhPJ22phztFuaR2PSNRap3KJGO6UO1Mua2PyFdURBZTNQrGyCUuV3wyKFe9 +5G2YqJYx2UK7ZSUOKtmllBcsbLv3xtbpst+iZvbLc0OX3Fy1NKMKi4S6cU6Bg41e +LJl/EomzkCcaDOjG1VGPn+pJuSRkaGRAmCGC61J8ixBIYbUB1EJsLhPxB1INzwXN +BnAMY6rh3JMru47rET5ZbxKQNynYRfzTUoe0olVU8hfY9DJSOTrPDSyg1OM= +-----END CERTIFICATE----- + + +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC07wYuyvTlH7Qe +0MrUoe8DTRS26E7pJuDFltcKNqVMbZJbBegOVxRkwYQffPSZOsdKQZJawZnBDDPW +gfJJ43oQ0S4kuD7RAKbApFalF31w33TlDJdeZy8FCoGLJFsitYdiEkqSsuK3O9Y5 +INwidlhhXKRt1TNLplQAf0NpzgrWOiHSjFke52atd2v+VtMSyr0YVclx5IvaZyiz +Y2tvMeK1iRWv6hqafzGz8boyIVmWgXGfaROG0tvFqgynlTtoo51GqWHJBBNTRD5g +gV7aVEOykHUz3Eqa7S7wgu8f5nJ/ayBkZ5vTZuSZZGpif0eDw1DzvP7ieshlmYIs +iTsseDLjAgMBAAECggEADhFrsfsL3D4r0Mg74m6OLqaGpCDJ9JeWRdFlpPX8wkWJ +af8P56reULs53QI9OSYjKriQevKvVB2bxjj9BWu6KPvqvOYqftuwcNgWTeiBU2O8 +gLa1lPHW68BrtCLpMc4FhHBBph23QmH/qm94o6FUsTVKf6kNFPu4xP/K1mYz3NYv +ejQGXFtmi1bJFo+wf5KhUOg1devz4gWYodGPZlJ2M3tbFLv3Xaaj9k776rSkXmD8 +DQvP5yND0j1x9N6hT/tE2f0pSZmO1iu2782ER2LN2C12FEQKtEReGi9Pm+DkPl/u +KqgxUeIAQazmppP8cfIJH6SK7RXNvHZjCnXKigaJmQKBgQDVR6Gh+mArDXFfalqg +Me2V13On4exe3zwIqHOYxHLcEHqWyLsSKa+xa+CUCfJpc0Nux51SnDaxBOwYBNqT +rYRLxXyN5ocJWpdguiBP8nXdTFVC8XwZtLC2QH+2UK322AUTBmFV85xIVofeLgY/ +H/GOqdi7wIGfg/vdyJUxMnhFdQKBgQDZLMcVo62FEgyPB90ZE3KnGdJJlFHHKkj8 +AC0R20Rd6Y3oDFuoHmKaV1vo5ePthjHhyMgJ2VHIPih3+jt5mQf/zaveHKfrwg4F +rlPbqsY08tWM51qQ1wKgyKi4ASZKWzYQUZBhZrd2YXLyN0EQzMtTLTjaPHIpDVnP +r+w37/+T9wKBgQCRJ0Y3Ekr/IhAF60Ewg6p5739UQ+t2CiI2lkbOMu0lHsX/9y9y +RhLAAnZ+6mIkKIE9VPeacJy8T2hLVIpaNZ6zXv3NKZa/4/rgpuw03QQgj8H7ZJSc +fiBCeZUxxKkRNaYGc7ItKDY1+UZRDSvNLHVfLfNGnNbbdJ0nLUt0hy/ZvQKBgHsj +0J6MeE8DtOtE4jDdvhzRn1LpLpVnfIqm7uc5FMLLMxNoLnBdCjvJXOvprht4A8Cq +QAKVnrGTzQ56bE6+XrLEw7blOLGNDrZZ6mKbqldLeZqzc768q1jPbhsnS7bNkRIf +rWYM/+m3x51fhx0nggJfmeTkcTalw07nyWDOTHRxAoGAQ7poaI25mTTwyX4J4adK +n1BMrrFns7ztHdbWD+P2T1MnJ/ibURPwXCdFuxKCEBtyoEiFTBZRLuB4N4UUwGe3 +pFmgWL4d+qPrCfOyksn0YyTjBtoBPrxSccWrNeKBqePewmnQk0SRLN8w8hoiJFgV +zGLzeXsIRbDfvPT3ZUT3zgI= +-----END PRIVATE KEY----- + + +# +# 2048 bit OpenVPN static key +# +-----BEGIN OpenVPN Static key V1----- +488b61084812969fe8ad0f9dd40f56a2 +6cdadddfe345daef6b5c6d3c3e779fc5 +1f7d236966953482d2af085e3f8581b7 +d216f2d891972a463bbb22ca6c104b9d +f99dcb19d7d575a1d46e7918bb2556c6 +db9f51cd792c5e89e011586214692b95 +2a32a7fe85e4538c40e1d0aa2a9f8e15 +fcc0ce5d31974e3c2041b127776f7658 +878cb8245ed235ec996c2370c0fc0023 +699bc028b3412bc40209cba8233bc111 +fa1438095f99052d799fa718f3b04499 +472254d0286b4b2ce99db49e98a4cc25 +fd948bddcdcf08006a6d7bff40354e7b +5e93ea753a8ecc05de41ae34d280e7eb +99220e436bf8b7693a00667485631e28 +edba3e33b6f558dfa50b92eec6ac8b44 +-----END OpenVPN Static key V1----- + diff --git a/tests/local_vpn/create_openvpn_certificates.sh b/tests/local_vpn/create_openvpn_certificates.sh new file mode 100755 index 00000000..91aef0dc --- /dev/null +++ b/tests/local_vpn/create_openvpn_certificates.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +./_build_docker.sh + +docker run --rm -v ./ca_user:/home/ca_user -v ./client_configs:/client_configs -v ./server_config:/server_config -p 9194:9194/udp --cap-add=NET_ADMIN --privileged --name odelia_testing_openvpnserver odelia_testing_openvpnserver:latest /bin/bash -c "./_openvpn_certificate_creation.sh" diff --git a/tests/local_vpn/run_docker_openvpnserver.sh b/tests/local_vpn/run_docker_openvpnserver.sh new file mode 100755 index 00000000..f501f811 --- /dev/null +++ b/tests/local_vpn/run_docker_openvpnserver.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +./_build_docker.sh + +docker run -d -t --rm -v ./ca_user:/home/ca_user -v ./server_config:/server_config -p 9194:9194/udp --cap-add=NET_ADMIN --privileged --name odelia_testing_openvpnserver odelia_testing_openvpnserver:latest /bin/bash -c "./_openvpn_start.sh && /bin/bash" diff --git a/tests/local_vpn/server_config/.gitignore b/tests/local_vpn/server_config/.gitignore new file mode 100644 index 00000000..23de1ea2 --- /dev/null +++ b/tests/local_vpn/server_config/.gitignore @@ -0,0 +1,2 @@ +nohup.out +ipp.txt \ No newline at end of file diff --git a/tests/local_vpn/server_config/ca.crt b/tests/local_vpn/server_config/ca.crt new file mode 100644 index 00000000..02ee2179 --- /dev/null +++ b/tests/local_vpn/server_config/ca.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDQjCCAiqgAwIBAgIUBwqUYD1oxBKeImaMZfm44TsTAF0wDQYJKoZIhvcNAQEL +BQAwEzERMA8GA1UEAwwIQ2hhbmdlTWUwHhcNMjUwOTIzMTI0NjQyWhcNMzUwOTIx +MTI0NjQyWjATMREwDwYDVQQDDAhDaGFuZ2VNZTCCASIwDQYJKoZIhvcNAQEBBQAD +ggEPADCCAQoCggEBAKGt+8oRY7cWPg1SahfIV3XAeeH1SQEFq4f2q+E9ZbWVnCg9 +b59hMzwYr84/j4V73Hlv2udLrkguvnT9KqqJY/0wo3Bd1swH2WLej1fo0+rVo24w +hzeLfeH1e4erZbzQk8XG68U7yNDHKYo+LIz9syBzZA4Bq12bHxDsZbJF7HUANzFR +j9Xg3dR7utPtG8ktmD83rV9/E97whblMpLmjmf2sbCqdLOKTkZnwp5mI47TTkhMj +9K0q7irHmbtZcPZQH5Z59GtqaCaRt8DKfeYniyoPnGVfzFberHHQ4C11pcRrdvgY +n14/W5myh6HESQD6umyCYooyXG7wfqIKujROQCMCAwEAAaOBjTCBijAdBgNVHQ4E +FgQUtMsHbl94qRV7OW5UNNjk2mJ+/U8wTgYDVR0jBEcwRYAUtMsHbl94qRV7OW5U +NNjk2mJ+/U+hF6QVMBMxETAPBgNVBAMMCENoYW5nZU1lghQHCpRgPWjEEp4iZoxl ++bjhOxMAXTAMBgNVHRMEBTADAQH/MAsGA1UdDwQEAwIBBjANBgkqhkiG9w0BAQsF +AAOCAQEAGeryP/2JuOp7tzi7Ww9lFUx2DRcgq/FwnU4biotUfuLejHQt/IeIwRYs +dW6AToUYJak8Uy/AFffMootwLcC8z8FATBnxtokWNpxtscpbTSHbeS0HvXnXFaU8 +xxlzp9l5k+46MrrvdzFsjoRfVxs0FUHzWifBnObBziTLfHt+J71509uqRWX6JuTa +PDAT8CMcLKxxS4BcorWtAmc51lW/dQQ41HDJ8a6acltDAprmlnhd8ksWzpTjUDNR +/cfSMcVTpPxPSW/WchR5NlJKQEAf9B/xC+LQgDRSDLaZ8CvzRDgosllzJ+aIS7GK +GPec69LiKqpirZ7enwDM67R4DwIHKA== +-----END CERTIFICATE----- diff --git a/tests/local_vpn/server_config/ccd/admin@test.odelia b/tests/local_vpn/server_config/ccd/admin@test.odelia new file mode 100644 index 00000000..3e8f368e --- /dev/null +++ b/tests/local_vpn/server_config/ccd/admin@test.odelia @@ -0,0 +1 @@ +ifconfig-push 10.8.0.5 255.0.0.0 diff --git a/tests/local_vpn/server_config/ccd/client_A b/tests/local_vpn/server_config/ccd/client_A new file mode 100644 index 00000000..2009193e --- /dev/null +++ b/tests/local_vpn/server_config/ccd/client_A @@ -0,0 +1 @@ +ifconfig-push 10.8.0.6 255.0.0.0 diff --git a/tests/local_vpn/server_config/ccd/client_B b/tests/local_vpn/server_config/ccd/client_B new file mode 100644 index 00000000..da607617 --- /dev/null +++ b/tests/local_vpn/server_config/ccd/client_B @@ -0,0 +1 @@ +ifconfig-push 10.8.0.7 255.0.0.0 diff --git a/tests/local_vpn/server_config/ccd/testserver.local b/tests/local_vpn/server_config/ccd/testserver.local new file mode 100644 index 00000000..75bd4873 --- /dev/null +++ b/tests/local_vpn/server_config/ccd/testserver.local @@ -0,0 +1 @@ +ifconfig-push 10.8.0.4 255.0.0.0 diff --git a/tests/local_vpn/server_config/server.conf b/tests/local_vpn/server_config/server.conf new file mode 100755 index 00000000..8d90cd74 --- /dev/null +++ b/tests/local_vpn/server_config/server.conf @@ -0,0 +1,304 @@ +################################################# +# Sample OpenVPN 2.0 config file for # +# multi-client server. # +# # +# This file is for the server side # +# of a many-clients <-> one-server # +# OpenVPN configuration. # +# # +# OpenVPN also supports # +# single-machine <-> single-machine # +# configurations (See the Examples page # +# on the web site for more info). # +# # +# This config should work on Windows # +# or Linux/BSD systems. Remember on # +# Windows to quote pathnames and use # +# double backslashes, e.g.: # +# "C:\\Program Files\\OpenVPN\\config\\foo.key" # +# # +# Comments are preceded with '#' or ';' # +################################################# + +# Which local IP address should OpenVPN +# listen on? (optional) +;local a.b.c.d + +# Which TCP/UDP port should OpenVPN listen on? +# If you want to run multiple OpenVPN instances +# on the same machine, use a different port +# number for each one. You will need to +# open up this port on your firewall. +port 9194 + +# TCP or UDP server? +;proto tcp +proto udp + +# "dev tun" will create a routed IP tunnel, +# "dev tap" will create an ethernet tunnel. +# Use "dev tap0" if you are ethernet bridging +# and have precreated a tap0 virtual interface +# and bridged it with your ethernet interface. +# If you want to control access policies +# over the VPN, you must create firewall +# rules for the the TUN/TAP interface. +# On non-Windows systems, you can give +# an explicit unit number, such as tun0. +# On Windows, use "dev-node" for this. +# On most systems, the VPN will not function +# unless you partially or fully disable +# the firewall for the TUN/TAP interface. +;dev tap +dev tun + +# Windows needs the TAP-Win32 adapter name +# from the Network Connections panel if you +# have more than one. On XP SP2 or higher, +# you may need to selectively disable the +# Windows firewall for the TAP adapter. +# Non-Windows systems usually don't need this. +;dev-node MyTap + +# SSL/TLS root certificate (ca), certificate +# (cert), and private key (key). Each client +# and the server must have their own cert and +# key file. The server and all clients will +# use the same ca file. +# +# See the "easy-rsa" directory for a series +# of scripts for generating RSA certificates +# and private keys. Remember to use +# a unique Common Name for the server +# and each of the client certificates. +# +# Any X509 key management system can be used. +# OpenVPN can also use a PKCS #12 formatted key file +# (see "pkcs12" directive in man page). +ca /etc/openvpn/server/ca.crt +cert /etc/openvpn/server/server.crt +key /etc/openvpn/server/server.key # This file should be kept secret + +# Diffie hellman parameters. +# Generate your own with: +# openssl dhparam -out dh1024.pem 1024 +# Substitute 2048 for 1024 if you are using +# 2048 bit keys. +;dh dh1024.pem +dh none + +# Configure server mode and supply a VPN subnet +# for OpenVPN to draw client addresses from. +# The server will take 10.8.0.1 for itself, +# the rest will be made available to clients. +# Each client will be able to reach the server +# on 10.8.0.1. Comment this line out if you are +# ethernet bridging. See the man page for more info. +server 10.8.0.0 255.255.255.0 + +# Maintain a record of client <-> virtual IP address +# associations in this file. If OpenVPN goes down or +# is restarted, reconnecting clients can be assigned +# the same virtual IP address from the pool that was +# previously assigned. +ifconfig-pool-persist ipp.txt + +# Configure server mode for ethernet bridging. +# You must first use your OS's bridging capability +# to bridge the TAP interface with the ethernet +# NIC interface. Then you must manually set the +# IP/netmask on the bridge interface, here we +# assume 10.8.0.4/255.255.255.0. Finally we +# must set aside an IP range in this subnet +# (start=10.8.0.50 end=10.8.0.100) to allocate +# to connecting clients. Leave this line commented +# out unless you are ethernet bridging. +;server-bridge 10.8.0.4 255.255.255.0 10.8.0.50 10.8.0.100 + +# Configure server mode for ethernet bridging +# using a DHCP-proxy, where clients talk +# to the OpenVPN server-side DHCP server +# to receive their IP address allocation +# and DNS server addresses. You must first use +# your OS's bridging capability to bridge the TAP +# interface with the ethernet NIC interface. +# Note: this mode only works on clients (such as +# Windows), where the client-side TAP adapter is +# bound to a DHCP client. +;server-bridge + +# Push routes to the client to allow it +# to reach other private subnets behind +# the server. Remember that these +# private subnets will also need +# to know to route the OpenVPN client +# address pool (10.8.0.0/255.255.255.0) +# back to the OpenVPN server. +;push "route 192.168.10.0 255.255.255.0" +;push "route 192.168.20.0 255.255.255.0" + +# To assign specific IP addresses to specific +# clients or if a connecting client has a private +# subnet behind it that should also have VPN access, +# use the subdirectory "ccd" for client-specific +# configuration files (see man page for more info). + +# EXAMPLE: Suppose the client +# having the certificate common name "Thelonious" +# also has a small subnet behind his connecting +# machine, such as 192.168.40.128/255.255.255.248. +# First, uncomment out these lines: +;client-config-dir ccd +;route 192.168.40.128 255.255.255.248 +# Then create a file ccd/Thelonious with this line: +# iroute 192.168.40.128 255.255.255.248 +# This will allow Thelonious' private subnet to +# access the VPN. This example will only work +# if you are routing, not bridging, i.e. you are +# using "dev tun" and "server" directives. + +# EXAMPLE: Suppose you want to give +# Thelonious a fixed VPN IP address of 10.9.0.1. +# First uncomment out these lines: +client-config-dir /server_config/ccd +;route 10.9.0.0 255.255.255.252 +# Then add this line to ccd/Thelonious: +# ifconfig-push 10.9.0.1 10.9.0.2 + +# Suppose that you want to enable different +# firewall access policies for different groups +# of clients. There are two methods: +# (1) Run multiple OpenVPN daemons, one for each +# group, and firewall the TUN/TAP interface +# for each group/daemon appropriately. +# (2) (Advanced) Create a script to dynamically +# modify the firewall in response to access +# from different clients. See man +# page for more info on learn-address script. +;learn-address ./script + +# If enabled, this directive will configure +# all clients to redirect their default +# network gateway through the VPN, causing +# all IP traffic such as web browsing and +# and DNS lookups to go through the VPN +# (The OpenVPN server machine may need to NAT +# or bridge the TUN/TAP interface to the internet +# in order for this to work properly). +;push "redirect-gateway def1 bypass-dhcp" + +# Certain Windows-specific network settings +# can be pushed to clients, such as DNS +# or WINS server addresses. CAVEAT: +# http://openvpn.net/faq.html#dhcpcaveats +# The addresses below refer to the public +# DNS servers provided by opendns.com. +;push "dhcp-option DNS 208.67.222.222" +;push "dhcp-option DNS 208.67.220.220" + +# Uncomment this directive to allow different +# clients to be able to "see" each other. +# By default, clients will only see the server. +# To force clients to only see the server, you +# will also need to appropriately firewall the +# server's TUN/TAP interface. +;client-to-client + +# Uncomment this directive if multiple clients +# might connect with the same certificate/key +# files or common names. This is recommended +# only for testing purposes. For production use, +# each client should have its own certificate/key +# pair. +# +# IF YOU HAVE NOT GENERATED INDIVIDUAL +# CERTIFICATE/KEY PAIRS FOR EACH CLIENT, +# EACH HAVING ITS OWN UNIQUE "COMMON NAME", +# UNCOMMENT THIS LINE OUT. +;duplicate-cn + +# The keepalive directive causes ping-like +# messages to be sent back and forth over +# the link so that each side knows when +# the other side has gone down. +# Ping every 10 seconds, assume that remote +# peer is down if no ping received during +# a 120 second time period. +keepalive 2 10 + +# For extra security beyond that provided +# by SSL/TLS, create an "HMAC firewall" +# to help block DoS attacks and UDP port flooding. +# +# Generate with: +# openvpn --genkey --secret ta.key +# +# The server and each client must have +# a copy of this key. +# The second parameter should be '0' +# on the server and '1' on the clients. +;tls-auth ta.key 0 # This file is secret +tls-crypt /etc/openvpn/server/ta.key + +# Select a cryptographic cipher. +# This config item must be copied to +# the client config file as well. +;cipher BF-CBC # Blowfish (default) +;cipher AES-128-CBC # AES +;cipher DES-EDE3-CBC # Triple-DES +cipher AES-256-GCM + +auth SHA256 + +# Enable compression on the VPN link. +# If you enable it here, you must also +# enable it in the client config file. +;comp-lzo + +# The maximum number of concurrently connected +# clients we want to allow. +;max-clients 100 + +# It's a good idea to reduce the OpenVPN +# daemon's privileges after initialization. +# +# You can uncomment this out on +# non-Windows systems. +user nobody +group nogroup + +# The persist options will try to avoid +# accessing certain resources on restart +# that may no longer be accessible because +# of the privilege downgrade. +persist-key +persist-tun + +# Output a short status file showing +# current connections, truncated +# and rewritten every minute. +status openvpn-status.log + +# By default, log messages will go to the syslog (or +# on Windows, if running as a service, they will go to +# the "\Program Files\OpenVPN\log" directory). +# Use log or log-append to override this default. +# "log" will truncate the log file on OpenVPN startup, +# while "log-append" will append to it. Use one +# or the other (but not both). +;log openvpn.log +;log-append openvpn.log + +# Set the appropriate level of log +# file verbosity. +# +# 0 is silent, except for fatal errors +# 4 is reasonable for general usage +# 5 and 6 can help to debug connection problems +# 9 is extremely verbose +verb 3 + +# Silence repeating messages. At most 20 +# sequential messages of the same message +# category will be output to the log. +;mute 20 diff --git a/tests/local_vpn/server_config/server.crt b/tests/local_vpn/server_config/server.crt new file mode 100644 index 00000000..8a6bcc20 --- /dev/null +++ b/tests/local_vpn/server_config/server.crt @@ -0,0 +1,87 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 61:74:e5:68:11:63:be:bb:fa:fa:4d:63:12:ad:fa:6a + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN=ChangeMe + Validity + Not Before: Sep 23 12:46:42 2025 GMT + Not After : Dec 27 12:46:42 2027 GMT + Subject: CN=ChangeMe + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:ae:66:65:3b:39:6e:aa:39:39:7f:f1:be:18:c4: + 52:60:c3:3c:63:77:2a:fd:d0:79:22:6a:5f:b7:ab: + 9d:94:27:89:9a:5c:2d:b7:ea:66:91:f7:06:57:24: + 38:bd:55:71:2d:ff:9a:dd:b3:ed:0c:bf:1b:8c:93: + 27:63:d4:a1:a7:00:55:68:c5:a0:c4:9e:d3:51:d7: + ec:f8:9d:7e:b1:a4:84:80:78:9b:76:58:61:b9:89: + c9:94:e5:ad:ca:61:33:e0:f7:f3:35:0a:fc:6c:28: + b5:53:57:52:01:0a:e1:60:f1:42:f0:a4:d3:e1:4e: + 25:12:83:01:ba:f5:1a:96:44:33:17:b8:69:bc:a4: + b1:2e:b1:e0:e3:50:c6:6f:dc:f7:12:16:40:21:63: + db:14:b1:b1:fe:6f:76:84:f7:ef:a0:bb:0b:dc:03: + 44:b6:2a:f0:61:7b:7c:4a:7a:51:9b:ab:01:8f:10: + a8:db:10:62:c3:72:3b:2c:fc:b5:03:e2:73:e6:1d: + d0:3e:a5:83:f5:ae:30:4c:d8:79:28:d1:d1:5c:61: + 84:2d:8c:0d:8d:39:ce:a6:15:21:0b:4b:cd:29:28: + 72:ed:9e:63:7d:73:bd:70:f3:29:4f:c5:c4:95:ef: + dc:a7:28:27:af:36:91:e0:53:ef:4e:7d:ba:50:34: + 83:51 + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Basic Constraints: + CA:FALSE + X509v3 Subject Key Identifier: + 5C:FC:F9:A1:E6:BE:75:F7:02:23:72:3B:F2:09:B5:A9:D2:8C:A1:3B + X509v3 Authority Key Identifier: + keyid:B4:CB:07:6E:5F:78:A9:15:7B:39:6E:54:34:D8:E4:DA:62:7E:FD:4F + DirName:/CN=ChangeMe + serial:07:0A:94:60:3D:68:C4:12:9E:22:66:8C:65:F9:B8:E1:3B:13:00:5D + X509v3 Extended Key Usage: + TLS Web Server Authentication + X509v3 Key Usage: + Digital Signature, Key Encipherment + X509v3 Subject Alternative Name: + DNS:ChangeMe + Signature Algorithm: sha256WithRSAEncryption + Signature Value: + 6c:df:63:30:de:ae:e7:4a:07:be:c3:c6:78:fe:91:f4:89:c1: + 41:fc:58:d3:52:e8:bd:ab:6b:a1:68:d5:8a:36:4f:6f:21:68: + 2a:07:c6:cd:56:7f:8b:f9:0d:00:f7:9f:ba:2f:84:79:08:a2: + 53:8b:4b:76:6b:49:59:bb:9a:51:45:63:c3:25:ce:d2:46:61: + fe:2c:86:d4:ae:f7:bb:de:c2:f1:4f:8d:46:6e:a6:f3:cb:25: + 72:75:e7:eb:c6:a2:10:34:8a:a9:ca:9c:b4:ba:9c:e0:50:6d: + cd:91:a9:97:37:be:d7:40:e1:21:ba:a8:fe:8f:0d:96:2d:19: + a0:10:41:8b:cf:16:4a:a3:83:24:96:62:11:0f:e1:76:5d:46: + 1e:60:1d:2f:9d:1c:87:de:b0:1b:f7:26:61:13:af:41:44:01: + b6:dd:40:de:94:20:04:5e:68:42:79:7b:13:03:b0:6c:5f:d2: + ff:3c:15:6b:ca:21:57:69:61:de:05:68:b1:9e:e5:f8:be:c2: + 38:c7:1f:53:2e:da:7b:fd:26:fa:83:8e:5d:06:70:d9:7d:9e: + c1:75:99:70:f7:3e:66:e4:95:8e:43:ec:4a:9d:bd:0f:d7:08: + 64:f1:5f:f8:94:46:6e:46:20:44:5f:71:0b:2e:e2:0d:87:eb: + 69:cb:86:af +-----BEGIN CERTIFICATE----- +MIIDZTCCAk2gAwIBAgIQYXTlaBFjvrv6+k1jEq36ajANBgkqhkiG9w0BAQsFADAT +MREwDwYDVQQDDAhDaGFuZ2VNZTAeFw0yNTA5MjMxMjQ2NDJaFw0yNzEyMjcxMjQ2 +NDJaMBMxETAPBgNVBAMMCENoYW5nZU1lMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEArmZlOzluqjk5f/G+GMRSYMM8Y3cq/dB5Impft6udlCeJmlwtt+pm +kfcGVyQ4vVVxLf+a3bPtDL8bjJMnY9ShpwBVaMWgxJ7TUdfs+J1+saSEgHibdlhh +uYnJlOWtymEz4PfzNQr8bCi1U1dSAQrhYPFC8KTT4U4lEoMBuvUalkQzF7hpvKSx +LrHg41DGb9z3EhZAIWPbFLGx/m92hPfvoLsL3ANEtirwYXt8SnpRm6sBjxCo2xBi +w3I7LPy1A+Jz5h3QPqWD9a4wTNh5KNHRXGGELYwNjTnOphUhC0vNKShy7Z5jfXO9 +cPMpT8XEle/cpygnrzaR4FPvTn26UDSDUQIDAQABo4G0MIGxMAkGA1UdEwQCMAAw +HQYDVR0OBBYEFFz8+aHmvnX3AiNyO/IJtanSjKE7ME4GA1UdIwRHMEWAFLTLB25f +eKkVezluVDTY5Npifv1PoRekFTATMREwDwYDVQQDDAhDaGFuZ2VNZYIUBwqUYD1o +xBKeImaMZfm44TsTAF0wEwYDVR0lBAwwCgYIKwYBBQUHAwEwCwYDVR0PBAQDAgWg +MBMGA1UdEQQMMAqCCENoYW5nZU1lMA0GCSqGSIb3DQEBCwUAA4IBAQBs32Mw3q7n +Sge+w8Z4/pH0icFB/FjTUui9q2uhaNWKNk9vIWgqB8bNVn+L+Q0A95+6L4R5CKJT +i0t2a0lZu5pRRWPDJc7SRmH+LIbUrve73sLxT41GbqbzyyVydefrxqIQNIqpypy0 +upzgUG3NkamXN77XQOEhuqj+jw2WLRmgEEGLzxZKo4MklmIRD+F2XUYeYB0vnRyH +3rAb9yZhE69BRAG23UDelCAEXmhCeXsTA7BsX9L/PBVryiFXaWHeBWixnuX4vsI4 +xx9TLtp7/Sb6g45dBnDZfZ7BdZlw9z5m5JWOQ+xKnb0P1whk8V/4lEZuRiBEX3EL +LuINh+tpy4av +-----END CERTIFICATE----- diff --git a/tests/local_vpn/server_config/server.key b/tests/local_vpn/server_config/server.key new file mode 100644 index 00000000..b73e742e --- /dev/null +++ b/tests/local_vpn/server_config/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCuZmU7OW6qOTl/ +8b4YxFJgwzxjdyr90Hkial+3q52UJ4maXC236maR9wZXJDi9VXEt/5rds+0MvxuM +kydj1KGnAFVoxaDEntNR1+z4nX6xpISAeJt2WGG5icmU5a3KYTPg9/M1CvxsKLVT +V1IBCuFg8ULwpNPhTiUSgwG69RqWRDMXuGm8pLEuseDjUMZv3PcSFkAhY9sUsbH+ +b3aE9++guwvcA0S2KvBhe3xKelGbqwGPEKjbEGLDcjss/LUD4nPmHdA+pYP1rjBM +2Hko0dFcYYQtjA2NOc6mFSELS80pKHLtnmN9c71w8ylPxcSV79ynKCevNpHgU+9O +fbpQNINRAgMBAAECggEAQBXzfBxiLJ4joX7hRnOZ++GyZrCLNUKuyLVDIBipqqAO +whC+Yhd6Aog+JbZzPSvRD8CeFXsBEE6HnpQShO5FSrtmJz58EdR1Pd11QHSLclbM +s/Ld2dKncokN8K+nubcXW8NxdRvo3wvkedAcG7L2V+vAF/LRwzi2icNnVt6rmuyw +Z/W5/HERlt4IAikKDQhBZrtGx5Cbjun5ekjN2sWFVB7TT2u6o/BsYQ7ljGUZt9uQ +DStfOURAv5BE8eYyWQIxd7fPCfY3UNxpJUPxvuDxpeCwITzD5v8qoVSBBH1lvQ7s +i61/Cr7dfwNsAtlMzrERxRmMR5WQzsfxPvfqhb3IuwKBgQDZo8zBAEXiTnNSl3W5 +1bs1ab8AFTfTzeY2Th2SxDZLdcy5I3dirwfusyQkv5eoRWi5Vx/fNXIh5OJos2Fu +M0CxuuJVP2dkXzBJrazAkzlDhEsG/MGMaIE/p3aFQeyID1EZcN3u3Z9StEbW4B8y +I/8dTgJCnBzfHs2HH80VbQHsZwKBgQDNI44tqevSW6XwwXz1sYJy5NBPrksRLEcU +rhm6rsLMhKXHUJa0KDeOeM9sjiBBrCL/pOkqwcnLUsqZ8pIQIEhwaIBfHznblgxZ +jCho3ZjYm4/Is9XD/lcS2yU5ialRI9kFz6qTkOlO0XonIwJs4NiITbuzfopr2BGh +IlzXcrC/hwKBgQCUl6IvT4lnJprUE/bbx1JG+IjgfJweLyDziMfmMbLEOIxrBwz2 +wnwO/B48PNdFmwYSLKrlEa939raiN37Y54NPFUJ8Y4qq29azJzGgVaQuNb+n6KAY +xi0gkax49PaSOqrrTMUp1gR2SgFnqaOC71K55k3ivoVzzKsUi6DQ9RjwFwKBgEiD +/BuaSJmo+iT8UO8NW96/kf/IzhJ5A3uE++VpJ8ViUrP9gfiXiuQbQr/OEgsFDa4v +HpmVvX7ZenMnM4jt0I2j81Us1agRB7aT/CjtxL01aIN7RuKswx0QSL1pM2hScsJC +Ibtea4sIM9Un5BCW/xRX3jVaUxZCYCEE46rpiR97AoGBAJdQQnQAxYS2Ua7Jj0go +0SgG99w7ONZjSupTTr4VpMaXmh6CBke44RulMUA+PwB1XtfVpxx8xhuPq2d6y79T +o5OLbEjdLPq8A8S0n5eXMD7FXXG8TYPpcqoO2Hqhgu9q1vRgqPopIcRuhhp5wdCp +iIGJHhwsI9sYN6wnGydeOH9U +-----END PRIVATE KEY----- diff --git a/tests/local_vpn/server_config/ta.key b/tests/local_vpn/server_config/ta.key new file mode 100644 index 00000000..2bf036ac --- /dev/null +++ b/tests/local_vpn/server_config/ta.key @@ -0,0 +1,21 @@ +# +# 2048 bit OpenVPN static key +# +-----BEGIN OpenVPN Static key V1----- +488b61084812969fe8ad0f9dd40f56a2 +6cdadddfe345daef6b5c6d3c3e779fc5 +1f7d236966953482d2af085e3f8581b7 +d216f2d891972a463bbb22ca6c104b9d +f99dcb19d7d575a1d46e7918bb2556c6 +db9f51cd792c5e89e011586214692b95 +2a32a7fe85e4538c40e1d0aa2a9f8e15 +fcc0ce5d31974e3c2041b127776f7658 +878cb8245ed235ec996c2370c0fc0023 +699bc028b3412bc40209cba8233bc111 +fa1438095f99052d799fa718f3b04499 +472254d0286b4b2ce99db49e98a4cc25 +fd948bddcdcf08006a6d7bff40354e7b +5e93ea753a8ecc05de41ae34d280e7eb +99220e436bf8b7693a00667485631e28 +edba3e33b6f558dfa50b92eec6ac8b44 +-----END OpenVPN Static key V1----- diff --git a/tests/provision/dummy_project_for_testing.yml b/tests/provision/dummy_project_for_testing.yml index 1ab98c45..ea544be7 100644 --- a/tests/provision/dummy_project_for_testing.yml +++ b/tests/provision/dummy_project_for_testing.yml @@ -4,11 +4,11 @@ description: > Test setup. participants: - - name: server.local + - name: testserver.local type: server org: Test_Org - fed_learn_port: 8002 - admin_port: 8003 + fed_learn_port: 8012 + admin_port: 8013 - name: client_A type: client org: Test_Org @@ -28,13 +28,13 @@ builders: - path: nvflare.lighter.impl.static_file.StaticFileBuilder args: config_folder: config - scheme: grpc - docker_image: jefftud/odelia:__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__ + scheme: http + docker_image: localhost:5000/odelia:__REPLACED_BY_CURRENT_VERSION_NUMBER_WHEN_BUILDING_STARTUP_KITS__ overseer_agent: path: nvflare.ha.dummy_overseer_agent.DummyOverseerAgent overseer_exists: false args: - sp_end_point: odeliatempvm.local:8002:8003 + sp_end_point: testserver.local:8012:8013 - path: nvflare.lighter.impl.cert.CertBuilder - path: nvflare.lighter.impl.signature.SignatureBuilder diff --git a/tests/unit_tests/_run_nvflare_unit_tests.sh b/tests/unit_tests/_run_nvflare_unit_tests.sh new file mode 100755 index 00000000..890406c2 --- /dev/null +++ b/tests/unit_tests/_run_nvflare_unit_tests.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +set -e + +run_nvflare_unit_tests () { + cd /MediSwarm/docker_config/NVFlare + ./runtest.sh -c -r + coverage report -m + cd .. +} + +run_nvflare_unit_tests