Skip to content

Fix predict_as_dataframe: single-pass, shuffle-safe, multi-dim#879

Open
sevmag wants to merge 4 commits intographnet-team:mainfrom
sevmag:fix/predict_as_dataframe
Open

Fix predict_as_dataframe: single-pass, shuffle-safe, multi-dim#879
sevmag wants to merge 4 commits intographnet-team:mainfrom
sevmag:fix/predict_as_dataframe

Conversation

@sevmag
Copy link
Copy Markdown
Collaborator

@sevmag sevmag commented Apr 30, 2026

Closes #880

Summary

  • Refactor EasySyntax.predict_as_dataframe so additional_attributes are gathered inside predict_step during the same trainer pass as the model predictions, instead of in a second for batch in dataloader loop. This keeps predictions and attributes aligned by construction and removes the implicit assumption that the dataloader replays batches in the same order.
  • Drop the SequentialSampler guard: with attributes now collected in-pass, a shuffled dataloader is safe to use with additional_attributes.
  • Add additional_attributes parameter to EasySyntax.predict so callers that don't need a DataFrame can still get aligned attributes back; the return type becomes List[Union[Tensor, np.ndarray]] (task tensors first, attribute arrays after).
  • Support multi-dimensional attributes (e.g. direction with x/y/z components): a 2-D attribute array is flattened to one column per component (<name>_0, <name>_1, ...) so the resulting DataFrame stays tabular.
  • Preserve the pulse-level repeat behavior (np.repeat(value, batch.n_pulses)) and the warn-and-skip path for length mismatches.
  • Add tests/models/test_easy_model.py covering event-level, pulse-level, multi-dim, and shuffled-dataloader cases.
  • Drop trailing whitespace in .github/workflows/build.yml (no behavior change).

@sevmag sevmag changed the title Fix predict_as_dataframe: single-pass attribute gather, shuffle-safe, multi-dim support Fix predict_as_dataframe: single-pass, shuffle-safe, multi-dim Apr 30, 2026
@sevmag sevmag marked this pull request as ready for review April 30, 2026 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

predict_as_dataframe iterates the dataloader twice, causing wasted I/O and silent misalignment risk

1 participant