1use std::marker::PhantomData;
9use std::net::{SocketAddr, TcpListener as StdTcpListener};
10use std::sync::mpsc::{Receiver, Sender, channel};
11use std::sync::{Arc, Mutex};
12use std::thread;
13
14use bytes::Bytes;
15use http::{Method, StatusCode};
16use log::{debug, error, trace, warn};
17use tokio::net::TcpListener;
18use url::{Host, Url};
19use warp::{Buf, Filter, Rejection};
20use webdriver::command::{WebDriverCommand, WebDriverMessage};
21use webdriver::error::{ErrorStatus, WebDriverError, WebDriverResult};
22use webdriver::httpapi::{
23 Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute, standard_routes,
24};
25use webdriver::response::{CloseWindowResponse, WebDriverResponse};
26
27use crate::Parameters;
28
29#[allow(dead_code)]
31enum DispatchMessage<U: WebDriverExtensionRoute> {
32 HandleWebDriver(
33 WebDriverMessage<U>,
34 Sender<WebDriverResult<WebDriverResponse>>,
35 ),
36 Quit,
37}
38
39#[derive(Clone, Debug, PartialEq)]
40pub enum SessionTeardownKind {
43 Deleted,
45 NotDeleted,
47}
48
49#[derive(Clone, Debug, PartialEq)]
50pub struct Session {
51 pub id: String,
52}
53
54impl Session {
55 fn new(id: String) -> Session {
56 Session { id }
57 }
58}
59
60pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send {
61 fn handle_command(
62 &mut self,
63 session: &Option<Session>,
64 msg: WebDriverMessage<U>,
65 ) -> WebDriverResult<WebDriverResponse>;
66 fn teardown_session(&mut self, kind: SessionTeardownKind);
67}
68
69#[derive(Debug)]
70struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> {
71 handler: T,
72 session: Option<Session>,
73 extension_type: PhantomData<U>,
74}
75
76impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> {
77 fn new(handler: T) -> Dispatcher<T, U> {
78 Dispatcher {
79 handler,
80 session: None,
81 extension_type: PhantomData,
82 }
83 }
84
85 fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) {
86 loop {
87 match msg_chan.recv() {
88 Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => {
89 let resp = match self.check_session(&msg) {
90 Ok(_) => self.handler.handle_command(&self.session, msg),
91 Err(e) => Err(e),
92 };
93
94 match resp {
95 Ok(WebDriverResponse::NewSession(ref new_session)) => {
96 self.session = Some(Session::new(new_session.session_id.clone()));
97 },
98 Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles)))
99 if handles.is_empty() =>
100 {
101 debug!("Last window was closed, deleting session");
102 self.teardown_session(SessionTeardownKind::NotDeleted);
105 },
106 Ok(WebDriverResponse::DeleteSession) => {
107 self.teardown_session(SessionTeardownKind::Deleted);
108 },
109 Err(ref x) if x.delete_session => {
110 self.teardown_session(SessionTeardownKind::NotDeleted)
112 },
113 _ => {},
114 }
115
116 if resp_chan.send(resp).is_err() {
117 error!("Sending response to the main thread failed");
118 };
119 },
120 Ok(DispatchMessage::Quit) => break,
121 Err(e) => panic!("Error receiving message in handler: {:?}", e),
122 }
123 }
124 }
125
126 fn teardown_session(&mut self, kind: SessionTeardownKind) {
127 debug!("Teardown session");
128 let final_kind = match kind {
129 SessionTeardownKind::NotDeleted if self.session.is_some() => {
130 let delete_session = WebDriverMessage {
131 session_id: Some(
132 self.session
133 .as_ref()
134 .expect("Failed to get session")
135 .id
136 .clone(),
137 ),
138 command: WebDriverCommand::DeleteSession,
139 };
140 match self.handler.handle_command(&self.session, delete_session) {
141 Ok(_) => SessionTeardownKind::Deleted,
142 Err(_) => SessionTeardownKind::NotDeleted,
143 }
144 },
145 _ => kind,
146 };
147 self.handler.teardown_session(final_kind);
148 self.session = None;
149 }
150
151 fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> {
152 match msg.session_id {
153 Some(ref msg_session_id) => match self.session {
154 Some(ref existing_session) => {
155 if existing_session.id != *msg_session_id {
156 Err(WebDriverError::new(
157 ErrorStatus::InvalidSessionId,
158 format!("Got unexpected session id {}", msg_session_id),
159 ))
160 } else {
161 Ok(())
162 }
163 },
164 None => Ok(()),
165 },
166 None => {
167 match self.session {
168 Some(_) => {
169 match msg.command {
170 WebDriverCommand::Status => Ok(()),
171 WebDriverCommand::NewSession(_) => Err(WebDriverError::new(
172 ErrorStatus::SessionNotCreated,
173 "Session is already started",
174 )),
175 _ => {
176 error!("Got a message with no session id");
178 Err(WebDriverError::new(
179 ErrorStatus::UnknownError,
180 "Got a command with no session?!",
181 ))
182 },
183 }
184 },
185 None => match msg.command {
186 WebDriverCommand::NewSession(_) => Ok(()),
187 WebDriverCommand::Status => Ok(()),
188 _ => Err(WebDriverError::new(
189 ErrorStatus::InvalidSessionId,
190 "Tried to run a command before creating a session",
191 )),
192 },
193 }
194 },
195 }
196 }
197}
198
199pub struct Listener {
200 guard: Option<thread::JoinHandle<()>>,
201 pub socket: SocketAddr,
202}
203
204impl Drop for Listener {
205 fn drop(&mut self) {
206 let _ = self.guard.take().map(|j| j.join());
207 }
208}
209
210pub fn start<T, U>(
211 mut address: SocketAddr,
212 allow_hosts: Vec<Host>,
213 allow_origins: Vec<Url>,
214 handler: T,
215 extension_routes: Vec<(Method, &'static str, U)>,
216) -> ::std::io::Result<Listener>
217where
218 T: 'static + WebDriverHandler<U>,
219 U: 'static + WebDriverExtensionRoute + Send + Sync,
220{
221 let listener = StdTcpListener::bind(address)?;
222 listener.set_nonblocking(true)?;
223 let addr = listener.local_addr()?;
224 if address.port() == 0 {
225 address.set_port(addr.port())
228 }
229 let (msg_send, msg_recv) = channel();
230
231 let builder = thread::Builder::new().name("webdriver server".to_string());
232 let handle = builder.spawn(move || {
233 let rt = tokio::runtime::Builder::new_current_thread()
234 .enable_io()
235 .build()
236 .unwrap();
237 let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() });
238 let wroutes = build_warp_routes(
239 address,
240 allow_hosts,
241 allow_origins,
242 &extension_routes,
243 msg_send.clone(),
244 );
245 let fut = warp::serve(wroutes).incoming(listener).run();
246 rt.block_on(fut);
247 })?;
248
249 let builder = thread::Builder::new().name("webdriver dispatcher".to_string());
250 builder.spawn(move || {
251 let mut dispatcher = Dispatcher::new(handler);
252 dispatcher.run(&msg_recv);
253 })?;
254
255 Ok(Listener {
256 guard: Some(handle),
257 socket: addr,
258 })
259}
260
261fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>(
262 address: SocketAddr,
263 allow_hosts: Vec<Host>,
264 allow_origins: Vec<Url>,
265 ext_routes: &[(Method, &'static str, U)],
266 chan: Sender<DispatchMessage<U>>,
267) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone + 'static {
268 let chan = Arc::new(Mutex::new(chan));
269 let mut std_routes = standard_routes::<U>();
270
271 let (method, path, res) = std_routes.pop().unwrap();
272 trace!("Build standard route for {path}");
273 let mut wroutes = build_route(
274 address,
275 allow_hosts.clone(),
276 allow_origins.clone(),
277 convert_method(method),
278 path,
279 res,
280 chan.clone(),
281 );
282
283 for (method, path, res) in std_routes {
284 trace!("Build standard route for {path}");
285 wroutes = wroutes
286 .or(build_route(
287 address,
288 allow_hosts.clone(),
289 allow_origins.clone(),
290 convert_method(method),
291 path,
292 res.clone(),
293 chan.clone(),
294 ))
295 .unify()
296 .boxed()
297 }
298
299 for (method, path, res) in ext_routes {
300 trace!("Build vendor route for {path}");
301 wroutes = wroutes
302 .or(build_route(
303 address,
304 allow_hosts.clone(),
305 allow_origins.clone(),
306 method.clone(),
307 path,
308 Route::Extension(res.clone()),
309 chan.clone(),
310 ))
311 .unify()
312 .boxed()
313 }
314
315 wroutes
316}
317
318fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool {
319 let header_host_url = match Url::parse(&format!("http://{}", &host_header)) {
322 Ok(x) => x,
323 Err(_) => {
324 return false;
325 },
326 };
327
328 let host = match header_host_url.host() {
329 Some(host) => host.to_owned(),
330 None => {
331 return false;
335 },
336 };
337 let port = match header_host_url.port_or_known_default() {
338 Some(port) => port,
339 None => {
340 return false;
344 },
345 };
346
347 let host_matches = match host {
348 Host::Domain(_) => allow_hosts.contains(&host),
349 Host::Ipv4(_) | Host::Ipv6(_) => true,
350 };
351 let port_matches = server_address.port() == port;
352 host_matches && port_matches
353}
354
355fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool {
356 allow_origins.contains(&origin_url)
358}
359
360fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>(
361 server_address: SocketAddr,
362 allow_hosts: Vec<Host>,
363 allow_origins: Vec<Url>,
364 method: Method,
365 path: &'static str,
366 route: Route<U>,
367 chan: Arc<Mutex<Sender<DispatchMessage<U>>>>,
368) -> warp::filters::BoxedFilter<(impl warp::Reply,)> {
369 let mut subroute = match method {
372 Method::GET => warp::get().boxed(),
373 Method::POST => warp::post().boxed(),
374 Method::DELETE => warp::delete().boxed(),
375 Method::OPTIONS => warp::options().boxed(),
376 Method::PUT => warp::put().boxed(),
377 _ => panic!("Unsupported method"),
378 }
379 .or(warp::head())
380 .unify()
381 .map(Parameters::new)
382 .boxed();
383
384 for part in path.split('/') {
388 if part.is_empty() {
389 continue;
390 } else if part.starts_with('{') {
391 assert!(part.ends_with('}'));
392
393 subroute = subroute
394 .and(warp::path::param())
395 .map(move |mut params: Parameters, param: String| {
396 let name = &part[1..part.len() - 1];
397 params.insert(name.to_string(), param);
398 params
399 })
400 .boxed();
401 } else {
402 subroute = subroute.and(warp::path(part)).boxed();
403 }
404 }
405
406 subroute
408 .and(warp::path::end())
409 .and(warp::path::full())
410 .and(warp::method())
411 .and(warp::header::optional::<String>("origin"))
412 .and(warp::header::optional::<String>("host"))
413 .and(warp::header::optional::<String>("content-type"))
414 .and(warp::body::bytes())
415 .map(
416 move |params,
417 full_path: warp::path::FullPath,
418 method,
419 origin_header: Option<String>,
420 host_header: Option<String>,
421 content_type_header: Option<String>,
422 body: Bytes| {
423 if method == Method::HEAD {
424 return warp::reply::with_status("".into(), StatusCode::OK);
425 }
426 if let Some(host) = host_header {
427 if !is_host_allowed(&server_address, &allow_hosts, &host) {
428 warn!(
429 "Rejected request with Host header {}, allowed values are [{}]",
430 host,
431 allow_hosts
432 .iter()
433 .map(|x| format!("{}:{}", x, server_address.port()))
434 .collect::<Vec<_>>()
435 .join(",")
436 );
437 let err = WebDriverError::new(
438 ErrorStatus::UnknownError,
439 format!("Invalid Host header {}", host),
440 );
441 return warp::reply::with_status(
442 serde_json::to_string(&err).unwrap(),
443 StatusCode::INTERNAL_SERVER_ERROR,
444 );
445 };
446 } else {
447 warn!("Rejected request with missing Host header");
448 let err = WebDriverError::new(
449 ErrorStatus::UnknownError,
450 "Missing Host header".to_string(),
451 );
452 return warp::reply::with_status(
453 serde_json::to_string(&err).unwrap(),
454 StatusCode::INTERNAL_SERVER_ERROR,
455 );
456 }
457 if let Some(origin) = origin_header {
458 let make_err = || {
459 warn!(
460 "Rejected request with Origin header {}, allowed values are [{}]",
461 origin,
462 allow_origins
463 .iter()
464 .map(|x| x.to_string())
465 .collect::<Vec<_>>()
466 .join(",")
467 );
468 WebDriverError::new(
469 ErrorStatus::UnknownError,
470 format!("Invalid Origin header {}", origin),
471 )
472 };
473 let origin_url = match Url::parse(&origin) {
474 Ok(url) => url,
475 Err(_) => {
476 return warp::reply::with_status(
477 serde_json::to_string(&make_err()).unwrap(),
478 StatusCode::INTERNAL_SERVER_ERROR,
479 );
480 },
481 };
482 if !is_origin_allowed(&allow_origins, origin_url) {
483 return warp::reply::with_status(
484 serde_json::to_string(&make_err()).unwrap(),
485 StatusCode::INTERNAL_SERVER_ERROR,
486 );
487 }
488 }
489 if method == Method::POST {
490 let content_type = content_type_header
493 .as_ref()
494 .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x))
495 .map(|x| x.trim())
496 .map(|x| x.to_lowercase());
497 match content_type.as_ref().map(|x| x.as_ref()) {
498 Some("application/x-www-form-urlencoded") |
499 Some("multipart/form-data") |
500 Some("text/plain") => {
501 warn!(
502 "Rejected POST request with disallowed content type {}",
503 content_type.unwrap_or_else(|| "".into())
504 );
505 let err = WebDriverError::new(
506 ErrorStatus::UnknownError,
507 "Invalid Content-Type",
508 );
509 return warp::reply::with_status(
510 serde_json::to_string(&err).unwrap(),
511 StatusCode::INTERNAL_SERVER_ERROR,
512 );
513 },
514 Some(_) | None => {},
515 }
516 }
517 let body = String::from_utf8(body.chunk().to_vec());
518 if body.is_err() {
519 let err = WebDriverError::new(
520 ErrorStatus::UnknownError,
521 "Request body wasn't valid UTF-8",
522 );
523 return warp::reply::with_status(
524 serde_json::to_string(&err).unwrap(),
525 StatusCode::INTERNAL_SERVER_ERROR,
526 );
527 }
528 let body = body.unwrap();
529
530 debug!("-> {} {} {}", method, full_path.as_str(), body);
531 let msg_result = WebDriverMessage::from_http(
532 route.clone(),
533 ¶ms,
534 &body,
535 method == Method::POST,
536 );
537
538 let (status, resp_body) = match msg_result {
539 Ok(message) => {
540 let (send_res, recv_res) = channel();
541 match chan.lock() {
542 Ok(ref c) => {
543 let res =
544 c.send(DispatchMessage::HandleWebDriver(message, send_res));
545 match res {
546 Ok(x) => x,
547 Err(e) => panic!("Error: {:?}", e),
548 }
549 },
550 Err(e) => panic!("Error reading response: {:?}", e),
551 }
552
553 match recv_res.recv() {
554 Ok(data) => match data {
555 Ok(response) => {
556 (StatusCode::OK, serde_json::to_string(&response).unwrap())
557 },
558 Err(e) => (
559 StatusCode::from_u16(e.http_status().as_u16()).unwrap(),
560 serde_json::to_string(&e).unwrap(),
561 ),
562 },
563 Err(e) => panic!("Error reading response: {:?}", e),
564 }
565 },
566 Err(e) => (
567 convert_status(e.http_status()),
568 serde_json::to_string(&e).unwrap(),
569 ),
570 };
571
572 debug!("<- {} {}", status, resp_body);
573 warp::reply::with_status(resp_body, status)
574 },
575 )
576 .with(warp::reply::with::header(
577 http::header::CONTENT_TYPE,
578 "application/json; charset=utf-8",
579 ))
580 .with(warp::reply::with::header(
581 http::header::CACHE_CONTROL,
582 "no-cache",
583 ))
584 .boxed()
585}
586
587fn convert_status(status: http02::StatusCode) -> StatusCode {
589 StatusCode::from_u16(status.as_u16()).unwrap()
590}
591
592fn convert_method(method: http02::Method) -> Method {
594 match method {
595 http02::Method::OPTIONS => http::Method::OPTIONS,
596 http02::Method::GET => http::Method::GET,
597 http02::Method::POST => http::Method::POST,
598 http02::Method::PUT => http::Method::PUT,
599 http02::Method::DELETE => http::Method::DELETE,
600 http02::Method::HEAD => http::Method::HEAD,
601 http02::Method::TRACE => http::Method::TRACE,
602 http02::Method::CONNECT => http::Method::CONNECT,
603 http02::Method::PATCH => http::Method::PATCH,
604 _ => http::Method::from_bytes(method.as_str().as_bytes()).unwrap(),
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use std::net::IpAddr;
611 use std::str::FromStr;
612
613 use super::*;
614
615 #[test]
616 fn test_host_allowed() {
617 let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
618 let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000);
619 let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80);
620 let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000);
621
622 let localhost_host = Host::Domain("localhost".to_string());
624 let test_host = Host::Domain("example.test".to_string());
625 let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string());
626
627 assert!(is_host_allowed(
628 &addr_80,
629 &[localhost_host.clone()],
630 "localhost:80"
631 ));
632 assert!(is_host_allowed(
633 &addr_80,
634 &[test_host.clone()],
635 "example.test:80"
636 ));
637 assert!(is_host_allowed(
638 &addr_80,
639 &[test_host.clone(), localhost_host.clone()],
640 "example.test"
641 ));
642 assert!(is_host_allowed(
643 &addr_80,
644 &[subdomain_localhost_host.clone()],
645 "subdomain.localhost"
646 ));
647
648 assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80"));
650 assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1"));
651 assert!(is_host_allowed(&addr_80, &[], "[::1]"));
652 assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000"));
653 assert!(is_host_allowed(
654 &addr_80,
655 &[subdomain_localhost_host.clone()],
656 "[::1]"
657 ));
658 assert!(is_host_allowed(
659 &addr_v6_8000,
660 &[subdomain_localhost_host.clone()],
661 "[::1]:8000"
662 ));
663
664 assert!(!is_host_allowed(&addr_80, &[test_host], "localhost"));
667
668 assert!(!is_host_allowed(&addr_80, &[], "localhost:80"));
669
670 assert!(!is_host_allowed(
673 &addr_80,
674 &[localhost_host.clone()],
675 "localhost:8000"
676 ));
677 assert!(!is_host_allowed(
678 &addr_8000,
679 &[localhost_host.clone()],
680 "localhost"
681 ));
682 assert!(!is_host_allowed(
683 &addr_v6_8000,
684 &[localhost_host.clone()],
685 "[::1]"
686 ));
687 }
688
689 #[test]
690 fn test_origin_allowed() {
691 assert!(is_origin_allowed(
692 &[Url::parse("http://localhost").unwrap()],
693 Url::parse("http://localhost").unwrap()
694 ));
695 assert!(is_origin_allowed(
696 &[Url::parse("http://localhost").unwrap()],
697 Url::parse("http://localhost:80").unwrap()
698 ));
699 assert!(is_origin_allowed(
700 &[
701 Url::parse("https://test.example").unwrap(),
702 Url::parse("http://localhost").unwrap()
703 ],
704 Url::parse("http://localhost").unwrap()
705 ));
706 assert!(is_origin_allowed(
707 &[
708 Url::parse("https://test.example").unwrap(),
709 Url::parse("http://localhost").unwrap()
710 ],
711 Url::parse("https://test.example:443").unwrap()
712 ));
713 assert!(!is_origin_allowed(
715 &[],
716 Url::parse("http://localhost").unwrap()
717 ));
718 assert!(!is_origin_allowed(
719 &[Url::parse("http://localhost").unwrap()],
720 Url::parse("http://localhost:8000").unwrap()
721 ));
722 assert!(!is_origin_allowed(
723 &[Url::parse("https://localhost").unwrap()],
724 Url::parse("http://localhost").unwrap()
725 ));
726 assert!(!is_origin_allowed(
727 &[Url::parse("https://example.test").unwrap()],
728 Url::parse("http://subdomain.example.test").unwrap()
729 ));
730 }
731}