Diff of /setup.py [000000] .. [15fc01]

Switch to unified view

a b/setup.py
1
from setuptools import find_packages, setup
2
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
3
4
def get_extensions():
5
 
6
    return extensions
7
8
9
if __name__ == "__main__":
10
    extensions = [
11
        CUDAExtension(
12
            "broadcast",
13
            sources=[
14
                "Cluster-ViT/models/extensions/broadcast.cu"
15
            ],
16
            extra_compile_args=["-arch=compute_50"]
17
        ),
18
        CUDAExtension(
19
            "weighted_sum",
20
            sources=[
21
                "Cluster-ViT/models/extensions/weighted_sum.cu"
22
            ],
23
            extra_compile_args=["-arch=compute_50"]
24
        )
25
    ]
26
27
    setup(
28
        name="clutering-Transformer",
29
        packages=find_packages(),
30
        ext_modules=extensions,
31
        cmdclass={"build_ext": BuildExtension},
32
        install_requires=["torch"]
33
    )