(* Copyright 2012 Tony Garnock-Jones . *) (* This file is part of Hop. *) (* Hop is free software: you can redistribute it and/or modify it *) (* under the terms of the GNU General Public License as published by the *) (* Free Software Foundation, either version 3 of the License, or (at your *) (* option) any later version. *) (* Hop is distributed in the hope that it will be useful, but *) (* WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU *) (* General Public License for more details. *) (* You should have received a copy of the GNU General Public License *) (* along with Hop. If not, see . *) open Lwt open Hof type version = [`HTTP_1_0 | `HTTP_1_1] type resp_version = [version | `SAME_AS_REQUEST] type content = Fixed of bytes | Variable of bytes Lwt_stream.t type body = { headers: (string * string) list; content: content } let empty_content = Fixed Bytes.empty let empty_body = {headers = []; content = empty_content} type req = { verb: string; path: string; query: (string * string option) list; req_version: version; req_body: body } type resp = { resp_version: resp_version; status: int; reason: string; resp_body: body; completion_callbacks: (unit -> unit Lwt.t) list } exception HTTPError of (int * string * body) exception HTTPSyntaxError of string let html_content_type = "text/html;charset=utf-8" let text_content_type = "text/plain;charset=utf-8" let content_type_header_name = "Content-Type" let html_content_type_header = (content_type_header_name, html_content_type) let text_content_type_header = (content_type_header_name, text_content_type) let disable_cache_headers () = ["Expires", "Thu, 01 Jan 1981 00:00:00 GMT"; "Last-Modified", Httpd_date.http_gmtime (Unix.time ()); "Cache-Control", "no-cache, must-revalidate, max-age=0"; "Pragma", "no-cache"] let add_headers headers resp_thr = lwt resp = resp_thr in let b = resp.resp_body in return {resp with resp_body = {b with headers = b.headers @ headers}} let add_disable_cache_headers resp = add_headers (disable_cache_headers ()) resp let add_date_header resp = add_headers ["Date", Httpd_date.http_gmtime (Unix.time ())] resp let add_completion_callback cb resp_thr = lwt resp = resp_thr in return {resp with completion_callbacks = cb :: resp.completion_callbacks} let http_error code reason body = raise_lwt (HTTPError (code, reason, body)) let http_error_plain code reason = http_error code reason {headers = [text_content_type_header]; content = Fixed (Bytes.of_string reason)} let http_error_html_doc code reason doc = http_error code reason {headers = [html_content_type_header]; content = Variable (Streamutil.stream_encode (Html.stream_of_html_doc doc))} let html_error_doc code reason extra_body = let code_str = string_of_int code in (Html.html_document (code_str^" "^reason) [] ((Html.tag "h1" [] [Html.text reason]) :: extra_body)) let http_error_html code reason extra_body = http_error_html_doc code reason (html_error_doc code reason extra_body) let trap_syntax_errors thunk = try return (thunk ()) with HTTPSyntaxError message -> http_error_html 400 message [] let resp_generic code reason headers content = return { resp_version = `SAME_AS_REQUEST; status = code; reason = reason; resp_body = {headers = headers; content = content}; completion_callbacks = [] } let resp_generic_ok headers content = resp_generic 200 "OK" headers content let resp_html_doc code reason extra_headers doc = resp_generic code reason (html_content_type_header :: extra_headers) (Variable (Streamutil.stream_encode (Html.stream_of_html_doc doc))) let resp_html_doc_ok extra_headers doc = resp_html_doc 200 "OK" extra_headers doc let resp_html code reason extra_headers title content = resp_html_doc code reason extra_headers (Html.html_document title [] content) let resp_html_ok extra_headers title content = resp_html 200 "OK" extra_headers title content let resp_plain code reason extra_headers text = resp_generic code reason (text_content_type_header :: extra_headers) (Fixed text) let resp_plain_ok extra_headers text = resp_plain 200 "OK" extra_headers text let resp_redirect_permanent new_path = resp_html_doc 301 "Moved permanently" ["Location", new_path] (html_error_doc 301 "Moved permanently" [Html.text "The document has moved "; Html.tag "a" ["href", new_path] [Html.text "here"]; Html.text "."]) let escape_url_char c = match c with | '%' -> Some (fun (s, pos) -> ("%25", pos + 1)) | ' ' -> Some (fun (s, pos) -> ("%20", pos + 1)) | _ -> None let url_escape s = Util.strsub escape_url_char s let unescape_url_hex_code (s, pos) = let len = String.length s in if len - pos >= 3 then let v1 = Util.unhex_char (String.get s (pos + 1)) in let v2 = Util.unhex_char (String.get s (pos + 2)) in if v1 = -1 || v2 = -1 then raise (HTTPSyntaxError ("Bad percent escaping: '"^String.sub s pos 3^"'")) else (String.make 1 (Char.chr (v1 * 16 + v2)), pos + 3) else raise (HTTPSyntaxError ("Bad percent escaping: '"^String.sub s pos (len - pos)^"'")) let unescape_url_char c = match c with | '%' -> Some unescape_url_hex_code | _ -> None let url_unescape s = trap_syntax_errors (fun () -> Util.strsub unescape_url_char s) let render_header cout (k, v) = lwt () = Lwt_io.write cout k in lwt () = Lwt_io.write cout ": " in lwt () = Lwt_io.write cout v in Lwt_io.write cout "\r\n" let render_chunk cout chunk = let chunk_len = Bytes.length chunk in if chunk_len = 0 then return () else lwt () = Lwt_io.write cout (Printf.sprintf "%x\r\n" chunk_len) in lwt () = Lwt_io.write_from_exactly cout chunk 0 chunk_len in Lwt_io.write cout "\r\n" let render_fixed_content cout s headers_only = lwt () = render_header cout ("Content-Length", string_of_int (Bytes.length s)) in lwt () = Lwt_io.write cout "\r\n" in if headers_only then return () else Lwt_io.write_from_exactly cout s 0 (Bytes.length s) let bytes_of_content c = match c with | Fixed s -> return s | Variable s -> Streamutil.stream_to_bytes s let render_content cout v c headers_only = match c with | Fixed s -> render_fixed_content cout s headers_only | Variable s -> match v with | `HTTP_1_0 -> lwt str = Streamutil.stream_to_bytes s in render_fixed_content cout str headers_only | `HTTP_1_1 -> if headers_only then (Lwt_io.write cout "\r\n") else (lwt () = render_header cout ("Transfer-Encoding", "chunked") in lwt () = Lwt_io.write cout "\r\n" in lwt () = Lwt_stream.iter_s (render_chunk cout) s in Lwt_io.write cout "0\r\n\r\n") let render_body cout v b headers_only = lwt () = Lwt_list.iter_s (render_header cout) b.headers in render_content cout v b.content headers_only let string_of_version v = match v with | `HTTP_1_0 -> "HTTP/1.0" | `HTTP_1_1 -> "HTTP/1.1" let version_of_string v = match v with | "HTTP/1.0" -> `HTTP_1_0 | "HTTP/1.1" -> `HTTP_1_1 | _ -> raise (HTTPSyntaxError "Invalid HTTP version") let render_req cout r = lwt () = Lwt_io.write cout (r.verb^" "^url_escape r.path^" "^string_of_version r.req_version^"\r\n") in render_body cout r.req_version r.req_body false let render_resp cout req_version req_verb r = let resp_version = (match r.resp_version with | `SAME_AS_REQUEST -> req_version | #version as v -> v) in lwt () = Lwt_io.write cout (string_of_version resp_version^" "^string_of_int r.status^" "^r.reason^"\r\n") in render_body cout resp_version r.resp_body (match req_verb with "HEAD" -> true | _ -> false) let split_query p = match Str.bounded_split (Str.regexp "\\?") p 2 with | path :: query :: _ -> (path, query) | path :: [] -> (path, "") | [] -> ("", "") let parse_urlencoded_binding s = match Str.bounded_split (Str.regexp "=") s 2 with | k :: v :: _ -> lwt k' = url_unescape k in lwt v' = url_unescape v in return (k', Some v') | k :: [] -> lwt k' = url_unescape k in return (k', None) | [] -> return ("", None) let parse_urlencoded q = let pieces = Str.split (Str.regexp "&") q in Lwt_list.map_s parse_urlencoded_binding pieces let find_header' name hs = let lc_name = String.lowercase_ascii name in let rec search hs = match hs with | [] -> raise Not_found | (k, v) :: hs' -> if String.lowercase_ascii k = lc_name then v else search hs' in search hs let find_header name hs = try Some (find_header' name hs) with Not_found -> None let find_param name params = try Some (List.assoc name params) with Not_found -> None let input_crlf cin = lwt line = Lwt_io.read_line cin in let len = String.length line in if len > 0 && String.get line (len - 1) = '\r' then return (String.sub line 0 (len - 1)) else return line let rec parse_headers cin = lwt header_line = input_crlf cin in match Str.bounded_split (Str.regexp ":") header_line 2 with | [] -> return [] | [k; v] -> lwt headers = parse_headers cin in return ((k, Util.strip v) :: headers) | k :: _ -> http_error_html 400 ("Bad header: "^k) [] let parse_chunks cin = fun () -> lwt hexlen_str = input_crlf cin in let chunk_len = Util.unhex hexlen_str in let buffer = Bytes.make chunk_len '\000' in lwt () = Lwt_io.read_into_exactly cin buffer 0 chunk_len in lwt chunk_terminator = input_crlf cin in if chunk_terminator <> "" then http_error_html 400 "Invalid chunk boundary" [] else if chunk_len = 0 then return None else return (Some buffer) let parse_body cin = lwt headers = parse_headers cin in match find_header "Transfer-Encoding" headers with | None | Some "identity" -> (match find_header "Content-Length" headers with | None -> (* http_error_html 411 "Length required" [] *) return {headers = headers; content = empty_content} | Some length_str -> let length = int_of_string length_str in let buffer = Bytes.make length '\000' in lwt () = Lwt_io.read_into_exactly cin buffer 0 length in return {headers = headers; content = Fixed buffer}) | Some "chunked" -> return {headers = headers; content = Variable (Lwt_stream.from (parse_chunks cin))} | Some unsupported -> http_error_html 400 ("Unsupported Transfer-Encoding: "^unsupported) [] let rec parse_req cin spurious_newline_credit = lwt req_line = input_crlf cin in parse_req' cin spurious_newline_credit req_line and parse_req' cin spurious_newline_credit req_line = match Str.bounded_split (Str.regexp " ") req_line 3 with | [] -> (* HTTP spec requires that we ignore leading CRLFs. We choose to do so, up to a point. *) if spurious_newline_credit = 0 then http_error_html 400 "Bad request: too many leading CRLFs" [] else parse_req cin (spurious_newline_credit - 1) | [verb; path; version_str] -> let version = version_of_string version_str in lwt body = parse_body cin in let (path, query) = split_query path in lwt path = url_unescape path in lwt query = parse_urlencoded query in return { verb = verb; path = path; query = query; req_version = version; req_body = body } | _ -> http_error_html 400 "Bad request line" [] let discard_unread_body req = match req.req_body.content with | Fixed _ -> return () | Variable s -> Lwt_stream.junk_while (fun _ -> true) s (* force chunks to be read *) let connection_keepalive req = find_header "Connection" req.req_body.headers = Some "keep-alive" let main handle_req (s, peername) = let cin = Lwt_io.of_fd Lwt_io.input s in let cout = Lwt_io.of_fd Lwt_io.output s in let pending_completion_callbacks = Queue.create () in let fire_pending_callbacks () = while_lwt not (Queue.is_empty pending_completion_callbacks) do let cbs = Queue.take pending_completion_callbacks in ignore (Lwt_list.iter_s (fun cb -> cb ()) cbs); return () done in let next_request () = (try_lwt (try_lwt lwt req = parse_req cin 512 in lwt () = fire_pending_callbacks () in return (Some req) with e -> lwt () = fire_pending_callbacks () in raise_lwt e) with End_of_file -> return None) in let request_stream = Lwt_stream.from next_request in let rec request_loop () = match_lwt Lwt_stream.get request_stream with | None -> return () | Some req -> lwt resp = handle_req req in (* Watch in the background for a new request arriving, and let the currently-streaming (well, the about-to-be-streaming) response know about it so it can decide to terminate if it likes. *) Queue.add resp.completion_callbacks pending_completion_callbacks; ignore (Lwt_stream.peek request_stream); lwt () = try_lwt lwt () = render_resp cout req.req_version req.verb resp in lwt () = discard_unread_body req in Lwt_io.flush cout with e -> lwt () = fire_pending_callbacks () in raise_lwt e in lwt () = fire_pending_callbacks () in if connection_keepalive req then request_loop () else return () in lwt () = try_lwt request_loop () with | HTTPError (code, reason, body) -> render_resp cout `HTTP_1_0 "GET" (* ugh this should probably be done better *) { resp_version = `HTTP_1_0; status = code; reason = reason; resp_body = body; completion_callbacks = [] } | Sys_error message -> Log.info "Sys_error in httpd handler" [Sexp.str message] | exn -> Log.error "Uncaught exception in httpd handler" [Sexp.str (Printexc.to_string exn)] in lwt () = fire_pending_callbacks () in lwt () = (try_lwt Lwt_io.flush cout with _ -> return ()) in Lwt_unix.close s