--- a +++ b/setup.py @@ -0,0 +1,33 @@ +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +def get_extensions(): + + return extensions + + +if __name__ == "__main__": + extensions = [ + CUDAExtension( + "broadcast", + sources=[ + "Cluster-ViT/models/extensions/broadcast.cu" + ], + extra_compile_args=["-arch=compute_50"] + ), + CUDAExtension( + "weighted_sum", + sources=[ + "Cluster-ViT/models/extensions/weighted_sum.cu" + ], + extra_compile_args=["-arch=compute_50"] + ) + ] + + setup( + name="clutering-Transformer", + packages=find_packages(), + ext_modules=extensions, + cmdclass={"build_ext": BuildExtension}, + install_requires=["torch"] + )