Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[codespell]
skip = [setup.cfg]
ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo
ignore-words-list = Reson, DNE, ACI, FPT, sagital, saggital, abd, Joo, Mapp, Commun, vor, Claus, coo, ans
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ AFQ/version.py
docs/_build
docs/build
docs/source/auto_examples/
docs/source/tutorials/tutorial_examples/
docs/source/howto/howto_examples/
docs/source/reference/config.rst
examples/**/*.nii.gz
examples/**/*.trk
Expand Down
77 changes: 57 additions & 20 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
from math import radians

import imageio
import numpy as np
from dipy.align import vector_fields as vfu
from dipy.align.imwarp import DiffeomorphicMap, mult_aff
Expand Down Expand Up @@ -361,17 +362,17 @@ def _weighting_failed():
return w


def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
def make_mp4(show_m, out_path, n_frames=720, az_ang=-0.5, fps=30, crf=35, verbose=True):
"""
Make a video from a Fury Show Manager.
Make an MP4 video from a Fury Show Manager with auto-cropping.

Parameters
----------
show_m : Fury Show Manager
The Fury Show Manager to use for rendering.

out_path : str
The name of the output file.
The name of the output file

n_frames : int
The number of frames to render.
Expand All @@ -382,17 +383,32 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
z-axis for each frame, in degrees.
Default: -10

Comment thread
36000 marked this conversation as resolved.
duration : int
The duration of each frame in the output GIF, in milliseconds.
Default: 150
fps : float
The frames per second for the output video.
Default: 30

crf : int
The Constant Rate Factor for the output video, which controls the
quality and file size. Lower values result in
higher quality and larger file sizes.
Default: 35 (very low quality, small file size)

verbose : bool
Whether to show a progress bar while generating the video.
Default: True
"""
if not out_path.lower().endswith(".mp4"):
out_path += ".mp4"

video = []

show_m.render()
show_m.window.draw()

with tempfile.TemporaryDirectory() as tmp_dir:
for ii in tqdm(range(n_frames), desc="Generating GIF", leave=False):
for ii in tqdm(
range(n_frames), desc="Generating MP4", leave=False, disable=not verbose
):
frame_fname = f"{tmp_dir}/{ii}.png"
show_m.screens[0].controller.rotate((radians(az_ang), 0), None)
show_m.render()
Expand All @@ -406,7 +422,6 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
for img in video:
arr = np.array(img)
bg_color = arr[0, 0]

mask = np.any(arr != bg_color, axis=-1)

if np.any(mask):
Expand All @@ -421,20 +436,42 @@ def make_gif(show_m, out_path, n_frames=36, az_ang=-10, duration=150):
all_lower = max(all_lower, ymax)

if all_left < all_right:
crop_box = (
max(0, all_left),
max(0, all_upper),
min(video[0].width, all_right),
min(video[0].height, all_lower),
)

def align_down(x, multiple=16):
return (x // multiple) * multiple

def align_up(x, multiple=16):
return ((x + multiple - 1) // multiple) * multiple

left = align_up(max(0, all_left), 16)
upper = align_up(max(0, all_upper), 16)
right = align_down(min(video[0].width, all_right), 16)
lower = align_down(min(video[0].height, all_lower), 16)
Comment thread
36000 marked this conversation as resolved.

crop_box = (left, upper, right, lower)
cropped_video = [img.crop(crop_box) for img in video]
else:
cropped_video = video

cropped_video[0].save(
width, height = cropped_video[0].size
with imageio.get_writer(
out_path,
save_all=True,
append_images=cropped_video[1:],
duration=duration,
loop=1,
)
fps=fps,
format="ffmpeg",
codec="libx264",
pixelformat="yuv420p",
output_params=[
"-crf",
f"{str(int(crf))}",
"-preset",
"veryslow",
"-movflags",
"+faststart",
],
) as writer:
for img in cropped_video:
if img.size != (width, height):
img = img.crop((0, 0, width, height))

frame_arr = np.array(img)
writer.append_data(frame_arr)
4 changes: 2 additions & 2 deletions AFQ/api/bundle_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def default_bd():
"primary_axis": "I/S",
"ORG_spectral_subbundles": SpectralSubbundleDict(
{
Comment thread
36000 marked this conversation as resolved.
"Left V1V3": {
"Left Early Visual": {
"cluster_IDs": [78],
"Left Optic Radiation": {
"core": "Anterior",
Expand Down Expand Up @@ -449,7 +449,7 @@ def default_bd():
"primary_axis": "I/S",
"ORG_spectral_subbundles": SpectralSubbundleDict(
{
"Right V1V3": {
"Right Early Visual": {
"cluster_IDs": [78],
"Right Optic Radiation": {
"core": "Anterior",
Expand Down
4 changes: 2 additions & 2 deletions AFQ/api/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def load_next_subject():
ses
]
seg_sft = aus.SegmentedSFT.fromfile(this_bundles_file, this_img)
seg_sft.sft.to_rasmm()
seg_sft.to_rasmm()
subses_info.append((seg_sft, this_mapping, this_img, this_reg_template))

bundle_dict = self.export("bundle_dict", collapse=False)[
Expand All @@ -567,7 +567,7 @@ def load_next_subject():
for b in bundle_dict.bundle_names:
for i in range(len(self.valid_sub_list)):
seg_sft, mapping, img, reg_template = subses_info[i]
idx = seg_sft.bundle_idxs[b]
idx = seg_sft.get_bundle_idxs(b)
# use the first subses that works
# otherwise try each successive subses
if len(idx) == 0:
Expand Down
19 changes: 5 additions & 14 deletions AFQ/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@

@immlib.calc("bundles")
@as_file("_desc-bundles_tractography")
def segment(
structural_imap, data_imap, mapping_imap, tractography_imap, segmentation_params
):
def segment(data_imap, mapping_imap, tractography_imap, segmentation_params):
"""
full path to a trk/trx file containing containing
segmented streamlines, labeled by bundle
Expand Down Expand Up @@ -102,7 +100,6 @@ def segment(
seg_sft.sft.dtype_dict = {"positions": np.float16, "offsets": np.uint32}
tgram = TrxFile.from_sft(seg_sft.sft)
tgram.groups = seg_sft.bundle_idxs

else:
tgram = seg_sft.sft

Expand Down Expand Up @@ -209,15 +206,9 @@ def export_bundle_lengths(bundles):
len_data[f"{bundle} Median"] = 0
len_data[f"{bundle} Min"] = 0
len_data[f"{bundle} Max"] = 0
len_data["Total Recognized Median"] = np.median(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Min"] = np.min(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Max"] = np.max(
seg_sft.sft._tractogram._streamlines._lengths
)
len_data["Total Recognized Median"] = np.median(seg_sft.get_lengths())
len_data["Total Recognized Min"] = np.min(seg_sft.get_lengths())
len_data["Total Recognized Max"] = np.max(seg_sft.get_lengths())

counts_df = pd.DataFrame(
data=len_data,
Expand Down Expand Up @@ -297,7 +288,7 @@ def tract_profiles(
reference = nib.load(scalar_dict[list(scalar_dict.keys())[0]])
seg_sft = aus.SegmentedSFT.fromfile(bundles, reference=reference)

seg_sft.sft.to_rasmm()
seg_sft.to_rasmm()
for bundle_name in seg_sft.bundle_names:
this_sl = seg_sft.get_bundle(bundle_name).streamlines
if len(this_sl) == 0:
Expand Down
22 changes: 11 additions & 11 deletions AFQ/tasks/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def viz_bundles(
)

fname = None
if "no_gif" not in viz_backend.backend:
fname = get_fname(base_fname, ".gif", "..")
if "no_mp4" not in viz_backend.backend:
fname = get_fname(base_fname, ".mp4", "..")

Comment thread
36000 marked this conversation as resolved.
try:
viz_backend.create_gif(figure, fname)
viz_backend.create_mp4(figure, fname)
except PermissionError as e:
logger.warning(f"Failed to write GIF file: {fname} \n{e}")
logger.warning(f"Failed to write MP4 file: {fname} \n{e}")
if "plotly" in viz_backend.backend:
fname = get_fname(base_fname, ".html", "..")

Expand Down Expand Up @@ -156,7 +156,7 @@ def viz_indivBundle(
n_points_indiv=40,
):
"""
list of full paths to html or gif files
list of full paths to html or mp4 files
containing visualizations of individual bundles

Parameters
Expand Down Expand Up @@ -255,15 +255,15 @@ def viz_indivBundle(

base_fname = op.join(output_dir, op.split(base_fname)[1])
figures[bundle_name] = figure
if "no_gif" not in viz_backend.backend:
if "no_mp4" not in viz_backend.backend:
fname = get_fname(
base_fname,
f"_desc-{str_to_desc(bundle_name)}_tractography.gif",
f"_desc-{str_to_desc(bundle_name)}_tractography.mp4",
"viz_bundles",
)

try:
viz_backend.create_gif(figure, fname)
viz_backend.create_mp4(figure, fname)
except PermissionError as e:
if not failed_write:
logger.warning(
Expand Down Expand Up @@ -380,7 +380,7 @@ def plot_tract_profiles(base_fname, output_dir, scalars, segmentation_imap):


@immlib.calc("viz_backend")
def init_viz_backend(viz_backend_spec="plotly_no_gif"):
def init_viz_backend(viz_backend_spec="plotly_no_mp4"):
"""
An instance of the `AFQ.viz.utils.viz_backend` class.

Expand All @@ -390,8 +390,8 @@ def init_viz_backend(viz_backend_spec="plotly_no_gif"):
Which visualization backend to use.
See Visualization Backends page in documentation for details
https://tractometry.org/pyAFQ/reference/viz_backend.html
One of {"fury", "plotly", "plotly_no_gif"}.
Default: "plotly_no_gif"
One of {"fury", "plotly", "plotly_no_mp4"}.
Default: "plotly_no_mp4"
"""
if "fury" not in viz_backend_spec and "plotly" not in viz_backend_spec:
raise TypeError("viz_backend_spec must contain either 'fury' or 'plotly'")
Expand Down
65 changes: 17 additions & 48 deletions AFQ/utils/docs.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,24 @@
import os
import shutil
from glob import glob
import base64

from sphinx_gallery.scrapers import figure_rst
from IPython.display import HTML


class PNGScraper(object):
def __init__(self):
self.seen = set()
def embed_video(path):
with open(path, "rb") as f:
mp4_data = base64.b64encode(f.read()).decode()
return HTML(
f'<video controls><source src="data:video/mp4;base64,{mp4_data}" '
'type="video/mp4"></video>'
)

def __repr__(self):
return "PNGScraper"

def __call__(self, block, block_vars, gallery_conf):
# Find all PNG files in the directory of this example.
path_current_example = os.path.dirname(block_vars["src_file"])
pngs = sorted(glob(os.path.join(path_current_example, "*.png")))
def embed_image(path):
with open(path, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
return HTML(f'<img src="data:image/png;base64,{img_data}"/>')

# Iterate through PNGs, copy them to the sphinx-gallery output directory
image_names = list()
image_path_iterator = block_vars["image_path_iterator"]
for png in pngs:
if png not in self.seen:
self.seen |= set(png)
this_image_path = image_path_iterator.next()
image_names.append(this_image_path)
shutil.move(png, this_image_path)
# Use the `figure_rst` helper function to generate rST for image files
return figure_rst(image_names, gallery_conf["src_dir"])


class GIFScraper(object):
def __init__(self):
self.seen = set()

def __repr__(self):
return "GIFScraper"

def __call__(self, block, block_vars, gallery_conf):
# Find all GIF files in the directory of this example.
path_current_example = os.path.dirname(block_vars["src_file"])
gifs = sorted(glob(os.path.join(path_current_example, "*.gif")))

# Iterate through GIFs, copy them to the sphinx-gallery output directory
image_names = list()
image_path_iterator = block_vars["image_path_iterator"]
for gif in gifs:
if gif not in self.seen:
self.seen |= set(gif)
this_image_path = image_path_iterator.next()
image_names.append(this_image_path)
shutil.move(gif, this_image_path)
# Use the `figure_rst` helper function to generate rST for image files
return figure_rst(image_names, gallery_conf["src_dir"])
def embed_html(path):
with open(path, "r") as f:
html_content = f.read()
return HTML(html_content)
Loading
Loading