diff --git a/setup.py b/setup.py index ce6c5f3a395..697973086c0 100644 --- a/setup.py +++ b/setup.py @@ -131,6 +131,12 @@ def get_macros_and_flags(): if sysconfig.get_config_var("Py_GIL_DISABLED"): extra_compile_args["cxx"].append("-DPy_GIL_DISABLED") + if sys.platform == "darwin": + extra_compile_args["cxx"].append("-Xpreprocessor") + extra_compile_args["cxx"].append("-fopenmp") + elif sys.platform != "win32": + extra_compile_args["cxx"].append("-fopenmp") + if DEBUG: extra_compile_args["cxx"].append("-g") extra_compile_args["cxx"].append("-O0") @@ -182,12 +188,22 @@ def make_C_extension(): sources += mps_sources define_macros, extra_compile_args = get_macros_and_flags() + + extra_link_args = [] + if sys.platform == "darwin": + # Link against libomp shipped with PyTorch for at::parallel_for support + torch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib") + extra_link_args = [f"-L{torch_lib_dir}", "-lomp"] + elif sys.platform != "win32": + extra_link_args = ["-lgomp"] + return Extension( name="torchvision._C", sources=sorted(str(s) for s in sources), include_dirs=[CSRS_DIR], define_macros=define_macros, extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, )