diff --git a/src/toil_vg/vg_surject.py b/src/toil_vg/vg_surject.py index 6553499b..019ccd29 100644 --- a/src/toil_vg/vg_surject.py +++ b/src/toil_vg/vg_surject.py @@ -40,6 +40,8 @@ def surject_subparser(parser): help="Path to xg index") parser.add_argument("--paths", nargs='+', default = [], help="list of path names to surject to (default: all in xg)") + parser.add_argument("--ref_paths_file", type=make_url, + help="file containing reference paths for SAM header (if 2nd column present, it will be used for lengths)") parser.add_argument("--interleaved", action="store_true", default=False, help="treat gam as interleaved read pairs. overrides map-args") parser.add_argument("--gam_input_reads", type=make_url, required=True, @@ -58,7 +60,7 @@ def surject_subparser(parser): add_container_tool_parse_args(parser) -def run_surjecting(job, context, gam_input_reads_id, output_name, interleaved, xg_file_id, paths): +def run_surjecting(job, context, gam_input_reads_id, output_name, interleaved, xg_file_id, paths, ref_paths_id): """ split the fastq, then surject each chunk. returns outputgams, paired with total surject time (excluding toil-vg overhead such as transferring and splitting files )""" @@ -76,10 +78,10 @@ def run_surjecting(job, context, gam_input_reads_id, output_name, interleaved, x reads_chunk_ids = [[r] for r in [gam_input_reads_id]] return child_job.addFollowOnJobFn(run_whole_surject, context, reads_chunk_ids, output_name, - interleaved, xg_file_id, paths, cores=context.config.misc_cores, + interleaved, xg_file_id, paths, ref_paths_id, cores=context.config.misc_cores, memory=context.config.misc_mem, disk=context.config.misc_disk).rv() -def run_whole_surject(job, context, reads_chunk_ids, output_name, interleaved, xg_file_id, paths): +def run_whole_surject(job, context, reads_chunk_ids, output_name, interleaved, xg_file_id, paths, ref_paths_id): """ Surject all gam chunks in parallel. @@ -106,7 +108,7 @@ def run_whole_surject(job, context, reads_chunk_ids, output_name, interleaved, x for chunk_id, chunk_filename_ids in enumerate(zip(*reads_chunk_ids)): #Run graph surject on each gam chunk chunk_surject_job = child_job.addChildJobFn(run_chunk_surject, context, interleaved, xg_file_id, - paths, chunk_filename_ids, '{}_chunk{}'.format(output_name, chunk_id), + paths, ref_paths_id, chunk_filename_ids, '{}_chunk{}'.format(output_name, chunk_id), cores=context.config.alignment_cores, memory=context.config.alignment_mem, disk=context.config.alignment_disk) @@ -118,7 +120,7 @@ def run_whole_surject(job, context, reads_chunk_ids, output_name, interleaved, x memory=context.config.misc_mem, disk=context.config.misc_disk).rv() -def run_chunk_surject(job, context, interleaved, xg_file_id, paths, chunk_filename_ids, chunk_id): +def run_chunk_surject(job, context, interleaved, xg_file_id, paths, ref_paths_id, chunk_filename_ids, chunk_id): """ run surject on a chunk. interface mostly copied from run_chunk_alignment. Takes an xg file and path colleciton to surject against, a list of chunk @@ -149,6 +151,10 @@ def run_chunk_surject(job, context, interleaved, xg_file_id, paths, chunk_filena gam_file = os.path.join(work_dir, 'reads_chunk_{}_{}.{}'.format(chunk_id, j, reads_ext)) job.fileStore.readGlobalFile(chunk_filename_id, gam_file) gam_files.append(gam_file) + + if ref_paths_id: + ref_paths_file = os.path.join(work_dir, 'ref-paths.tsv') + job.fileStore.readGlobalFile(ref_paths_id, ref_paths_file) # And a temp file for our surject output output_file = os.path.join(work_dir, "surject_{}.bam".format(chunk_id)) @@ -162,8 +168,10 @@ def run_chunk_surject(job, context, interleaved, xg_file_id, paths, chunk_filena cmd += ['-x', os.path.basename(xg_file)] for surject_path in paths: cmd += ['--into-path', surject_path] + if ref_paths_id: + cmd += ['--ref-paths', os.path.basename(ref_paths_file)] cmd += ['-t', str(context.config.alignment_cores)] - + # Mark when we start the surjection start_time = timeit.default_timer() try: @@ -244,6 +252,7 @@ def surject_main(context, options): # Upload local files to the remote IO Store inputXGFileID = importer.load(options.xg_index) inputGAMFileID = importer.load(options.gam_input_reads) + inputRefPathsID = importer.load(options.ref_paths_file) if options.ref_paths_file else None importer.wait() @@ -251,6 +260,7 @@ def surject_main(context, options): root_job = Job.wrapJobFn(run_surjecting, context, importer.resolve(inputGAMFileID), 'surject', options.interleaved, importer.resolve(inputXGFileID), options.paths, + importer.resolve(inputRefPathsID) if options.ref_paths_file else None, cores=context.config.misc_cores, memory=context.config.misc_mem, disk=context.config.misc_disk)