diff --git a/.gitignore b/.gitignore index 4737cb7..b672352 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ dist gpuctypes.egg-info __pycache__ .*.swp +venv/ \ No newline at end of file diff --git a/generate_hip.sh b/generate_hip.sh index ea84963..d1e34a3 100755 --- a/generate_hip.sh +++ b/generate_hip.sh @@ -1,8 +1,59 @@ #!/bin/bash -e -clang2py /opt/rocm/include/hip/hiprtc.h /opt/rocm/include/hip/hip_runtime_api.h /opt/rocm/include/hip/driver_types.h --clang-args="-D__HIP_PLATFORM_AMD__ -I/opt/rocm/include" -o gpuctypes/hip.py -l /opt/rocm/lib/libhiprtc.so -l /opt/rocm/lib/libamdhip64.so -grep FIXME_STUB gpuctypes/hip.py || true -# we can trust HIP is always at /opt/rocm/lib -#sed -i "s\import ctypes\import ctypes, ctypes.util\g" gpuctypes/hip.py + +out_file="gpuctypes/hip.py" +if command -v clang2py &> /dev/null; then + clang2py /opt/rocm/include/hip/hiprtc.h /opt/rocm/include/hip/hip_runtime_api.h /opt/rocm/include/hip/driver_types.h --clang-args="-D__HIP_PLATFORM_AMD__ -I/opt/rocm/include" -o $out_file -l /opt/rocm/lib/libhiprtc.so -l /opt/rocm/lib/libamdhip64.so +else + echo "error: clang2py was not found..." +fi + +# grep FIXME_STUB $out_file || true +# hot patches +get_hiprtc_code=" +def get_hiprtc(): + try: + if 'linux' in sys.platform: + return ctypes.CDLL(os.path.join('/opt/rocm/lib/libhiprtc.so')) + elif 'win' in sys.platform: + hip_path = os.getenv('HIP_PATH', None) + if not hip_path: + raise RuntimeError('HIP_PATH is not set') + return ctypes.CDLL(os.path.join(hip_path, 'bin', 'hiprtc0505.dll')) + else: + raise RuntimeError('Only windows and linux are supported') + except Exception as err: + raise Exception('Error: {0}'.format(err)) +" + +get_hip_code=" +def get_hip(): + try: + if 'linux' in sys.platform: + return ctypes.CDLL('/opt/rocm/lib/libamdhip64.so') + elif 'win' in sys.platform: + return ctypes.cdll.LoadLibrary('amdhip64') + else: + raise RuntimeError('Only windows and linux are supported') + except Exception as err: + raise Exception('Error: {0}'.format(err)) +" + +declare -A patches +patches=( + ["import ctypes"]="import ctypes, sys, os" + ["ctypes.CDLL('/opt/rocm/lib/libhiprtc.so')"]="get_hiprtc()" + ["ctypes.CDLL('/opt/rocm/lib/libamdhip64.so')"]="get_hip()" +) +for key in "${!patches[@]}"; do + sed -i "s@${key}@${patches[${key}]}@g" $out_file +done + +# get the import line +import_line=$(grep -n "import ctypes" $out_file | cut -d ":" -f 1) +import_line=$(($import_line + 1)) +sed -i "${import_line}r /dev/stdin" "$out_file" <<< "$get_hiprtc_code" +sed -i "${import_line}r /dev/stdin" "$out_file" <<< "$get_hip_code" +# sed -i "s\import ctypes\import ctypes, ctypes.util\g" gpuctypes/hip.py #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libhiprtc.so')\ctypes.CDLL(ctypes.util.find_library('hiprtc'))\g" gpuctypes/hip.py #sed -i "s\ctypes.CDLL('/opt/rocm/lib/libamdhip64.so')\ctypes.CDLL(ctypes.util.find_library('amdhip64'))\g" gpuctypes/hip.py python3 -c "import gpuctypes.hip" diff --git a/gpuctypes/hip.py b/gpuctypes/hip.py index 3975fa8..765317a 100644 --- a/gpuctypes/hip.py +++ b/gpuctypes/hip.py @@ -5,7 +5,34 @@ # POINTER_SIZE is: 8 # LONGDOUBLE_SIZE is: 16 # -import ctypes +import ctypes, sys, os + + +def get_hip(): + try: + if 'linux' in sys.platform: + return ctypes.CDLL('/opt/rocm/lib/libamdhip64.so') + elif 'win' in sys.platform: + return ctypes.cdll.LoadLibrary('amdhip64') + else: + raise RuntimeError('Only windows and linux are supported') + except Exception as err: + raise Exception('Error: {0}'.format(err)) + + +def get_hiprtc(): + try: + if 'linux' in sys.platform: + return ctypes.CDLL(os.path.join('/opt/rocm/lib/libhiprtc.so')) + elif 'win' in sys.platform: + hip_path = os.getenv('HIP_PATH', None) + if not hip_path: + raise RuntimeError('HIP_PATH is not set') + return ctypes.CDLL(os.path.join(hip_path, 'bin', 'hiprtc0505.dll')) + else: + raise RuntimeError('Only windows and linux are supported') + except Exception as err: + raise Exception('Error: {0}'.format(err)) class AsDictMixin: @@ -117,7 +144,7 @@ class Union(ctypes.Union, AsDictMixin): _libraries = {} -_libraries['libhiprtc.so'] = ctypes.CDLL('/opt/rocm/lib/libhiprtc.so') +_libraries['libhiprtc.so'] = get_hiprtc() def string_cast(char_pointer, encoding='utf-8', errors='strict'): value = ctypes.cast(char_pointer, ctypes.c_char_p).value if value is not None and encoding is not None: @@ -155,7 +182,7 @@ def __getattr__(self, _): # You can either re-run clan2py with -l /path/to/library.so # Or manually fix this by comment the ctypes.CDLL loading _libraries['FIXME_STUB'] = FunctionFactoryStub() # ctypes.CDLL('FIXME_STUB') -_libraries['libamdhip64.so'] = ctypes.CDLL('/opt/rocm/lib/libamdhip64.so') +_libraries['libamdhip64.so'] = get_hip() diff --git a/test/helpers.py b/test/helpers.py index 6e407d1..e7f4b63 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -22,8 +22,8 @@ def to_char_p_p(options: List[str]): c_options[:] = [ctypes.cast(ctypes.create_string_buffer(o.encode("utf-8")), ctypes.POINTER(ctypes.c_char)) for o in options] return c_options -def cuda_compile(prg, options, f, check): - check(f.create(ctypes.pointer(prog := f.new()), prg.encode(), "".encode(), 0, None, None)) +def cuda_compile(prg, options, f, check, filename=""): + check(f.create(ctypes.pointer(prog := f.new()), prg.encode(), filename.encode() if filename else None, 0, None, None)) status = f.compile(prog, len(options), to_char_p_p(options)) if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, f.getLogSize, f.getLog, check)}") return get_bytes(prog, f.getCodeSize, f.getCode, check) diff --git a/test/test_hip.py b/test/test_hip.py index 57aca33..4fd5ab8 100644 --- a/test/test_hip.py +++ b/test/test_hip.py @@ -26,7 +26,7 @@ def test_compile_fail(self): cuda_compile("void test() { {", ["--offload-arch=gfx1100"], HIPCompile, check) def test_compile(self): - prg = cuda_compile("int test() { return 42; }", ["--offload-arch=gfx1100"], HIPCompile, check) + prg = cuda_compile("int test() { return 42; }", ["--offload-arch=gfx1100"], HIPCompile, check, filename=None) assert len(prg) > 10 class TestHIPDevice(unittest.TestCase): @@ -48,7 +48,6 @@ def test_get_device_properties(self) -> hip.hipDeviceProp_t: device_properties = hip.hipDeviceProp_t() check(hip.hipGetDeviceProperties(device_properties, 0)) print(device_properties.gcnArchName) - return device_properties if __name__ == '__main__': unittest.main() \ No newline at end of file