(* Copyright 2012 Tony Garnock-Jones . *) (* This file is part of Hop. *) (* Hop is free software: you can redistribute it and/or modify it *) (* under the terms of the GNU General Public License as published by the *) (* Free Software Foundation, either version 3 of the License, or (at your *) (* option) any later version. *) (* Hop is distributed in the hope that it will be useful, but *) (* WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU *) (* General Public License for more details. *) (* You should have received a copy of the GNU General Public License *) (* along with Hop. If not, see . *) open Lwt open Printf open Amqp_spec open Amqp_wireformat type connection_t = { peername: Unix.sockaddr; mtx: Lwt_mutex.t; cin: Lwt_io.input_channel; cout: Lwt_io.output_channel; name: Node.name; mutable input_buf: bytes; mutable output_buf: Obuffer.t; mutable frame_max: int; mutable connection_closed: bool; mutable recent_queue_name: Node.name option; mutable delivery_tag: int } let initial_frame_size = frame_min_size let suggested_frame_max = 131072 let amqp_boot (peername, cin, cout) = return { peername = peername; mtx = Lwt_mutex.create (); cin = cin; cout = cout; name = Node.name_of_bytes (Bytes.of_string (Uuid.create ())); input_buf = Bytes.create initial_frame_size; output_buf = Obuffer.create initial_frame_size; frame_max = initial_frame_size; connection_closed = false; recent_queue_name = None; delivery_tag = 1 (* Not 0: 0 means "all deliveries" in an ack *) } let input_byte c = lwt b = Lwt_io.read_char c in return (int_of_char b) let read_frame conn = lwt frame_type = input_byte conn.cin in lwt channel_hi = input_byte conn.cin in lwt channel_lo = input_byte conn.cin in let channel = (channel_hi lsr 8) lor channel_lo in lwt length = Lwt_io.BE.read_int conn.cin in if length > conn.frame_max then die frame_error "Frame longer than current frame_max" else (lwt () = Lwt_io.read_into_exactly conn.cin conn.input_buf 0 length in lwt end_marker = input_byte conn.cin in if end_marker <> frame_end then die frame_error "Missing frame_end octet" else return (frame_type, channel, length)) let output_byte c b = Lwt_io.write_char c (char_of_int b) let write_frame conn frame_type channel = lwt () = output_byte conn.cout frame_type in lwt () = output_byte conn.cout ((channel lsr 8) land 255) in lwt () = output_byte conn.cout (channel land 255) in lwt () = Lwt_io.BE.write_int conn.cout (Obuffer.length conn.output_buf) in lwt () = Obuffer.write conn.cout conn.output_buf in Obuffer.reset conn.output_buf; output_byte conn.cout frame_end let serialize_method buf m = let (class_id, method_id) = method_index m in write_short buf class_id; write_short buf method_id; write_method m buf let deserialize_method buf = let class_id = read_short buf in let method_id = read_short buf in read_method class_id method_id buf let serialize_header buf body_size p = let class_id = class_index p in write_short buf class_id; write_short buf 0; write_longlong buf (Int64.of_int body_size); write_properties p buf let deserialize_header buf = let class_id = read_short buf in let _ = read_short buf in let body_size = Int64.to_int (read_longlong buf) in lwt props = read_properties class_id buf in return (body_size, props) let send_content_body conn channel body = let len = Bytes.length body in let rec send_remainder offset = if offset >= len then return () else let snip_len = min conn.frame_max (len - offset) in Obuffer.add_substring conn.output_buf body offset snip_len; lwt () = write_frame conn frame_body channel in send_remainder (offset + snip_len) in send_remainder 0 let next_frame conn required_type = lwt (frame_type, channel, length) = read_frame conn in if frame_type <> required_type then die command_invalid (Printf.sprintf "Unexpected frame type %d" frame_type) else return (channel, length) let next_method conn = lwt (channel, length) = next_frame conn frame_method in lwt m = deserialize_method (Ibuffer.create conn.input_buf 0 length) in return (channel, m) let next_header conn = lwt (channel, length) = next_frame conn frame_header in lwt h = deserialize_header (Ibuffer.create conn.input_buf 0 length) in return (channel, h) let recv_content_body conn body_size = let buf = Obuffer.create body_size in lwt () = while_lwt Obuffer.length buf < body_size do lwt (_, length) = next_frame conn frame_body in return (Obuffer.add_substring buf conn.input_buf 0 length) done in return (Obuffer.contents buf) let with_conn_mutex conn thunk = Lwt_mutex.with_lock conn.mtx thunk let send_method conn channel m = with_conn_mutex conn (fun () -> serialize_method conn.output_buf m; write_frame conn frame_method channel) let send_content_method conn channel m p body_bs = with_conn_mutex conn (fun () -> serialize_method conn.output_buf m; lwt () = write_frame conn frame_method 1 in serialize_header conn.output_buf (Bytes.length body_bs) p; lwt () = write_frame conn frame_header 1 in send_content_body conn 1 body_bs) let send_error conn code message = if conn.connection_closed then return () else (conn.connection_closed <- true; let m = Connection_close (code, Bytes.of_string message, 0, 0) in ignore (Log.warn "Sending error" [sexp_of_method m]); send_method conn 0 m) let send_warning conn code message = let m = Channel_close (code, Bytes.of_string message, 0, 0) in ignore (Log.warn "Sending warning" [sexp_of_method m]); send_method conn 1 m let issue_banner cin cout = let handshake = Bytes.create 8 in try lwt () = Lwt_io.read_into_exactly cin handshake 0 8 in if Bytes.sub handshake 0 4 <> (Bytes.of_string "AMQP") then (lwt () = Lwt_io.write_from_exactly cout (Bytes.of_string "AMQP\000\000\009\001") 0 8 in return false) else (ignore (Log.info "AMQP handshake bytes" [Sexp.str (string_of_int (int_of_char (Bytes.get handshake 4))); Sexp.str (string_of_int (int_of_char (Bytes.get handshake 5))); Sexp.str (string_of_int (int_of_char (Bytes.get handshake 6))); Sexp.str (string_of_int (int_of_char (Bytes.get handshake 7)))]); return true) with End_of_file -> return false let reference_to_logs = (Bytes.of_string "See server logs for details") let extract_str v = match v with | Sexp.Str s -> s | _ -> reference_to_logs let reply_to_declaration conn status ok_fn = match Message.message_of_sexp status with | Message.Create_ok info -> send_method conn 1 (ok_fn info) | Message.Create_failed reason -> (match reason with | Sexp.Arr [Sexp.Str who; Sexp.Str code] when who = (Bytes.of_string "factory") && code = (Bytes.of_string "class-not-found")-> send_error conn command_invalid "Object type not supported by server" | Sexp.Arr [Sexp.Str who; Sexp.Str code] when who = (Bytes.of_string "constructor") && code = (Bytes.of_string "class-mismatch") -> send_error conn not_allowed "Redeclaration with different object type not permitted" | Sexp.Arr [Sexp.Str who; explanation] -> send_warning conn precondition_failed ((Bytes.to_string who)^" failed: "^(Bytes.to_string (extract_str explanation))) | _ -> send_warning conn precondition_failed (Bytes.to_string reference_to_logs)) | _ -> die internal_error "Declare reply malformed" let make_queue_declare_ok info = match info with | Sexp.Str queue_name -> Queue_declare_ok (queue_name, Int32.zero, Int32.zero) | _ -> die internal_error "Unusable queue name in declare response" let send_delivery conn consumer_tag body_sexp = match body_sexp with | Sexp.Arr [Sexp.Hint {Sexp.hint = maybe_amqp; Sexp.body = h_body_bs}; Sexp.Str exchange; Sexp.Str routing_key; properties_sexp; Sexp.Str body_bs] when maybe_amqp = (Bytes.of_string "amqp") && h_body_bs = Bytes.empty -> lwt tag = with_conn_mutex conn (fun () -> let v = conn.delivery_tag in conn.delivery_tag <- v + 1; return v) in send_content_method conn 1 (Basic_deliver (consumer_tag, Int64.of_int tag, false, exchange, routing_key)) (properties_of_sexp basic_class_id properties_sexp) body_bs | _ -> die internal_error "Malformed AMQP message body sexp" let amqp_handler conn n m_sexp = try (match Message.message_of_sexp m_sexp with | Message.Post (Sexp.Str type_bs, status, _) -> (match Bytes.to_string type_bs with | "Exchange_declare_reply" -> reply_to_declaration conn status (fun (_) -> Exchange_declare_ok) | "Queue_declare_reply" -> reply_to_declaration conn status make_queue_declare_ok | "Queue_bind_reply" -> (match Message.message_of_sexp status with | Message.Subscribe_ok _ -> send_method conn 1 Queue_bind_ok | _ -> die internal_error "Queue bind reply malformed") | _ -> Log.warn "AMQP outbound relay ignoring message" [m_sexp]) | Message.Post (Sexp.Arr [Sexp.Str type_bs; Sexp.Str consumer_tag], status_or_body, _) -> (match Bytes.to_string type_bs with | "Basic_consume_reply" -> (match Message.message_of_sexp status_or_body with | Message.Subscribe_ok _ -> send_method conn 1 (Basic_consume_ok consumer_tag) | _ -> die internal_error "Basic consume reply malformed") | "delivery" -> send_delivery conn consumer_tag status_or_body | _ -> Log.warn "AMQP outbound relay ignoring message" [m_sexp]) | _ -> Log.warn "AMQP outbound relay ignoring message" [m_sexp]) with | Amqp_exception (code, message) -> send_error conn code message | exn -> lwt () = send_error conn internal_error "" in raise_lwt exn let get_recent_queue_name conn = match conn.recent_queue_name with | Some q -> q | None -> die syntax_error "Attempt to use nonexistent most-recently-declared-queue name" let expand_mrdq conn queue = if queue = Bytes.empty then get_recent_queue_name conn else Node.name_of_bytes queue let handle_method conn channel m = (* ignore (Log.info "method" [sexp_of_method m]); *) if channel > 1 then die channel_error "Unsupported channel number" else (); match m with | Connection_close (code, text, _, _) -> ignore (Log.info "Client closed AMQP connection" [Sexp.str (string_of_int code); Sexp.Str text]); lwt () = send_method conn channel Connection_close_ok in return (conn.connection_closed <- true) | Channel_open -> conn.delivery_tag <- 1; send_method conn channel Channel_open_ok | Channel_close (code, text, _, _) -> ignore (Log.info "Client closed AMQP channel" [Sexp.str (string_of_int code); Sexp.Str text]); send_method conn channel Channel_close_ok; | Channel_close_ok -> return () | Exchange_declare (maybe_empty, type_, passive, durable, no_wait, arguments) when maybe_empty = Bytes.empty -> (* Qpid does this bizarre thing of declaring the default exchange. *) if no_wait then return () else send_method conn channel Exchange_declare_ok | Exchange_declare (exchange, type_, passive, durable, no_wait, arguments) -> let (reply_sink, reply_name) = if no_wait then (Bytes.empty, Bytes.empty) else (conn.name.Node.label, Bytes.of_string "Exchange_declare_reply") in Node.send_ignore' (Bytes.of_string "factory") (Message.create (Sexp.Str type_, Sexp.Arr [Sexp.Str exchange], Sexp.Str reply_sink, Sexp.Str reply_name)) | Queue_declare (queue, passive, durable, exclusive, auto_delete, no_wait, arguments) -> let queue = (if queue = Bytes.empty then Bytes.of_string (Uuid.create ()) else queue) in conn.recent_queue_name <- Some (Node.name_of_bytes queue); Node.send_ignore' (Bytes.of_string "factory") (Message.create (Sexp.litstr "queue", Sexp.Arr [Sexp.Str queue], Sexp.Str conn.name.Node.label, Sexp.litstr "Queue_declare_reply")) | Queue_bind (queue, maybe_empty, routing_key, no_wait, arguments) when maybe_empty = Bytes.empty -> (* Qpid does this bizarre thing of binding to the default exchange. *) if no_wait then return () else send_method conn channel Queue_bind_ok | Queue_bind (queue, exchange, routing_key, no_wait, arguments) -> let queue = expand_mrdq conn queue in if not (Node.approx_exists queue) then send_warning conn not_found ("Queue '"^(Bytes.to_string queue.Node.label)^"' not found") else (match_lwt Node.send' exchange (Message.subscribe (Sexp.Str routing_key, Sexp.Str queue.Node.label, Sexp.emptystr, Sexp.Str conn.name.Node.label, Sexp.litstr "Queue_bind_reply")) with | true -> return () | false -> send_warning conn not_found ("Exchange '"^(Bytes.to_string exchange)^"' not found")) | Basic_consume (queue, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments) -> let queue = expand_mrdq conn queue in let consumer_tag = (if consumer_tag = Bytes.empty then (Bytes.of_string (Uuid.create ())) else consumer_tag) in (match_lwt Node.send queue (Message.subscribe (Sexp.emptystr, Sexp.Str conn.name.Node.label, Sexp.Arr [Sexp.litstr "delivery"; Sexp.Str consumer_tag], Sexp.Str conn.name.Node.label, Sexp.Arr [Sexp.litstr "Basic_consume_reply"; Sexp.Str consumer_tag])) with | true -> return () | false -> send_warning conn not_found ("Queue '"^(Bytes.to_string queue.Node.label)^"' not found")) | Basic_publish (exchange, routing_key, false, false) -> lwt (_, (body_size, properties)) = next_header conn in lwt body = recv_content_body conn body_size in let (pseudotype, sink, name) = if exchange = Bytes.empty then ("Queue", routing_key, Bytes.empty) else ("Exchange", exchange, routing_key) in (match_lwt Node.post' sink (Sexp.Str name) (Sexp.Arr [Sexp.Hint {Sexp.hint = (Bytes.of_string "amqp"); Sexp.body = Bytes.empty}; Sexp.Str exchange; Sexp.Str routing_key; sexp_of_properties properties; Sexp.Str body]) Sexp.emptystr with | true -> return () | false -> send_warning conn not_found (pseudotype^" '"^(Bytes.to_string sink)^"' not found")) | Basic_ack (delivery_tag, multiple) -> return () | Basic_qos (_, _, _) -> ignore (Log.warn "Ignoring Basic_qos instruction from client" []); send_method conn channel Basic_qos_ok | Channel_flow (on) -> ignore (Log.warn "Ignoring Channel_flow setting" [Sexp.str (string_of_bool on)]); send_method conn channel (Channel_flow_ok on) | _ -> let (cid, mid) = method_index m in die not_implemented (Printf.sprintf "Unsupported method (or method arguments) %s" (method_name cid mid)) let server_properties = table_of_list [ ((Bytes.of_string "product"), Table_string App_info.product); ((Bytes.of_string "version"), Table_string App_info.version); ((Bytes.of_string "copyright"), Table_string App_info.copyright); ((Bytes.of_string "licence"), Table_string App_info.licence_blurb); ((Bytes.of_string "capabilities"), Table_table (table_of_list [])); ] let check_login_details mechanism response = match (match Bytes.to_string mechanism with | "PLAIN" -> (match Bytes.index_opt response '\000' with | Some 0 -> (let response = Bytes.sub response 1 ((Bytes.length response) - 1) in match Bytes.index_opt response '\000' with | Some pos -> let user = Bytes.sub response 0 pos in let pass = Bytes.sub response (pos + 1) ((Bytes.length response) - (pos + 1)) in Some (Bytes.to_string user, Bytes.to_string pass) | None -> None) | Some _ -> None | None -> None) | "AMQPLAIN" -> (let fields = decode_named_fields (Ibuffer.of_bytes response) in match (field_lookup_some (Bytes.of_string "LOGIN") fields, field_lookup_some (Bytes.of_string "PASSWORD") fields) with | (Some (Table_string user), Some (Table_string pass)) -> Some (Bytes.to_string user, Bytes.to_string pass) | _ -> None) | _ -> die access_refused "Bad auth mechanism") with | Some ("guest", "guest") -> () | Some (u, p) -> (ignore (Log.info "Access refused" [Sexp.str u; Sexp.str p]); die access_refused "Access refused") | None -> (ignore (Log.info "Access refused; bad credential format" [Str response]); die access_refused "Access refused; bad credential format") let tune_connection conn frame_max = with_conn_mutex conn (fun () -> conn.input_buf <- Bytes.create frame_max; conn.output_buf <- Obuffer.create frame_max; conn.frame_max <- frame_max; return ()) let handshake_and_tune conn = let (major_version, minor_version, revision) = version in lwt () = send_method conn 0 (Connection_start (major_version, minor_version, server_properties, (Bytes.of_string "PLAIN AMQPLAIN"), (Bytes.of_string "en_US"))) in lwt (client_properties, mechanism, response, locale) = match_lwt next_method conn with | (0, Connection_start_ok props) -> return props | _ -> die not_allowed "Expected Connection_start_ok on channel 0" in check_login_details mechanism response; ignore (Log.info "Connection from AMQP client" [sexp_of_table client_properties]); lwt () = send_method conn 0 (Connection_tune (1, Int32.of_int suggested_frame_max, 0)) in lwt (channel_max, frame_max, heartbeat) = match_lwt next_method conn with | (0, Connection_tune_ok props) -> return props | _ -> die not_allowed "Expected Connection_tune_ok on channel 0" in if channel_max > 1 then die not_implemented "Channel numbers higher than 1 are not supported" else (); if (Int32.to_int frame_max) > suggested_frame_max then die syntax_error "Requested frame max too large" else (); if heartbeat > 0 then die not_implemented "Heartbeats not yet implemented (patches welcome)" else (); lwt () = tune_connection conn (Int32.to_int frame_max) in lwt (virtual_host) = match_lwt next_method conn with | (0, Connection_open props) -> return props | _ -> die not_allowed "Expected Connection_open on channel 0" in ignore (Log.info "Connected to vhost" [Sexp.Str virtual_host]); send_method conn 0 Connection_open_ok let amqp_mainloop conn n = lwt () = Node.bind_ignore (conn.name, n) in (try_lwt lwt () = handshake_and_tune conn in while_lwt not conn.connection_closed do lwt (channel, m) = next_method conn in handle_method conn channel m done with | Amqp_exception (code, message) -> send_error conn code message) let start (s, peername) = Connections.start_connection "amqp" issue_banner amqp_boot amqp_handler amqp_mainloop (s, peername) let init () = lwt () = Node.send_ignore' (Bytes.of_string "factory") (Message.create (Sexp.litstr "direct", Sexp.Arr [Sexp.litstr "amq.direct"], Sexp.emptystr, Sexp.emptystr)) in lwt () = Node.send_ignore' (Bytes.of_string "factory") (Message.create (Sexp.litstr "fanout", Sexp.Arr [Sexp.litstr "amq.fanout"], Sexp.emptystr, Sexp.emptystr)) in let port = Config.get_int "amqp.port" Amqp_spec.port in Util.create_daemon_thread (Bytes.of_string "AMQP listener") None (Net.start_net "AMQP" port) start