diff --git a/httpd.ml b/httpd.ml index b528411..6383b97 100644 --- a/httpd.ml +++ b/httpd.ml @@ -117,13 +117,23 @@ let escape_url_char s = | _ -> failwith ("Unexpected URL char to escape: " ^ s) let url_escape s = Str.global_substitute url_escape_re escape_url_char s -let url_unescape_re = Str.regexp "%[0-9a-zA-Z][0-9a-zA-Z]" let unhex_char c = match c with | '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9' -> Char.code c - Char.code '0' | 'a' | 'b' | 'c' | 'd' | 'e' | 'f' -> Char.code c - Char.code 'a' | 'A' | 'B' | 'C' | 'D' | 'E' | 'F' -> Char.code c - Char.code 'A' | _ -> 0 + +let unhex s = + let len = String.length s in + let rec loop index acc = + if index = len + then acc + else loop (index + 1) (acc * 16 + unhex_char (String.get s index)) + in + loop 0 0 + +let url_unescape_re = Str.regexp "%[0-9a-zA-Z][0-9a-zA-Z]" let unescape_url_char s = String.make 1 (Char.chr (unhex_char (String.get s 1) * 16 + unhex_char (String.get s 2))) let url_unescape s = Str.global_substitute url_unescape_re unescape_url_char s @@ -196,7 +206,20 @@ let split_query p = | path :: [] -> (path, "") | [] -> ("", "") -let parse_body cin = empty_body +let find_header name hs = + let lc_name = String.lowercase name in + let rec search hs = + match hs with + | [] -> raise Not_found + | (k, v) :: hs' -> + if String.lowercase k = lc_name + then v + else search hs' + in + search hs + +let find_header' name hs = + try Some (find_header name hs) with Not_found -> None let input_crlf cin = let line = input_line cin in @@ -205,6 +228,42 @@ let input_crlf cin = then String.sub line 0 (len - 1) else line +let rec parse_headers cin = + match Str.bounded_split (Str.regexp ":") (input_crlf cin) 2 with + | [] -> + [] + | [k; v] -> + (k, Util.strip v) :: parse_headers cin + | k :: _ -> + http_error_html 400 ("Bad header: "^k) [] + +let parse_chunks cin = + fun () -> + let hexlen_str = input_crlf cin in + let chunk_len = unhex hexlen_str in + let buffer = String.make chunk_len '\000' in + really_input cin buffer 0 chunk_len; + (if input_crlf cin <> "" then http_error_html 400 "Invalid chunk boundary" [] else ()); + if chunk_len = 0 then None else Some buffer + +let parse_body cin = + let headers = parse_headers cin in + match find_header' "Transfer-Encoding" headers with + | None | Some "identity" -> + (match find_header' "Content-Length" headers with + | None -> + (* http_error_html 411 "Length required" [] *) + {headers = headers; content = Fixed ""} + | Some length_str -> + let length = int_of_string length_str in + let buffer = String.make length '\000' in + really_input cin buffer 0 length; + {headers = headers; content = Fixed buffer}) + | Some "chunked" -> + {headers = headers; content = Variable (Stringstream.from_iter (parse_chunks cin))} + | Some unsupported -> + http_error_html 400 ("Unsupported Transfer-Encoding: "^unsupported) [] + let rec parse_req cin spurious_newline_credit = match Str.bounded_split (Str.regexp " ") (input_crlf cin) 3 with | [] -> @@ -220,13 +279,27 @@ let rec parse_req cin spurious_newline_credit = { verb = verb; path = path; query = query; req_version = version; req_body = body } | _ -> http_error_html 400 "Bad request line" [] +let discard_unread_body req = + match req.req_body.content with + | Fixed _ -> () + | Variable s -> Stringstream.iter (fun v -> ()) s (* force chunks to be read *) + +let connection_keepalive req = + find_header' "Connection" req.req_body.headers = Some "keep-alive" + let main handle_req (s, peername) = let cin = in_channel_of_descr s in let cout = out_channel_of_descr s in (try (try - let req = parse_req cin 512 in - render_resp cout req.req_version (handle_req req) + let rec request_loop () = + let req = parse_req cin 512 in + render_resp cout req.req_version (handle_req req); + discard_unread_body req; + flush cout; + if connection_keepalive req then request_loop () else () + in + request_loop () with HTTPError (code, reason, body) -> render_resp cout `HTTP_1_0 { resp_version = `HTTP_1_0; status = code; reason = reason; resp_body = body }) diff --git a/util.ml b/util.ml index 3093f50..7951b73 100644 --- a/util.ml +++ b/util.ml @@ -50,3 +50,12 @@ let starts_with s1 s2 = let ends_with s1 s2 = try Str.last_chars s1 (String.length s2) = s2 with _ -> false + +let strip s = + let len = String.length s in + let ws i = Char.code (String.get s i) <= 32 in + let rec left index = if index < len && ws index then left (index + 1) else index in + let rec right index = if index >= 0 && ws index then right (index - 1) else index in + let l = left 0 in + let r = 1 + right (len - 1) in + if r <= l then "" else String.sub s l (r - l)