(* Copyright 2012 Tony Garnock-Jones . *) (* This file is part of Hop. *) (* Hop is free software: you can redistribute it and/or modify it *) (* under the terms of the GNU General Public License as published by the *) (* Free Software Foundation, either version 3 of the License, or (at your *) (* option) any later version. *) (* Hop is distributed in the hope that it will be useful, but *) (* WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU *) (* General Public License for more details. *) (* You should have received a copy of the GNU General Public License *) (* along with Hop. If not, see . *) open Lwt open Printf open Amqp_spec open Amqp_wireformat type connection_t = { peername: Unix.sockaddr; mtx: Lwt_mutex.t; cin: Lwt_io.input_channel; cout: Lwt_io.output_channel; name: Node.name; mutable input_buf: string; mutable output_buf: Obuffer.t; mutable frame_max: int; mutable connection_closed: bool; mutable recent_queue_name: Node.name option; mutable delivery_tag: int } let initial_frame_size = frame_min_size let suggested_frame_max = 131072 let amqp_boot (peername, cin, cout) = return { peername = peername; mtx = Lwt_mutex.create (); cin = cin; cout = cout; name = Node.name_of_string (Uuid.create ()); input_buf = String.create initial_frame_size; output_buf = Obuffer.create initial_frame_size; frame_max = initial_frame_size; connection_closed = false; recent_queue_name = None; delivery_tag = 1 (* Not 0: 0 means "all deliveries" in an ack *) } let input_byte c = lwt b = Lwt_io.read_char c in return (int_of_char b) let read_frame conn = lwt frame_type = input_byte conn.cin in lwt channel_hi = input_byte conn.cin in lwt channel_lo = input_byte conn.cin in let channel = (channel_hi lsr 8) lor channel_lo in lwt length = Lwt_io.BE.read_int conn.cin in if length > conn.frame_max then die frame_error "Frame longer than current frame_max" else (lwt () = Lwt_io.read_into_exactly conn.cin conn.input_buf 0 length in lwt end_marker = input_byte conn.cin in if end_marker <> frame_end then die frame_error "Missing frame_end octet" else return (frame_type, channel, length)) let output_byte c b = Lwt_io.write_char c (char_of_int b) let write_frame conn frame_type channel = lwt () = output_byte conn.cout frame_type in lwt () = output_byte conn.cout ((channel lsr 8) land 255) in lwt () = output_byte conn.cout (channel land 255) in lwt () = Lwt_io.BE.write_int conn.cout (Obuffer.length conn.output_buf) in lwt () = Obuffer.write conn.cout conn.output_buf in Obuffer.reset conn.output_buf; output_byte conn.cout frame_end let serialize_method buf m = let (class_id, method_id) = method_index m in write_short buf class_id; write_short buf method_id; write_method m buf let deserialize_method buf = let class_id = read_short buf in let method_id = read_short buf in read_method class_id method_id buf let serialize_header buf body_size p = let class_id = class_index p in write_short buf class_id; write_short buf 0; write_longlong buf (Int64.of_int body_size); write_properties p buf let deserialize_header buf = let class_id = read_short buf in let _ = read_short buf in let body_size = Int64.to_int (read_longlong buf) in lwt props = read_properties class_id buf in return (body_size, props) let send_content_body conn channel body = let len = String.length body in let rec send_remainder offset = if offset >= len then return () else let snip_len = min conn.frame_max (len - offset) in Obuffer.add_substring conn.output_buf body offset snip_len; lwt () = write_frame conn frame_body channel in send_remainder (offset + snip_len) in send_remainder 0 let next_frame conn required_type = lwt (frame_type, channel, length) = read_frame conn in if frame_type <> required_type then die command_invalid (Printf.sprintf "Unexpected frame type %d" frame_type) else return (channel, length) let next_method conn = lwt (channel, length) = next_frame conn frame_method in lwt m = deserialize_method (Ibuffer.create conn.input_buf 0 length) in return (channel, m) let next_header conn = lwt (channel, length) = next_frame conn frame_header in lwt h = deserialize_header (Ibuffer.create conn.input_buf 0 length) in return (channel, h) let recv_content_body conn body_size = let buf = Obuffer.create body_size in lwt () = while_lwt Obuffer.length buf < body_size do lwt (_, length) = next_frame conn frame_body in return (Obuffer.add_substring buf conn.input_buf 0 length) done in return (Obuffer.contents buf) let with_conn_mutex conn thunk = Lwt_mutex.with_lock conn.mtx thunk let send_method conn channel m = with_conn_mutex conn (fun () -> serialize_method conn.output_buf m; write_frame conn frame_method channel) let send_content_method conn channel m p body_str = with_conn_mutex conn (fun () -> serialize_method conn.output_buf m; lwt () = write_frame conn frame_method 1 in serialize_header conn.output_buf (String.length body_str) p; lwt () = write_frame conn frame_header 1 in send_content_body conn 1 body_str) let send_error conn code message = if conn.connection_closed then return () else (conn.connection_closed <- true; let m = Connection_close (code, message, 0, 0) in ignore (Log.warn "Sending error" [sexp_of_method m]); send_method conn 0 m) let send_warning conn code message = let m = Channel_close (code, message, 0, 0) in ignore (Log.warn "Sending warning" [sexp_of_method m]); send_method conn 1 m let issue_banner cin cout = let handshake = String.create 8 in try lwt () = Lwt_io.read_into_exactly cin handshake 0 8 in if String.sub handshake 0 4 <> "AMQP" then (lwt () = Lwt_io.write cout "AMQP\000\000\009\001" in return false) else (ignore (Log.info "AMQP handshake bytes" [Sexp.Str (string_of_int (int_of_char (String.get handshake 4))); Sexp.Str (string_of_int (int_of_char (String.get handshake 5))); Sexp.Str (string_of_int (int_of_char (String.get handshake 6))); Sexp.Str (string_of_int (int_of_char (String.get handshake 7)))]); return true) with End_of_file -> return false let reference_to_logs = "See server logs for details" let extract_str v = match v with | Sexp.Str s -> s | _ -> reference_to_logs let reply_to_declaration conn status ok_fn = match Message.message_of_sexp status with | Message.Create_ok info -> send_method conn 1 (ok_fn info) | Message.Create_failed reason -> (match reason with | Sexp.Arr [Sexp.Str "factory"; Sexp.Str "class-not-found"] -> send_error conn command_invalid "Object type not supported by server" | Sexp.Arr [Sexp.Str "constructor"; Sexp.Str "class-mismatch"] -> send_error conn not_allowed "Redeclaration with different object type not permitted" | Sexp.Arr [Sexp.Str who; explanation] -> send_warning conn precondition_failed (who^" failed: "^(extract_str explanation)) | _ -> send_warning conn precondition_failed reference_to_logs) | _ -> die internal_error "Declare reply malformed" let make_queue_declare_ok info = match info with | Sexp.Str queue_name -> Queue_declare_ok (queue_name, Int32.zero, Int32.zero) | _ -> die internal_error "Unusable queue name in declare response" let send_delivery conn consumer_tag body_sexp = match body_sexp with | Sexp.Hint {Sexp.hint = Sexp.Str "amqp"; Sexp.body = Sexp.Arr [Sexp.Str exchange; Sexp.Str routing_key; properties_sexp; Sexp.Str body_str]} -> lwt tag = with_conn_mutex conn (fun () -> let v = conn.delivery_tag in conn.delivery_tag <- v + 1; return v) in send_content_method conn 1 (Basic_deliver (consumer_tag, Int64.of_int tag, false, exchange, routing_key)) (properties_of_sexp basic_class_id properties_sexp) body_str | _ -> die internal_error "Malformed AMQP message body sexp" let amqp_handler conn n m_sexp = try (match Message.message_of_sexp m_sexp with | Message.Post (Sexp.Str "Exchange_declare_reply", status, _) -> reply_to_declaration conn status (fun (_) -> Exchange_declare_ok) | Message.Post (Sexp.Str "Queue_declare_reply", status, _) -> reply_to_declaration conn status make_queue_declare_ok | Message.Post (Sexp.Str "Queue_bind_reply", status, _) -> (match Message.message_of_sexp status with | Message.Subscribe_ok _ -> send_method conn 1 Queue_bind_ok | _ -> die internal_error "Queue bind reply malformed") | Message.Post (Sexp.Arr [Sexp.Str "Basic_consume_reply"; Sexp.Str consumer_tag], status, _) -> (match Message.message_of_sexp status with | Message.Subscribe_ok _ -> send_method conn 1 (Basic_consume_ok consumer_tag) | _ -> die internal_error "Basic consume reply malformed") | Message.Post (Sexp.Arr [Sexp.Str "delivery"; Sexp.Str consumer_tag], body, _) -> send_delivery conn consumer_tag body | _ -> Log.warn "AMQP outbound relay ignoring message" [m_sexp]) with | Amqp_exception (code, message) -> send_error conn code message | exn -> lwt () = send_error conn internal_error "" in raise_lwt exn let get_recent_queue_name conn = match conn.recent_queue_name with | Some q -> q | None -> die syntax_error "Attempt to use nonexistent most-recently-declared-queue name" let expand_mrdq conn queue = match queue with | "" -> get_recent_queue_name conn | other -> Node.name_of_string other let handle_method conn channel m = (* ignore (Log.info "method" [sexp_of_method m]); *) if channel > 1 then die channel_error "Unsupported channel number" else (); match m with | Connection_close (code, text, _, _) -> ignore (Log.info "Client closed AMQP connection" [Sexp.Str (string_of_int code); Sexp.Str text]); lwt () = send_method conn channel Connection_close_ok in return (conn.connection_closed <- true) | Channel_open -> conn.delivery_tag <- 1; send_method conn channel Channel_open_ok | Channel_close (code, text, _, _) -> ignore (Log.info "Client closed AMQP channel" [Sexp.Str (string_of_int code); Sexp.Str text]); send_method conn channel Channel_close_ok; | Channel_close_ok -> return () | Exchange_declare ("", type_, passive, durable, no_wait, arguments) -> (* Qpid does this bizarre thing of declaring the default exchange. *) if no_wait then return () else send_method conn channel Exchange_declare_ok | Exchange_declare (exchange, type_, passive, durable, no_wait, arguments) -> let (reply_sink, reply_name) = if no_wait then ("", "") else (conn.name.Node.label, "Exchange_declare_reply") in Node.send_ignore' "factory" (Message.create (Sexp.Str type_, Sexp.Arr [Sexp.Str exchange], Sexp.Str reply_sink, Sexp.Str reply_name)) | Queue_declare (queue, passive, durable, exclusive, auto_delete, no_wait, arguments) -> let queue = (if queue = "" then Uuid.create () else queue) in conn.recent_queue_name <- Some (Node.name_of_string queue); Node.send_ignore' "factory" (Message.create (Sexp.Str "queue", Sexp.Arr [Sexp.Str queue], Sexp.Str conn.name.Node.label, Sexp.Str "Queue_declare_reply")) | Queue_bind (queue, "", routing_key, no_wait, arguments) -> (* Qpid does this bizarre thing of binding to the default exchange. *) if no_wait then return () else send_method conn channel Queue_bind_ok | Queue_bind (queue, exchange, routing_key, no_wait, arguments) -> let queue = expand_mrdq conn queue in if not (Node.approx_exists queue) then send_warning conn not_found ("Queue '"^queue.Node.label^"' not found") else (match_lwt Node.send' exchange (Message.subscribe (Sexp.Str routing_key, Sexp.Str queue.Node.label, Sexp.Str "", Sexp.Str conn.name.Node.label, Sexp.Str "Queue_bind_reply")) with | true -> return () | false -> send_warning conn not_found ("Exchange '"^exchange^"' not found")) | Basic_consume (queue, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments) -> let queue = expand_mrdq conn queue in let consumer_tag = (if consumer_tag = "" then Uuid.create () else consumer_tag) in (match_lwt Node.send queue (Message.subscribe (Sexp.Str "", Sexp.Str conn.name.Node.label, Sexp.Arr [Sexp.Str "delivery"; Sexp.Str consumer_tag], Sexp.Str conn.name.Node.label, Sexp.Arr [Sexp.Str "Basic_consume_reply"; Sexp.Str consumer_tag])) with | true -> return () | false -> send_warning conn not_found ("Queue '"^queue.Node.label^"' not found")) | Basic_publish (exchange, routing_key, false, false) -> lwt (_, (body_size, properties)) = next_header conn in lwt body = recv_content_body conn body_size in let (pseudotype, sink, name) = if exchange = "" then ("Queue", routing_key, "") else ("Exchange", exchange, routing_key) in (match_lwt Node.post' sink (Sexp.Str name) (Sexp.Hint {Sexp.hint = Sexp.Str "amqp"; Sexp.body = Sexp.Arr [Sexp.Str exchange; Sexp.Str routing_key; sexp_of_properties properties; Sexp.Str body]}) (Sexp.Str "") with | true -> return () | false -> send_warning conn not_found (pseudotype^" '"^sink^"' not found")) | Basic_ack (delivery_tag, multiple) -> return () | Basic_qos (_, _, _) -> ignore (Log.warn "Ignoring Basic_qos instruction from client" []); send_method conn channel Basic_qos_ok | Channel_flow (on) -> ignore (Log.warn "Ignoring Channel_flow setting" [Sexp.Str (string_of_bool on)]); send_method conn channel (Channel_flow_ok on) | _ -> let (cid, mid) = method_index m in die not_implemented (Printf.sprintf "Unsupported method (or method arguments) %s" (method_name cid mid)) let server_properties = table_of_list [ ("product", Table_string App_info.product); ("version", Table_string App_info.version); ("copyright", Table_string App_info.copyright); ("licence", Table_string App_info.licence_blurb); ("capabilities", Table_table (table_of_list [])); ] let check_login_details mechanism response = match mechanism with | "PLAIN" -> (match (Str.split (Str.regexp "\000") response) with | ["guest"; "guest"] -> () | _ -> die access_refused "Access refused") | "AMQPLAIN" -> (let fields = decode_named_fields (Ibuffer.of_string response) in match (field_lookup_some "LOGIN" fields, field_lookup_some "PASSWORD" fields) with | (Some (Table_string "guest"), Some (Table_string "guest")) -> () | _ -> die access_refused "Access refused") | _ -> die access_refused "Bad auth mechanism" let tune_connection conn frame_max = with_conn_mutex conn (fun () -> conn.input_buf <- String.create frame_max; conn.output_buf <- Obuffer.create frame_max; conn.frame_max <- frame_max; return ()) let handshake_and_tune conn = let (major_version, minor_version, revision) = version in lwt () = send_method conn 0 (Connection_start (major_version, minor_version, server_properties, "PLAIN AMQPLAIN", "en_US")) in lwt (client_properties, mechanism, response, locale) = match_lwt next_method conn with | (0, Connection_start_ok props) -> return props | _ -> die not_allowed "Expected Connection_start_ok on channel 0" in check_login_details mechanism response; ignore (Log.info "Connection from AMQP client" [sexp_of_table client_properties]); lwt () = send_method conn 0 (Connection_tune (1, Int32.of_int suggested_frame_max, 0)) in lwt (channel_max, frame_max, heartbeat) = match_lwt next_method conn with | (0, Connection_tune_ok props) -> return props | _ -> die not_allowed "Expected Connection_tune_ok on channel 0" in if channel_max > 1 then die not_implemented "Channel numbers higher than 1 are not supported" else (); if (Int32.to_int frame_max) > suggested_frame_max then die syntax_error "Requested frame max too large" else (); if heartbeat > 0 then die not_implemented "Heartbeats not yet implemented (patches welcome)" else (); lwt () = tune_connection conn (Int32.to_int frame_max) in lwt (virtual_host) = match_lwt next_method conn with | (0, Connection_open props) -> return props | _ -> die not_allowed "Expected Connection_open on channel 0" in ignore (Log.info "Connected to vhost" [Sexp.Str virtual_host]); send_method conn 0 Connection_open_ok let amqp_mainloop conn n = lwt () = Node.bind_ignore (conn.name, n) in (try_lwt lwt () = handshake_and_tune conn in while_lwt not conn.connection_closed do lwt (channel, m) = next_method conn in handle_method conn channel m done with | Amqp_exception (code, message) -> send_error conn code message) let start (s, peername) = Connections.start_connection "amqp" issue_banner amqp_boot amqp_handler amqp_mainloop (s, peername) let init () = lwt () = Node.send_ignore' "factory" (Message.create (Sexp.Str "direct", Sexp.Arr [Sexp.Str "amq.direct"], Sexp.Str "", Sexp.Str "")) in lwt () = Node.send_ignore' "factory" (Message.create (Sexp.Str "fanout", Sexp.Arr [Sexp.Str "amq.fanout"], Sexp.Str "", Sexp.Str "")) in Util.create_daemon_thread "AMQP listener" None (Net.start_net "AMQP" Amqp_spec.port) start