| @@ -40,7 +40,7 @@ impl Fetcher { |
| .user_agent(USER_AGENT) |
| .build()?; |
| let rt = tokio::runtime::Builder::new_current_thread() |
| - .enable_time() |
| + .enable_all() |
| .build()?; |
| let cache = Self::load_cache(&cache_path); |
| Ok(Self { client, rt, robots_cache: Mutex::new(cache), cache_path }) |
| @@ -72,43 +72,51 @@ impl Fetcher { |
| } |
| |
| async fn fetch_async(&self, url: &str) -> Result<String, String> { |
| - let parsed = url::Url::parse(url).map_err(|e| format!("invalid url: {}", e))?; |
| + let parsed = url::Url::parse(url).map_err(|e| format!("invalid url: {e}"))?; |
| let domain = parsed.host_str().unwrap_or("").to_string(); |
| + let port = parsed.port(); |
| let one_month = chrono::Duration::days(30); |
| |
| - let allowed = { |
| - let mut cache = match self.robots_cache.lock() { |
| - Ok(c) => c, |
| - Err(_) => return Err("internal error: lock poisoned".into()), |
| - }; |
| - let expired = cache.get(&domain).map_or(true, |e| Utc::now() - e.fetched_at > one_month); |
| - if !cache.contains_key(&domain) || expired { |
| - let robots_url = format!("{}://{}/robots.txt", parsed.scheme(), domain); |
| - let body = match self.client.get(&robots_url).send().await { |
| - Ok(resp) => resp.text().await.unwrap_or_default(), |
| - Err(_) => String::new(), |
| - }; |
| - cache.insert(domain.clone(), CacheEntry { body, fetched_at: Utc::now() }); |
| - } |
| - drop(cache); |
| + // Check cache (no await while holding lock) |
| + let need_fetch = { |
| + let cache = self.robots_cache.lock() |
| + .map_err(|_| "internal error: lock poisoned".to_string())?; |
| + let expired = cache.get(&domain) |
| + .map_or(true, |e| Utc::now() - e.fetched_at > one_month); |
| + !cache.contains_key(&domain) || expired |
| + }; |
| |
| - let cache = match self.robots_cache.lock() { |
| - Ok(c) => c, |
| - Err(_) => return Err("internal error: lock poisoned".into()), |
| + // Fetch robots.txt if needed (no lock held during await) |
| + if need_fetch { |
| + let robots_url = match port { |
| + Some(p) => format!("{}://{}:{}/robots.txt", parsed.scheme(), domain, p), |
| + None => format!("{}://{}/robots.txt", parsed.scheme(), domain), |
| }; |
| - DefaultMatcher::default() |
| - .one_agent_allowed_by_robots(&cache[&domain].body, USER_AGENT, url) |
| - }; |
| - if !allowed { |
| - return Err(format!("blocked by robots.txt")); |
| + let body = match self.client.get(&robots_url).send().await { |
| + Ok(resp) => resp.text().await.unwrap_or_default(), |
| + Err(_) => String::new(), |
| + }; |
| + // Brief lock to update cache (no await) |
| + let mut cache = self.robots_cache.lock() |
| + .map_err(|_| "internal error: lock poisoned".to_string())?; |
| + cache.insert(domain.clone(), CacheEntry { body, fetched_at: Utc::now() }); |
| } |
| |
| - let resp = self.client.get(url) |
| - .send() |
| - .await |
| - .map_err(|e| format!("failed to fetch: {}", e))?; |
| + // Check if allowed (brief lock, no await) |
| + let body = { |
| + let cache = self.robots_cache.lock() |
| + .map_err(|_| "internal error: lock poisoned".to_string())?; |
| + cache.get(&domain).map(|e| e.body.clone()).unwrap_or_default() |
| + }; |
| + if !DefaultMatcher::default() |
| + .one_agent_allowed_by_robots(&body, USER_AGENT, url) |
| + { |
| + return Err("blocked by robots.txt".to_string()); |
| + } |
| |
| - resp.text().await.map_err(|e| format!("failed to read body: {}", e)) |
| + let resp = self.client.get(url).send().await |
| + .map_err(|e| format!("failed to fetch {url}: {e}"))?; |
| + resp.text().await.map_err(|e| format!("failed to read body for {url}: {e}")) |
| } |
| } |
| |
| @@ -941,6 +949,125 @@ fn fetch_and_convert(fetcher: &mut Fetcher, url: &str, task_id: Option<usize>) - |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| + use std::io::{Read, Write}; |
| + use std::net::TcpListener; |
| + |
| + fn test_server() -> (TcpListener, u16) { |
| + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); |
| + let port = listener.local_addr().unwrap().port(); |
| + (listener, port) |
| + } |
| + |
| + fn handle_connection(mut stream: std::net::TcpStream) { |
| + use std::time::Duration; |
| + let mut buf = [0; 4096]; |
| + if stream.read(&mut buf).is_err() { |
| + return; |
| + } |
| + let request = String::from_utf8_lossy(&buf[..]); |
| + |
| + let (status, body) = if request.contains("/robots.txt") { |
| + ("200 OK", "User-agent: *\nDisallow: /private/\n") |
| + } else if request.contains("/private/") { |
| + ("200 OK", "<html><body>private</body></html>") |
| + } else { |
| + ("200 OK", "<html><body>hello</body></html>") |
| + }; |
| + let response = format!( |
| + "HTTP/1.1 {status}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", |
| + body.len() |
| + ); |
| + let _ = stream.write_all(response.as_bytes()); |
| + let _ = stream.flush(); |
| + // Give the client time to consume the body before we close |
| + std::thread::sleep(Duration::from_millis(10)); |
| + } |
| + |
| + #[test] |
| + fn fetcher_fetch_basic_page() { |
| + let (listener, port) = test_server(); |
| + std::thread::spawn(move || { |
| + for stream in listener.incoming() { |
| + if let Ok(stream) = stream { |
| + handle_connection(stream); |
| + } |
| + } |
| + }); |
| + |
| + let cache_dir = tempfile::tempdir().unwrap(); |
| + let cache_path = cache_dir.path().join("robots_cache.json"); |
| + let mut fetcher = Fetcher::new(cache_path).unwrap(); |
| + |
| + let result = fetcher.fetch(&format!("http://127.0.0.1:{port}/page")); |
| + assert!(result.is_ok(), "fetch should succeed: {:?}", result); |
| + assert!(result.unwrap().contains("hello")); |
| + } |
| + |
| + #[test] |
| + fn fetcher_blocks_disallowed_path() { |
| + let (listener, port) = test_server(); |
| + std::thread::spawn(move || { |
| + for stream in listener.incoming() { |
| + if let Ok(stream) = stream { |
| + handle_connection(stream); |
| + } |
| + } |
| + }); |
| + |
| + let cache_dir = tempfile::tempdir().unwrap(); |
| + let cache_path = cache_dir.path().join("robots_cache.json"); |
| + let mut fetcher = Fetcher::new(cache_path).unwrap(); |
| + |
| + // First request fetches robots.txt |
| + let result = fetcher.fetch(&format!("http://127.0.0.1:{port}/page")); |
| + assert!(result.is_ok(), "first fetch should succeed: {:?}", result); |
| + |
| + // Second request to disallowed path should be blocked |
| + let result = fetcher.fetch(&format!("http://127.0.0.1:{port}/private/page")); |
| + assert!(result.is_err(), "disallowed path should be blocked"); |
| + assert!(result.unwrap_err().contains("blocked by robots.txt")); |
| + } |
| + |
| + #[test] |
| + fn fetcher_caches_robots_txt() { |
| + let (listener, port) = test_server(); |
| + let request_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); |
| + let count = request_count.clone(); |
| + std::thread::spawn(move || { |
| + for stream in listener.incoming() { |
| + if let Ok(mut stream) = stream { |
| + count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); |
| + let mut buf = [0; 4096]; |
| + let _ = stream.read(&mut buf); |
| + let request = String::from_utf8_lossy(&buf[..]); |
| + let (status, body) = if request.contains("/robots.txt") { |
| + ("200 OK", "User-agent: *\nDisallow:\n") |
| + } else { |
| + ("200 OK", "<html><body>hello</body></html>") |
| + }; |
| + let response = format!( |
| + "HTTP/1.1 {status}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", |
| + body.len() |
| + ); |
| + let _ = stream.write_all(response.as_bytes()); |
| + let _ = stream.flush(); |
| + } |
| + } |
| + }); |
| + |
| + let cache_dir = tempfile::tempdir().unwrap(); |
| + let cache_path = cache_dir.path().join("robots_cache.json"); |
| + let mut fetcher = Fetcher::new(cache_path).unwrap(); |
| + |
| + // First fetch fetches robots.txt + page |
| + assert!(fetcher.fetch(&format!("http://127.0.0.1:{port}/page")).is_ok()); |
| + let first_count = request_count.load(std::sync::atomic::Ordering::SeqCst); |
| + |
| + // Second fetch to same domain should NOT re-fetch robots.txt (use cached) |
| + assert!(fetcher.fetch(&format!("http://127.0.0.1:{port}/other")).is_ok()); |
| + let second_count = request_count.load(std::sync::atomic::Ordering::SeqCst); |
| + assert_eq!(second_count, first_count + 1, "only one additional HTTP request (not two)"); |
| + } |
| |
| #[test] |
| fn strip_fragment_removes_hash() { |