Skip to content
Snippets Groups Projects
Commit 1d6a9324 authored by Seungmin Kim's avatar Seungmin Kim
Browse files

Fix PyTorch

parent 7e0fe44f
No related branches found
No related tags found
No related merge requests found
Pipeline #55752 passed
...@@ -15,27 +15,27 @@ RUN mamba install -c conda-forge --yes \ ...@@ -15,27 +15,27 @@ RUN mamba install -c conda-forge --yes \
'rasterio' \ 'rasterio' \
'gdal' \ 'gdal' \
'libgdal-arrow-parquet' \ 'libgdal-arrow-parquet' \
&& mamba install -c pytorch --yes \
'pytorch' \
'torchvision' \
'torchaudio' \
'cpuonly' \
&& mamba clean --all -f -y \ && mamba clean --all -f -y \
&& pip install --no-cache-dir --index-url "https://download.pytorch.org/whl/cpu" \
# PyTorch with CPU
torch \
torchvision \
torchaudio \
&& if [ "$(uname -m)" = "x86_64" ]; then TF_POSTFIX="-cpu"; else TF_POSTFIX=""; fi \ && if [ "$(uname -m)" = "x86_64" ]; then TF_POSTFIX="-cpu"; else TF_POSTFIX=""; fi \
&& pip install --no-cache-dir \ && pip install --no-cache-dir \
"tensorflow${TF_POSTFIX}==$(python3 -c 'import tensorflow as tf; print(tf.__version__)' 2>/dev/null)" \ "tensorflow${TF_POSTFIX}==$(python3 -c 'import tensorflow as tf; print(tf.__version__)' 2>/dev/null)" \
# JAX and Flax with CPU
jax \
flax \
# Keras and other packages # Keras and other packages
keras \ keras \
keras-hub \ keras-hub \
keras-nlp \ keras-nlp \
keras-tuner \ keras-tuner \
# JAX and Flax with CPU
jax \
flax \
# https://github.com/explosion/cython-blis/issues/117#issuecomment-2599094860
'spacy<3.8.0' \
accelerate \ accelerate \
deepspeed \ deepspeed \
# https://github.com/explosion/cython-blis/issues/117#issuecomment-2599094860
'spacy<3.8.0' \
fastai \ fastai \
lightning \ lightning \
transformers \ transformers \
......
...@@ -15,32 +15,30 @@ RUN mamba install -c conda-forge --yes \ ...@@ -15,32 +15,30 @@ RUN mamba install -c conda-forge --yes \
'rasterio' \ 'rasterio' \
'gdal' \ 'gdal' \
'libgdal-arrow-parquet' \ 'libgdal-arrow-parquet' \
&& mamba install -c pytorch -c nvidia --yes \
# PyTorch with CUDA, upgrade pytorch-cuda when available
'pytorch' \
'torchvision' \
'torchaudio' \
'pytorch-cuda=12.4' \
&& mamba install -c conda-forge -c nvidia --yes \
# NVIDIA cuQuantum SDK
'cuquantum' \
'cuquantum-python' \
'cuda-version=12.4' \
&& mamba clean --all -f -y \ && mamba clean --all -f -y \
&& pip install --no-cache-dir --extra-index-url="https://pypi.nvidia.com" --index-url "https://download.pytorch.org/whl/cu124" \
# PyTorch with CUDA, upgrade index-url when available
torch \
torchvision \
torchaudio \
&& pip install --no-cache-dir --extra-index-url="https://pypi.nvidia.com" \
# NVIDIA cuQuantum SDK
'cuquantum-cu12' \
'cuquantum-python-cu12' \
&& pip install --no-cache-dir --extra-index-url="https://pypi.nvidia.com" \ && pip install --no-cache-dir --extra-index-url="https://pypi.nvidia.com" \
"tensorflow[and-cuda]==$(python3 -c 'import tensorflow as tf; print(tf.__version__)' 2>/dev/null)" \ "tensorflow[and-cuda]==$(python3 -c 'import tensorflow as tf; print(tf.__version__)' 2>/dev/null)" \
# JAX and Flax with CUDA
'jax[cuda12]' \
flax \
# Keras and other packages # Keras and other packages
keras \ keras \
keras-hub \ keras-hub \
keras-nlp \ keras-nlp \
keras-tuner \ keras-tuner \
# JAX and Flax with CUDA
'jax[cuda12]' \
flax \
# https://github.com/explosion/cython-blis/issues/117#issuecomment-2599094860
'spacy<3.8.0' \
accelerate \ accelerate \
deepspeed \ deepspeed \
# https://github.com/explosion/cython-blis/issues/117#issuecomment-2599094860
'spacy<3.8.0' \
fastai \ fastai \
lightning \ lightning \
transformers \ transformers \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment