Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[codespell]
skip = [setup.cfg]
ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo
ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo, ans
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ AFQ/version.py
docs/_build
docs/build
docs/source/auto_examples/
docs/source/tutorials/tutorial_examples/
docs/source/howto/howto_examples/
docs/source/reference/config.rst
examples/**/*.nii.gz
examples/**/*.trk
Expand Down
83 changes: 60 additions & 23 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from math import radians

import imageio
import numpy as np
from dipy.align import vector_fields as vfu
from dipy.align.imwarp import DiffeomorphicMap, mult_aff
Expand All @@ -11,7 +12,7 @@
from PIL import Image
from scipy.linalg import blas, pinvh
from scipy.special import gammaln, lpmv
from tqdm import tqdm
from tqdm.auto import tqdm

logger = logging.getLogger("AFQ")

Expand Down Expand Up @@ -361,38 +362,53 @@ def _weighting_failed():
return w


def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
def make_mp4(show_m, out_path, n_frames=720, az_ang=-0.5, fps=30, crf=35, verbose=True):
"""
Make a video from a Fury Show Manager.
Make an MP4 video from a Fury Show Manager with auto-cropping.

Parameters
----------
show_m : Fury Show Manager
The Fury Show Manager to use for rendering.

out_path : str
The name of the output file.
The name of the output file

n_frames : int
The number of frames to render.
Default: 36
Default: 720

az_ang : float
The angle to rotate the camera around the
z-axis for each frame, in degrees.
Default: -10
Default: -0.5

Comment thread
36000 marked this conversation as resolved.
duration : int
The duration of each frame in the output GIF, in milliseconds.
Default: 150
fps : float
The frames per second for the output video.
Default: 30

crf : int
The Constant Rate Factor for the output video, which controls the
quality and file size. Lower values result in
higher quality and larger file sizes.
Default: 35 (very low quality, small file size)

verbose : bool
Whether to show a progress bar while generating the video.
Default: True
"""
if not out_path.lower().endswith(".mp4"):
out_path += ".mp4"

video = []

show_m.render()
show_m.window.draw()

with tempfile.TemporaryDirectory() as tmp_dir:
for ii in tqdm(range(n_frames), desc="Generating GIF", leave=False):
for ii in tqdm(
range(n_frames), desc="Generating MP4", leave=False, disable=not verbose
):
frame_fname = f"{tmp_dir}/{ii}.png"
show_m.screens[0].controller.rotate((radians(az_ang), 0), None)
show_m.render()
Expand All @@ -406,7 +422,6 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
for img in video:
arr = np.array(img)
bg_color = arr[0, 0]

mask = np.any(arr != bg_color, axis=-1)

if np.any(mask):
Expand All @@ -421,20 +436,42 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
all_lower = max(all_lower, ymax)

if all_left < all_right:
crop_box = (
max(0, all_left),
max(0, all_upper),
min(video[0].width, all_right),
min(video[0].height, all_lower),
)

def align_down(x, multiple=16):
return (x // multiple) * multiple

def align_up(x, multiple=16):
return ((x + multiple - 1) // multiple) * multiple

left = align_up(max(0, all_left), 16)
upper = align_up(max(0, all_upper), 16)
right = align_down(min(video[0].width, all_right), 16)
lower = align_down(min(video[0].height, all_lower), 16)

crop_box = (left, upper, right, lower)
cropped_video = [img.crop(crop_box) for img in video]
else:
cropped_video = video

cropped_video[0].save(
width, height = cropped_video[0].size
with imageio.get_writer(
out_path,
save_all=True,
append_images=cropped_video[1:],
duration=duration,
loop=1,
)
fps=fps,
format="ffmpeg",
codec="libx264",
pixelformat="yuv420p",
output_params=[
"-crf",
f"{str(int(crf))}",
"-preset",
"veryslow",
"-movflags",
"+faststart",
],
) as writer:
for img in cropped_video:
if img.size != (width, height):
img = img.crop((0, 0, width, height))

frame_arr = np.array(img)
writer.append_data(frame_arr)
4 changes: 2 additions & 2 deletions AFQ/api/bundle_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def default_bd():
"primary_axis": "I/S",
"ORG_spectral_subbundles": SpectralSubbundleDict(
{
Comment thread
36000 marked this conversation as resolved.
"Left V1V3": {
"Left Early Visual": {
"cluster_IDs": [78],
"Left Optic Radiation": {
"core": "Anterior",
Expand Down Expand Up @@ -449,7 +449,7 @@ def default_bd():
"primary_axis": "I/S",
"ORG_spectral_subbundles": SpectralSubbundleDict(
{
"Right V1V3": {
"Right Early Visual": {
"cluster_IDs": [78],
"Right Optic Radiation": {
"core": "Anterior",
Expand Down
6 changes: 3 additions & 3 deletions AFQ/api/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dipy.io.streamline import save_tractogram
from dipy.utils.parallel import paramap
from PIL import Image
from tqdm import tqdm
from tqdm.auto import tqdm

import AFQ.api.bundle_dict as abd
import AFQ.definitions.image as afm
Expand Down Expand Up @@ -555,7 +555,7 @@ def load_next_subject():
ses
]
seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img)
seg_sft.sft.to_rasmm()
seg_sft.to_rasmm()
subses_info.append((seg_sft, this_mapping, this_img, this_reg_template))

bundle_dict = self.export("bundle_dict", collapse=False)[
Expand All @@ -567,7 +567,7 @@ def load_next_subject():
for b in bundle_dict.bundle_names:
for i in range(len(self.valid_sub_list)):
seg_sft, mapping, img, reg_template = subses_info[i]
idx = seg_sft.bundle_idxs[b]
idx = seg_sft.get_bundle_idxs(b)
# use the first subses that works
# otherwise try each successive subses
if len(idx) == 0:
Expand Down
2 changes: 1 addition & 1 deletion AFQ/api/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import nibabel as nib
from dipy.align import resample
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from tqdm.auto import tqdm

import AFQ.utils.streamlines as aus
from AFQ.api.utils import (
Expand Down
2 changes: 1 addition & 1 deletion AFQ/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dipy.segment.clustering import QuickBundles
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metric import AveragePointwiseEuclideanMetric
from tqdm import tqdm
from tqdm.auto import tqdm

from AFQ._fixes import get_simplified_transform
from AFQ.data.utils import aws_import_msg_error
Expand Down
2 changes: 1 addition & 1 deletion AFQ/models/asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dipy.direction import peak_directions
from dipy.reconst.shm import sh_to_sf, sh_to_sf_matrix, sph_harm_ind_list
from numba import config, njit, prange, set_num_threads
from tqdm import tqdm
from tqdm.auto import tqdm

logger = logging.getLogger("AFQ")

Expand Down
2 changes: 1 addition & 1 deletion AFQ/models/msmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import osqp
from dipy.reconst.mcsd import MSDeconvFit, MultiShellDeconvModel
from scipy.sparse import csr_matrix
from tqdm import tqdm
from tqdm.auto import tqdm

__all__ = ["MultiShellDeconvModel"]

Expand Down
2 changes: 1 addition & 1 deletion AFQ/nn/multiaxial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import nibabel as nib
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from AFQ.data.fetch import afq_home, fetch_multiaxial_models
from AFQ.nn.utils import prepare_t1_for_nn, resample_output
Expand Down
2 changes: 1 addition & 1 deletion AFQ/recognition/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm
from tqdm.auto import tqdm
from trx.io import load as load_trx

import AFQ.recognition.cleaning as abc
Expand Down
2 changes: 1 addition & 1 deletion AFQ/recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dipy.io.streamline import save_tractogram
from dipy.tracking import Streamlines
from dipy.tracking.distances import bundles_distances_mdf
from tqdm import tqdm
from tqdm.auto import tqdm

axes_dict = {
"L/R": 0,
Expand Down
25 changes: 8 additions & 17 deletions AFQ/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@

@immlib.calc("bundles")
@as_file("_desc-bundles_tractography")
def segment(
structural_imap, data_imap, mapping_imap, tractography_imap, segmentation_params
):
def segment(data_imap, mapping_imap, tractography_imap, segmentation_params):
"""
full path to a trk/trx file containing containing
segmented streamlines, labeled by bundle
Expand Down Expand Up @@ -93,16 +91,15 @@ def segment(
**segmentation_params,
)

seg_sft = aus.SegmentedSFT(bundles)

if len(seg_sft.sft) < 1:
if len(bundles) == 0:
raise ValueError("Fatal: No bundles recognized.")

seg_sft = aus.SegmentedSFT(bundles)

if is_trx:
seg_sft.sft.dtype_dict = {"positions": np.float16, "offsets": np.uint32}
tgram = TrxFile.from_sft(seg_sft.sft)
tgram.groups = seg_sft.bundle_idxs

else:
tgram = seg_sft.sft

Expand Down Expand Up @@ -209,15 +206,9 @@ def export_bundle_lengths(bundles):
len_data[f"{bundle} Median"] = 0
len_data[f"{bundle} Min"] = 0
len_data[f"{bundle} Max"] = 0
len_data["Total Recognized Median"] = np.median(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Min"] = np.min(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Max"] = np.max(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Median"] = np.median(seg_sft.get_lengths())
len_data["Total Recognized Min"] = np.min(seg_sft.get_lengths())
len_data["Total Recognized Max"] = np.max(seg_sft.get_lengths())

counts_df = pd.DataFrame(
data=len_data,
Expand Down Expand Up @@ -297,7 +288,7 @@ def tract_profiles(
reference = nib.load(scalar_dict[list(scalar_dict.keys())[0]])
seg_sft = aus.SegmentedSFT.fromfile(bundles, reference=reference)

seg_sft.sft.to_rasmm()
seg_sft.to_rasmm()
for bundle_name in seg_sft.bundle_names:
this_sl = seg_sft.get_bundle(bundle_name).streamlines
if len(this_sl) == 0:
Expand Down
22 changes: 11 additions & 11 deletions AFQ/tasks/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def viz_bundles(
)

fname = None
if "no_gif" not in viz_backend.backend:
fname = get_fname(base_fname, ".gif", "..")
if "no_mp4" not in viz_backend.backend:
fname = get_fname(base_fname, ".mp4", "..")

Comment thread
36000 marked this conversation as resolved.
try:
viz_backend.create_gif(figure, fname)
viz_backend.create_mp4(figure, fname)
except PermissionError as e:
logger.warning(f"Failed to write GIF file: {fname} \n{e}")
logger.warning(f"Failed to write MP4 file: {fname} \n{e}")
if "plotly" in viz_backend.backend:
fname = get_fname(base_fname, ".html", "..")

Expand Down Expand Up @@ -156,7 +156,7 @@ def viz_indivBundle(
n_points_indiv=40,
):
"""
list of full paths to html or gif files
list of full paths to html or mp4 files
containing visualizations of individual bundles

Parameters
Expand Down Expand Up @@ -255,15 +255,15 @@ def viz_indivBundle(

base_fname = op.join(output_dir, op.split(base_fname)[1])
figures[bundle_name] = figure
if "no_gif" not in viz_backend.backend:
if "no_mp4" not in viz_backend.backend:
fname = get_fname(
base_fname,
f"_desc-{str_to_desc(bundle_name)}_tractography.gif",
f"_desc-{str_to_desc(bundle_name)}_tractography.mp4",
"viz_bundles",
)

try:
viz_backend.create_gif(figure, fname)
viz_backend.create_mp4(figure, fname)
except PermissionError as e:
if not failed_write:
logger.warning(
Expand Down Expand Up @@ -380,7 +380,7 @@ def plot_tract_profiles(base_fname, output_dir, scalars, segmentation_imap):


@immlib.calc("viz_backend")
def init_viz_backend(viz_backend_spec="plotly_no_gif"):
def init_viz_backend(viz_backend_spec="plotly_no_mp4"):
"""
An instance of the `AFQ.viz.utils.viz_backend` class.

Expand All @@ -390,8 +390,8 @@ def init_viz_backend(viz_backend_spec="plotly_no_gif"):
Which visualization backend to use.
See Visualization Backends page in documentation for details
https://tractometry.org/pyAFQ/reference/viz_backend.html
One of {"fury", "plotly", "plotly_no_gif"}.
Default: "plotly_no_gif"
One of {"fury", "plotly", "plotly_no_mp4"}.
Default: "plotly_no_mp4"
"""
if "fury" not in viz_backend_spec and "plotly" not in viz_backend_spec:
raise TypeError("viz_backend_spec must contain either 'fury' or 'plotly'")
Expand Down
Loading
Loading