Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
opts.hf_prune_old_files = params.hf_prune_old_files;
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;

Expand Down Expand Up @@ -331,7 +332,8 @@ struct handle_model_result {

static handle_model_result common_params_handle_model(struct common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool offline,
bool hf_prune_old_files) {
handle_model_result result;

if (!model.docker_repo.empty()) {
Expand All @@ -346,6 +348,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
opts.hf_prune_old_files = hf_prune_old_files;
auto download_result = common_download_model(model, opts, true);

if (download_result.model_path.empty()) {
Expand All @@ -370,6 +373,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
opts.hf_prune_old_files = hf_prune_old_files;
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
Expand Down Expand Up @@ -576,7 +580,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context

// handle model and download
if (!skip_model_download) {
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, params.hf_prune_old_files);
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
Expand All @@ -586,12 +590,12 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// only download mmproj if the current example is using it
for (const auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
common_params_handle_model(params.mmproj, params.hf_token, params.offline, params.hf_prune_old_files);
break;
}
}
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline, params.hf_prune_old_files);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline, params.hf_prune_old_files);
}

// model is required (except for server)
Expand Down Expand Up @@ -2634,6 +2638,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_token = value;
}
).set_env("HF_TOKEN"));
add_opt(common_arg(
{"-hfp", "--hf-prune-old-files"},
string_format("Keep only latest version of model files, delete old ones (default: %s)", params.hf_prune_old_files ? "true" : "false"),
[](common_params & params) {
params.hf_prune_old_files = true;
}
).set_env("LLAMA_ARG_HF_PRUNE_OLD_FILES"));
add_opt(common_arg(
{"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)",
Expand Down
7 changes: 4 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,10 @@ struct common_params {

struct common_params_model model;

std::set<std::string> model_alias; // model aliases // NOLINT
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
std::string hf_token = ""; // HF token // NOLINT
std::set<std::string> model_alias; // model aliases // NOLINT
std::set<std::string> model_tags; // model tags (informational, not used for routing) // NOLINT
std::string hf_token = ""; // HF token // NOLINT
bool hf_prune_old_files = false; // whether to keep only latest version of model files // NOLINT
std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT
Expand Down
5 changes: 5 additions & 0 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,11 @@ common_download_model_result common_download_model(const common_params_model &
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}

if (opts.hf_prune_old_files) {
auto hf_repo_with_tag = common_download_split_repo_tag(model.hf_repo);
hf_cache::prune_old_files(hf_repo_with_tag.first, hf.model_files, hf.mmproj);
}
} else {
result.model_path = model.path;
}
Expand Down
1 change: 1 addition & 0 deletions common/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct common_download_opts {
common_header_list headers;
bool offline = false;
common_download_callback * callback = nullptr;
bool hf_prune_old_files = false;
};

// Result of common_download_model
Expand Down
129 changes: 129 additions & 0 deletions common/hf-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,57 @@ hf_files get_cached_files(const std::string & repo_id) {
return files;
}

hf_files get_all_snapshot_files(const std::string & repo_id) {
fs::path cache_dir = get_cache_directory();
if (!fs::exists(cache_dir)) {
return {};
}

if (!repo_id.empty() && !is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return {};
}

hf_files files;

for (const auto & repo : fs::directory_iterator(cache_dir)) {
if (!repo.is_directory()) {
continue;
}
fs::path snapshots_path = repo.path() / "snapshots";

if (!fs::exists(snapshots_path)) {
continue;
}
std::string _repo_id = folder_name_to_repo(repo.path().filename().string());

if (!is_valid_repo_id(_repo_id)) {
continue;
}
if (!repo_id.empty() && _repo_id != repo_id) {
continue;
}

for (const auto & entry : fs::recursive_directory_iterator(snapshots_path)) {
if (!entry.is_regular_file() && !fs::is_directory(entry.path())) {
continue;
}
fs::path path = entry.path();

if (!path.empty()) {
hf_file file;
file.repo_id = _repo_id;
file.path = path.generic_string();
file.local_path = entry.path().string();
file.final_path = file.local_path;
files.push_back(std::move(file));
}
}
}

return files;
}

std::string finalize_file(const hf_file & file) {
static std::atomic<bool> symlinks_disabled{false};

Expand Down Expand Up @@ -501,6 +552,84 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}

void prune_old_files(const std::string & hf_repo, const hf_cache::hf_files & current_model_files, const hf_cache::hf_file & current_mmproj) {
std::vector<std::string> filenames_to_delete;
std::vector<std::string> files_to_keep;

const auto get_symlink_target = [&](const std::string & file) {
std::error_code ec;

const auto & parent = fs::path(file).parent_path();
const auto & target_relative = fs::read_symlink(file, ec);
if (ec) {
LOG_DBG("%s: failed to read symlink %s: %s\n", __func__, file.c_str(), ec.message().c_str());
return std::string();
}
const auto & target_unresolved = parent / target_relative;
const auto & target = fs::weakly_canonical(target_unresolved, ec);
if (ec) {
LOG_DBG("%s: failed to resolve symlink target %s: %s\n", __func__, file.c_str(), ec.message().c_str());
return std::string();
}
return std::string(target);
};

for (const auto & file : current_model_files) {
files_to_keep.push_back(file.local_path);
filenames_to_delete.push_back(fs::path(file.local_path).filename());
const auto & target = get_symlink_target(file.local_path);
if (!target.empty()) {
files_to_keep.push_back(target);
}
}

if (!current_mmproj.local_path.empty()) {
files_to_keep.push_back(current_mmproj.local_path);
filenames_to_delete.push_back(fs::path(current_mmproj.local_path).filename());
const auto & target = get_symlink_target(current_mmproj.local_path);
if (!target.empty()) {
files_to_keep.push_back(target);
}
}

const auto cached_snapshot_files = hf_cache::get_all_snapshot_files(hf_repo);
for (int i = cached_snapshot_files.size() - 1; i >= 0; --i) {
const auto & file_path = cached_snapshot_files[i].local_path;
if (std::find(files_to_keep.begin(), files_to_keep.end(), file_path) != files_to_keep.end()) {
continue;
}
std::error_code ec;
for (const auto & filename : filenames_to_delete) {
if (string_ends_with(file_path, filename) && fs::is_symlink(file_path)) {
const auto & commit = fs::path(file_path).parent_path();
const auto & blob_file = get_symlink_target(file_path);

if (!fs::remove(file_path.c_str(), ec)) {
LOG_ERR("%s: error deleting old symlink file from hf cache %s: %s\n", __func__, file_path.c_str(), ec.message().c_str());
return;
}

if (fs::is_empty(commit)) {
if (!fs::remove(commit.c_str(), ec)) {
LOG_ERR("%s: error deleting old commit directory from hf cache %s: %s\n", __func__, commit.c_str(), ec.message().c_str());
return;
}
}

if (!blob_file.empty() && std::find(files_to_keep.begin(), files_to_keep.end(), blob_file) == files_to_keep.end()) {
LOG_INF("deleting old blob file from hf cache: %s -> %s\n", file_path.c_str(), blob_file.c_str());
if (fs::exists(blob_file)) {
if (!fs::remove(blob_file.c_str(), ec)) {
LOG_ERR("%s: error deleting old hf blob file %s: %s\n", __func__, file_path.c_str(), ec.message().c_str());
return;
}
}
}
}
}
}
}

// delete everything after this line, one day

// copied from download.cpp without the tag part
Expand Down
3 changes: 3 additions & 0 deletions common/hf-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ hf_files get_repo_files(
);

hf_files get_cached_files(const std::string & repo_id = {});
hf_files get_all_snapshot_files(const std::string & repo_id = {});

// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);

void prune_old_files(const std::string & hf_repo, const hf_files & current_model_files, const hf_file & current_mmproj);

// TODO: Remove later
void migrate_old_cache_to_hf_cache(const std::string & token, bool offline = false);

Expand Down