Commit
Message
Changed Files (19)
-
modified .abbaye/theme/static/site.css
diff --git a/.abbaye/theme/static/site.css b/.abbaye/theme/static/site.css index 21ca1c1..ca84d7e 100644 --- a/.abbaye/theme/static/site.css +++ b/.abbaye/theme/static/site.css @@ -120,7 +120,7 @@ header { letter-spacing: 0.15em; margin-bottom: 2rem; border-bottom: 3px solid var(--border-strong); - padding-bottom: 0.5rem; + padding: 0.5rem 1.5rem; background: none; color: var(--text); flex-wrap: wrap; -
modified CHANGELOG.md
diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d9278e..4eab7b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/). ## [Unreleased] +### Added +- `self-update` CLI command: checks abbaye Atom feed for new releases, downloads and replaces the binary +- `contrib/search-hub-self-update.{service,timer}` for weekly automated updates + ## [0.3.0] - 2026-06-19 ### Added -
modified README.md
diff --git a/README.md b/README.md index 76b72fa..795cdc3 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ search_hub serve Open http://127.0.0.1:8080 in your browser. You can now search your bookmarks. -Search queries are also forwarded to external engines: [crates.io](https://crates.io) via its public JSON API, and optionally [SearXNG](https://searx.space) (which aggregates Google, Bing, DDG, and dozens more) if `[engines.searxng]` is configured. Works as a custom search provider in Firefox/Zen via the OpenSearch protocol (your browser should auto-discover it at `/opensearch.xml`). +Search queries are also forwarded to external engines: [crates.io](https://crates.io) via its public JSON API, and optionally [SearXNG](https://searx.space) (which aggregates Google, Bing, DDG, and dozens more) if `[[engines]]` is configured. Works as a custom search provider in Firefox/Zen via the OpenSearch protocol (your browser should auto-discover it at `/opensearch.xml`). ## CLI reference @@ -52,6 +52,9 @@ Search queries are also forwarded to external engines: [crates.io](https://crate | `search_hub remove --id 1` | Delete a bookmark by ID | | `search_hub retag --all` | Re-run auto-tagging (requires `tagging_enabled = true` in config) | | `search_hub init-config` | Create a default config file at `~/.config/search_hub/config.toml` | +| `search_hub self-update` | Check abbaye Atom feed and update to the latest release | +| `search_hub self-update --dry-run` | Check for updates without downloading | +| `search_hub self-update --target x86_64-unknown-linux-gnu` | Override the target triple | All commands use `~/.local/share/search_hub/bookmarks.db` by default. Override with `--db-path` or set `db_path` in the config file. @@ -83,11 +86,20 @@ Run `search_hub init-config` to create `~/.config/search_hub/config.toml` with a # exclude_urls = ["localhost", "127.0.0.1", "::1"] # Per-engine configuration (optional) -# [engines.searxng] -# instance = "https://search.kael.ink" +# Multiple instances supported (e.g., public + private crates.io registries) +[[engines]] +type = "searxng" +instance = "https://search.kael.ink" +# timeout_secs = 10.0 # optional per-engine timeout # Best: use an existing public instance (see https://searx.space). # Also possible: run your own with Docker: # docker run -d --name searxng -p 8888:8080 searxng/searxng + +# Custom crates.io registry (optional) +# [[engines]] +# type = "crates_io" +# url = "https://registry.example.com/api/v1/crates?q={}&per_page=10" +# timeout_secs = 5.0 ``` ## Run the web server as a systemd user service @@ -113,6 +125,17 @@ systemctl --user enable --now search-hub-import.timer This imports bookmarks from Zen Browser daily. Edit the file to import from another browser. +## Auto-update with systemd + +```sh +cp contrib/search-hub-self-update.service ~/.config/systemd/user/ +cp contrib/search-hub-self-update.timer ~/.config/systemd/user/ +systemctl --user daemon-reload +systemctl --user enable --now search-hub-self-update.timer +``` + +This checks for new releases weekly and updates the binary automatically. + ## Resources - **Downloads:** [vit.am/~ololduck/search_hub/latest](https://vit.am/~ololduck/search_hub/latest/) -
modified abbaye.toml
diff --git a/abbaye.toml b/abbaye.toml index 6042789..e990fa8 100644 --- a/abbaye.toml +++ b/abbaye.toml @@ -13,6 +13,7 @@ dirty_suffix = "-dirty" [git_ui] clone_url = "https://vit.am/~ololduck/search_hub/repository.git" +include = ["main", "develop", "v*"] [[builders]] type = "archive" -
modified contrib/release.sh
diff --git a/contrib/release.sh b/contrib/release.sh index 5403d0d..b73a5e0 100755 --- a/contrib/release.sh +++ b/contrib/release.sh @@ -10,7 +10,7 @@ fi VERSION="$1" # Validate semver -if ! echo "$VERSION" | grep -qP '^\d+\.\d+\.\d+$'; then +if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+$'; then echo "error: version must be semver (e.g. 0.2.0)" exit 1 fi @@ -54,10 +54,10 @@ echo "==> Bumping version to $VERSION in Cargo.toml" sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml echo "==> Running cargo test" -cargo test 2>&1 | tail -5 +cargo test 2>&1 echo "==> Running cargo build --release" -cargo build --release 2>&1 | tail -3 +cargo build --release 2>&1 echo "==> Updating CHANGELOG.md for v$VERSION" RELEASE_DATE=$(date +%Y-%m-%d) -
added contrib/search-hub-self-update.service
diff --git a/contrib/search-hub-self-update.service b/contrib/search-hub-self-update.service new file mode 100644 index 0000000..2d5319c --- /dev/null +++ b/contrib/search-hub-self-update.service @@ -0,0 +1,6 @@ +[Unit] +Description=SearchHub self-update + +[Service] +Type=oneshot +ExecStart=%h/.cargo/bin/search_hub self-update -
added contrib/search-hub-self-update.timer
diff --git a/contrib/search-hub-self-update.timer b/contrib/search-hub-self-update.timer new file mode 100644 index 0000000..5c76b5e --- /dev/null +++ b/contrib/search-hub-self-update.timer @@ -0,0 +1,10 @@ +[Unit] +Description=Weekly SearchHub self-update + +[Timer] +OnCalendar=weekly +OnBootSec=30min +Persistent=true + +[Install] +WantedBy=timers.target -
modified src/config.rs
diff --git a/src/config.rs b/src/config.rs index 67f39f7..bd17f27 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,54 +1,33 @@ use figment::Figment; use figment::providers::{Format, Toml}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use serde::Deserialize; use std::path::PathBuf; -/// A single search engine definition with an optional CSS selector for -/// inline result extraction. -/// -/// When `selector` is `Some`, the search handler uses `scraper` to find -/// that container in the engine's HTML and extract `<a>` links from it. -/// Engines without a selector are skipped for inline extraction. -/// -/// # Example -/// -/// ```rust -/// use search_hub::config::ForwarderDef; -/// -/// let ddg = ForwarderDef { -/// id: "duckduckgo".into(), -/// name: "DuckDuckGo".into(), -/// url: "https://duckduckgo.com/?q={}".into(), -/// selector: Some("article[data-testid='result']".into()), -/// }; -/// assert_eq!(ddg.id, "duckduckgo"); -/// ``` -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ForwarderDef { - /// URL query parameter identifier (e.g. "duckduckgo"). - pub id: String, - /// Display name (e.g. "DuckDuckGo"). - pub name: String, - /// URL template with `{}` placeholder for the query string. - pub url: String, - /// CSS selector for the result container in the engine's HTML page. - /// Used for inline result extraction; `None` skips inline extraction. - #[serde(default)] - pub selector: Option<String>, +/// Configuration for a single search engine instance. +#[derive(Debug, Deserialize, Clone)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum EngineConfig { + /// crates.io registry (public or private) + CratesIo { + #[serde(default)] + url: Option<String>, + #[serde(default)] + timeout_secs: Option<f32>, + }, + /// SearXNG meta-search engine + SearXng { + instance: String, + #[serde(default)] + timeout_secs: Option<f32>, + }, } /// Application configuration loaded from the TOML config file. /// -/// Supports `[[tags]]`, `enabled_engines`, `tagging_enabled`, `tagging_threshold`, and `[engines.*]`. -/// /// # Example /// /// ```ignore /// let cfg = search_hub::config::Config::load(); -/// if cfg.tags.is_empty() { -/// println!("using default tags"); -/// } /// let engines = cfg.resolve_engines(); /// println!("{} engines enabled", engines.len()); /// ``` @@ -58,27 +37,43 @@ pub struct Config { #[serde(default)] pub tags: Vec<crate::tagging::TagDef>, /// List of engine IDs to enable for inline search results. - /// If `None`, all engines from `search_engines::default_search_engines()` are used. #[serde(default)] pub enabled_engines: Option<Vec<String>>, /// Whether auto-tagging is enabled. Defaults to `false` if not set. #[serde(default)] pub tagging_enabled: Option<bool>, - /// Tagging threshold (0.0 to 1.0). Tags with a score below this are - /// discarded. Defaults to 0.60 if not set. + /// Tagging threshold (0.0 to 1.0). Defaults to 0.60 if not set. #[serde(default)] pub tagging_threshold: Option<f64>, - /// Hostnames to exclude from content fetching during import. - /// Defaults to localhost addresses if not set. + /// Hostnames to exclude from content fetching. #[serde(default)] pub exclude_urls: Option<Vec<String>>, - /// Per-engine configuration, keyed by engine ID. - /// For example: `[engines.searxng]` with `instance = "https://..."`. + /// Per-engine configuration (supports multiple instances of the same engine type). #[serde(default)] - pub engines: Option<HashMap<String, toml::Table>>, - /// Default bookmark database path. Overrides the platform default. + pub engines: Vec<EngineConfig>, + /// Default bookmark database path. #[serde(default)] pub db_path: Option<String>, + + /// Server bind address (default: "127.0.0.1"). + #[serde(default)] + pub bind_address: Option<String>, + /// Results per page (default: 20). + #[serde(default)] + pub page_size: Option<usize>, + /// Actix worker threads (default: 2). + #[serde(default)] + pub workers: Option<usize>, + + /// ONNX embedding model name (default: "BGESmallENV15"). + #[serde(default)] + pub onnx_model: Option<String>, + /// Max characters to use from page content for tagging (default: 2000). + #[serde(default)] + pub truncation: Option<usize>, + /// Max tags to assign per bookmark (default: 5). + #[serde(default)] + pub max_tags: Option<usize>, } impl Config { @@ -129,8 +124,14 @@ impl Default for Config { tagging_enabled: None, tagging_threshold: None, exclude_urls: None, - engines: None, + engines: Vec::new(), db_path: None, + bind_address: None, + page_size: None, + workers: None, + onnx_model: None, + truncation: None, + max_tags: None, } } } @@ -139,8 +140,8 @@ impl Config { /// Resolve the list of enabled search engines. /// /// Default engines (`crates.io`) are included unless filtered by - /// `enabled_engines`. Engines with configuration in the `engines` map - /// (e.g. `searxng`) are added subject to the same filter. + /// `enabled_engines`. Engines with configuration in the `engines` vec + /// are added subject to the same filter. pub fn resolve_engines(&self) -> Vec<Box<dyn crate::search_engines::SearchEngine>> { let is_enabled = |id: &str| -> bool { self.enabled_engines @@ -151,24 +152,40 @@ impl Config { let mut engines: Vec<Box<dyn crate::search_engines::SearchEngine>> = Vec::new(); - for e in crate::search_engines::default_search_engines() { - if is_enabled(e.id()) { - engines.push(e); + // Add default crates.io if enabled and not explicitly configured + if is_enabled("crates.io") { + let has_custom_crates_io = self.engines.iter().any(|cfg| { + matches!(cfg, EngineConfig::CratesIo { .. }) + }); + if !has_custom_crates_io { + engines.push(Box::new(crate::search_engines::crates_io::CratesIo { + timeout_secs: None, + api_url: crate::search_engines::crates_io::DEFAULT_API_URL.into(), + })); } } - if let Some(ref configs) = self.engines { - for (id, config) in configs { - if !is_enabled(id) { - continue; + // Add configured engines + for cfg in &self.engines { + match cfg { + EngineConfig::CratesIo { url, timeout_secs } => { + if is_enabled("crates.io") { + let api_url = url.clone().unwrap_or_else(|| crate::search_engines::crates_io::DEFAULT_API_URL.into()); + engines.push(Box::new(crate::search_engines::crates_io::CratesIo { + timeout_secs: *timeout_secs, + api_url, + })); + } } - match id.as_str() { - "searxng" => { - if let Some(engine) = crate::search_engines::searxng::SearXng::from_config(config) { - engines.push(engine); - } + EngineConfig::SearXng { instance, timeout_secs } => { + if is_enabled("searxng") { + let url_tpl = format!("{}/search?format=json&q={{}}", instance.trim_end_matches('/')); + engines.push(Box::new(crate::search_engines::searxng::SearXng { + instance: instance.clone(), + url_tpl, + timeout_secs: *timeout_secs, + })); } - _ => {} } } } @@ -205,29 +222,34 @@ mod tests { fn load_from_missing_file_returns_default() { let cfg = Config::load_from(&PathBuf::from("/nonexistent/path.toml")); assert!(cfg.tags.is_empty()); - assert!(cfg.engines.is_none()); + assert!(cfg.engines.is_empty()); } #[test] fn load_from_valid_file_with_engines() { let mut file = NamedTempFile::new().unwrap(); write!(file, r#" -[engines.searxng] +[[engines]] +type = "searxng" instance = "https://search.example.com" "#).unwrap(); let cfg = Config::load_from(&file.path().to_path_buf()); - let engines = cfg.engines.unwrap(); - assert!(engines.contains_key("searxng")); - let searxng = &engines["searxng"]; - assert_eq!(searxng.get("instance").unwrap().as_str(), Some("https://search.example.com")); + assert!(!cfg.engines.is_empty()); + assert!(matches!(cfg.engines[0], EngineConfig::SearXng { .. })); + if let EngineConfig::SearXng { instance, .. } = &cfg.engines[0] { + assert_eq!(instance, "https://search.example.com"); + } else { + panic!("expected SearXng"); + } } #[test] - fn resolve_engines_includes_searxng_from_engines_map() { + fn resolve_engines_includes_searxng_from_engines_vec() { let mut file = NamedTempFile::new().unwrap(); write!(file, r#" -[engines.searxng] +[[engines]] +type = "searxng" instance = "https://search.example.com" "#).unwrap(); @@ -241,7 +263,8 @@ instance = "https://search.example.com" let mut file = NamedTempFile::new().unwrap(); write!(file, r#" enabled_engines = ["crates.io"] -[engines.searxng] +[[engines]] +type = "searxng" instance = "https://search.example.com" "#).unwrap(); @@ -258,7 +281,8 @@ instance = "https://search.example.com" let mut file = NamedTempFile::new().unwrap(); write!(file, r#" enabled_engines = ["crates.io", "searxng"] -[engines.searxng] +[[engines]] +type = "searxng" instance = "https://search.example.com" "#).unwrap(); @@ -274,7 +298,7 @@ instance = "https://search.example.com" write!(file, "invalid toml [[[").unwrap(); let cfg = Config::load_from(&file.path().to_path_buf()); assert!(cfg.tags.is_empty()); - assert!(cfg.engines.is_none()); + assert!(cfg.engines.is_empty()); } #[test] -
modified src/lib.rs
diff --git a/src/lib.rs b/src/lib.rs index f6fcf48..f66e51e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod config; pub mod importer; pub mod models; pub mod search_engines; +pub mod self_update; pub mod storage; pub mod tagging; pub mod web; -
modified src/main.rs
diff --git a/src/main.rs b/src/main.rs index 61a7960..d548bb0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,38 +1,68 @@ +use chrono::{DateTime, Local, TimeZone, Utc}; use clap::{Parser, Subcommand}; +use colored::Colorize; use indicatif::{ProgressBar, ProgressStyle}; +use robotstxt::DefaultMatcher; use search_hub::config::{config_file_path, Config}; -use search_hub::search_engines::SearchEngine; use search_hub::importer::chrome::ChromeImporter; use search_hub::importer::firefox::FirefoxImporter; use search_hub::importer::zen::ZenImporter; use search_hub::importer::Importer; use search_hub::models::Bookmark; +use search_hub::search_engines::SearchEngine; use search_hub::storage; use search_hub::tagging::{default_tags, TagDef, TaggingEngine}; use search_hub::web; -use chrono::{Local, TimeZone, Utc}; -use colored::Colorize; -use robotstxt::DefaultMatcher; +use serde::{Deserialize, Serialize}; use std::collections::{HashMap, VecDeque}; +use std::fs; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tracing::{error, info}; const USER_AGENT: &str = concat!("search_hub/", env!("CARGO_PKG_VERSION")); +#[derive(Serialize, Deserialize)] +struct CacheEntry { + body: String, + fetched_at: DateTime<Utc>, +} + struct Fetcher { client: reqwest::Client, rt: tokio::runtime::Runtime, - robots_cache: Mutex<HashMap<String, String>>, + robots_cache: Mutex<HashMap<String, CacheEntry>>, + cache_path: PathBuf, } impl Fetcher { - fn new() -> anyhow::Result<Self> { + fn new(cache_path: PathBuf) -> anyhow::Result<Self> { let client = reqwest::Client::builder() .user_agent(USER_AGENT) .build()?; let rt = tokio::runtime::Runtime::new()?; - Ok(Self { client, rt, robots_cache: Mutex::new(HashMap::new()) }) + let cache = Self::load_cache(&cache_path); + Ok(Self { client, rt, robots_cache: Mutex::new(cache), cache_path }) + } + + fn load_cache(path: &PathBuf) -> HashMap<String, CacheEntry> { + if path.exists() { + match fs::read_to_string(path) { + Ok(data) => { + serde_json::from_str(&data).unwrap_or_default() + } + Err(_) => HashMap::new(), + } + } else { + HashMap::new() + } + } + + fn save_cache(&self) { + let cache = self.robots_cache.lock().unwrap(); + if let Ok(data) = serde_json::to_string(&*cache) { + let _ = fs::write(&self.cache_path, data); + } } fn fetch(&mut self, url: &str) -> Result<String, String> { @@ -42,23 +72,25 @@ impl Fetcher { async fn fetch_async(&self, url: &str) -> Result<String, String> { let parsed = url::Url::parse(url).map_err(|e| format!("invalid url: {}", e))?; let domain = parsed.host_str().unwrap_or("").to_string(); + let one_month = chrono::Duration::days(30); { let mut cache = self.robots_cache.lock().unwrap(); - if !cache.contains_key(&domain) { + let expired = cache.get(&domain).map_or(true, |e| Utc::now() - e.fetched_at > one_month); + if !cache.contains_key(&domain) || expired { let robots_url = format!("{}://{}/robots.txt", parsed.scheme(), domain); let body = match self.client.get(&robots_url).send().await { Ok(resp) => resp.text().await.unwrap_or_default(), Err(_) => String::new(), }; - cache.insert(domain.clone(), body); + cache.insert(domain.clone(), CacheEntry { body, fetched_at: Utc::now() }); } } let allowed = { let cache = self.robots_cache.lock().unwrap(); DefaultMatcher::default() - .one_agent_allowed_by_robots(&cache[&domain], USER_AGENT, url) + .one_agent_allowed_by_robots(&cache[&domain].body, USER_AGENT, url) }; if !allowed { return Err(format!("blocked by robots.txt")); @@ -73,6 +105,12 @@ impl Fetcher { } } +impl Drop for Fetcher { + fn drop(&mut self) { + self.save_cache(); + } +} + #[derive(Parser)] #[command(author, version, about)] struct Args { @@ -144,6 +182,18 @@ Import { }, /// Create a default config file at the default config path InitConfig, +/// Check for updates and apply them automatically +SelfUpdate { + #[arg(long)] + /// Release feed URL (default: abbaye Atom feed) + feed_url: Option<String>, + #[arg(long)] + /// Target triple for binary download (default: auto-detected) + target: Option<String>, + #[arg(long)] + /// Check for updates without downloading + dry_run: bool, +}, } #[derive(Subcommand)] @@ -259,12 +309,27 @@ async fn main() { } else { config.tags }; + let bind_address = config.bind_address.clone().unwrap_or_else(|| "127.0.0.1".into()); + let page_size = config.page_size.unwrap_or(20); + let workers = config.workers.unwrap_or(2); + let onnx_model = config.onnx_model.clone().unwrap_or_else(|| "BGESmallENV15".into()); + let truncation = config.truncation.unwrap_or(2000); + let max_tags = config.max_tags.unwrap_or(5); + let cache_dir = directories::ProjectDirs::from("com", "search_hub", "search_hub") + .map(|d| d.cache_dir().to_path_buf()) + .unwrap_or_else(|| { + let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".into()); + PathBuf::from(home).join(".cache").join("search_hub") + }); + let _ = fs::create_dir_all(&cache_dir); + let cache_path = cache_dir.join("robots_cache.json"); match args.command { Command::Serve { port, db_path } => { let db_path = resolve_db_path(db_path, config.db_path.as_deref()); - info!("Starting server on 127.0.0.1:{}", port); - if let Err(e) = web::run_server(&db_path.to_string_lossy(), port, engines).await { + info!("Starting server on {}:{}", bind_address, port); + let srv_cfg = web::ServerConfig { port, bind_address, page_size, workers }; + if let Err(e) = web::run_server(&db_path.to_string_lossy(), srv_cfg, engines).await { error!("Server error: {}", e); } } @@ -277,7 +342,7 @@ async fn main() { return; } - let mut fetcher = match Fetcher::new() { + let mut fetcher = match Fetcher::new(cache_path.clone()) { Ok(f) => f, Err(e) => { eprintln!("Warning: failed to create HTTP client: {}", e); @@ -309,9 +374,9 @@ async fn main() { return None; } info!("tagging content..."); - match TaggingEngine::new(&tags, tag_threshold) { + match TaggingEngine::new(&tags, tag_threshold, max_tags, truncation, &onnx_model) { Ok(mut engine) => { - let tags = engine.tags_for(c, 5).unwrap_or_default(); + let tags = engine.tags_for(c).unwrap_or_default(); if tags.is_empty() { info!("no tags matched"); None @@ -379,7 +444,7 @@ async fn main() { let db_path = resolve_db_path(db_path, config.db_path.as_deref()); let conn = storage::init_db(&db_path.to_string_lossy()).expect("Failed to open database"); - let mut engine = match TaggingEngine::new(&tags, tag_threshold) { + let mut engine = match TaggingEngine::new(&tags, tag_threshold, max_tags, truncation, &onnx_model) { Ok(e) => e, Err(e) => { eprintln!("Warning: failed to initialize tagger: {}", e); @@ -451,7 +516,7 @@ async fn main() { Ok(Some(b)) => { match b.content { Some(ref content) => { - match engine.tags_for(content, 5) { + match engine.tags_for(content) { Ok(tags) => { let tags_str = if tags.is_empty() { None } else { Some(tags.join(", ")) }; storage::update_bookmark_tags(&conn, *rowid, tags_str.as_deref()) @@ -539,19 +604,32 @@ async fn main() { tokio::fs::write(&path, content).await.expect("Failed to write config file"); println!("Default config created at {:?}", path); } + Command::SelfUpdate { feed_url, target, dry_run } => { + let mut updater = search_hub::self_update::SelfUpdate::new() + .dry_run(dry_run); + if let Some(url) = feed_url { + updater = updater.with_feed_url(url.clone()); + } + if let Some(t) = target { + updater = updater.with_target(t.clone()); + } + if let Err(e) = updater.run().await { + error!("Self-update failed: {}", e); + } + } Command::Import { action } => { match action { ImportAction::Bookmarks { source, profile, db_path } => { let db_path = resolve_db_path(db_path, config.db_path.as_deref()); - run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::Bookmarks).await; + run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::Bookmarks, cache_path.clone(), max_tags, truncation, onnx_model.clone()).await; } ImportAction::History { source, profile, db_path } => { let db_path = resolve_db_path(db_path, config.db_path.as_deref()); - run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::History).await; + run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::History, cache_path.clone(), max_tags, truncation, onnx_model.clone()).await; } ImportAction::All { source, profile, db_path } => { let db_path = resolve_db_path(db_path, config.db_path.as_deref()); - run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::All).await; + run_import(&source, profile, &db_path.to_string_lossy(), tags.clone(), tagging_enabled, tag_threshold, &exclude_hosts, ImportKind::All, cache_path.clone(), max_tags, truncation, onnx_model.clone()).await; } } } @@ -611,7 +689,7 @@ fn resolve_profiles(importer: &(impl Importer + ?Sized), profile: Option<String> } } -async fn run_import(source: &str, profile: Option<String>, db_path: &str, tags: Vec<TagDef>, tagging_enabled: bool, tag_threshold: f32, exclude_hosts: &[String], kind: ImportKind) { +async fn run_import(source: &str, profile: Option<String>, db_path: &str, tags: Vec<TagDef>, tagging_enabled: bool, tag_threshold: f32, exclude_hosts: &[String], kind: ImportKind, cache_path: PathBuf, max_tags: usize, truncation: usize, onnx_model: String) { let importer: Box<dyn Importer> = match source { "firefox" => Box::new(FirefoxImporter), "zen" => Box::new(ZenImporter), @@ -719,8 +797,10 @@ async fn run_import(source: &str, profile: Option<String>, db_path: &str, tags: let task_threshold = tag_threshold; let task_exclude = exclude_hosts.to_vec(); let task_tagging_enabled = tagging_enabled; + let task_cache = cache_path.clone(); + let task_onnx = onnx_model.clone(); tokio::task::spawn_blocking(move || { - let mut fetcher = match Fetcher::new() { + let mut fetcher = match Fetcher::new(task_cache) { Ok(f) => f, Err(e) => { let _ = tx.send(format!("[{}] failed to create HTTP client: {}", task_id, e)); @@ -728,7 +808,7 @@ async fn run_import(source: &str, profile: Option<String>, db_path: &str, tags: } }; let mut tagger = if task_tagging_enabled { - TaggingEngine::new(&task_tags, task_threshold).ok() + TaggingEngine::new(&task_tags, task_threshold, max_tags, truncation, &task_onnx).ok() } else { None }; @@ -743,7 +823,7 @@ async fn run_import(source: &str, profile: Option<String>, db_path: &str, tags: match fetch_and_convert(&mut fetcher, url, Some(task_id)) { Some(md) => { let entry_tags = tagger.as_mut() - .and_then(|e| e.tags_for(&md, 5).ok()) + .and_then(|e| e.tags_for(&md).ok()) .unwrap_or_default(); let tags_str = if entry_tags.is_empty() { None } else { Some(entry_tags.join(", ")) }; storage::update_bookmark_content_tags( -
modified src/models.rs
diff --git a/src/models.rs b/src/models.rs index 3e505c8..0f12448 100644 --- a/src/models.rs +++ b/src/models.rs @@ -44,5 +44,4 @@ pub struct Bookmark { pub created_at: DateTime<Utc>, } -/// Re-export of the search engine result type for external search results. -pub use crate::search_engines::ResultEntry as ExternalResult; + -
modified src/search_engines/crates_io.rs
diff --git a/src/search_engines/crates_io.rs b/src/search_engines/crates_io.rs index c9fadad..b9b0b68 100644 --- a/src/search_engines/crates_io.rs +++ b/src/search_engines/crates_io.rs @@ -3,7 +3,12 @@ use serde::Deserialize; use crate::search_engines::{EngineError, ResultEntry, SearchEngine}; -pub struct CratesIo; +pub const DEFAULT_API_URL: &str = "https://crates.io/api/v1/crates?q={}&per_page=10"; + +pub struct CratesIo { + pub timeout_secs: Option<f32>, + pub api_url: String, +} #[derive(Deserialize)] struct CrateResult { @@ -30,13 +35,19 @@ impl SearchEngine for CratesIo { } fn url_template(&self) -> &str { - "https://crates.io/api/v1/crates?q={}&per_page=10" + &self.api_url } fn selector(&self) -> &str { "" } + fn timeout(&self) -> std::time::Duration { + self.timeout_secs + .map(|s| std::time::Duration::from_secs_f32(s)) + .unwrap_or_else(|| std::time::Duration::from_secs(5)) + } + async fn fetch_results( &self, query: &str, @@ -86,10 +97,6 @@ impl SearchEngine for CratesIo { } } -pub fn engine() -> CratesIo { - CratesIo -} - #[cfg(test)] mod tests { use super::*; @@ -97,30 +104,34 @@ mod tests { #[test] fn test_id() { - assert_eq!(CratesIo.id(), "crates.io"); + let e = CratesIo { timeout_secs: None, api_url: DEFAULT_API_URL.into() }; + assert_eq!(e.id(), "crates.io"); } #[test] fn test_name() { - assert_eq!(CratesIo.name(), "crates.io"); + let e = CratesIo { timeout_secs: None, api_url: DEFAULT_API_URL.into() }; + assert_eq!(e.name(), "crates.io"); } #[test] fn test_url_template() { + let e = CratesIo { timeout_secs: None, api_url: DEFAULT_API_URL.into() }; assert_eq!( - CratesIo.url_template(), + e.url_template(), "https://crates.io/api/v1/crates?q={}&per_page=10" ); } #[test] fn test_selector() { - assert_eq!(CratesIo.selector(), ""); + let e = CratesIo { timeout_secs: None, api_url: DEFAULT_API_URL.into() }; + assert_eq!(e.selector(), ""); } #[test] fn test_engine_construct() { - let e = engine(); + let e = CratesIo { timeout_secs: None, api_url: DEFAULT_API_URL.into() }; assert_eq!(e.id(), "crates.io"); } } -
modified src/search_engines/mod.rs
diff --git a/src/search_engines/mod.rs b/src/search_engines/mod.rs index 31a2158..3e6ee03 100644 --- a/src/search_engines/mod.rs +++ b/src/search_engines/mod.rs @@ -6,6 +6,7 @@ use scraper::{Html, Selector}; use serde::Serialize; use std::collections::HashSet; use std::fmt; +use std::time::Duration; /// A single search result returned by an external search engine. /// @@ -87,6 +88,11 @@ pub trait SearchEngine: Send + Sync { /// CSS selector targeting the result container in the engine's HTML page. fn selector(&self) -> &str; + /// Maximum time to wait for this engine to respond (default: 5s). + fn timeout(&self) -> Duration { + Duration::from_secs(5) + } + /// Build a search URL from the given query by replacing `{}` with the /// URL-encoded query string. fn search_url(&self, query: &str) -> String { @@ -187,7 +193,7 @@ pub trait SearchEngine: Send + Sync { /// ``` pub fn default_search_engines() -> Vec<Box<dyn SearchEngine>> { vec![ - Box::new(crates_io::CratesIo), + Box::new(crates_io::CratesIo { timeout_secs: None, api_url: crates_io::DEFAULT_API_URL.into() }), ] } -
modified src/search_engines/searxng.rs
diff --git a/src/search_engines/searxng.rs b/src/search_engines/searxng.rs index ea9699e..5d8218e 100644 --- a/src/search_engines/searxng.rs +++ b/src/search_engines/searxng.rs @@ -6,14 +6,7 @@ use crate::search_engines::{EngineError, ResultEntry, SearchEngine}; pub struct SearXng { pub instance: String, pub url_tpl: String, -} - -impl SearXng { - pub fn from_config(config: &toml::Table) -> Option<Box<dyn SearchEngine>> { - let instance = config.get("instance")?.as_str()?.to_string(); - let url_tpl = format!("{}/search?format=json&q={{}}", instance.trim_end_matches('/')); - Some(Box::new(SearXng { instance, url_tpl })) - } + pub timeout_secs: Option<f32>, } #[derive(Deserialize)] @@ -47,6 +40,12 @@ impl SearchEngine for SearXng { "" } + fn timeout(&self) -> std::time::Duration { + self.timeout_secs + .map(|s| std::time::Duration::from_secs_f32(s)) + .unwrap_or_else(|| std::time::Duration::from_secs(5)) + } + async fn fetch_results( &self, query: &str, @@ -102,6 +101,7 @@ mod tests { let e = SearXng { instance: "https://example.com".into(), url_tpl: "https://example.com/search?format=json&q={}".into(), + timeout_secs: None, }; assert_eq!(e.id(), "searxng"); } @@ -111,6 +111,7 @@ mod tests { let e = SearXng { instance: "https://example.com".into(), url_tpl: "https://example.com/search?format=json&q={}".into(), + timeout_secs: None, }; assert_eq!(e.name(), "SearXNG"); } @@ -120,6 +121,7 @@ mod tests { let e = SearXng { instance: "https://example.com".into(), url_tpl: "https://example.com/search?format=json&q={}".into(), + timeout_secs: None, }; assert_eq!(e.selector(), ""); } @@ -129,6 +131,7 @@ mod tests { let e = SearXng { instance: "https://my-instance.net".into(), url_tpl: "https://my-instance.net/search?format=json&q={}".into(), + timeout_secs: None, }; assert_eq!( e.url_template(), @@ -136,51 +139,12 @@ mod tests { ); } - #[test] - fn test_from_config_valid() { - let mut config = toml::Table::new(); - config.insert("instance".into(), toml::Value::String("https://search.example.com".into())); - let result = SearXng::from_config(&config); - assert!(result.is_some()); - let engine = result.unwrap(); - assert_eq!(engine.id(), "searxng"); - assert_eq!( - engine.url_template(), - "https://search.example.com/search?format=json&q={}" - ); - } - - #[test] - fn test_from_config_trailing_slash_stripped() { - let mut config = toml::Table::new(); - config.insert("instance".into(), toml::Value::String("https://search.example.com/".into())); - let result = SearXng::from_config(&config); - assert!(result.is_some()); - let engine = result.unwrap(); - assert_eq!( - engine.url_template(), - "https://search.example.com/search?format=json&q={}" - ); - } - - #[test] - fn test_from_config_missing_instance() { - let config = toml::Table::new(); - assert!(SearXng::from_config(&config).is_none()); - } - - #[test] - fn test_from_config_non_string_instance() { - let mut config = toml::Table::new(); - config.insert("instance".into(), toml::Value::Integer(42)); - assert!(SearXng::from_config(&config).is_none()); - } - #[test] fn test_search_url_uses_template() { let e = SearXng { instance: "https://example.com".into(), url_tpl: "https://example.com/search?format=json&q={}".into(), + timeout_secs: None, }; assert_eq!( e.search_url("tokio"), -
added src/self_update.rs
diff --git a/src/self_update.rs b/src/self_update.rs new file mode 100644 index 0000000..e8b246c --- /dev/null +++ b/src/self_update.rs @@ -0,0 +1,324 @@ +use anyhow::{bail, Context, Result}; +use reqwest::Client; +use std::fmt; +use std::path::PathBuf; +use std::str::FromStr; + +const DEFAULT_FEED_URL: &str = "https://vit.am/~ololduck/search_hub/releases.atom"; + +#[derive(Debug, PartialEq, Eq)] +struct Version(u32, u32, u32); + +impl FromStr for Version { + type Err = String; + fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { + let s = s.strip_prefix('v').unwrap_or(s); + let parts: Vec<&str> = s.splitn(3, '.').collect(); + if parts.len() != 3 { + return Err(format!("invalid version: {}", s)); + } + Ok(Version( + parts[0].parse().map_err(|_| format!("invalid major in {}", s))?, + parts[1].parse().map_err(|_| format!("invalid minor in {}", s))?, + parts[2].parse().map_err(|_| format!("invalid patch in {}", s))?, + )) + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}.{}.{}", self.0, self.1, self.2) + } +} + +impl PartialOrd for Version { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for Version { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0 + .cmp(&other.0) + .then(self.1.cmp(&other.1)) + .then(self.2.cmp(&other.2)) + } +} + +fn current_version() -> Version { + env!("CARGO_PKG_VERSION") + .parse() + .expect("CARGO_PKG_VERSION is not valid semver") +} + +fn default_target() -> String { + format!("{}-unknown-linux-gnu", std::env::consts::ARCH) +} + +fn user_agent() -> String { + format!("search_hub/{}", env!("CARGO_PKG_VERSION")) +} + +fn parse_versions_from_atom(xml: &str) -> Vec<String> { + let mut versions = Vec::new(); + let mut pos = 0; + while let Some(entry_start) = xml[pos..].find("<entry>") { + pos += entry_start + "<entry>".len(); + if let Some(title_start) = xml[pos..].find("<title>") { + let content_pos = pos + title_start + "<title>".len(); + if let Some(content_end) = xml[content_pos..].find("</title>") { + versions.push(xml[content_pos..content_pos + content_end].to_string()); + pos = content_pos + content_end + "</title>".len(); + } + } + } + versions +} + +pub struct SelfUpdate { + feed_url: String, + target: String, + dry_run: bool, +} + +impl Default for SelfUpdate { + fn default() -> Self { + Self::new() + } +} + +impl SelfUpdate { + pub fn new() -> Self { + Self { + feed_url: DEFAULT_FEED_URL.to_string(), + target: default_target(), + dry_run: false, + } + } + + pub fn with_feed_url(mut self, url: String) -> Self { + self.feed_url = url; + self + } + + pub fn with_target(mut self, target: String) -> Self { + self.target = target; + self + } + + pub fn dry_run(mut self, yes: bool) -> Self { + self.dry_run = yes; + self + } + + pub async fn run(&self) -> Result<()> { + let client = Client::builder() + .user_agent(&user_agent()) + .build()?; + + println!("Checking for updates at {}...", self.feed_url); + let xml = client + .get(&self.feed_url) + .send() + .await + .context("Failed to fetch release feed")? + .text() + .await + .context("Failed to read release feed")?; + + let version_strings = parse_versions_from_atom(&xml); + if version_strings.is_empty() { + println!("No releases found in feed."); + return Ok(()); + } + + let current = current_version(); + + let latest = version_strings + .iter() + .filter_map(|s| { + let v = s.parse::<Version>().ok()?; + Some((v, s)) + }) + .max_by(|(a, _), (b, _)| a.cmp(b)); + + let latest_version = match latest { + Some((v, s)) => { + println!("Latest release: {}", s); + v + } + None => { + println!("No valid releases found in feed."); + return Ok(()); + } + }; + + if latest_version <= current { + println!( + "Already up to date (current: {}, latest: {})", + current, latest_version + ); + return Ok(()); + } + + println!("Update available: {} -> {}", current, latest_version); + + if self.dry_run { + let url = format!( + "https://vit.am/~ololduck/search_hub/{}/dist/search_hub-{}-{}", + latest_version, latest_version, self.target + ); + println!("Would download: {}", url); + return Ok(()); + } + + let url = format!( + "https://vit.am/~ololduck/search_hub/{}/dist/search_hub-{}-{}", + latest_version, latest_version, self.target + ); + + println!("Downloading {}...", url); + let response = client + .get(&url) + .send() + .await + .context("Failed to download update")?; + + let status = response.status(); + if !status.is_success() { + bail!( + "Download failed (HTTP {}) - no binary available for target '{}'. \ + Try --target (e.g. x86_64-unknown-linux-gnu) or build from source.", + status, + self.target + ); + } + + let bytes = response.bytes().await?; + + if bytes.len() < 4 || bytes[..4] != [0x7f, b'E', b'L', b'F'] { + bail!("Downloaded file is not a valid ELF binary"); + } + + let current_exe = std::env::current_exe() + .context("Cannot determine current executable path")?; + + let temp_path = { + let mut p = current_exe.clone().into_os_string(); + p.push(format!(".{}.tmp", std::process::id())); + PathBuf::from(p) + }; + + tokio::fs::write(&temp_path, &bytes) + .await + .context("Failed to write temporary file")?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&temp_path, std::fs::Permissions::from_mode(0o755)) + .context("Failed to set executable permissions")?; + } + + std::fs::rename(&temp_path, ¤t_exe) + .context("Failed to replace current binary - maybe missing write permission? \ + Try running with sudo or reinstall via cargo install --path .")?; + + println!("Successfully updated to v{}!", latest_version); + println!("Restart any running search_hub processes to use the new version."); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_versions_simple() { + let xml = r#"<?xml version="1.0"?> +<feed> + <entry><title>0.3.0</title></entry> + <entry><title>0.2.3</title></entry> + <entry><title>0.2.2</title></entry> +</feed>"#; + let versions = parse_versions_from_atom(xml); + assert_eq!(versions, vec!["0.3.0", "0.2.3", "0.2.2"]); + } + + #[test] + fn parse_versions_with_namespace() { + let xml = r#"<?xml version="1.0"?> +<feed xmlns="http://www.w3.org/2005/Atom"> + <entry> + <title>1.0.0</title> + <id>urn:test:1.0.0</id> + <updated>2026-01-01T00:00:00Z</updated> + </entry> +</feed>"#; + let versions = parse_versions_from_atom(xml); + assert_eq!(versions, vec!["1.0.0"]); + } + + #[test] + fn parse_versions_empty() { + assert!(parse_versions_from_atom("").is_empty()); + assert!(parse_versions_from_atom("<feed></feed>").is_empty()); + } + + #[test] + fn version_strips_v_prefix() { + assert_eq!("v1.2.3".parse::<Version>().unwrap(), Version(1, 2, 3)); + } + + #[test] + fn version_without_prefix() { + assert_eq!("1.2.3".parse::<Version>().unwrap(), Version(1, 2, 3)); + } + + #[test] + fn version_ordering() { + assert!(Version(0, 3, 0) > Version(0, 2, 9)); + assert!(Version(1, 0, 0) > Version(0, 99, 99)); + assert!(Version(0, 0, 1) < Version(0, 0, 2)); + assert_eq!(Version(1, 2, 3), Version(1, 2, 3)); + } + + #[test] + fn version_invalid() { + assert!("abc".parse::<Version>().is_err()); + assert!("1.2".parse::<Version>().is_err()); + assert!("1.2.3.4".parse::<Version>().is_err()); + assert!("1.2.x".parse::<Version>().is_err()); + } + + #[test] + fn version_display() { + assert_eq!(Version(0, 3, 0).to_string(), "0.3.0"); + assert_eq!(Version(10, 20, 30).to_string(), "10.20.30"); + } + + #[test] + fn parse_versions_skips_irrelevant_tags() { + let xml = r#"<feed> + <entry><title>0.3.0</title><title>ignored</title></entry> +</feed>"#; + let versions = parse_versions_from_atom(xml); + assert_eq!(versions, vec!["0.3.0"]); + } + + #[test] + fn selects_highest_version() { + let versions = vec!["0.2.0".to_string(), "0.3.0".to_string(), "0.1.0".to_string()]; + let latest = versions + .iter() + .filter_map(|s| { + let v = s.parse::<Version>().ok()?; + Some((v, s)) + }) + .max_by(|(a, _), (b, _)| a.cmp(b)); + assert_eq!(latest.unwrap().0, Version(0, 3, 0)); + } +} -
modified src/tagging.rs
diff --git a/src/tagging.rs b/src/tagging.rs index ecb4e5d..2c7dce3 100644 --- a/src/tagging.rs +++ b/src/tagging.rs @@ -1,6 +1,24 @@ use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions}; use fastembed::similarity::cosine_similarity; use serde::Deserialize; +use std::path::PathBuf; + +/// Return the standard cache directory for ONNX embedding models. +fn cache_dir_for_models() -> PathBuf { + let dirs = directories::ProjectDirs::from("com", "search_hub", "search_hub") + .expect("no valid cache directory"); + let cache_dir = dirs.cache_dir().join("fastembed"); + std::fs::create_dir_all(&cache_dir).ok(); + cache_dir +} + +fn parse_model(name: &str) -> anyhow::Result<EmbeddingModel> { + Ok(match name { + "AllMiniLML6V2" => EmbeddingModel::AllMiniLML6V2, + "BGESmallENV15" => EmbeddingModel::BGESmallENV15, + other => anyhow::bail!("unknown ONNX model '{}'", other), + }) +} /// A named tag with example texts used for semantic similarity scoring. /// @@ -170,9 +188,9 @@ pub fn default_tags() -> Vec<TagDef> { /// /// ```ignore /// let tags = search_hub::tagging::default_tags(); -/// let mut engine = search_hub::tagging::TaggingEngine::new(&tags, 0.40) +/// let mut engine = search_hub::tagging::TaggingEngine::new(&tags, 0.40, 5, 2000, "BGESmallENV15") /// .expect("failed to init tagging engine"); -/// let matched = engine.tags_for("the rust programming language borrow checker", 3) +/// let matched = engine.tags_for("the rust programming language borrow checker") /// .expect("tagging failed"); /// assert!(matched.contains(&"rust".to_string())); /// ``` @@ -180,39 +198,40 @@ pub struct TaggingEngine { model: TextEmbedding, tag_examples: Vec<(String, Vec<Vec<f32>>)>, threshold: f32, + max_tags: usize, + truncation: usize, } impl TaggingEngine { - /// Create a new tagging engine from the given tag definitions. + /// Create a new tagging engine. /// /// Downloads the ONNX embedding model on first run (cached afterwards). /// /// # Parameters /// - /// * `tags` - Slice of `TagDef` entries (from config or `default_tags()`). - /// * `threshold` - Minimum cosine-similarity score (0.0 to 1.0) for a tag - /// to be assigned. Default 0.40 in `tags_for()` but can - /// be overridden per-call with `tags_for_with_threshold()`. - /// - /// # Returns - /// - /// A `TaggingEngine` ready to score content. + /// * `tags` - Slice of `TagDef` entries. + /// * `threshold` - Minimum cosine-similarity score (0.0 to 1.0). + /// * `max_tags` - Default max tags to assign. + /// * `truncation` - Max characters to use from page content. + /// * `model_name` - ONNX model name (e.g. "BGESmallENV15"). /// /// # Errors /// - /// Returns an error if the embedding model cannot be loaded or the - /// tag examples fail to embed. + /// Returns an error if the model is unknown or fails to load. /// /// # Example /// /// ```ignore /// let tags = search_hub::tagging::default_tags(); - /// let mut engine = search_hub::tagging::TaggingEngine::new(&tags, 0.60) + /// let mut engine = search_hub::tagging::TaggingEngine::new(&tags, 0.60, 5, 2000, "BGESmallENV15") /// .expect("model init"); /// ``` - pub fn new(tags: &[TagDef], threshold: f32) -> anyhow::Result<Self> { - let mut model = TextEmbedding::try_new( - TextInitOptions::new(EmbeddingModel::BGESmallENV15) + pub fn new(tags: &[TagDef], threshold: f32, max_tags: usize, truncation: usize, model_name: &str) -> anyhow::Result<Self> { + let cache_dir = cache_dir_for_models(); + let model = parse_model(model_name)?; + let mut embedder = TextEmbedding::try_new( + TextInitOptions::new(model) + .with_cache_dir(cache_dir) .with_show_download_progress(true), )?; @@ -226,7 +245,7 @@ impl TaggingEngine { } } - let embeddings = model.embed(all_examples, None)?; + let embeddings = embedder.embed(all_examples, None)?; let mut tag_examples: Vec<(String, Vec<Vec<f32>>)> = tags .iter() @@ -237,12 +256,12 @@ impl TaggingEngine { tag_examples[*ti].1.push(emb.clone()); } - Ok(Self { model, tag_examples, threshold }) + Ok(Self { model: embedder, tag_examples, threshold, max_tags, truncation }) } - fn truncate(content: &str, max_chars: usize) -> &str { + fn truncate<'a>(&self, content: &'a str) -> &'a str { let end = content.char_indices() - .take(max_chars) + .take(self.truncation) .last() .map(|(i, c)| i + c.len_utf8()) .unwrap_or(content.len()); @@ -250,7 +269,7 @@ impl TaggingEngine { } fn score_content(&mut self, content: &str) -> anyhow::Result<Vec<(String, f32)>> { - let truncated = Self::truncate(content, 2000); + let truncated = self.truncate(content); let emb = self.model.embed( vec![format!("passage: {}", truncated)], None, @@ -303,13 +322,13 @@ impl TaggingEngine { /// let tags = search_hub::tagging::default_tags(); /// let mut engine = search_hub::tagging::TaggingEngine::new(&tags, 0.40) /// .expect("model init"); - /// let matched = engine.tags_for("the rust programming language", 3) + /// let matched = engine.tags_for("the rust programming language") /// .expect("tagging failed"); /// println!("{:?}", matched); /// ``` - pub fn tags_for(&mut self, content: &str, max_tags: usize) -> anyhow::Result<Vec<String>> { + pub fn tags_for(&mut self, content: &str) -> anyhow::Result<Vec<String>> { Ok(self - .tags_for_with_threshold(content, max_tags, self.threshold)? + .tags_for_with_threshold(content, self.max_tags, self.threshold)? .into_iter() .map(|(tag, _)| tag) .collect()) -
modified src/web.rs
diff --git a/src/web.rs b/src/web.rs index 71a97e6..0740238 100644 --- a/src/web.rs +++ b/src/web.rs @@ -1,4 +1,4 @@ -use crate::search_engines::{ResultEntry, SearchEngine}; +use crate::search_engines::{EngineError, ResultEntry, SearchEngine}; use crate::storage; use actix_web::{get, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use rusqlite::Connection; @@ -29,12 +29,20 @@ const VERSION: &str = concat!( ")", ); +#[derive(Clone)] +pub struct ServerConfig { + pub port: u16, + pub bind_address: String, + pub page_size: usize, + pub workers: usize, +} + #[get("/")] -async fn index(templates: web::Data<Tera>, port: web::Data<Port>) -> impl Responder { +async fn index(templates: web::Data<Tera>, cfg: web::Data<ServerConfig>) -> impl Responder { info!("serving index page"); let mut ctx = tera::Context::new(); ctx.insert("version", VERSION); - ctx.insert("port", &(**port).0); + ctx.insert("port", &cfg.port); match templates.render("index.html", &ctx) { Ok(rendered) => HttpResponse::Ok().content_type("text/html").body(rendered), Err(e) => { @@ -45,9 +53,9 @@ async fn index(templates: web::Data<Tera>, port: web::Data<Port>) -> impl Respon } #[get("/opensearch.xml")] -async fn opensearch(templates: web::Data<Tera>, port: web::Data<Port>) -> impl Responder { +async fn opensearch(templates: web::Data<Tera>, cfg: web::Data<ServerConfig>) -> impl Responder { let mut ctx = tera::Context::new(); - ctx.insert("port", &(**port).0); + ctx.insert("port", &cfg.port); match templates.render("opensearch.xml", &ctx) { Ok(xml) => HttpResponse::Ok().content_type("application/opensearchdescription+xml").body(xml), Err(e) => { @@ -57,8 +65,6 @@ async fn opensearch(templates: web::Data<Tera>, port: web::Data<Port>) -> impl R } } -struct Port(u16); - #[get("/search")] async fn search( req: HttpRequest, @@ -66,11 +72,12 @@ async fn search( templates: web::Data<Tera>, db_pool: web::Data<DbPool>, engines: web::Data<Vec<Box<dyn SearchEngine>>>, + cfg: web::Data<ServerConfig>, ) -> impl Responder { let start = Instant::now(); let q = query.q.as_deref().unwrap_or(""); let page = query.page.unwrap_or(1).max(1); - let page_size: usize = 20; + let page_size = cfg.page_size; let has_query = !q.is_empty(); info!("search request: query=\"{}\" page={}", q, page); @@ -110,9 +117,25 @@ async fn search( let engines = engines.clone(); handles.push(tokio::spawn(async move { let t0 = Instant::now(); - let result = engines[i].fetch_results(&q_owned, &client).await; + let timeout_dur = engines[i].timeout(); + let result = tokio::time::timeout( + timeout_dur, + engines[i].fetch_results(&q_owned, &client), + ) + .await; let elapsed = t0.elapsed(); - (engine_name, result, elapsed) + match result { + Ok(Ok(results)) => (engine_name, Ok(results), elapsed), + Ok(Err(e)) => (engine_name, Err(e), elapsed), + Err(_) => ( + engine_name, + Err(EngineError(format!( + "timed out after {:?}", + timeout_dur + ))), + elapsed, + ), + } })); } for handle in handles { @@ -183,12 +206,12 @@ pub struct SearchQuery { pub async fn run_server( db_path: &str, - port: u16, + cfg: ServerConfig, engines: Vec<Box<dyn SearchEngine>>, ) -> std::io::Result<()> { let db_pool = web::Data::new(DbPool::new(db_path)); let engines = web::Data::new(engines); - let port_data = web::Data::new(Port(port)); + let cfg = web::Data::new(cfg); let mut tera = Tera::default(); tera.add_raw_template("index.html", include_str!("../templates/index.html")) .expect("Failed to parse index template"); @@ -196,18 +219,27 @@ pub async fn run_server( .expect("Failed to parse opensearch template"); let tera = web::Data::new(tera); + let bind_addr = cfg.bind_address.clone(); + let bind_port = cfg.port; + let workers = cfg.workers; + + info!( + "Starting server on {}:{}, {} workers", + bind_addr, bind_port, workers + ); + HttpServer::new(move || { App::new() .app_data(tera.clone()) .app_data(db_pool.clone()) .app_data(engines.clone()) - .app_data(port_data.clone()) + .app_data(cfg.clone()) .service(index) .service(search) .service(opensearch) }) - .workers(2) - .bind(("127.0.0.1", port))? + .workers(workers) + .bind((bind_addr.as_str(), bind_port))? .run() .await } -
modified tests/search_engines_integration.rs
diff --git a/tests/search_engines_integration.rs b/tests/search_engines_integration.rs index 9bb1151..335e968 100644 --- a/tests/search_engines_integration.rs +++ b/tests/search_engines_integration.rs @@ -17,7 +17,7 @@ fn client() -> reqwest::Client { #[test] fn crates_io_returns_results_for_generic_query() { - let engine = search_hub::search_engines::crates_io::CratesIo; + let engine = search_hub::search_engines::crates_io::CratesIo { timeout_secs: None, api_url: "https://crates.io/api/v1/crates?q={}&per_page=10".into() }; let client = client(); let results = rt().block_on(engine.fetch_results("tokio", &client)); @@ -42,7 +42,7 @@ fn crates_io_returns_results_for_generic_query() { #[test] fn crates_io_search_uses_https_urls() { - let engine = search_hub::search_engines::crates_io::CratesIo; + let engine = search_hub::search_engines::crates_io::CratesIo { timeout_secs: None, api_url: "https://crates.io/api/v1/crates?q={}&per_page=10".into() }; let client = client(); let results = rt().block_on(engine.fetch_results("serde", &client)); @@ -62,7 +62,7 @@ fn crates_io_search_uses_https_urls() { #[test] fn crates_io_empty_query_returns_error() { - let engine = search_hub::search_engines::crates_io::CratesIo; + let engine = search_hub::search_engines::crates_io::CratesIo { timeout_secs: None, api_url: "https://crates.io/api/v1/crates?q={}&per_page=10".into() }; let client = client(); let results = rt().block_on(engine.fetch_results("zzzzzzzzzz_nonexistent_crate_xxxxxxxxx", &client)); @@ -83,6 +83,7 @@ fn searxng_returns_results_if_configured() { let engine = search_hub::search_engines::searxng::SearXng { instance: instance.clone(), url_tpl: format!("{}/search?format=json&q={{}}", instance.trim_end_matches('/')), + timeout_secs: None, }; let client = client(); -
modified tests/tagging_thresholds.rs
diff --git a/tests/tagging_thresholds.rs b/tests/tagging_thresholds.rs index 9fd93a8..9730dac 100644 --- a/tests/tagging_thresholds.rs +++ b/tests/tagging_thresholds.rs @@ -396,7 +396,7 @@ so users can challenge moderation decisions they disagree with. #[test] fn explore_tagging_thresholds() { let tags = default_tags(); - let mut engine = TaggingEngine::new(&tags, 0.40).expect("failed to init tagging engine"); + let mut engine = TaggingEngine::new(&tags, 0.40, 5, 2000, "BGESmallENV15").expect("failed to init tagging engine"); let thresholds = [ 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90,