Steps toward being able to run some of the tests
This commit is contained in:
parent
c5a665adb2
commit
69c009b4a3
|
@ -206,7 +206,7 @@ def print_codec():
|
|||
print 'let method_name class_index method_index = match (class_index, method_index) with'
|
||||
for m in methods:
|
||||
print ' | (%d, %d) -> "%s"' % (m.class_index, m.index, ctor(m.full_name))
|
||||
print ' | _ -> "??unknownmethod??"'
|
||||
print ' | _ -> Printf.sprintf "unknown(%d/%d)" class_index method_index'
|
||||
print
|
||||
print 'let read_method class_index method_index input_buf = match (class_index, method_index) with'
|
||||
for m in methods:
|
||||
|
@ -247,6 +247,16 @@ def print_codec():
|
|||
print ' write_%s output_buf %s;' % (mlify(f.type), source)
|
||||
print ' ()'
|
||||
print
|
||||
print 'let sexp_of_properties p = match p with '
|
||||
for c in classes:
|
||||
if c.fields:
|
||||
print c.match_clause
|
||||
print ' let fields__ = [] in'
|
||||
for f in reversed(c.accessible_fields):
|
||||
print ' let fields__ = (match %s with Some v -> Arr [Str "%s"; sexp_of_%s(v)] :: fields__ | None -> fields__) in' % \
|
||||
(mlify(f.name), f.name, mlify(f.type))
|
||||
print ' Arr fields__'
|
||||
print
|
||||
print 'let read_properties class_index input_buf = match class_index with'
|
||||
for c in classes:
|
||||
if c.fields:
|
||||
|
|
188
amqp_relay.ml
188
amqp_relay.ml
|
@ -2,6 +2,7 @@ open Unix
|
|||
open Printf
|
||||
open Thread
|
||||
open Amqp_spec
|
||||
open Amqp_wireformat
|
||||
|
||||
type connection_t = {
|
||||
n: Node.t;
|
||||
|
@ -11,7 +12,8 @@ type connection_t = {
|
|||
mutable input_buf: string;
|
||||
mutable output_buf: Buffer.t;
|
||||
mutable frame_max: int;
|
||||
mutable connection_closed: bool
|
||||
mutable connection_closed: bool;
|
||||
mutable recent_queue_name: string option;
|
||||
}
|
||||
|
||||
let read_frame conn =
|
||||
|
@ -20,13 +22,12 @@ let read_frame conn =
|
|||
let channel_lo = input_byte conn.cin in
|
||||
let channel = (channel_hi lsr 8) lor channel_lo in
|
||||
let length = input_binary_int conn.cin in
|
||||
printf "Length %d\n%!" length;
|
||||
if length > conn.frame_max
|
||||
then raise (Amqp_wireformat.Amqp_exception (frame_error, "Frame longer than current frame_max"))
|
||||
then die frame_error "Frame longer than current frame_max"
|
||||
else
|
||||
(really_input conn.cin conn.input_buf 0 length;
|
||||
if input_byte conn.cin <> frame_end
|
||||
then raise (Amqp_wireformat.Amqp_exception (frame_error, "Missing frame_end octet"))
|
||||
then die frame_error "Missing frame_end octet"
|
||||
else (frame_type, channel, length))
|
||||
|
||||
let write_frame conn frame_type channel =
|
||||
|
@ -40,22 +41,28 @@ let write_frame conn frame_type channel =
|
|||
|
||||
let serialize_method buf m =
|
||||
let (class_id, method_id) = method_index m in
|
||||
Amqp_wireformat.write_short buf class_id;
|
||||
Amqp_wireformat.write_short buf method_id;
|
||||
write_short buf class_id;
|
||||
write_short buf method_id;
|
||||
write_method m buf
|
||||
|
||||
let deserialize_method buf =
|
||||
let class_id = Amqp_wireformat.read_short buf in
|
||||
let method_id = Amqp_wireformat.read_short buf in
|
||||
let class_id = read_short buf in
|
||||
let method_id = read_short buf in
|
||||
read_method class_id method_id buf
|
||||
|
||||
let serialize_header buf body_size p =
|
||||
let class_id = class_index p in
|
||||
Amqp_wireformat.write_short buf class_id;
|
||||
Amqp_wireformat.write_short buf 0;
|
||||
Amqp_wireformat.write_longlong buf (Int64.of_int body_size);
|
||||
write_short buf class_id;
|
||||
write_short buf 0;
|
||||
write_longlong buf (Int64.of_int body_size);
|
||||
write_properties p buf
|
||||
|
||||
let deserialize_header buf =
|
||||
let class_id = read_short buf in
|
||||
let _ = read_short buf in
|
||||
let body_size = Int64.to_int (read_longlong buf) in
|
||||
(body_size, read_properties class_id buf)
|
||||
|
||||
let send_content_body conn channel body =
|
||||
let offset = ref 0 in
|
||||
let len = String.length body in
|
||||
|
@ -66,22 +73,34 @@ let send_content_body conn channel body =
|
|||
offset := !offset + snip_len
|
||||
done
|
||||
|
||||
let next_method conn =
|
||||
let next_frame conn required_type =
|
||||
let (frame_type, channel, length) = read_frame conn in
|
||||
if frame_type <> frame_method
|
||||
then raise (Amqp_wireformat.Amqp_exception
|
||||
(command_invalid,
|
||||
(Printf.sprintf "Unexpected frame type %d" frame_type)))
|
||||
else
|
||||
let buf = Ibuffer.create conn.input_buf 0 length in
|
||||
(channel, deserialize_method buf)
|
||||
if frame_type <> required_type
|
||||
then die command_invalid (Printf.sprintf "Unexpected frame type %d" frame_type)
|
||||
else (channel, length)
|
||||
|
||||
let next_method conn =
|
||||
let (channel, length) = next_frame conn frame_method in
|
||||
(channel, deserialize_method (Ibuffer.create conn.input_buf 0 length))
|
||||
|
||||
let next_header conn =
|
||||
let (channel, length) = next_frame conn frame_header in
|
||||
(channel, deserialize_header (Ibuffer.create conn.input_buf 0 length))
|
||||
|
||||
let recv_content_body conn body_size =
|
||||
let buf = Buffer.create body_size in
|
||||
while Buffer.length buf < body_size do
|
||||
let (_, length) = next_frame conn frame_body in
|
||||
Buffer.add_substring buf conn.input_buf 0 length
|
||||
done;
|
||||
Buffer.contents buf
|
||||
|
||||
let with_conn_mutex conn thunk = Util.with_mutex0 conn.mtx thunk
|
||||
|
||||
let send_method conn m =
|
||||
let send_method conn channel m =
|
||||
with_conn_mutex conn (fun () ->
|
||||
serialize_method conn.output_buf m;
|
||||
write_frame conn frame_method 0;
|
||||
write_frame conn frame_method channel;
|
||||
flush conn.cout)
|
||||
|
||||
let send_error conn code message =
|
||||
|
@ -91,7 +110,7 @@ let send_error conn code message =
|
|||
else
|
||||
let m = Connection_close (code, message, 0, 0) in
|
||||
Log.warn "Sending error" [sexp_of_method m];
|
||||
send_method conn m
|
||||
send_method conn 0 m
|
||||
|
||||
let issue_banner cin cout =
|
||||
let handshake = String.create 8 in
|
||||
|
@ -103,17 +122,116 @@ let issue_banner cin cout =
|
|||
with End_of_file -> false
|
||||
|
||||
let amqp_handler mtx cin cout n m =
|
||||
raise (Amqp_wireformat.Amqp_exception (not_implemented, "TODO"))
|
||||
die not_implemented "TODO:amqp_handler"
|
||||
|
||||
let handle_method conn m =
|
||||
let get_recent_queue_name conn =
|
||||
match conn.recent_queue_name with
|
||||
| Some q -> q
|
||||
| None -> die syntax_error "Attempt to use nonexistent most-recently-declared-queue name"
|
||||
|
||||
let expand_mrdq conn queue =
|
||||
match queue with
|
||||
| "" -> get_recent_queue_name conn
|
||||
| other -> other
|
||||
|
||||
let handle_method conn channel m =
|
||||
if channel > 1 then die channel_error "Unsupported channel number" else ();
|
||||
match m with
|
||||
| Connection_close (code, text, _, _) ->
|
||||
Log.info "Client closed AMQP connection" [Sexp.Str (string_of_int code); Sexp.Str text];
|
||||
send_method conn channel Connection_close_ok;
|
||||
conn.connection_closed <- true
|
||||
| Channel_open -> send_method conn channel Channel_open_ok
|
||||
| Channel_close (code, text, _, _) ->
|
||||
Log.info "Client closed AMQP channel" [Sexp.Str (string_of_int code); Sexp.Str text];
|
||||
send_method conn channel Channel_close_ok;
|
||||
| Exchange_declare (exchange, type_, passive, durable, no_wait, arguments) ->
|
||||
Log.info "XDeclare%%%" [Sexp.Str exchange; Sexp.Str type_];
|
||||
send_method conn channel Exchange_declare_ok
|
||||
| Queue_declare (queue, passive, durable, exclusive, auto_delete, no_wait, arguments) ->
|
||||
let queue = (if queue = "" then Uuid.create () else queue) in
|
||||
Log.info "QDeclare%%%" [Sexp.Str queue];
|
||||
conn.recent_queue_name <- Some queue;
|
||||
send_method conn channel (Queue_declare_ok (queue, Int32.of_int 0, Int32.of_int 0))
|
||||
| Queue_bind (queue, exchange, routing_key, no_wait, arguments) ->
|
||||
let queue = expand_mrdq conn queue in
|
||||
Log.info "QBind%%%" [Sexp.Str queue; Sexp.Str exchange; Sexp.Str routing_key];
|
||||
send_method conn channel Queue_bind_ok
|
||||
| Basic_consume (queue, consumer_tag, no_local, no_ack, exclusive, no_wait, arguments) ->
|
||||
let queue = expand_mrdq conn queue in
|
||||
Log.info "Consume%%%" [Sexp.Str queue; Sexp.Str consumer_tag];
|
||||
send_method conn channel (Basic_consume_ok consumer_tag)
|
||||
| Basic_publish (exchange, routing_key, false, false) ->
|
||||
let (_, (body_size, properties)) = next_header conn in
|
||||
let body = recv_content_body conn body_size in
|
||||
Log.info "Publish%%%" [Sexp.Str exchange; Sexp.Str routing_key;
|
||||
sexp_of_properties properties; Sexp.Str body]
|
||||
| _ ->
|
||||
let (cid, mid) = method_index m in
|
||||
raise (Amqp_wireformat.Amqp_exception (not_implemented,
|
||||
Printf.sprintf "Unsupported method %s (%d/%d)"
|
||||
(method_name cid mid) cid mid))
|
||||
die not_implemented (Printf.sprintf "Unsupported method (or method arguments) %s"
|
||||
(method_name cid mid))
|
||||
|
||||
let initial_frame_size = frame_min_size
|
||||
let suggested_frame_max = 131072
|
||||
|
||||
let server_properties = table_of_list [
|
||||
("product", Table_string App_info.product);
|
||||
("version", Table_string App_info.version);
|
||||
("copyright", Table_string App_info.copyright);
|
||||
("licence", Table_string App_info.licence);
|
||||
("capabilities", Table_table (table_of_list []));
|
||||
]
|
||||
|
||||
let check_login_details mechanism response =
|
||||
match mechanism with
|
||||
| "PLAIN" ->
|
||||
(match (Str.split (Str.regexp "\000") response) with
|
||||
| ["guest"; "guest"] -> ()
|
||||
| _ -> die access_refused "Access refused")
|
||||
| "AMQPLAIN" ->
|
||||
(let fields = decode_named_fields (Ibuffer.of_string response) in
|
||||
match (field_lookup_some "LOGIN" fields, field_lookup_some "PASSWORD" fields) with
|
||||
| (Some (Table_string "guest"), Some (Table_string "guest")) -> ()
|
||||
| _ -> die access_refused "Access refused")
|
||||
| _ -> die access_refused "Bad auth mechanism"
|
||||
|
||||
let tune_connection conn frame_max =
|
||||
with_conn_mutex conn (fun () ->
|
||||
conn.input_buf <- String.create frame_max;
|
||||
conn.output_buf <- Buffer.create frame_max;
|
||||
conn.frame_max <- frame_max)
|
||||
|
||||
let handshake_and_tune conn =
|
||||
let (major_version, minor_version, revision) = version in
|
||||
send_method conn 0 (Connection_start (major_version, minor_version, server_properties,
|
||||
"PLAIN AMQPLAIN", "en_US"));
|
||||
let (client_properties, mechanism, response, locale) =
|
||||
match next_method conn with
|
||||
| (0, Connection_start_ok props) -> props
|
||||
| _ -> die not_allowed "Expected Connection_start_ok on channel 0"
|
||||
in
|
||||
check_login_details mechanism response;
|
||||
Log.info "Connection from AMQP client" [sexp_of_table client_properties];
|
||||
send_method conn 0 (Connection_tune (1, Int32.of_int suggested_frame_max, 0));
|
||||
let (channel_max, frame_max, heartbeat) =
|
||||
match next_method conn with
|
||||
| (0, Connection_tune_ok props) -> props
|
||||
| _ -> die not_allowed "Expected Connection_tune_ok on channel 0"
|
||||
in
|
||||
if channel_max > 1
|
||||
then die not_implemented "Channel numbers higher than 1 are not supported" else ();
|
||||
if (Int32.to_int frame_max) > suggested_frame_max
|
||||
then die syntax_error "Requested frame max too large" else ();
|
||||
if heartbeat > 0
|
||||
then die not_implemented "Heartbeats not yet implemented (patches welcome)" else ();
|
||||
tune_connection conn (Int32.to_int frame_max);
|
||||
let (virtual_host) =
|
||||
match next_method conn with
|
||||
| (0, Connection_open props) -> props
|
||||
| _ -> die not_allowed "Expected Connection_open on channel 0"
|
||||
in
|
||||
Log.info "Connected to vhost" [Sexp.Str virtual_host];
|
||||
send_method conn 0 Connection_open_ok
|
||||
|
||||
let amqp_mainloop peername mtx cin cout n =
|
||||
let conn = {
|
||||
|
@ -124,21 +242,17 @@ let amqp_mainloop peername mtx cin cout n =
|
|||
input_buf = String.create initial_frame_size;
|
||||
output_buf = Buffer.create initial_frame_size;
|
||||
frame_max = initial_frame_size;
|
||||
connection_closed = false
|
||||
connection_closed = false;
|
||||
recent_queue_name = None;
|
||||
} in
|
||||
(try
|
||||
let (major_version, minor_version, revision) = version in
|
||||
send_method conn (Connection_start (major_version, minor_version,
|
||||
Amqp_wireformat.table_of_list [],
|
||||
"", ""));
|
||||
let (_, Connection_start_ok (client_properties, mechanism, response, locale))
|
||||
= next_method conn in
|
||||
while true do
|
||||
handshake_and_tune conn;
|
||||
while not conn.connection_closed do
|
||||
let (channel, m) = next_method conn in
|
||||
handle_method conn m
|
||||
handle_method conn channel m
|
||||
done
|
||||
with
|
||||
| Amqp_wireformat.Amqp_exception (code, message) ->
|
||||
| Amqp_exception (code, message) ->
|
||||
send_error conn code message
|
||||
)
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ open Sexp
|
|||
|
||||
exception Amqp_exception of (int * string)
|
||||
|
||||
let die code message = raise (Amqp_exception (code, message))
|
||||
|
||||
type octet_t = int
|
||||
type short_t = int
|
||||
type long_t = int32
|
||||
|
@ -108,8 +110,7 @@ and read_table_value input_buf =
|
|||
| 'T' -> Table_timestamp (read_longlong input_buf)
|
||||
| 'F' -> Table_table { table_body = Encoded_table (read_longstr input_buf) }
|
||||
| 'V' -> Table_void
|
||||
| c -> raise (Amqp_exception (502 (*syntax-error*),
|
||||
Printf.sprintf "Unknown table field type code '%c'" c))
|
||||
| c -> die 502 (*syntax-error*) (Printf.sprintf "Unknown table field type code '%c'" c)
|
||||
|
||||
and decoded_table t =
|
||||
match t.table_body with
|
||||
|
@ -252,3 +253,15 @@ let reserved_value_longstr = ""
|
|||
let reserved_value_bit = false
|
||||
let reserved_value_timestamp = Int64.zero
|
||||
let reserved_value_table = { table_body = Encoded_table "" }
|
||||
|
||||
let field_lookup k def fs =
|
||||
try List.assoc k fs
|
||||
with Not_found -> def
|
||||
|
||||
let field_lookup_some k fs =
|
||||
try Some (List.assoc k fs)
|
||||
with Not_found -> None
|
||||
|
||||
let table_lookup k t = List.assoc k (decoded_table t)
|
||||
let table_lookup_default k def t = field_lookup k def (decoded_table t)
|
||||
let table_lookup_some k t = field_lookup_some k (decoded_table t)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
let product = "ocamlmsg"
|
||||
let version = "ALPHA"
|
||||
let copyright = "Copyright (C) 2012 Tony Garnock-Jones."
|
||||
let licence = "All rights reserved."
|
|
@ -10,6 +10,8 @@ let create s ofs len = {
|
|||
buf = s
|
||||
}
|
||||
|
||||
let of_string s = create s 0 (String.length s)
|
||||
|
||||
let sub b ofs len =
|
||||
if b.pos + ofs + len > b.limit
|
||||
then
|
||||
|
|
|
@ -25,7 +25,7 @@ let hook_log () =
|
|||
Log.hook := new_hook
|
||||
|
||||
let _ =
|
||||
printf "ocamlmsg ALPHA, Copyright (C) 2012 Tony Garnock-Jones. All rights reserved.\n%!";
|
||||
printf "%s %s, %s %s\n%!" App_info.product App_info.version App_info.copyright App_info.licence;
|
||||
Sys.set_signal Sys.sigpipe Sys.Signal_ignore;
|
||||
Uuid.init ();
|
||||
Factory.init ();
|
||||
|
|
Loading…
Reference in New Issue