11#! /bin/bash -ex
22
3+ # install system packages
34export DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC
45sed -i ' s|http://archive.ubuntu.com|http://azure.archive.ubuntu.com|g' /etc/apt/sources.list
56apt-get update -y
67apt-get install -y --no-install-recommends \
7- tzdata wget curl ssh sudo git-core libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1
8+ tzdata wget curl ssh sudo git-core vim libibverbs1 ibverbs-providers ibverbs-utils librdmacm1 libibverbs-dev rdma-core libmlx5-1
89
910if [[ ${PYTHON_VERSION} != " 3.10" ]]; then
1011 apt-get install -y --no-install-recommends software-properties-common
1112 add-apt-repository -y ppa:deadsnakes/ppa
1213 apt-get update -y
1314fi
1415
16+ # install python, create virtual env
1517apt-get install -y --no-install-recommends \
1618 python${PYTHON_VERSION} python${PYTHON_VERSION} -dev python${PYTHON_VERSION} -venv
1719
1820pushd /opt > /dev/null
1921 python${PYTHON_VERSION} -m venv py3
2022popd > /dev/null
2123
24+ # install CUDA build tools
2225if [[ " ${CUDA_VERSION_SHORT} " = " cu118" ]]; then
2326 apt-get install -y --no-install-recommends cuda-minimal-build-11-8
2427elif [[ " ${CUDA_VERSION_SHORT} " = " cu124" ]]; then
25- apt-get install -y --no-install-recommends cuda-minimal-build-12-4
28+ apt-get install -y --no-install-recommends cuda-minimal-build-12-4 dkms
2629elif [[ " ${CUDA_VERSION_SHORT} " = " cu128" ]]; then
27- apt-get install -y --no-install-recommends cuda-minimal-build-12-8
30+ apt-get install -y --no-install-recommends cuda-minimal-build-12-8 dkms
2831elif [[ " ${CUDA_VERSION_SHORT} " = " cu130" ]]; then
29- apt-get install -y --no-install-recommends cuda-minimal-build-13-0
32+ apt-get install -y --no-install-recommends cuda-minimal-build-13-0 dkms
3033fi
3134
3235apt-get clean -y
3336rm -rf /var/lib/apt/lists/*
3437
38+ # install GDRCopy debs
39+ if [ -d " /debs" ] && [ " $( ls -A /debs/* .deb 2> /dev/null) " ]; then
40+ dpkg -i /debs/* .deb
41+ fi
42+
43+ # install python packages
3544export PATH=/opt/py3/bin:$PATH
3645
3746if [[ " ${CUDA_VERSION_SHORT} " = " cu118" ]]; then
3847 FA_VERSION=2.7.3
3948 TORCH_VERSION=" <2.7"
49+ elif [[ " ${CUDA_VERSION_SHORT} " = " cu130" ]]; then
50+ FA_VERSION=2.8.3
51+ TORCH_VERSION=" ==2.9.0"
4052else
4153 FA_VERSION=2.8.3
42- TORCH_VERSION=" "
54+ # pin torch version to avoid build and runtime mismatch, o.w. deep_gemm undefined symbol error
55+ TORCH_VERSION=" ==2.8.0"
4356fi
4457
4558pip install -U pip wheel setuptools
@@ -50,13 +63,15 @@ elif [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then
5063 pip install nvidia-nvshmem-cu12
5164fi
5265
53- pip install /wheels/* .whl torch${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT}
66+ pip install torch${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/${CUDA_VERSION_SHORT}
67+ pip install /wheels/* .whl
5468
5569
5670if [[ " ${CUDA_VERSION_SHORT} " != " cu118" ]] && [[ " ${PYTHON_VERSION} " != " 3.9" ]]; then
57- pip install cuda-python dlblas
71+ pip install cuda-python dlblas==0.0.6
5872fi
5973
74+ # install pre-compiled flash attention wheel
6075PLATFORM=" linux_x86_64"
6176PY_VERSION=$( python3 - << 'PY '
6277import torch, sys
0 commit comments