hop-2012/server/amqp_relay.ml

490 lines
20 KiB
OCaml

(* Copyright 2012 Tony Garnock-Jones <tonygarnockjones@gmail.com>. *)
(* This file is part of Hop. *)
(* Hop is free software: you can redistribute it and/or modify it *)
(* under the terms of the GNU General Public License as published by the *)
(* Free Software Foundation, either version 3 of the License, or (at your *)
(* option) any later version. *)
(* Hop is distributed in the hope that it will be useful, but *)
(* WITHOUT ANY WARRANTY; without even the implied warranty of *)
(* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU *)
(* General Public License for more details. *)
(* You should have received a copy of the GNU General Public License *)
(* along with Hop. If not, see <http://www.gnu.org/licenses/>. *)
open Lwt
open Printf
open Amqp_spec
open Amqp_wireformat
type connection_t = {
peername: Unix.sockaddr;
mtx: Lwt_mutex.t;
cin: Lwt_io.input_channel;
cout: Lwt_io.output_channel;
name: Node.name;
mutable input_buf: bytes;
mutable output_buf: Obuffer.t;
mutable frame_max: int;
mutable connection_closed: bool;
mutable recent_queue_name: Node.name option;
mutable delivery_tag: int
}
let initial_frame_size = frame_min_size
let suggested_frame_max = 131072
let amqp_boot (peername, cin, cout) = return {
peername = peername;
mtx = Lwt_mutex.create ();
cin = cin;
cout = cout;
name = Node.name_of_bytes (Bytes.of_string (Uuid.create ()));
input_buf = Bytes.create initial_frame_size;
output_buf = Obuffer.create initial_frame_size;
frame_max = initial_frame_size;
connection_closed = false;
recent_queue_name = None;
delivery_tag = 1 (* Not 0: 0 means "all deliveries" in an ack *)
}
let input_byte c = lwt b = Lwt_io.read_char c in return (int_of_char b)
let read_frame conn =
lwt frame_type = input_byte conn.cin in
lwt channel_hi = input_byte conn.cin in
lwt channel_lo = input_byte conn.cin in
let channel = (channel_hi lsr 8) lor channel_lo in
lwt length = Lwt_io.BE.read_int conn.cin in
if length > conn.frame_max
then die frame_error "Frame longer than current frame_max"
else
(lwt () = Lwt_io.read_into_exactly conn.cin conn.input_buf 0 length in
lwt end_marker = input_byte conn.cin in
if end_marker <> frame_end
then die frame_error "Missing frame_end octet"
else return (frame_type, channel, length))
let output_byte c b = Lwt_io.write_char c (char_of_int b)
let write_frame conn frame_type channel =
lwt () = output_byte conn.cout frame_type in
lwt () = output_byte conn.cout ((channel lsr 8) land 255) in
lwt () = output_byte conn.cout (channel land 255) in
lwt () = Lwt_io.BE.write_int conn.cout (Obuffer.length conn.output_buf) in
lwt () = Obuffer.write conn.cout conn.output_buf in
Obuffer.reset conn.output_buf;
output_byte conn.cout frame_end
let serialize_method buf m =
let (class_id, method_id) = method_index m in
write_short buf class_id;
write_short buf method_id;
write_method m buf
let deserialize_method buf =
let class_id = read_short buf in
let method_id = read_short buf in
read_method class_id method_id buf
let serialize_header buf body_size p =
let class_id = class_index p in
write_short buf class_id;
write_short buf 0;
write_longlong buf (Int64.of_int body_size);
write_properties p buf
let deserialize_header buf =
let class_id = read_short buf in
let _ = read_short buf in
let body_size = Int64.to_int (read_longlong buf) in
lwt props = read_properties class_id buf in
return (body_size, props)
let send_content_body conn channel body =
let len = Bytes.length body in
let rec send_remainder offset =
if offset >= len
then return ()
else
let snip_len = min conn.frame_max (len - offset) in
Obuffer.add_substring conn.output_buf body offset snip_len;
lwt () = write_frame conn frame_body channel in
send_remainder (offset + snip_len)
in send_remainder 0
let next_frame conn required_type =
lwt (frame_type, channel, length) = read_frame conn in
if frame_type <> required_type
then die command_invalid (Printf.sprintf "Unexpected frame type %d" frame_type)
else return (channel, length)
let next_method conn =
lwt (channel, length) = next_frame conn frame_method in
lwt m = deserialize_method (Ibuffer.create conn.input_buf 0 length) in
return (channel, m)
let next_header conn =
lwt (channel, length) = next_frame conn frame_header in
lwt h = deserialize_header (Ibuffer.create conn.input_buf 0 length) in
return (channel, h)
let recv_content_body conn body_size =
let buf = Obuffer.create body_size in
lwt () =
while_lwt Obuffer.length buf < body_size do
lwt (_, length) = next_frame conn frame_body in
return (Obuffer.add_substring buf conn.input_buf 0 length)
done
in
return (Obuffer.contents buf)
let with_conn_mutex conn thunk = Lwt_mutex.with_lock conn.mtx thunk
let send_method conn channel m =
with_conn_mutex conn (fun () ->
serialize_method conn.output_buf m;
write_frame conn frame_method channel)
let send_content_method conn channel m p body_bs =
with_conn_mutex conn (fun () ->
serialize_method conn.output_buf m;
lwt () = write_frame conn frame_method 1 in
serialize_header conn.output_buf (Bytes.length body_bs) p;
lwt () = write_frame conn frame_header 1 in
send_content_body conn 1 body_bs)
let send_error conn code message =
if conn.connection_closed
then return ()
else
(conn.connection_closed <- true;
let m = Connection_close (code, Bytes.of_string message, 0, 0) in
ignore (Log.warn "Sending error" [sexp_of_method m]);
send_method conn 0 m)
let send_warning conn code message =
let m = Channel_close (code, Bytes.of_string message, 0, 0) in
ignore (Log.warn "Sending warning" [sexp_of_method m]);
send_method conn 1 m
let issue_banner cin cout =
let handshake = Bytes.create 8 in
try
lwt () = Lwt_io.read_into_exactly cin handshake 0 8 in
if Bytes.sub handshake 0 4 <> (Bytes.of_string "AMQP")
then (lwt () = Lwt_io.write_from_exactly cout (Bytes.of_string "AMQP\000\000\009\001") 0 8 in return false)
else (ignore (Log.info "AMQP handshake bytes"
[Sexp.str (string_of_int (int_of_char (Bytes.get handshake 4)));
Sexp.str (string_of_int (int_of_char (Bytes.get handshake 5)));
Sexp.str (string_of_int (int_of_char (Bytes.get handshake 6)));
Sexp.str (string_of_int (int_of_char (Bytes.get handshake 7)))]);
return true)
with End_of_file -> return false
let reference_to_logs = (Bytes.of_string "See server logs for details")
let extract_str v =
match v with
| Sexp.Str s -> s
| _ -> reference_to_logs
let reply_to_declaration conn status ok_fn =
match Message.message_of_sexp status with
| Message.Create_ok info ->
send_method conn 1 (ok_fn info)
| Message.Create_failed reason ->
(match reason with
| Sexp.Arr [Sexp.Str who; Sexp.Str code] when who = (Bytes.of_string "factory") && code = (Bytes.of_string "class-not-found")->
send_error conn command_invalid "Object type not supported by server"
| Sexp.Arr [Sexp.Str who; Sexp.Str code] when who = (Bytes.of_string "constructor") && code = (Bytes.of_string "class-mismatch") ->
send_error conn not_allowed "Redeclaration with different object type not permitted"
| Sexp.Arr [Sexp.Str who; explanation] ->
send_warning conn precondition_failed ((Bytes.to_string who)^" failed: "^(Bytes.to_string (extract_str explanation)))
| _ ->
send_warning conn precondition_failed (Bytes.to_string reference_to_logs))
| _ -> die internal_error "Declare reply malformed"
let make_queue_declare_ok info =
match info with
| Sexp.Str queue_name -> Queue_declare_ok (queue_name, Int32.zero, Int32.zero)
| _ -> die internal_error "Unusable queue name in declare response"
let send_delivery conn consumer_tag body_sexp =
match body_sexp with
| Sexp.Arr [Sexp.Hint {Sexp.hint = maybe_amqp; Sexp.body = h_body_bs};
Sexp.Str exchange;
Sexp.Str routing_key;
properties_sexp;
Sexp.Str body_bs] when maybe_amqp = (Bytes.of_string "amqp") && h_body_bs = Bytes.empty ->
lwt tag = with_conn_mutex conn (fun () ->
let v = conn.delivery_tag in conn.delivery_tag <- v + 1; return v)
in
send_content_method conn 1
(Basic_deliver (consumer_tag, Int64.of_int tag, false, exchange, routing_key))
(properties_of_sexp basic_class_id properties_sexp)
body_bs
| _ -> die internal_error "Malformed AMQP message body sexp"
let amqp_handler conn n m_sexp =
try
(match Message.message_of_sexp m_sexp with
| Message.Post (Sexp.Str type_bs, status, _) ->
(match Bytes.to_string type_bs with
| "Exchange_declare_reply" ->
reply_to_declaration conn status (fun (_) -> Exchange_declare_ok)
| "Queue_declare_reply" ->
reply_to_declaration conn status make_queue_declare_ok
| "Queue_bind_reply" ->
(match Message.message_of_sexp status with
| Message.Subscribe_ok _ -> send_method conn 1 Queue_bind_ok
| _ -> die internal_error "Queue bind reply malformed")
| _ ->
Log.warn "AMQP outbound relay ignoring message" [m_sexp])
| Message.Post (Sexp.Arr [Sexp.Str type_bs; Sexp.Str consumer_tag], status_or_body, _) ->
(match Bytes.to_string type_bs with
| "Basic_consume_reply" ->
(match Message.message_of_sexp status_or_body with
| Message.Subscribe_ok _ -> send_method conn 1 (Basic_consume_ok consumer_tag)
| _ -> die internal_error "Basic consume reply malformed")
| "delivery" ->
send_delivery conn consumer_tag status_or_body
| _ ->
Log.warn "AMQP outbound relay ignoring message" [m_sexp])
| _ ->
Log.warn "AMQP outbound relay ignoring message" [m_sexp])
with
| Amqp_exception (code, message) ->
send_error conn code message
| exn ->
lwt () = send_error conn internal_error "" in
raise_lwt exn
let get_recent_queue_name conn =
match conn.recent_queue_name with
| Some q -> q
| None -> die syntax_error "Attempt to use nonexistent most-recently-declared-queue name"
let expand_mrdq conn queue =
if queue = Bytes.empty then get_recent_queue_name conn
else Node.name_of_bytes queue
let handle_method conn channel m =
(* ignore (Log.info "method" [sexp_of_method m]); *)
if channel > 1 then die channel_error "Unsupported channel number" else ();
match m with
| Connection_close (code, text, _, _) ->
ignore (Log.info "Client closed AMQP connection" [Sexp.str (string_of_int code); Sexp.Str text]);
lwt () = send_method conn channel Connection_close_ok in
return (conn.connection_closed <- true)
| Channel_open ->
conn.delivery_tag <- 1;
send_method conn channel Channel_open_ok
| Channel_close (code, text, _, _) ->
ignore (Log.info "Client closed AMQP channel" [Sexp.str (string_of_int code); Sexp.Str text]);
send_method conn channel Channel_close_ok;
| Channel_close_ok ->
return ()
| Exchange_declare (maybe_empty, type_, passive, durable, no_wait, arguments) when maybe_empty = Bytes.empty ->
(* Qpid does this bizarre thing of declaring the default exchange. *)
if no_wait
then return ()
else send_method conn channel Exchange_declare_ok
| Exchange_declare (exchange, type_, passive, durable, no_wait, arguments) ->
let (reply_sink, reply_name) =
if no_wait
then (Bytes.empty, Bytes.empty)
else (conn.name.Node.label, Bytes.of_string "Exchange_declare_reply")
in
Node.send_ignore'
(Bytes.of_string "factory")
(Message.create (Sexp.Str type_,
Sexp.Arr [Sexp.Str exchange],
Sexp.Str reply_sink,
Sexp.Str reply_name))
| Queue_declare (queue, passive, durable, exclusive, auto_delete, no_wait, arguments) ->
let queue = (if queue = Bytes.empty then Bytes.of_string (Uuid.create ()) else queue) in
conn.recent_queue_name <- Some (Node.name_of_bytes queue);
Node.send_ignore'
(Bytes.of_string "factory")
(Message.create (Sexp.litstr "queue",
Sexp.Arr [Sexp.Str queue],
Sexp.Str conn.name.Node.label,
Sexp.litstr "Queue_declare_reply"))
| Queue_bind (queue, maybe_empty, routing_key, no_wait, arguments) when maybe_empty = Bytes.empty ->
(* Qpid does this bizarre thing of binding to the default exchange. *)
if no_wait
then return ()
else send_method conn channel Queue_bind_ok
| Queue_bind (queue, exchange, routing_key, no_wait, arguments) ->
let queue = expand_mrdq conn queue in
if not (Node.approx_exists queue)
then send_warning conn not_found ("Queue '"^(Bytes.to_string queue.Node.label)^"' not found")
else
(match_lwt Node.send' exchange (Message.subscribe (Sexp.Str routing_key,
Sexp.Str queue.Node.label,
Sexp.emptystr,
Sexp.Str conn.name.Node.label,
Sexp.litstr "Queue_bind_reply")) with
| true -> return ()
| false -> send_warning conn not_found ("Exchange '"^(Bytes.to_string exchange)^"' not found"))
| Basic_consume (queue, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments) ->
let queue = expand_mrdq conn queue in
let consumer_tag = (if consumer_tag = Bytes.empty then (Bytes.of_string (Uuid.create ())) else consumer_tag) in
(match_lwt Node.send queue (Message.subscribe
(Sexp.emptystr,
Sexp.Str conn.name.Node.label,
Sexp.Arr [Sexp.litstr "delivery"; Sexp.Str consumer_tag],
Sexp.Str conn.name.Node.label,
Sexp.Arr [Sexp.litstr "Basic_consume_reply";
Sexp.Str consumer_tag])) with
| true -> return ()
| false -> send_warning conn not_found ("Queue '"^(Bytes.to_string queue.Node.label)^"' not found"))
| Basic_publish (exchange, routing_key, false, false) ->
lwt (_, (body_size, properties)) = next_header conn in
lwt body = recv_content_body conn body_size in
let (pseudotype, sink, name) =
if exchange = Bytes.empty
then ("Queue", routing_key, Bytes.empty)
else ("Exchange", exchange, routing_key)
in
(match_lwt
Node.post' sink
(Sexp.Str name)
(Sexp.Arr [Sexp.Hint {Sexp.hint = (Bytes.of_string "amqp"); Sexp.body = Bytes.empty};
Sexp.Str exchange;
Sexp.Str routing_key;
sexp_of_properties properties;
Sexp.Str body])
Sexp.emptystr
with
| true -> return ()
| false -> send_warning conn not_found (pseudotype^" '"^(Bytes.to_string sink)^"' not found"))
| Basic_ack (delivery_tag, multiple) ->
return ()
| Basic_qos (_, _, _) ->
ignore (Log.warn "Ignoring Basic_qos instruction from client" []);
send_method conn channel Basic_qos_ok
| Channel_flow (on) ->
ignore (Log.warn "Ignoring Channel_flow setting" [Sexp.str (string_of_bool on)]);
send_method conn channel (Channel_flow_ok on)
| _ ->
let (cid, mid) = method_index m in
die not_implemented (Printf.sprintf "Unsupported method (or method arguments) %s"
(method_name cid mid))
let server_properties = table_of_list [
((Bytes.of_string "product"), Table_string App_info.product);
((Bytes.of_string "version"), Table_string App_info.version);
((Bytes.of_string "copyright"), Table_string App_info.copyright);
((Bytes.of_string "licence"), Table_string App_info.licence_blurb);
((Bytes.of_string "capabilities"), Table_table (table_of_list []));
]
let check_login_details mechanism response =
match
(match Bytes.to_string mechanism with
| "PLAIN" ->
(match Bytes.index_opt response '\000' with
| Some 0 ->
(let response = Bytes.sub response 1 ((Bytes.length response) - 1) in
match Bytes.index_opt response '\000' with
| Some pos ->
let user = Bytes.sub response 0 pos in
let pass = Bytes.sub response (pos + 1) ((Bytes.length response) - (pos + 1)) in
Some (Bytes.to_string user, Bytes.to_string pass)
| None -> None)
| Some _ -> None
| None -> None)
| "AMQPLAIN" ->
(let fields = decode_named_fields (Ibuffer.of_bytes response) in
match (field_lookup_some (Bytes.of_string "LOGIN") fields,
field_lookup_some (Bytes.of_string "PASSWORD") fields) with
| (Some (Table_string user), Some (Table_string pass)) ->
Some (Bytes.to_string user, Bytes.to_string pass)
| _ -> None)
| _ ->
die access_refused "Bad auth mechanism")
with
| Some ("guest", "guest") -> ()
| Some (u, p) -> (ignore (Log.info "Access refused" [Sexp.str u; Sexp.str p]);
die access_refused "Access refused")
| None -> (ignore (Log.info "Access refused; bad credential format" [Str response]);
die access_refused "Access refused; bad credential format")
let tune_connection conn frame_max =
with_conn_mutex conn (fun () ->
conn.input_buf <- Bytes.create frame_max;
conn.output_buf <- Obuffer.create frame_max;
conn.frame_max <- frame_max;
return ())
let handshake_and_tune conn =
let (major_version, minor_version, revision) = version in
lwt () = send_method conn 0 (Connection_start (major_version,
minor_version,
server_properties,
(Bytes.of_string "PLAIN AMQPLAIN"),
(Bytes.of_string "en_US"))) in
lwt (client_properties, mechanism, response, locale) =
match_lwt next_method conn with
| (0, Connection_start_ok props) -> return props
| _ -> die not_allowed "Expected Connection_start_ok on channel 0"
in
check_login_details mechanism response;
ignore (Log.info "Connection from AMQP client" [sexp_of_table client_properties]);
lwt () = send_method conn 0 (Connection_tune (1, Int32.of_int suggested_frame_max, 0)) in
lwt (channel_max, frame_max, heartbeat) =
match_lwt next_method conn with
| (0, Connection_tune_ok props) -> return props
| _ -> die not_allowed "Expected Connection_tune_ok on channel 0"
in
if channel_max > 1
then die not_implemented "Channel numbers higher than 1 are not supported" else ();
if (Int32.to_int frame_max) > suggested_frame_max
then die syntax_error "Requested frame max too large" else ();
if heartbeat > 0
then die not_implemented "Heartbeats not yet implemented (patches welcome)" else ();
lwt () = tune_connection conn (Int32.to_int frame_max) in
lwt (virtual_host) =
match_lwt next_method conn with
| (0, Connection_open props) -> return props
| _ -> die not_allowed "Expected Connection_open on channel 0"
in
ignore (Log.info "Connected to vhost" [Sexp.Str virtual_host]);
send_method conn 0 Connection_open_ok
let amqp_mainloop conn n =
lwt () = Node.bind_ignore (conn.name, n) in
(try_lwt
lwt () = handshake_and_tune conn in
while_lwt not conn.connection_closed do
lwt (channel, m) = next_method conn in
handle_method conn channel m
done
with
| Amqp_exception (code, message) ->
send_error conn code message)
let start (s, peername) =
Connections.start_connection "amqp" issue_banner
amqp_boot amqp_handler amqp_mainloop (s, peername)
let init () =
lwt () = Node.send_ignore'
(Bytes.of_string "factory")
(Message.create (Sexp.litstr "direct",
Sexp.Arr [Sexp.litstr "amq.direct"],
Sexp.emptystr,
Sexp.emptystr)) in
lwt () = Node.send_ignore'
(Bytes.of_string "factory")
(Message.create (Sexp.litstr "fanout",
Sexp.Arr [Sexp.litstr "amq.fanout"],
Sexp.emptystr,
Sexp.emptystr)) in
let port = Config.get_int "amqp.port" Amqp_spec.port in
Util.create_daemon_thread (Bytes.of_string "AMQP listener") None (Net.start_net "AMQP" port) start