Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 278 additions & 5 deletions iltm/inference_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,11 @@ def __init__(
gpu_oom_retries: int = 2,
auto_disable_retrieval_on_low_mem: bool = False,
auto_amp_on_low_mem: bool = False,
auto_stop_on_low_cpu_memory: bool = True,
cpu_memory_limit_gb: float | None = None,
cpu_memory_safety_margin_gb: float = 2.0,
corr_select_k: int = 300,
inference_storage_dtype: str | None = "float16",
) -> None:

# Logging
Expand Down Expand Up @@ -379,9 +383,14 @@ def __init__(
self.gpu_oom_retries = int(gpu_oom_retries)
self.auto_disable_retrieval_on_low_mem = bool(auto_disable_retrieval_on_low_mem)
self.auto_amp_on_low_mem = bool(auto_amp_on_low_mem)
self.auto_stop_on_low_cpu_memory = bool(auto_stop_on_low_cpu_memory)
self.cpu_memory_limit_gb = None if cpu_memory_limit_gb is None else float(cpu_memory_limit_gb)
self.cpu_memory_safety_margin_gb = float(cpu_memory_safety_margin_gb)

# Correlation-based feature selection
self.corr_select_k = int(corr_select_k)
self.inference_storage_dtype = inference_storage_dtype
self._inference_storage_torch_dtype = self._resolve_inference_storage_dtype(inference_storage_dtype)

# Placeholders
self.tr_: TreeEmbedding | List[TreeEmbedding] | None = None
Expand All @@ -393,6 +402,228 @@ def __init__(
self.predictors_: List[dict] = []
self.preprocessors_: List[dict] = []

@staticmethod
def _resolve_inference_storage_dtype(dtype: str | torch.dtype | None) -> torch.dtype | None:
if dtype is None:
return None
if isinstance(dtype, torch.dtype):
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(
"inference_storage_dtype must be one of None, 'float32', "
"'float16', or 'bfloat16'."
)
return None if dtype == torch.float32 else dtype

normalized = str(dtype).lower()
aliases = {
"none": None,
"float32": None,
"fp32": None,
"torch.float32": None,
"float16": torch.float16,
"fp16": torch.float16,
"half": torch.float16,
"torch.float16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"torch.bfloat16": torch.bfloat16,
}
if normalized not in aliases:
raise ValueError(
f"Unsupported inference_storage_dtype={dtype!r}. Expected one of "
"None, 'float32', 'float16', or 'bfloat16'."
)
return aliases[normalized]

def _module_to_device_and_dtype(
self,
module: nn.Module,
*,
device: str | torch.device,
dtype: torch.dtype | None,
) -> nn.Module:
if dtype is None:
return module.to(device)
return module.to(device=device, dtype=dtype)

@staticmethod
def _read_int_file(path: str) -> int | None:
try:
with open(path) as f:
raw = f.read().strip()
except OSError:
return None
if raw == "" or raw.lower() == "max":
return None
try:
value = int(raw)
except ValueError:
return None
if value <= 0 or value >= 1 << 60:
return None
return value

@classmethod
def _get_cgroup_memory_limit_bytes(cls) -> int | None:
candidates = [
"/sys/fs/cgroup/memory.max",
"/sys/fs/cgroup/memory/memory.limit_in_bytes",
]
values = [value for path in candidates if (value := cls._read_int_file(path)) is not None]
return min(values) if values else None

@classmethod
def _get_cgroup_memory_usage_bytes(cls) -> int | None:
candidates = [
"/sys/fs/cgroup/memory.current",
"/sys/fs/cgroup/memory/memory.usage_in_bytes",
]
values = [value for path in candidates if (value := cls._read_int_file(path)) is not None]
return min(values) if values else None

@staticmethod
def _get_autogluon_memory_limit_bytes() -> int | None:
raw = os.environ.get("AG_MEMORY_LIMIT_IN_GB")
if raw is None:
return None
try:
value_gb = float(raw)
except ValueError:
return None
if value_gb <= 0:
return None
return int(value_gb * (1024 ** 3))

@staticmethod
def _get_slurm_memory_limit_bytes() -> int | None:
mem_per_node = os.environ.get("SLURM_MEM_PER_NODE")
if mem_per_node:
try:
value_mb = int(mem_per_node)
except ValueError:
value_mb = 0
if value_mb > 0:
return value_mb * 1024 * 1024

mem_per_cpu = os.environ.get("SLURM_MEM_PER_CPU")
if mem_per_cpu:
try:
value_mb = int(mem_per_cpu)
except ValueError:
value_mb = 0
cpus_raw = (
os.environ.get("SLURM_CPUS_PER_TASK")
or os.environ.get("SLURM_CPUS_ON_NODE")
or os.environ.get("SLURM_JOB_CPUS_PER_NODE", "").split("(", 1)[0]
)
try:
cpus = int(cpus_raw)
except (TypeError, ValueError):
cpus = 1
if value_mb > 0 and cpus > 0:
return value_mb * cpus * 1024 * 1024
return None

def _get_effective_cpu_memory_limit_bytes(self) -> int | None:
limits: list[int] = []
if self.cpu_memory_limit_gb is not None and self.cpu_memory_limit_gb > 0:
limits.append(int(self.cpu_memory_limit_gb * (1024 ** 3)))
ag_limit = self._get_autogluon_memory_limit_bytes()
if ag_limit is not None:
limits.append(ag_limit)
cgroup_limit = self._get_cgroup_memory_limit_bytes()
if cgroup_limit is not None:
limits.append(cgroup_limit)
slurm_limit = self._get_slurm_memory_limit_bytes()
if slurm_limit is not None:
limits.append(slurm_limit)
return min(limits) if limits else None

@classmethod
def _get_process_rss_bytes(cls) -> int | None:
try:
with open("/proc/self/status") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1]) * 1024
except OSError:
return None
return None

def _get_cpu_memory_usage_bytes(self, limit_bytes: int | None = None) -> int | None:
cgroup_limit = self._get_cgroup_memory_limit_bytes()
if cgroup_limit is not None and limit_bytes is not None and cgroup_limit <= int(limit_bytes * 1.05):
return self._get_cgroup_memory_usage_bytes() or self._get_process_rss_bytes()
return self._get_process_rss_bytes() or self._get_cgroup_memory_usage_bytes()

@classmethod
def _estimate_object_storage_bytes(cls, obj: Any, seen: set[int] | None = None) -> int:
if seen is None:
seen = set()
if obj is None or isinstance(obj, (str, bytes, int, float, bool)):
return 0

obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)

if isinstance(obj, torch.Tensor):
return obj.numel() * obj.element_size()
if isinstance(obj, np.ndarray):
return int(obj.nbytes)
if isinstance(obj, nn.Module):
total = 0
for param in obj.parameters(recurse=True):
total += param.numel() * param.element_size()
for buffer in obj.buffers(recurse=True):
total += buffer.numel() * buffer.element_size()
return total
if isinstance(obj, dict):
return sum(cls._estimate_object_storage_bytes(value, seen) for value in obj.values())
if isinstance(obj, (list, tuple, set)):
return sum(cls._estimate_object_storage_bytes(value, seen) for value in obj)
return 0

def _estimate_predictors_storage_bytes(self) -> int:
return sum(self._estimate_object_storage_bytes(predictor) for predictor in self.predictors_)

def _should_stop_for_cpu_memory_before_predictor(self, next_predictor_index: int) -> tuple[bool, dict[str, float | int] | None]:
if not self.auto_stop_on_low_cpu_memory or not self.predictors_:
return False, None

limit_bytes = self._get_effective_cpu_memory_limit_bytes()
usage_bytes = self._get_cpu_memory_usage_bytes(limit_bytes=limit_bytes)
if limit_bytes is None or usage_bytes is None:
return False, None

current_model_bytes = self._estimate_predictors_storage_bytes()
if current_model_bytes <= 0:
return False, None

avg_predictor_bytes = current_model_bytes / len(self.predictors_)
projected_model_bytes = current_model_bytes + avg_predictor_bytes
margin_bytes = max(0, int(self.cpu_memory_safety_margin_gb * (1024 ** 3)))
# During fold return or serialization, a serialized copy can coexist
# with the live fitted model. Project that copy without materializing it.
projected_peak_bytes = usage_bytes + avg_predictor_bytes + projected_model_bytes + margin_bytes

if projected_peak_bytes <= limit_bytes:
return False, None

return True, {
"next_predictor_index": next_predictor_index,
"limit_gb": limit_bytes / (1024 ** 3),
"usage_gb": usage_bytes / (1024 ** 3),
"available_gb": (limit_bytes - usage_bytes) / (1024 ** 3),
"avg_predictor_gb": avg_predictor_bytes / (1024 ** 3),
"projected_model_gb": projected_model_bytes / (1024 ** 3),
"projected_peak_gb": projected_peak_bytes / (1024 ** 3),
"margin_gb": margin_bytes / (1024 ** 3),
"predictors_fit": len(self.predictors_),
"n_ensemble": self.n_ensemble,
}

@classmethod
def __sklearn_tags__(cls, estimator=None):
try:
Expand Down Expand Up @@ -824,20 +1055,34 @@ def _sample_data(
def _move_predictor_to_device(self, predictor: dict, device: str | torch.device | None = None) -> dict:
if device is None:
device = self.device
target = torch.device(device)
for key in predictor:
if key in ['rf', 'pca', 'norm']:
if predictor[key] is not None:
predictor[key] = predictor[key].to(device)
dtype = torch.float32 if key == "rf" else None
predictor[key] = self._module_to_device_and_dtype(predictor[key], device=target, dtype=dtype)
elif key == 'main_network':
for i, layer in enumerate(predictor[key]):
predictor[key][i] = layer.to(device)
predictor[key][i] = layer.to(target)
elif isinstance(predictor[key], (nn.Module, nn.Sequential)):
predictor[key] = predictor[key].to(device)
predictor[key] = predictor[key].to(target)
gc.collect()
return predictor

def _move_predictor_to_cpu(self, predictor: dict) -> dict:
return self._move_predictor_to_device(predictor, device='cpu')
storage_dtype = self._inference_storage_torch_dtype
for key in predictor:
if key in ['rf', 'pca', 'norm']:
if predictor[key] is not None:
dtype = storage_dtype if key == "rf" else None
predictor[key] = self._module_to_device_and_dtype(predictor[key], device="cpu", dtype=dtype)
elif key == 'main_network':
for i, layer in enumerate(predictor[key]):
predictor[key][i] = layer.to("cpu")
elif isinstance(predictor[key], (nn.Module, nn.Sequential)):
predictor[key] = predictor[key].to("cpu")
gc.collect()
return predictor

# -----------------------------
# Predictor generation and forward
Expand Down Expand Up @@ -1306,7 +1551,27 @@ def _as_numpy_1d(y, *, dtype=None) -> np.ndarray:
len(self.predictors_), self.n_ensemble
)
break


stop_for_memory, memory_details = self._should_stop_for_cpu_memory_before_predictor(i + 1)
if stop_for_memory and memory_details is not None:
logger.warning(
"Early return: CPU memory budget nearly exhausted before predictor %d. "
"Stopping at %d/%d predictors. "
"usage=%.2fGB, available=%.2fGB, avg_predictor=%.2fGB, "
"projected_model=%.2fGB, projected_peak=%.2fGB, limit=%.2fGB, margin=%.2fGB.",
memory_details["next_predictor_index"],
memory_details["predictors_fit"],
memory_details["n_ensemble"],
memory_details["usage_gb"],
memory_details["available_gb"],
memory_details["avg_predictor_gb"],
memory_details["projected_model_gb"],
memory_details["projected_peak_gb"],
memory_details["limit_gb"],
memory_details["margin_gb"],
)
break

logger.info(f"Generating predictor {i + 1} of {self.n_ensemble}...")
t_pred_start = time.time()
if self.tree_embedding and self.tree_for_each_predictor:
Expand Down Expand Up @@ -1589,7 +1854,11 @@ def __init__(
gpu_oom_retries: int = 2,
auto_disable_retrieval_on_low_mem: bool = False,
auto_amp_on_low_mem: bool = False,
auto_stop_on_low_cpu_memory: bool = True,
cpu_memory_limit_gb: float | None = None,
cpu_memory_safety_margin_gb: float = 2.0,
corr_select_k: int = 300,
inference_storage_dtype: str | None = "float16",
) -> None:
params = locals().copy()
params.pop("self")
Expand Down Expand Up @@ -1798,7 +2067,11 @@ def __init__(
gpu_oom_retries: int = 2,
auto_disable_retrieval_on_low_mem: bool = False,
auto_amp_on_low_mem: bool = False,
auto_stop_on_low_cpu_memory: bool = True,
cpu_memory_limit_gb: float | None = None,
cpu_memory_safety_margin_gb: float = 2.0,
corr_select_k: int = 300,
inference_storage_dtype: str | None = "float16",
) -> None:
params = locals().copy()
params.pop("self")
Expand Down
Loading