1. Stop X server:
sudo service lightdm stop2. Uninstall any existing driver:
sudo nvidia-uninstall 3. Disable nouveau: open the config file: sudo vim /etc/modprobe.d/blacklist-nouveau.conf and copy
blacklist nouveaublacklist lbm-nouveauoptions nouveau modset=0alias nouveau offalias lbm-nouveau off4. Update driver:
echo options nouveau modeset=0 | sudo tee -a /etc/modprobe.d/nouveau-kms.confsudo update-initramfs -usudo rebootsudo service lightdm stop5. Download driver
wget http://us.download.nvidia.com/XFree86/Linux-x86_64/390.48/NVIDIA-Linux-x86_64-390.48.run6. Install driver
chmod +x NVIDIA-Linux-x86_64-390.48.runsudo ./NVIDIA-Linux-x86_64-390.48.runsudo rebootThe instructions are from this webpage which worked for me
1. Download CUDA 9.0 toolkits (and patches) from here. Run the installer and the patches. Do not re-install NVIDIA drivers
sudo chmod +x cuda_9.0.176_384.81_linux.run./cuda_9.0.176_384.81_linux.run --override2. Install CUDNN-7.0 from here. I downloaded 7.1.4 for CUDA 9.0
# Unpack the archivetar -zxvf cudnn-9.0-linux-x64-v7.tgz# Move the unpacked contents to your CUDA directorysudo cp -P cuda/lib64/libcudnn* /usr/local/cuda-9.0/lib64/sudo cp cuda/include/cudnn.h /usr/local/cuda-9.0/include/# Give read access to all userssudo chmod a+r /usr/local/cuda-9.0/include/cudnn.h /usr/local/cuda/lib64/libcudnn*3. Install libcutpi
sudo apt-get install libcupti-dev4. Update the environmental variables
export PATH=/usr/local/cuda-9.0/bin${PATH:+:${PATH}}export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}Follow the instructions on this website. I installed it locally on a new Anaconda environmental variable which was cloned from base. I use python 2.7, but similar process for python 3 I believe.
pip install torch torchvision# Check if pytorch is using cuda In [1]: import torchIn [2]: torch.cuda.is_available()Out[2]: TrueIn [3]: a = torch.rand(5,5)In [4]: aOut[4]: tensor([[0.0941, 0.9865, 0.4618, 0.0944, 0.6364], [0.3783, 0.9316, 0.9661, 0.0253, 0.6062], [0.2219, 0.7945, 0.8306, 0.6733, 0.4095], [0.0396, 0.6149, 0.6605, 0.2157, 0.0762], [0.0193, 0.2906, 0.5172, 0.6390, 0.7789]])