Notebooks
H
Hugging Face
Geodiff Molecule Conformation

Geodiff Molecule Conformation

hf-notebooksdiffusers

Introduction

This colab is design to run the pretrained models from GeoDiff. The visualization code is inspired by this PyMol colab.

The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.

This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.

Colab made by natolambert.

diffusers_library

Installations

Install Conda

Here we check the cuda version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below.

[ ]
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Feb_14_21:12:58_PST_2021
Cuda compilation tools, release 11.2, V11.2.152
Build cuda_11.2.r11.2/compiler.29618528_0

Install Conda for some more complex dependencies for geometric networks.

[ ]
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Setup Conda

[ ]
✨🍰✨ Everything looks OK!

Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)

[ ]
Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / done
Solving environment: \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - cudatoolkit=11.1
    - pytorch
    - torchaudio
    - torchvision


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    conda-22.9.0               |   py37h89c1867_1         960 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         960 KB

The following packages will be UPDATED:

  conda                               4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1



Downloading and Extracting Packages
conda-22.9.0         | 960 KB    | : 100% 1.0/1 [00:00<00:00,  4.15it/s]
Preparing transaction: / done
Verifying transaction: \ done
Executing transaction: / done
Retrieving notices: ...working... done

Need to remove a pathspec for colab that specifies the incorrect cuda version.

[ ]
rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory

Install torch geometric (used in the model later)

[ ]
Collecting package metadata (current_repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - done
Solving environment: | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - pytorch-geometric=1.7.2


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    decorator-4.4.2            |             py_0          11 KB  conda-forge
    googledrivedownloader-0.4  |     pyhd3deb0d_1           7 KB  conda-forge
    jinja2-3.1.2               |     pyhd8ed1ab_1          99 KB  conda-forge
    joblib-1.2.0               |     pyhd8ed1ab_0         205 KB  conda-forge
    markupsafe-2.1.1           |   py37h540881e_1          22 KB  conda-forge
    networkx-2.5.1             |     pyhd8ed1ab_0         1.2 MB  conda-forge
    pandas-1.2.3               |   py37hdc94413_0        11.8 MB  conda-forge
    pyparsing-3.0.9            |     pyhd8ed1ab_0          79 KB  conda-forge
    python-dateutil-2.8.2      |     pyhd8ed1ab_0         240 KB  conda-forge
    python-louvain-0.15        |     pyhd8ed1ab_1          13 KB  conda-forge
    pytorch-cluster-1.5.9      |py37_torch_1.8.0_cu111         1.2 MB  rusty1s
    pytorch-geometric-1.7.2    |py37_torch_1.8.0_cu111         445 KB  rusty1s
    pytorch-scatter-2.0.8      |py37_torch_1.8.0_cu111         6.1 MB  rusty1s
    pytorch-sparse-0.6.12      |py37_torch_1.8.0_cu111         2.9 MB  rusty1s
    pytorch-spline-conv-1.2.1  |py37_torch_1.8.0_cu111         736 KB  rusty1s
    pytz-2022.4                |     pyhd8ed1ab_0         232 KB  conda-forge
    scikit-learn-1.0.2         |   py37hf9e9bfc_0         7.8 MB  conda-forge
    scipy-1.7.3                |   py37hf2a6cf1_0        21.8 MB  conda-forge
    setuptools-59.8.0          |   py37h89c1867_1         1.0 MB  conda-forge
    threadpoolctl-3.1.0        |     pyh8a188c0_0          18 KB  conda-forge
    ------------------------------------------------------------
                                           Total:        55.9 MB

The following NEW packages will be INSTALLED:

  decorator          conda-forge/noarch::decorator-4.4.2-py_0 None
  googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None
  jinja2             conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None
  joblib             conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None
  markupsafe         conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None
  networkx           conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None
  pandas             conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None
  pyparsing          conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None
  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None
  python-louvain     conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None
  pytorch-cluster    rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None
  pytorch-geometric  rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None
  pytorch-scatter    rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None
  pytorch-sparse     rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None
  pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None
  pytz               conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None
  scikit-learn       conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None
  scipy              conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None
  threadpoolctl      conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None

The following packages will be DOWNGRADED:

  setuptools                          65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None



Downloading and Extracting Packages
scikit-learn-1.0.2   | 7.8 MB    | : 100% 1.0/1 [00:01<00:00,  1.37s/it]              
pytorch-scatter-2.0. | 6.1 MB    | : 100% 1.0/1 [00:06<00:00,  6.18s/it]
pytorch-geometric-1. | 445 KB    | : 100% 1.0/1 [00:02<00:00,  2.53s/it]
scipy-1.7.3          | 21.8 MB   | : 100% 1.0/1 [00:03<00:00,  3.06s/it]
python-dateutil-2.8. | 240 KB    | : 100% 1.0/1 [00:00<00:00, 21.48it/s]
pytorch-spline-conv- | 736 KB    | : 100% 1.0/1 [00:01<00:00,  1.00s/it]
pytorch-sparse-0.6.1 | 2.9 MB    | : 100% 1.0/1 [00:07<00:00,  7.51s/it]
pyparsing-3.0.9      | 79 KB     | : 100% 1.0/1 [00:00<00:00, 26.32it/s]
pytorch-cluster-1.5. | 1.2 MB    | : 100% 1.0/1 [00:02<00:00,  2.78s/it]
jinja2-3.1.2         | 99 KB     | : 100% 1.0/1 [00:00<00:00, 20.28it/s]
decorator-4.4.2      | 11 KB     | : 100% 1.0/1 [00:00<00:00, 21.57it/s]
joblib-1.2.0         | 205 KB    | : 100% 1.0/1 [00:00<00:00, 15.04it/s]
pytz-2022.4          | 232 KB    | : 100% 1.0/1 [00:00<00:00, 10.21it/s]
python-louvain-0.15  | 13 KB     | : 100% 1.0/1 [00:00<00:00,  3.34it/s]
googledrivedownloade | 7 KB      | : 100% 1.0/1 [00:00<00:00,  3.33it/s]
threadpoolctl-3.1.0  | 18 KB     | : 100% 1.0/1 [00:00<00:00, 29.40it/s]
markupsafe-2.1.1     | 22 KB     | : 100% 1.0/1 [00:00<00:00, 28.62it/s]
pandas-1.2.3         | 11.8 MB   | : 100% 1.0/1 [00:02<00:00,  2.08s/it]               
networkx-2.5.1       | 1.2 MB    | : 100% 1.0/1 [00:01<00:00,  1.39s/it]
setuptools-59.8.0    | 1.0 MB    | : 100% 1.0/1 [00:00<00:00,  4.25it/s]
Preparing transaction: / - \ done
Verifying transaction: / - \ | / - \ | / - \ | / - \ done
Executing transaction: / - \ | / - \ | / - \ | / - \ | / - \ | / - done
Retrieving notices: ...working... done

Install Diffusers

[ ]
/content
Cloning into 'diffusers'...
remote: Enumerating objects: 9298, done.
remote: Counting objects: 100% (40/40), done.
remote: Compressing objects: 100% (23/23), done.
remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258
Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.
Resolving deltas: 100% (6168/6168), done.
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 757.0/757.0 kB 52.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 163.5/163.5 kB 21.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 40.8/40.8 kB 5.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 596.3/596.3 kB 51.7 MB/s eta 0:00:00
  Building wheel for diffusers (pyproject.toml) ... done
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 432.7/432.7 kB 36.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 90.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.3/35.3 MB 39.7 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.1/115.1 kB 16.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 948.0/948.0 kB 63.6 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.2/212.2 kB 21.5 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 95.8/95.8 kB 12.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.8/140.8 kB 18.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.6/7.6 MB 104.3 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 148.0/148.0 kB 20.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 231.3/231.3 kB 30.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 94.8/94.8 kB 14.0 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.8/58.8 kB 8.0 MB/s eta 0:00:00
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Check that torch is installed correctly and utilizing the GPU in the colab

[ ]
True
'1.8.2'

Install Chemistry-specific Dependencies

Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later).

[ ]
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rdkit
  Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 36.8/36.8 MB 34.6 MB/s eta 0:00:00
Requirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)
Installing collected packages: rdkit
Successfully installed rdkit-2022.3.5
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

Get viewer from nglview

The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors). The data we give to the model will also have a rdmol object (which can extract features to geometric if needed). The rdmol in this object is a source of ground truth for the generated molecules.

You will use one rendering function from nglviewer later!

[ ]
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting nglview
  Downloading nglview-3.0.3.tar.gz (5.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 91.2 MB/s eta 0:00:00
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)
Collecting jupyterlab-widgets
  Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 384.1/384.1 kB 40.6 MB/s eta 0:00:00
Collecting ipywidgets>=7
  Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.4/134.4 kB 21.2 MB/s eta 0:00:00
Collecting widgetsnbextension~=4.0
  Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 84.4 MB/s eta 0:00:00
Collecting ipython>=6.1.0
  Downloading ipython-7.34.0-py3-none-any.whl (793 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 793.8/793.8 kB 60.7 MB/s eta 0:00:00
Collecting ipykernel>=4.5.1
  Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.4/138.4 kB 20.9 MB/s eta 0:00:00
Collecting traitlets>=4.3.1
  Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 107.1/107.1 kB 17.0 MB/s eta 0:00:00
Requirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)
Collecting pyzmq>=17
  Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 68.4 MB/s eta 0:00:00
Collecting matplotlib-inline>=0.1
  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)
Collecting tornado>=6.1
  Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 424.0/424.0 kB 41.2 MB/s eta 0:00:00
Collecting nest-asyncio
  Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)
Collecting debugpy>=1.0
  Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 MB 83.4 MB/s eta 0:00:00
Collecting psutil
  Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 281.3/281.3 kB 33.1 MB/s eta 0:00:00
Collecting jupyter-client>=6.1.12
  Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 132.2/132.2 kB 19.7 MB/s eta 0:00:00
Collecting pickleshare
  Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)
Collecting backcall
  Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)
Collecting pexpect>4.3
  Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.0/59.0 kB 7.3 MB/s eta 0:00:00
Collecting pygments
  Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 70.9 MB/s eta 0:00:00
Collecting jedi>=0.16
  Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 83.5 MB/s eta 0:00:00
Collecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0
  Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 382.3/382.3 kB 40.0 MB/s eta 0:00:00
Requirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)
Collecting parso<0.9.0,>=0.8.0
  Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.8/100.8 kB 14.7 MB/s eta 0:00:00
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)
Collecting entrypoints
  Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)
Collecting jupyter-core>=4.9.2
  Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.4/88.4 kB 14.0 MB/s eta 0:00:00
Collecting ptyprocess>=0.5
  Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)
Collecting wcwidth
  Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)
Building wheels for collected packages: nglview
  Building wheel for nglview (pyproject.toml) ... done
  Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514
  Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5
Successfully built nglview
Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview
Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
 

Create a diffusion model

Model class(es)

Imports

[ ]

Helper classes

[ ]

Main model class!

[ ]

Load pretrained model

Load a model

The model used is a design an equivariant convolutional layer, named graph field network (GFN).

The warning about betas and alphas can be ignored, those were moved to the scheduler.

[ ]
Downloading:   0%|          | 0.00/3.27M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/401 [00:00<?, ?B/s]
The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.
Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']
- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

The warnings above are because the pre-trained model was uploaded before cleaning the code!

Create scheduler

Note, other schedulers are used in the paper for slightly improved performance over DDPM.

[ ]
[ ]

Get a dataset

Grab a google tool so we can upload our data directly. Note you need to download the data from this file

(direct downloading from the hub does not yet work for this datatype)

[ ]
[ ]

Load the dataset with torch.

[ ]
--2022-10-12 18:32:19--  https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl
Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...
Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 127774 (125K) [application/octet-stream]
Saving to: ‘molecules.pkl’

molecules.pkl       100%[===================>] 124.78K   180KB/s    in 0.7s    

2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]

Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more.

[ ]
Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=<rdkit.Chem.rdchem.Mol object at 0x7f707d2cb130>, smiles="CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1")

Run the diffusion process

Helper Functions

[ ]

Constants

[ ]

Generate samples!

Note that the 3d representation of a molecule is referred to as the conformation

[ ]
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  after removing the cwd from sys.path.
100%|██████████| 5/5 [00:55<00:00, 11.06s/it]

Render the results!

This function allows us to render 3d in colab.

[ ]

Helper functions

Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer.

[ ]

Process the generated data to make it easy to view.

[ ]
collect 5 generated molecules in `mols`

Import tools to visualize the 2d chemical diagram of the molecule.

[ ]

Select molecule to visualize

[ ]

Viewing

This 2D rendering is the equivalent of the input to the model!

[ ]

Generate the 3d molecule!

[ ]
 
[ ]
NGLWidget()
[ ]