|
a |
|
b/setup.py |
|
|
1 |
from setuptools import find_packages, setup |
|
|
2 |
from torch.utils.cpp_extension import CUDAExtension, BuildExtension |
|
|
3 |
|
|
|
4 |
def get_extensions(): |
|
|
5 |
|
|
|
6 |
return extensions |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
if __name__ == "__main__": |
|
|
10 |
extensions = [ |
|
|
11 |
CUDAExtension( |
|
|
12 |
"broadcast", |
|
|
13 |
sources=[ |
|
|
14 |
"Cluster-ViT/models/extensions/broadcast.cu" |
|
|
15 |
], |
|
|
16 |
extra_compile_args=["-arch=compute_50"] |
|
|
17 |
), |
|
|
18 |
CUDAExtension( |
|
|
19 |
"weighted_sum", |
|
|
20 |
sources=[ |
|
|
21 |
"Cluster-ViT/models/extensions/weighted_sum.cu" |
|
|
22 |
], |
|
|
23 |
extra_compile_args=["-arch=compute_50"] |
|
|
24 |
) |
|
|
25 |
] |
|
|
26 |
|
|
|
27 |
setup( |
|
|
28 |
name="clutering-Transformer", |
|
|
29 |
packages=find_packages(), |
|
|
30 |
ext_modules=extensions, |
|
|
31 |
cmdclass={"build_ext": BuildExtension}, |
|
|
32 |
install_requires=["torch"] |
|
|
33 |
) |