diff --git a/_tags b/_tags index 809c5b8..f7a5408 100644 --- a/_tags +++ b/_tags @@ -1,2 +1,3 @@ true: use_unix +true: use_str true: thread diff --git a/amqp_codegen.py b/amqp_codegen.py index c85165f..6fe197b 100644 --- a/amqp_codegen.py +++ b/amqp_codegen.py @@ -206,7 +206,7 @@ def print_codec(): print 'let method_name class_index method_index = match (class_index, method_index) with' for m in methods: print ' | (%d, %d) -> "%s"' % (m.class_index, m.index, ctor(m.full_name)) - print ' | _ -> "??unknownmethod??"' + print ' | _ -> Printf.sprintf "unknown(%d/%d)" class_index method_index' print print 'let read_method class_index method_index input_buf = match (class_index, method_index) with' for m in methods: @@ -247,6 +247,16 @@ def print_codec(): print ' write_%s output_buf %s;' % (mlify(f.type), source) print ' ()' print + print 'let sexp_of_properties p = match p with ' + for c in classes: + if c.fields: + print c.match_clause + print ' let fields__ = [] in' + for f in reversed(c.accessible_fields): + print ' let fields__ = (match %s with Some v -> Arr [Str "%s"; sexp_of_%s(v)] :: fields__ | None -> fields__) in' % \ + (mlify(f.name), f.name, mlify(f.type)) + print ' Arr fields__' + print print 'let read_properties class_index input_buf = match class_index with' for c in classes: if c.fields: diff --git a/amqp_relay.ml b/amqp_relay.ml index 70ee226..92a7814 100644 --- a/amqp_relay.ml +++ b/amqp_relay.ml @@ -2,6 +2,7 @@ open Unix open Printf open Thread open Amqp_spec +open Amqp_wireformat type connection_t = { n: Node.t; @@ -11,7 +12,8 @@ type connection_t = { mutable input_buf: string; mutable output_buf: Buffer.t; mutable frame_max: int; - mutable connection_closed: bool + mutable connection_closed: bool; + mutable recent_queue_name: string option; } let read_frame conn = @@ -20,13 +22,12 @@ let read_frame conn = let channel_lo = input_byte conn.cin in let channel = (channel_hi lsr 8) lor channel_lo in let length = input_binary_int conn.cin in - printf "Length %d\n%!" length; if length > conn.frame_max - then raise (Amqp_wireformat.Amqp_exception (frame_error, "Frame longer than current frame_max")) + then die frame_error "Frame longer than current frame_max" else (really_input conn.cin conn.input_buf 0 length; if input_byte conn.cin <> frame_end - then raise (Amqp_wireformat.Amqp_exception (frame_error, "Missing frame_end octet")) + then die frame_error "Missing frame_end octet" else (frame_type, channel, length)) let write_frame conn frame_type channel = @@ -40,22 +41,28 @@ let write_frame conn frame_type channel = let serialize_method buf m = let (class_id, method_id) = method_index m in - Amqp_wireformat.write_short buf class_id; - Amqp_wireformat.write_short buf method_id; + write_short buf class_id; + write_short buf method_id; write_method m buf let deserialize_method buf = - let class_id = Amqp_wireformat.read_short buf in - let method_id = Amqp_wireformat.read_short buf in + let class_id = read_short buf in + let method_id = read_short buf in read_method class_id method_id buf let serialize_header buf body_size p = let class_id = class_index p in - Amqp_wireformat.write_short buf class_id; - Amqp_wireformat.write_short buf 0; - Amqp_wireformat.write_longlong buf (Int64.of_int body_size); + write_short buf class_id; + write_short buf 0; + write_longlong buf (Int64.of_int body_size); write_properties p buf +let deserialize_header buf = + let class_id = read_short buf in + let _ = read_short buf in + let body_size = Int64.to_int (read_longlong buf) in + (body_size, read_properties class_id buf) + let send_content_body conn channel body = let offset = ref 0 in let len = String.length body in @@ -66,22 +73,34 @@ let send_content_body conn channel body = offset := !offset + snip_len done -let next_method conn = +let next_frame conn required_type = let (frame_type, channel, length) = read_frame conn in - if frame_type <> frame_method - then raise (Amqp_wireformat.Amqp_exception - (command_invalid, - (Printf.sprintf "Unexpected frame type %d" frame_type))) - else - let buf = Ibuffer.create conn.input_buf 0 length in - (channel, deserialize_method buf) + if frame_type <> required_type + then die command_invalid (Printf.sprintf "Unexpected frame type %d" frame_type) + else (channel, length) + +let next_method conn = + let (channel, length) = next_frame conn frame_method in + (channel, deserialize_method (Ibuffer.create conn.input_buf 0 length)) + +let next_header conn = + let (channel, length) = next_frame conn frame_header in + (channel, deserialize_header (Ibuffer.create conn.input_buf 0 length)) + +let recv_content_body conn body_size = + let buf = Buffer.create body_size in + while Buffer.length buf < body_size do + let (_, length) = next_frame conn frame_body in + Buffer.add_substring buf conn.input_buf 0 length + done; + Buffer.contents buf let with_conn_mutex conn thunk = Util.with_mutex0 conn.mtx thunk -let send_method conn m = +let send_method conn channel m = with_conn_mutex conn (fun () -> serialize_method conn.output_buf m; - write_frame conn frame_method 0; + write_frame conn frame_method channel; flush conn.cout) let send_error conn code message = @@ -91,7 +110,7 @@ let send_error conn code message = else let m = Connection_close (code, message, 0, 0) in Log.warn "Sending error" [sexp_of_method m]; - send_method conn m + send_method conn 0 m let issue_banner cin cout = let handshake = String.create 8 in @@ -103,17 +122,116 @@ let issue_banner cin cout = with End_of_file -> false let amqp_handler mtx cin cout n m = - raise (Amqp_wireformat.Amqp_exception (not_implemented, "TODO")) + die not_implemented "TODO:amqp_handler" -let handle_method conn m = +let get_recent_queue_name conn = + match conn.recent_queue_name with + | Some q -> q + | None -> die syntax_error "Attempt to use nonexistent most-recently-declared-queue name" + +let expand_mrdq conn queue = + match queue with + | "" -> get_recent_queue_name conn + | other -> other + +let handle_method conn channel m = + if channel > 1 then die channel_error "Unsupported channel number" else (); match m with + | Connection_close (code, text, _, _) -> + Log.info "Client closed AMQP connection" [Sexp.Str (string_of_int code); Sexp.Str text]; + send_method conn channel Connection_close_ok; + conn.connection_closed <- true + | Channel_open -> send_method conn channel Channel_open_ok + | Channel_close (code, text, _, _) -> + Log.info "Client closed AMQP channel" [Sexp.Str (string_of_int code); Sexp.Str text]; + send_method conn channel Channel_close_ok; + | Exchange_declare (exchange, type_, passive, durable, no_wait, arguments) -> + Log.info "XDeclare%%%" [Sexp.Str exchange; Sexp.Str type_]; + send_method conn channel Exchange_declare_ok + | Queue_declare (queue, passive, durable, exclusive, auto_delete, no_wait, arguments) -> + let queue = (if queue = "" then Uuid.create () else queue) in + Log.info "QDeclare%%%" [Sexp.Str queue]; + conn.recent_queue_name <- Some queue; + send_method conn channel (Queue_declare_ok (queue, Int32.of_int 0, Int32.of_int 0)) + | Queue_bind (queue, exchange, routing_key, no_wait, arguments) -> + let queue = expand_mrdq conn queue in + Log.info "QBind%%%" [Sexp.Str queue; Sexp.Str exchange; Sexp.Str routing_key]; + send_method conn channel Queue_bind_ok + | Basic_consume (queue, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments) -> + let queue = expand_mrdq conn queue in + Log.info "Consume%%%" [Sexp.Str queue; Sexp.Str consumer_tag]; + send_method conn channel (Basic_consume_ok consumer_tag) + | Basic_publish (exchange, routing_key, false, false) -> + let (_, (body_size, properties)) = next_header conn in + let body = recv_content_body conn body_size in + Log.info "Publish%%%" [Sexp.Str exchange; Sexp.Str routing_key; + sexp_of_properties properties; Sexp.Str body] | _ -> let (cid, mid) = method_index m in - raise (Amqp_wireformat.Amqp_exception (not_implemented, - Printf.sprintf "Unsupported method %s (%d/%d)" - (method_name cid mid) cid mid)) + die not_implemented (Printf.sprintf "Unsupported method (or method arguments) %s" + (method_name cid mid)) let initial_frame_size = frame_min_size +let suggested_frame_max = 131072 + +let server_properties = table_of_list [ + ("product", Table_string App_info.product); + ("version", Table_string App_info.version); + ("copyright", Table_string App_info.copyright); + ("licence", Table_string App_info.licence); + ("capabilities", Table_table (table_of_list [])); +] + +let check_login_details mechanism response = + match mechanism with + | "PLAIN" -> + (match (Str.split (Str.regexp "\000") response) with + | ["guest"; "guest"] -> () + | _ -> die access_refused "Access refused") + | "AMQPLAIN" -> + (let fields = decode_named_fields (Ibuffer.of_string response) in + match (field_lookup_some "LOGIN" fields, field_lookup_some "PASSWORD" fields) with + | (Some (Table_string "guest"), Some (Table_string "guest")) -> () + | _ -> die access_refused "Access refused") + | _ -> die access_refused "Bad auth mechanism" + +let tune_connection conn frame_max = + with_conn_mutex conn (fun () -> + conn.input_buf <- String.create frame_max; + conn.output_buf <- Buffer.create frame_max; + conn.frame_max <- frame_max) + +let handshake_and_tune conn = + let (major_version, minor_version, revision) = version in + send_method conn 0 (Connection_start (major_version, minor_version, server_properties, + "PLAIN AMQPLAIN", "en_US")); + let (client_properties, mechanism, response, locale) = + match next_method conn with + | (0, Connection_start_ok props) -> props + | _ -> die not_allowed "Expected Connection_start_ok on channel 0" + in + check_login_details mechanism response; + Log.info "Connection from AMQP client" [sexp_of_table client_properties]; + send_method conn 0 (Connection_tune (1, Int32.of_int suggested_frame_max, 0)); + let (channel_max, frame_max, heartbeat) = + match next_method conn with + | (0, Connection_tune_ok props) -> props + | _ -> die not_allowed "Expected Connection_tune_ok on channel 0" + in + if channel_max > 1 + then die not_implemented "Channel numbers higher than 1 are not supported" else (); + if (Int32.to_int frame_max) > suggested_frame_max + then die syntax_error "Requested frame max too large" else (); + if heartbeat > 0 + then die not_implemented "Heartbeats not yet implemented (patches welcome)" else (); + tune_connection conn (Int32.to_int frame_max); + let (virtual_host) = + match next_method conn with + | (0, Connection_open props) -> props + | _ -> die not_allowed "Expected Connection_open on channel 0" + in + Log.info "Connected to vhost" [Sexp.Str virtual_host]; + send_method conn 0 Connection_open_ok let amqp_mainloop peername mtx cin cout n = let conn = { @@ -124,21 +242,17 @@ let amqp_mainloop peername mtx cin cout n = input_buf = String.create initial_frame_size; output_buf = Buffer.create initial_frame_size; frame_max = initial_frame_size; - connection_closed = false + connection_closed = false; + recent_queue_name = None; } in (try - let (major_version, minor_version, revision) = version in - send_method conn (Connection_start (major_version, minor_version, - Amqp_wireformat.table_of_list [], - "", "")); - let (_, Connection_start_ok (client_properties, mechanism, response, locale)) - = next_method conn in - while true do + handshake_and_tune conn; + while not conn.connection_closed do let (channel, m) = next_method conn in - handle_method conn m + handle_method conn channel m done with - | Amqp_wireformat.Amqp_exception (code, message) -> + | Amqp_exception (code, message) -> send_error conn code message ) diff --git a/amqp_wireformat.ml b/amqp_wireformat.ml index 652bc51..45fd2f4 100644 --- a/amqp_wireformat.ml +++ b/amqp_wireformat.ml @@ -2,6 +2,8 @@ open Sexp exception Amqp_exception of (int * string) +let die code message = raise (Amqp_exception (code, message)) + type octet_t = int type short_t = int type long_t = int32 @@ -108,8 +110,7 @@ and read_table_value input_buf = | 'T' -> Table_timestamp (read_longlong input_buf) | 'F' -> Table_table { table_body = Encoded_table (read_longstr input_buf) } | 'V' -> Table_void - | c -> raise (Amqp_exception (502 (*syntax-error*), - Printf.sprintf "Unknown table field type code '%c'" c)) + | c -> die 502 (*syntax-error*) (Printf.sprintf "Unknown table field type code '%c'" c) and decoded_table t = match t.table_body with @@ -252,3 +253,15 @@ let reserved_value_longstr = "" let reserved_value_bit = false let reserved_value_timestamp = Int64.zero let reserved_value_table = { table_body = Encoded_table "" } + +let field_lookup k def fs = + try List.assoc k fs + with Not_found -> def + +let field_lookup_some k fs = + try Some (List.assoc k fs) + with Not_found -> None + +let table_lookup k t = List.assoc k (decoded_table t) +let table_lookup_default k def t = field_lookup k def (decoded_table t) +let table_lookup_some k t = field_lookup_some k (decoded_table t) diff --git a/app_info.ml b/app_info.ml new file mode 100644 index 0000000..a45663d --- /dev/null +++ b/app_info.ml @@ -0,0 +1,4 @@ +let product = "ocamlmsg" +let version = "ALPHA" +let copyright = "Copyright (C) 2012 Tony Garnock-Jones." +let licence = "All rights reserved." diff --git a/ibuffer.ml b/ibuffer.ml index dc4d543..f557b8a 100644 --- a/ibuffer.ml +++ b/ibuffer.ml @@ -10,6 +10,8 @@ let create s ofs len = { buf = s } +let of_string s = create s 0 (String.length s) + let sub b ofs len = if b.pos + ofs + len > b.limit then diff --git a/ocamlmsg.ml b/ocamlmsg.ml index 9f66583..7d492e0 100644 --- a/ocamlmsg.ml +++ b/ocamlmsg.ml @@ -25,7 +25,7 @@ let hook_log () = Log.hook := new_hook let _ = - printf "ocamlmsg ALPHA, Copyright (C) 2012 Tony Garnock-Jones. All rights reserved.\n%!"; + printf "%s %s, %s %s\n%!" App_info.product App_info.version App_info.copyright App_info.licence; Sys.set_signal Sys.sigpipe Sys.Signal_ignore; Uuid.init (); Factory.init ();