diff --git a/src/main.rs b/src/main.rs index c39fae0..c316a9e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -89,7 +89,9 @@ struct BootstrapArgs { long = "restart", help = "Remove all toolchain and restart the bootstrap", default_value = "false", - action = ArgAction::SetTrue)] + action = ArgAction::SetTrue, + conflicts_with = "upgrade" + )] restart: bool, #[arg( @@ -108,9 +110,17 @@ struct BootstrapArgs { )] debug: bool, + #[arg( + long = "upstream-verus", + help = "Pull the upstream verus-lang/verus instead of asterinas/verus", + default_value = "false", + action = ArgAction::SetTrue + )] + upstream_verus: bool, + #[arg( long = "branch", - help = "The branch name to pull (only used with --upgrade)", + help = "The branch name to pull", value_name = "BRANCH_NAME" )] branch: Option, @@ -453,6 +463,7 @@ fn bootstrap(args: &BootstrapArgs) -> Result<(), DynError> { restart: args.restart, branch: args.branch.clone(), force_reset: args.upgrade, + upstream_verus: args.upstream_verus, }; if args.upgrade { diff --git a/src/verus.rs b/src/verus.rs index 0d6ce61..89a7bd2 100644 --- a/src/verus.rs +++ b/src/verus.rs @@ -1402,10 +1402,124 @@ pub mod install { pub release: bool, pub branch: Option, pub force_reset: bool, + pub upstream_verus: bool, } pub const VERUS_REPO_HTTPS: &str = "https://github.com/asterinas/verus.git"; pub const VERUS_REPO_SSH: &str = "git@github.com:asterinas/verus.git"; + pub const UPSTREAM_VERUS_REPO_HTTPS: &str = "https://github.com/verus-lang/verus.git"; + pub const UPSTREAM_VERUS_REPO_SSH: &str = "git@github.com:verus-lang/verus.git"; + + fn target_repo_urls(upstream: bool) -> (&'static str, &'static str) { + if upstream { + (UPSTREAM_VERUS_REPO_SSH, UPSTREAM_VERUS_REPO_HTTPS) + } else { + (VERUS_REPO_SSH, VERUS_REPO_HTTPS) + } + } + + fn is_target_origin(origin_url: &str, upstream: bool) -> bool { + let (ssh, https) = target_repo_urls(upstream); + origin_url == ssh || origin_url == https + } + + fn maybe_switch_origin_remote(dir: &Path, upstream: bool) -> Result { + let repo = Repository::open(dir).unwrap_or_else(|e| { + error!( + "Unable to find the git repo of verus under {}: {}", + dir.display(), + e + ); + }); + + let mut remote = repo.find_remote("origin")?; + let origin_url = remote.url().unwrap_or(""); + if is_target_origin(origin_url, upstream) { + return Ok(false); + } + + let (target_ssh, target_https) = target_repo_urls(upstream); + let use_ssh = origin_url.starts_with("git@") || origin_url.contains("ssh://"); + let target_url = if use_ssh { target_ssh } else { target_https }; + + info!( + "Switching origin remote from {} to {}", + origin_url, target_url + ); + repo.remote_set_url("origin", target_url)?; + + // Refresh local remote handle after remote update. + remote = repo.find_remote("origin")?; + debug!("Updated origin remote URL: {}", remote.url().unwrap_or("")); + Ok(true) + } + + fn force_reset_to_origin(dir: &Path, branch: Option<&str>) -> Result<(), DynError> { + let repo = Repository::open(dir).unwrap_or_else(|e| { + error!( + "Unable to find the git repo of verus under {}: {}", + dir.display(), + e + ); + }); + + let target_branch = branch.unwrap_or("main"); + + let mut remote = repo.find_remote("origin")?; + let remote_url = remote.url().unwrap_or(""); + let is_ssh = remote_url.starts_with("git@") || remote_url.contains("ssh://"); + + let mut callbacks = git2::RemoteCallbacks::new(); + if is_ssh { + callbacks.credentials(|_url, username_from_url, _allowed_types| { + git2::Cred::ssh_key_from_agent(username_from_url.unwrap_or("git")) + }); + } + + let mut fetch_opts = git2::FetchOptions::new(); + fetch_opts.remote_callbacks(callbacks); + remote.fetch(&[target_branch], Some(&mut fetch_opts), None)?; + + let upstream_branch = format!("refs/remotes/origin/{}", target_branch); + let upstream_ref = repo.find_reference(&upstream_branch).map_err(|_| { + format!( + "Branch '{}' does not exist in the remote repository. Please check the branch name.", + target_branch + ) + })?; + let upstream_commit = upstream_ref.peel_to_commit()?; + + repo.reset(upstream_commit.as_object(), git2::ResetType::Hard, None)?; + + let refname = format!("refs/heads/{}", target_branch); + if repo.find_reference(&refname).is_err() { + repo.reference( + &refname, + upstream_commit.id(), + false, + &format!("Create branch {}", target_branch), + )?; + } else { + let mut reference = repo.find_reference(&refname)?; + reference.set_target( + upstream_commit.id(), + &format!("Force reset to origin/{}", target_branch), + )?; + } + + repo.set_head(&refname)?; + let mut checkout_opts = git2::build::CheckoutBuilder::new(); + checkout_opts.force(); + repo.checkout_head(Some(&mut checkout_opts))?; + + status!("Force reset to origin/{} completed", target_branch); + status!( + "Repo {} updated to commit {}", + dir.display(), + upstream_commit.id() + ); + Ok(()) + } #[memoize] pub fn tools_dir() -> PathBuf { @@ -1427,10 +1541,34 @@ pub mod install { tools_dir().join("patches") } - pub fn clone_repo(verus_dir: &Path) -> Result<(), DynError> { - info!("Cloning Verus repo to {} ...", verus_dir.display()); + pub fn clone_repo( + verus_dir: &Path, + branch: Option<&str>, + upstream: bool, + ) -> Result<(), DynError> { + let repo_ssh = if upstream { + UPSTREAM_VERUS_REPO_SSH + } else { + VERUS_REPO_SSH + }; + let repo_https = if upstream { + UPSTREAM_VERUS_REPO_HTTPS + } else { + VERUS_REPO_HTTPS + }; + + let branch_name = branch.unwrap_or("main"); + + info!( + "Cloning Verus repo from {} (branch: {}) to {} ...", + repo_ssh, + branch_name, + verus_dir.display() + ); let mut builder = git2::build::RepoBuilder::new(); + builder.branch(branch_name); + let mut callbacks = git2::RemoteCallbacks::new(); callbacks.credentials(|_url, username_from_url, _allowed_types| { @@ -1441,12 +1579,17 @@ pub mod install { fetch_opts.remote_callbacks(callbacks); builder.fetch_options(fetch_opts); - let ssh_result = builder.clone(VERUS_REPO_SSH, verus_dir); + let ssh_result = builder.clone(repo_ssh, verus_dir); if ssh_result.is_ok() { return Ok(()); } - Repository::clone(VERUS_REPO_HTTPS, verus_dir) + info!("SSH failed, trying HTTPS: {}", repo_https); + + let mut builder_https = git2::build::RepoBuilder::new(); + builder_https.branch(branch_name); + builder_https + .clone(repo_https, verus_dir) .map_err(|e| format!("Failed to clone verus repo: {}", e))?; Ok(()) @@ -1596,10 +1739,6 @@ pub mod install { pub fn exec_bootstrap(options: &VerusInstallOpts) -> Result<(), DynError> { let verus_dir = verus_dir(); - if options.branch.is_some() { - error!("Specifying a branch is only supported during upgrade."); - } - if options.restart && verus_dir.exists() { info!("Removing old verus installation..."); std::fs::remove_dir_all(&verus_dir)?; @@ -1607,7 +1746,11 @@ pub mod install { // Clone the Verus repo if it doesn't exist if !verus_dir.exists() { - clone_repo(&verus_dir)?; + clone_repo( + &verus_dir, + options.branch.as_deref(), + options.upstream_verus, + )?; } // Download Z3 @@ -1804,17 +1947,19 @@ pub mod install { pub fn exec_upgrade(options: &VerusInstallOpts) -> Result<(), DynError> { // rebuild if required or if the directory doesn't exist - if options.restart || !verus_dir().exists() { + if !verus_dir().exists() { return exec_bootstrap(options); } // git pull the Verus repo let verus_dir = verus_dir(); - git_pull( - &verus_dir, - options.branch.as_ref().map(|s| s.as_str()), - options.force_reset, - )?; + let branch = options.branch.as_ref().map(|s| s.as_str()); + let switched = maybe_switch_origin_remote(&verus_dir, options.upstream_verus)?; + if switched { + force_reset_to_origin(&verus_dir, branch)?; + } else { + git_pull(&verus_dir, branch, options.force_reset)?; + } status!("Verus repo updated to the latest version"); // Build Verus