--- a
+++ b/torchdrug/datasets/uspto50k.py
@@ -0,0 +1,264 @@
+import os
+import copy
+from collections import defaultdict
+
+import numpy as np
+import networkx as nx
+from tqdm import tqdm
+from rdkit import Chem
+
+import torch
+from torch.utils import data as torch_data
+from torch_scatter import scatter_max
+
+from torchdrug import data, utils
+from torchdrug.core import Registry as R
+
+
+@R.register("datasets.USPTO50k")
+@utils.copy_args(data.ReactionDataset.load_csv, ignore=("smiles_field", "target_fields"))
+class USPTO50k(data.ReactionDataset):
+    """
+    Chemical reactions extracted from USPTO patents.
+
+    Statistics:
+        - #Reaction: 50,017
+        - #Reaction class: 10
+
+    Parameters:
+        path (str): path to store the dataset
+        as_synthon (bool, optional): whether decompose (reactant, product) pairs into (reactant, synthon) pairs
+        verbose (int, optional): output verbose level
+        **kwargs
+    """
+
+    target_fields = ["class"]
+    target_alias = {"class": "reaction"}
+
+    reaction_names = ["Heteroatom alkylation and arylation",
+                      "Acylation and related processes",
+                      "C-C bond formation",
+                      "Heterocycle formation",
+                      "Protections",
+                      "Deprotections",
+                      "Reductions",
+                      "Oxidations",
+                      "Functional group interconversion (FGI)",
+                      "Functional group addition (FGA)"]
+
+    url = "https://raw.githubusercontent.com/connorcoley/retrosim/master/retrosim/data/data_processed.csv"
+    md5 = "404c361dd1568fbdb4d16ca588953749"
+
+    def __init__(self, path, as_synthon=False, verbose=1, **kwargs):
+        path = os.path.expanduser(path)
+        if not os.path.exists(path):
+            os.makedirs(path)
+        self.path = path
+        self.as_synthon = as_synthon
+
+        file_name = utils.download(self.url, path, md5=self.md5)
+
+        self.load_csv(file_name, smiles_field="rxn_smiles", target_fields=self.target_fields, verbose=verbose,
+                      **kwargs)
+
+        if as_synthon:
+            prefix = "Computing synthons"
+            process_fn = self._get_synthon
+        else:
+            prefix = "Computing reaction centers"
+            process_fn = self._get_reaction_center
+
+        data = self.data
+        targets = self.targets
+        self.data = []
+        self.targets = defaultdict(list)
+        indexes = range(len(data))
+        if verbose:
+            indexes = tqdm(indexes, prefix)
+        invalid = 0
+        for i in indexes:
+            reactant, product = data[i]
+            reactant.bond_stereo[:] = 0
+            product.bond_stereo[:] = 0
+
+            reactants, products = process_fn(reactant, product)
+            if not reactants:
+                invalid += 1
+                continue
+
+            self.data += zip(reactants, products)
+            for k in targets:
+                new_k = self.target_alias.get(k, k)
+                self.targets[new_k] += [targets[k][i] - 1] * len(reactants)
+            self.targets["sample id"] += [i] * len(reactants)
+
+        self.valid_rate = 1 - invalid / len(data)
+
+    def _get_difference(self, reactant, product):
+        product2id = product.atom_map
+        id2reactant = torch.zeros(product2id.max() + 1, dtype=torch.long)
+        id2reactant[reactant.atom_map] = torch.arange(reactant.num_node)
+        prod2react = id2reactant[product2id]
+
+        # check edges in the product
+        product = product.directed()
+        # O(n^2) brute-force match is faster than O(nlogn) data.Graph.match for small molecules
+        mapped_edge = product.edge_list.clone()
+        mapped_edge[:, :2] = prod2react[mapped_edge[:, :2]]
+        is_same_index = mapped_edge.unsqueeze(0) == reactant.edge_list.unsqueeze(1)
+        has_typed_edge = is_same_index.all(dim=-1).any(dim=0)
+        has_edge = is_same_index[:, :, :2].all(dim=-1).any(dim=0)
+        is_added = ~has_edge
+        is_modified = has_edge & ~has_typed_edge
+        edge_added = product.edge_list[is_added, :2]
+        edge_modified = product.edge_list[is_modified, :2]
+
+        return edge_added, edge_modified, prod2react
+
+    def _get_reaction_center(self, reactant, product):
+        edge_added, edge_modified, prod2react = self._get_difference(reactant, product)
+
+        edge_label = torch.zeros(product.num_edge, dtype=torch.long)
+        node_label = torch.zeros(product.num_node, dtype=torch.long)
+
+        if len(edge_added) > 0:
+            if len(edge_added) == 1: # add a single edge
+                any = -torch.ones(1, 1, dtype=torch.long)
+                pattern = torch.cat([edge_added, any], dim=-1)
+                index, num_match = product.match(pattern)
+                assert num_match.item() == 1
+                edge_label[index] = 1
+                h, t = edge_added[0]
+                reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]])
+        else:
+            if len(edge_modified) == 1: # modify a single edge
+                h, t = edge_modified[0]
+                if product.degree_in[h] == 1:
+                    node_label[h] = 1
+                    reaction_center = torch.tensor([product.atom_map[h], 0])
+                elif product.degree_in[t] == 1:
+                    node_label[t] = 1
+                    reaction_center = torch.tensor([product.atom_map[t], 0])
+                else:
+                    # pretend the reaction center is h
+                    node_label[h] = 1
+                    reaction_center = torch.tensor([product.atom_map[h], 0])
+            else:
+                product_hs = torch.tensor([atom.GetTotalNumHs() for atom in product.to_molecule().GetAtoms()])
+                reactant_hs = torch.tensor([atom.GetTotalNumHs() for atom in reactant.to_molecule().GetAtoms()])
+                atom_modified = (product_hs != reactant_hs[prod2react]).nonzero().flatten()
+                if len(atom_modified) == 1: # modify single node
+                    node_label[atom_modified] = 1
+                    reaction_center = torch.tensor([product.atom_map[atom_modified[0]], 0])
+
+        if edge_label.sum() + node_label.sum() == 0:
+            return [], []
+
+        with product.edge():
+            product.edge_label = edge_label
+        with product.node():
+            product.node_label = node_label
+        with reactant.graph():
+            reactant.reaction_center = reaction_center
+        with product.graph():
+            product.reaction_center = reaction_center
+        return [reactant], [product]
+
+    def _get_synthon(self, reactant, product):
+        edge_added, edge_modified, prod2react = self._get_difference(reactant, product)
+
+        reactants = []
+        synthons = []
+
+        if len(edge_added) > 0:
+            if len(edge_added) == 1:  # add a single edge
+                reverse_edge = edge_added.flip(1)
+                any = -torch.ones(2, 1, dtype=torch.long)
+                pattern = torch.cat([edge_added, reverse_edge])
+                pattern = torch.cat([pattern, any], dim=-1)
+                index, num_match = product.match(pattern)
+                edge_mask = torch.ones(product.num_edge, dtype=torch.bool)
+                edge_mask[index] = 0
+                product = product.edge_mask(edge_mask)
+                _reactants = reactant.connected_components()[0]
+                _synthons = product.connected_components()[0]
+                assert len(_synthons) >= len(_reactants) # because a few samples contain multiple products
+
+                h, t = edge_added[0]
+                reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]])
+                with _reactants.graph():
+                    _reactants.reaction_center = reaction_center.expand(len(_reactants), -1)
+                with _synthons.graph():
+                    _synthons.reaction_center = reaction_center.expand(len(_synthons), -1)
+                # reactant / sython can be uniquely indexed by their maximal atom mapping ID
+                reactant_id = scatter_max(_reactants.atom_map, _reactants.node2graph, dim_size=len(_reactants))[0]
+                synthon_id = scatter_max(_synthons.atom_map, _synthons.node2graph, dim_size=len(_synthons))[0]
+                react2synthon = (reactant_id.unsqueeze(-1) == synthon_id.unsqueeze(0)).long().argmax(-1)
+                react2synthon = react2synthon.tolist()
+                for r, s in enumerate(react2synthon):
+                    reactants.append(_reactants[r])
+                    synthons.append(_synthons[s])
+        else:
+            num_cc = reactant.connected_components()[1]
+            assert num_cc == 1
+
+            if len(edge_modified) == 1:  # modify a single edge
+                synthon = product
+                h, t = edge_modified[0]
+                if product.degree_in[h] == 1:
+                    reaction_center = torch.tensor([product.atom_map[h], 0])
+                elif product.degree_in[t] == 1:
+                    reaction_center = torch.tensor([product.atom_map[t], 0])
+                else:
+                    # pretend the reaction center is h
+                    reaction_center = torch.tensor([product.atom_map[h], 0])
+                with reactant.graph():
+                    reactant.reaction_center = reaction_center
+                with synthon.graph():
+                    synthon.reaction_center = reaction_center
+                reactants.append(reactant)
+                synthons.append(synthon)
+            else:
+                product_hs = torch.tensor([atom.GetTotalNumHs() for atom in product.to_molecule().GetAtoms()])
+                reactant_hs = torch.tensor([atom.GetTotalNumHs() for atom in reactant.to_molecule().GetAtoms()])
+                atom_modified = (product_hs != reactant_hs[prod2react]).nonzero().flatten()
+                if len(atom_modified) == 1:  # modify single node
+                    synthon = product
+                    reaction_center = torch.tensor([product.atom_map[atom_modified[0]], 0])
+                    with reactant.graph():
+                        reactant.reaction_center = reaction_center
+                    with synthon.graph():
+                        synthon.reaction_center = reaction_center
+                    reactants.append(reactant)
+                    synthons.append(synthon)
+
+        return reactants, synthons
+
+    def split(self, ratios=(0.8, 0.1, 0.1)):
+        react2index = defaultdict(list)
+        react2sample = defaultdict(list)
+        for i in range(len(self)):
+            reaction = self.targets["reaction"][i]
+            sample_id = self.targets["sample id"][i]
+            react2index[reaction].append(i)
+            react2sample[reaction].append(sample_id)
+
+        indexes = [[] for _ in ratios]
+        for reaction in react2index:
+            num_sample = len(set(react2sample[reaction]))
+            key_lengths = [int(round(num_sample * ratio)) for ratio in ratios]
+            key_lengths[-1] = num_sample - sum(key_lengths[:-1])
+            react_indexes = data.key_split(react2index[reaction], react2sample[reaction], key_lengths=key_lengths)
+            for index, react_index in zip(indexes, react_indexes):
+                index += [i for i in react_index]
+
+        return [torch_data.Subset(self, index) for index in indexes]
+
+    @property
+    def num_reaction_type(self):
+        return len(self.reaction_types)
+
+    @utils.cached_property
+    def reaction_types(self):
+        """All reaction types."""
+        return sorted(set(self.target["class"]))
\ No newline at end of file