search_hub

at 2ab9aa7 Raw

use anyhow::{bail, Context, Result};
use reqwest::Client;
use std::fmt;
use std::path::PathBuf;
use std::str::FromStr;

const DEFAULT_FEED_URL: &str = "https://vit.am/~ololduck/search_hub/releases.atom";

#[derive(Debug, PartialEq, Eq)]
struct Version(u32, u32, u32);

impl FromStr for Version {
    type Err = String;
    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        let s = s.strip_prefix('v').unwrap_or(s);
        let parts: Vec<&str> = s.splitn(3, '.').collect();
        if parts.len() != 3 {
            return Err(format!("invalid version: {}", s));
        }
        Ok(Version(
            parts[0].parse().map_err(|_| format!("invalid major in {}", s))?,
            parts[1].parse().map_err(|_| format!("invalid minor in {}", s))?,
            parts[2].parse().map_err(|_| format!("invalid patch in {}", s))?,
        ))
    }
}

impl fmt::Display for Version {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}.{}.{}", self.0, self.1, self.2)
    }
}

impl PartialOrd for Version {
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Ord for Version {
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        self.0
            .cmp(&other.0)
            .then(self.1.cmp(&other.1))
            .then(self.2.cmp(&other.2))
    }
}

fn current_version() -> Version {
    env!("CARGO_PKG_VERSION")
        .parse()
        .expect("CARGO_PKG_VERSION is not valid semver")
}

fn default_target() -> String {
    format!("{}-unknown-linux-gnu", std::env::consts::ARCH)
}

fn user_agent() -> String {
    format!("search_hub/{}", env!("CARGO_PKG_VERSION"))
}

fn parse_versions_from_atom(xml: &str) -> Vec<String> {
    let mut versions = Vec::new();
    let mut pos = 0;
    while let Some(entry_start) = xml[pos..].find("<entry>") {
        pos += entry_start + "<entry>".len();
        if let Some(title_start) = xml[pos..].find("<title>") {
            let content_pos = pos + title_start + "<title>".len();
            if let Some(content_end) = xml[content_pos..].find("</title>") {
                versions.push(xml[content_pos..content_pos + content_end].to_string());
                pos = content_pos + content_end + "</title>".len();
            }
        }
    }
    versions
}

pub struct SelfUpdate {
    feed_url: String,
    target: String,
    dry_run: bool,
}

impl Default for SelfUpdate {
    fn default() -> Self {
        Self::new()
    }
}

impl SelfUpdate {
    pub fn new() -> Self {
        Self {
            feed_url: DEFAULT_FEED_URL.to_string(),
            target: default_target(),
            dry_run: false,
        }
    }

    pub fn with_feed_url(mut self, url: String) -> Self {
        self.feed_url = url;
        self
    }

    pub fn with_target(mut self, target: String) -> Self {
        self.target = target;
        self
    }

    pub fn dry_run(mut self, yes: bool) -> Self {
        self.dry_run = yes;
        self
    }

    pub async fn run(&self) -> Result<()> {
        let client = Client::builder()
            .user_agent(&user_agent())
            .build()?;

        println!("Checking for updates at {}...", self.feed_url);
        let xml = client
            .get(&self.feed_url)
            .send()
            .await
            .context("Failed to fetch release feed")?
            .text()
            .await
            .context("Failed to read release feed")?;

        let version_strings = parse_versions_from_atom(&xml);
        if version_strings.is_empty() {
            println!("No releases found in feed.");
            return Ok(());
        }

        let current = current_version();

        let latest = version_strings
            .iter()
            .filter_map(|s| {
                let v = s.parse::<Version>().ok()?;
                Some((v, s))
            })
            .max_by(|(a, _), (b, _)| a.cmp(b));

        let latest_version = match latest {
            Some((v, s)) => {
                println!("Latest release: {}", s);
                v
            }
            None => {
                println!("No valid releases found in feed.");
                return Ok(());
            }
        };

        if latest_version <= current {
            println!(
                "Already up to date (current: {}, latest: {})",
                current, latest_version
            );
            return Ok(());
        }

        println!("Update available: {} -> {}", current, latest_version);

        if self.dry_run {
            let url = format!(
                "https://vit.am/~ololduck/search_hub/{}/dist/search_hub-{}-{}",
                latest_version, latest_version, self.target
            );
            println!("Would download: {}", url);
            return Ok(());
        }

        let url = format!(
            "https://vit.am/~ololduck/search_hub/{}/dist/search_hub-{}-{}",
            latest_version, latest_version, self.target
        );

        println!("Downloading {}...", url);
        let response = client
            .get(&url)
            .send()
            .await
            .context("Failed to download update")?;

        let status = response.status();
        if !status.is_success() {
            bail!(
                "Download failed (HTTP {}) - no binary available for target '{}'. \
                 Try --target (e.g. x86_64-unknown-linux-gnu) or build from source.",
                status,
                self.target
            );
        }

        let bytes = response.bytes().await?;

        if bytes.len() < 4 || bytes[..4] != [0x7f, b'E', b'L', b'F'] {
            bail!("Downloaded file is not a valid ELF binary");
        }

        let current_exe = std::env::current_exe()
            .context("Cannot determine current executable path")?;

        let temp_path = {
            let mut p = current_exe.clone().into_os_string();
            p.push(format!(".{}.tmp", std::process::id()));
            PathBuf::from(p)
        };

        tokio::fs::write(&temp_path, &bytes)
            .await
            .context("Failed to write temporary file")?;

        #[cfg(unix)]
        {
            use std::os::unix::fs::PermissionsExt;
            std::fs::set_permissions(&temp_path, std::fs::Permissions::from_mode(0o755))
                .context("Failed to set executable permissions")?;
        }

        std::fs::rename(&temp_path, &current_exe)
            .context("Failed to replace current binary - maybe missing write permission? \
                      Try running with sudo or reinstall via cargo install --path .")?;

        println!("Successfully updated to v{}!", latest_version);
        println!("Restart any running search_hub processes to use the new version.");

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_versions_simple() {
        let xml = r#"<?xml version="1.0"?>
<feed>
  <entry><title>0.3.0</title></entry>
  <entry><title>0.2.3</title></entry>
  <entry><title>0.2.2</title></entry>
</feed>"#;
        let versions = parse_versions_from_atom(xml);
        assert_eq!(versions, vec!["0.3.0", "0.2.3", "0.2.2"]);
    }

    #[test]
    fn parse_versions_with_namespace() {
        let xml = r#"<?xml version="1.0"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <entry>
    <title>1.0.0</title>
    <id>urn:test:1.0.0</id>
    <updated>2026-01-01T00:00:00Z</updated>
  </entry>
</feed>"#;
        let versions = parse_versions_from_atom(xml);
        assert_eq!(versions, vec!["1.0.0"]);
    }

    #[test]
    fn parse_versions_empty() {
        assert!(parse_versions_from_atom("").is_empty());
        assert!(parse_versions_from_atom("<feed></feed>").is_empty());
    }

    #[test]
    fn version_strips_v_prefix() {
        assert_eq!("v1.2.3".parse::<Version>().unwrap(), Version(1, 2, 3));
    }

    #[test]
    fn version_without_prefix() {
        assert_eq!("1.2.3".parse::<Version>().unwrap(), Version(1, 2, 3));
    }

    #[test]
    fn version_ordering() {
        assert!(Version(0, 3, 0) > Version(0, 2, 9));
        assert!(Version(1, 0, 0) > Version(0, 99, 99));
        assert!(Version(0, 0, 1) < Version(0, 0, 2));
        assert_eq!(Version(1, 2, 3), Version(1, 2, 3));
    }

    #[test]
    fn version_invalid() {
        assert!("abc".parse::<Version>().is_err());
        assert!("1.2".parse::<Version>().is_err());
        assert!("1.2.3.4".parse::<Version>().is_err());
        assert!("1.2.x".parse::<Version>().is_err());
    }

    #[test]
    fn version_display() {
        assert_eq!(Version(0, 3, 0).to_string(), "0.3.0");
        assert_eq!(Version(10, 20, 30).to_string(), "10.20.30");
    }

    #[test]
    fn parse_versions_skips_irrelevant_tags() {
        let xml = r#"<feed>
  <entry><title>0.3.0</title><title>ignored</title></entry>
</feed>"#;
        let versions = parse_versions_from_atom(xml);
        assert_eq!(versions, vec!["0.3.0"]);
    }

    #[test]
    fn selects_highest_version() {
        let versions = vec!["0.2.0".to_string(), "0.3.0".to_string(), "0.1.0".to_string()];
        let latest = versions
            .iter()
            .filter_map(|s| {
                let v = s.parse::<Version>().ok()?;
                Some((v, s))
            })
            .max_by(|(a, _), (b, _)| a.cmp(b));
        assert_eq!(latest.unwrap().0, Version(0, 3, 0));
    }
}