Skip to content

Add raw pointer upload support for Matrix data#42

Open
moulin1024 wants to merge 1 commit into
shwina:mainfrom
moulin1024:feature/raw_matrix_copy
Open

Add raw pointer upload support for Matrix data#42
moulin1024 wants to merge 1 commit into
shwina:mainfrom
moulin1024:feature/raw_matrix_copy

Conversation

@moulin1024

Copy link
Copy Markdown

Summary

This PR adds raw pointer upload support for pyamgx.Matrix, enabling matrix data to be copied directly from raw host or device pointers into AMGX.

The main motivation is interoperability with GPU array frameworks such as JAX, CuPy, Numba, or other libraries that can expose raw device pointers but may not provide NumPy-compatible host arrays or __cuda_array_interface__ objects in the form expected by the existing Matrix.upload() method.

This PR adds two new methods:

Matrix.upload_raw(
    row_ptrs_addr,
    col_indices_addr,
    data_addr,
    nrows,
    nnz,
    block_dims=[1, 1],
    shape=None,
)

and

Matrix.replace_coefficients_raw(data_addr)

These methods mirror the existing vector-level Vector.upload_raw() functionality, but for CSR matrix data and matrix coefficient replacement.

Motivation

pyamgx.Vector already supports uploading from raw pointers through:

Vector.upload_raw(ptr, n)

This allows efficient GPU-to-GPU copies from arrays that already live on the device.

However, pyamgx.Matrix previously only exposed:

Matrix.upload(row_ptrs, col_indices, data)
Matrix.upload_CSR(csr_matrix)

The existing Matrix.upload() path expects Python array-like objects and performs host-side operations such as calling .max(). This means it does not work with raw GPU buffers or Numba DeviceNDArray objects in all cases.

For example, attempting to upload CSR arrays through Numba device arrays can fail with:

AttributeError: 'DeviceNDArray' object has no attribute 'max'

This PR enables applications to bypass that limitation by passing explicit raw pointer addresses and matrix metadata directly.

This is particularly useful for workflows where sparse matrix data is already generated or stored on the GPU, for example:

JAX GPU CSR arrays
  -> raw device pointers
  -> pyamgx.Matrix.upload_raw()
  -> AMGX solve

New API

Matrix.upload_raw(...)

M.upload_raw(
    row_ptrs_addr,
    col_indices_addr,
    data_addr,
    nrows,
    nnz,
    block_dims=[1, 1],
    shape=None,
)

Copies CSR matrix data into an AMGX matrix directly from raw pointers.

Parameters:

  • row_ptrs_addr: address of the CSR row pointer buffer

    • expected type: int32*
    • expected length: nrows + 1
  • col_indices_addr: address of the CSR column index buffer

    • expected type: int32*
    • expected length: nnz
  • data_addr: address of the CSR values buffer

    • expected type depends on the matrix mode, for example float64* for dDDI
    • expected length: nnz * block_dimx * block_dimy
  • nrows: number of matrix rows, in block units

  • nnz: number of nonzero blocks

  • block_dims: block dimensions, default [1, 1]

  • shape: optional matrix shape (nrows, ncols)

    • if omitted, the matrix is assumed to be square

The method calls AMGX directly through:

AMGX_matrix_upload_all(...)

Matrix.replace_coefficients_raw(...)

M.replace_coefficients_raw(data_addr)

Replaces matrix coefficients in-place from a raw pointer while preserving the existing sparsity pattern.

This is useful for workflows where the CSR structure is fixed but values change between solves.

The method calls AMGX directly through:

AMGX_matrix_replace_coefficients(...)

Example

import jax
import jax.numpy as jnp
import pyamgx

jax.config.update("jax_enable_x64", True)

gpu = jax.devices("gpu")[0]

with jax.default_device(gpu):
    row_ptrs = jnp.array([0, 2, 4], dtype=jnp.int32)
    col_indices = jnp.array([0, 1, 0, 1], dtype=jnp.int32)
    values = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=jnp.float64)

row_ptrs.block_until_ready()
col_indices.block_until_ready()
values.block_until_ready()

row_ptrs_addr = row_ptrs.addressable_data(0).unsafe_buffer_pointer()
col_indices_addr = col_indices.addressable_data(0).unsafe_buffer_pointer()
values_addr = values.addressable_data(0).unsafe_buffer_pointer()

pyamgx.initialize()

cfg = pyamgx.Config().create_from_dict({
    "config_version": 2,
    "solver": {
        "solver": "BICGSTAB",
        "preconditioner": {
            "solver": "NOSOLVER"
        }
    }
})

rsc = pyamgx.Resources().create_simple(cfg)

A = pyamgx.Matrix().create(rsc, mode="dDDI")

A.upload_raw(
    row_ptrs_addr=row_ptrs_addr,
    col_indices_addr=col_indices_addr,
    data_addr=values_addr,
    nrows=2,
    nnz=4,
    block_dims=[1, 1],
    shape=(2, 2),
)

A.destroy()
rsc.destroy()
cfg.destroy()
pyamgx.finalize()

Validation

This feature was tested with a GPU-resident JAX workflow.

The test constructs CSR arrays on the GPU with JAX:

row_ptrs: int32
col_indices: int32
values: float64
rhs: float64
x0: float64

Then it passes the raw JAX device pointers directly to PyAMGX:

A.upload_raw(row_ptrs_ptr, col_indices_ptr, values_ptr, nrows, nnz)
b.upload_raw(rhs_ptr, n)
x.upload_raw(x0_ptr, n)

The AMGX solution was compared against SciPy’s host-side sparse solver.

Result for the initial matrix upload:

AMGX solution:
[0.18826665 0.35720229 0.56913282 0.7659675  0.97973771]

SciPy solution:
[0.18826665 0.35720229 0.56913282 0.7659675  0.97973771]

max abs error:
6.339373470609644e-14

The coefficient replacement path was also tested using:

A.replace_coefficients_raw(values_ptr_2)

Result after coefficient replacement:

AMGX solution:
[0.06299548 0.14891831 0.25417129 0.35292031 0.4589302 ]

SciPy solution:
[0.06299548 0.14891831 0.25417129 0.35292031 0.4589302 ]

max abs error:
1.0330625244137082e-13

Both tests passed.

Notes on memory ownership

These methods do not make PyAMGX alias the external buffers permanently.

They copy from the provided raw pointers into AMGX-owned matrix storage, consistent with AMGX upload semantics.

The pointer memory space must match the matrix mode:

  • mode="dDDI" expects device pointers
  • mode="hDDI" expects host pointers

The caller is responsible for ensuring that:

  • the pointer addresses are valid
  • the buffers remain alive for the duration of the upload call
  • the dtypes match the selected AMGX mode
  • the CSR metadata is correct

For example, with mode="dDDI":

  • row pointers must be int32
  • column indices must be int32
  • values must be float64

Backward compatibility

This PR is additive.

Existing APIs are unchanged:

Matrix.upload(...)
Matrix.upload_CSR(...)
Vector.upload(...)
Vector.upload_raw(...)

Existing user code should continue to work as before.

Why this is useful

This enables efficient integration with modern GPU-native Python workflows, including:

  • JAX-generated sparse systems
  • CuPy sparse data
  • Numba CUDA device arrays
  • custom CUDA kernels that produce CSR arrays
  • workflows where the sparsity pattern is fixed but coefficients change repeatedly

In particular, replace_coefficients_raw() makes it possible to update matrix values without re-uploading or reconstructing the sparsity structure, which is useful for repeated solves over the same matrix pattern.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant