diff --git a/beast/tools/convert_hdf5_to_fits.py b/beast/tools/convert_hdf5_to_fits.py index b8654c73..c9cd0c6d 100644 --- a/beast/tools/convert_hdf5_to_fits.py +++ b/beast/tools/convert_hdf5_to_fits.py @@ -3,13 +3,17 @@ import tables -def st_file(file_name): +def st_file(file_name, autotrim=False): """Converts HDF5 files (specifically from Ben Williams) into FITS format. Parameters ---------- file_name : str (default=None) - Name of HDF5 file (will also be name of output file, with different extension) + Name of HDF5 file (will also be name of output file, with .phot.hdf5 + replaced with .st.fits) + autotrim : bool, optional (default=False) + Whether to automatically trim column names that are too long for FITS files, + and will therefore make this function crash otherwise Returns ------- @@ -71,9 +75,7 @@ def st_file(file_name): # save everything to astropy table in a fits file print("making astropy table") - new_table = Table() - for tag in set0: new_table[tag.upper()] = np.array(data_dict[tag]) for tag in set1: @@ -81,13 +83,108 @@ def st_file(file_name): for tag in set2: new_table[tag.upper()] = np.array(data_dict[tag]) + # Identify and replace too-long column names, if requested + if autotrim: + set0_trim, set1_trim, set2_trim = replace_largest_common_substring( + set0, set1, set2 + ) + + # write to file print("saving to fits file") new_table.write( - file_name.replace("phot.hdf5", "st.fits"), format="fits", overwrite=True + file_name.replace(".phot.hdf5", ".st.fits"), format="fits", overwrite=True ) - f.close() +def replace_largest_common_substring(set0, set1, set2): + """Finds the largest contiguous substring common to column names too long + for FITS key names + + Parameters + ---------- + set0 : list + set0, as produced by st_file + set1 : list + set1, as produced by st_file + set2 : list + set2, as produced by st_file + + Returns + ------- + set0 : list + set0, with any necessary tag trimming to allow FITS table to write + set1 : list + set1, with any necessary tag trimming to allow FITS table to write + set2 : list + set2, with any necessary tag trimming to allow FITS table to write + """ + + # Find all column names longer than 65 characters + sets = set0 + set1 + set2 + indices = np.where(np.array([len(tag) for tag in sets]) > 65)[0] + + # If no column names are too long, returns unchanged + if len(indices) == 0: + return set0, set1, set2 + + # Extract too-long column names + subset = [sets[i] for i in indices] + + # Check the shortest possible shared string + shortest_str = min(subset, key=len) + max_len = len(shortest_str) + + # Check substrings of decreasing length; break out when longest found + longest = False + for length in range(max_len, 0, -1): + for start in range(0, max_len - length + 1): + candidate = shortest_str[start : start + length] + if all(candidate in s for s in subset): + longest = True + longest_str = candidate + break + if longest: + break + + # Raise exception if no fix possible + if not longest: + raise Exception( + "Some column names are too long to write to FITS table, " + + "but no shared string found to trim" + ) + + # Find maximum length of column names, and therefore how much trimming is necessary + longest_tag = np.array([len(tag) for tag in sets]).max() + n_char_trim = longest_tag - 65 + + # If longest shared string isn't enough, raise exception + if (longest_tag - n_char_trim) >= 68: + raise Exception( + "No shared string in column names is long enough to enable " + + "enough trimming to get column names short enough to allow " + + "FITS table to write without crashing" + ) + + # Use this to decide what replacement stirng should be, first checking if we can just replace - and _ characters + new_str_test = longest_str.replace("-", "").replace("_", "") + if (len(longest_str) - len(new_str_test)) >= n_char_trim: + new_str = new_str_test + else: + new_str = longest_str[n_char_trim:] + + # Replace all colnames that contain our long shared string with the trimmed verison + set0_trim, set1_trim, set2_trim = set0.copy(), set1.copy(), set2.copy() + for i in range(len(set0)): + set0_trim[i] = set0[i].replace(longest_str, new_str) + for i in range(len(set1)): + set1_trim[i] = set1[i].replace(longest_str, new_str) + for i in range(len(set2)): + set2_trim[i] = set2[i].replace(longest_str, new_str) + + # Return trimed column names + return set0_trim, set1_trim, set2_trim + + if __name__ == "__main__": st_file(file_name=None)