hop-2012/server/thirdparty/lwt-2.3.2/src/unix/lwt_io.ml

1502 lines
48 KiB
OCaml

(* Lightweight thread library for Objective Caml
* http://www.ocsigen.org/lwt
* Module Lwt_io
* Copyright (C) 2009 Jérémie Dimino
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation, with linking exceptions;
* either version 2.1 of the License, or (at your option) any later
* version. See COPYING file for details.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
* 02111-1307, USA.
*)
open Lwt
exception Channel_closed of string
(* Minimum size for buffers: *)
let min_buffer_size = 16
let check_buffer_size fun_name buffer_size =
if buffer_size < min_buffer_size then
Printf.ksprintf invalid_arg "Lwt_io.%s: too small buffer size (%d)" fun_name buffer_size
else if buffer_size > Sys.max_string_length then
Printf.ksprintf invalid_arg "Lwt_io.%s: too big buffer size (%d)" fun_name buffer_size
else
()
let default_buffer_size = ref 4096
(* +-----------------------------------------------------------------+
| Types |
+-----------------------------------------------------------------+ *)
type input
type output
#if ocaml_version >= (3, 13)
type 'a mode =
| Input : input mode
| Output : output mode
#else
type 'a mode =
| Input
| Output
#endif
let input : input mode = Input
let output : output mode = Output
(* A channel state *)
type 'mode state =
| Busy_primitive
(* A primitive is running on the channel *)
| Busy_atomic of 'mode channel
(* An atomic operations is being performed on the channel. The
argument is the temporary atomic wrapper. *)
| Waiting_for_busy
(* A queued operation has not yet started. *)
| Idle
(* The channel is unused *)
| Closed
(* The channel has been closed *)
| Invalid
(* The channel is a temporary channel created for an atomic
operation which has terminated. *)
(* A wrapper, which ensures that io operations are atomic: *)
and 'mode channel = {
mutable state : 'mode state;
channel : 'mode _channel;
(* The real channel *)
mutable queued : unit Lwt.u Lwt_sequence.t;
(* Queued operations *)
}
and 'mode _channel = {
mutable buffer : Lwt_bytes.t;
mutable length : int;
mutable ptr : int;
(* Current position *)
mutable max : int;
(* Position of the end of data int the buffer. It is equal to
[length] for output channels. *)
abort_waiter : int Lwt.t;
(* Thread which is wakeup with an exception when the channel is
closed. *)
abort_wakener : int Lwt.u;
mutable auto_flushing : bool;
(* Wether the auto-flusher is currently running or not *)
main : 'mode channel;
(* The main wrapper *)
close : unit Lwt.t Lazy.t;
(* Close function *)
mode : 'mode mode;
(* The channel mode *)
mutable offset : int64;
(* Number of bytes really read/written *)
typ : typ;
(* Type of the channel. *)
}
and typ =
| Type_normal of (Lwt_bytes.t -> int -> int -> int Lwt.t) * (int64 -> Unix.seek_command -> int64 Lwt.t)
(* The channel has been created with [make]. The first argument
is the refill/flush function and the second is the seek
function. *)
| Type_bytes
(* The channel has been created with [of_bytes]. *)
type input_channel = input channel
type output_channel = output channel
type direct_access = {
da_buffer : Lwt_bytes.t;
mutable da_ptr : int;
mutable da_max : int;
da_perform : unit -> int Lwt.t;
}
let mode wrapper = wrapper.channel.mode
(* +-----------------------------------------------------------------+
| Creations, closing, locking, ... |
+-----------------------------------------------------------------+ *)
module Outputs = Weak.Make(struct
type t = output_channel
let hash = Hashtbl.hash
let equal = ( == )
end)
(* Table of all opened output channels. On exit they are all
flushed: *)
let outputs = Outputs.create 32
#if ocaml_version >= (3, 13)
let position : type mode. mode channel -> int64 = fun wrapper ->
#else
let position wrapper =
#endif
let ch = wrapper.channel in
match ch.mode with
| Input ->
Int64.sub ch.offset (Int64.of_int (ch.max - ch.ptr))
| Output ->
Int64.add ch.offset (Int64.of_int ch.ptr)
#if ocaml_version >= (3, 13)
let name : type mode. mode _channel -> string = fun ch ->
#else
let name ch =
#endif
match ch.mode with
| Input -> "input"
| Output -> "output"
let closed_channel ch = Channel_closed(name ch)
let invalid_channel ch = Failure(Printf.sprintf "temporary atomic %s channel no more valid" (name ch))
let is_busy ch =
match ch.state with
| Invalid ->
raise (invalid_channel ch.channel)
| Idle | Closed ->
false
| Busy_primitive | Busy_atomic _ | Waiting_for_busy ->
true
(* Flush/refill the buffer. No race condition could happen because
this function is always called atomically: *)
#if ocaml_version >= (3, 13)
let perform_io : type mode. mode _channel -> int Lwt.t = fun ch -> match ch.main.state with
#else
let perform_io ch = match ch.main.state with
#endif
| Busy_primitive | Busy_atomic _ -> begin
match ch.typ with
| Type_normal(perform_io, seek) ->
let ptr, len = match ch.mode with
| Input ->
(* Size of data in the buffer *)
let size = ch.max - ch.ptr in
(* If there are still data in the buffer, keep them: *)
if size > 0 then Lwt_bytes.unsafe_blit ch.buffer ch.ptr ch.buffer 0 size;
(* Update positions: *)
ch.ptr <- 0;
ch.max <- size;
(size, ch.length - size)
| Output ->
(0, ch.ptr) in
lwt n = pick [ch.abort_waiter; perform_io ch.buffer ptr len] in
(* Never trust user functions... *)
if n < 0 || n > len then
raise_lwt (Failure (Printf.sprintf "Lwt_io: invalid result of the [%s] function(request=%d,result=%d)"
(match ch.mode with Input -> "read" | Output -> "write") len n))
else begin
(* Update the global offset: *)
ch.offset <- Int64.add ch.offset (Int64.of_int n);
(* Update buffer positions: *)
begin match ch.mode with
| Input ->
ch.max <- ch.max + n
| Output ->
(* Shift remaining data: *)
let len = len - n in
Lwt_bytes.unsafe_blit ch.buffer n ch.buffer 0 len;
ch.ptr <- len
end;
return n
end
| Type_bytes -> begin
match ch.mode with
| Input ->
return 0
| Output ->
raise_lwt (Failure "cannot flush a channel created with Lwt_io.of_string")
end
end
| Closed ->
raise_lwt (closed_channel ch)
| Invalid ->
raise_lwt (invalid_channel ch)
| Idle | Waiting_for_busy ->
assert false
let refill = perform_io
let flush_partial = perform_io
let rec flush_total oc =
if oc.ptr > 0 then
lwt _ = flush_partial oc in
flush_total oc
else
return ()
let safe_flush_total oc =
try_lwt
flush_total oc
with
_ -> return ()
let deepest_wrapper ch =
let rec loop wrapper =
match wrapper.state with
| Busy_atomic wrapper ->
loop wrapper
| _ ->
wrapper
in
loop ch.main
let auto_flush oc =
lwt () = Lwt.pause () in
let wrapper = deepest_wrapper oc in
match wrapper.state with
| Busy_primitive | Waiting_for_busy ->
(* The channel is used, cancel auto flushing. It will be
restarted when the channel returns to the [Idle] state: *)
oc.auto_flushing <- false;
return ()
| Busy_atomic _ ->
(* Cannot happen since we took the deepest wrapper: *)
assert false
| Idle ->
oc.auto_flushing <- false;
wrapper.state <- Busy_primitive;
lwt () = safe_flush_total oc in
if wrapper.state = Busy_primitive then
wrapper.state <- Idle;
if not (Lwt_sequence.is_empty wrapper.queued) then
wakeup_later (Lwt_sequence.take_l wrapper.queued) ();
return ()
| Closed | Invalid ->
return ()
(* A ``locked'' channel is a channel in the state [Busy_primitive] or
[Busy_atomic] *)
#if ocaml_version >= (3, 13)
let unlock : type m. m channel -> unit = fun wrapper -> match wrapper.state with
#else
let unlock wrapper = match wrapper.state with
#endif
| Busy_primitive | Busy_atomic _ ->
if Lwt_sequence.is_empty wrapper.queued then
wrapper.state <- Idle
else begin
wrapper.state <- Waiting_for_busy;
wakeup_later (Lwt_sequence.take_l wrapper.queued) ()
end;
(* Launches the auto-flusher: *)
let ch = wrapper.channel in
if (* Launch the auto-flusher only if the channel is not busy: *)
(wrapper.state = Idle &&
(* Launch the auto-flusher only for output channel: *)
(match ch.mode with Input -> false | Output -> true) &&
(* Do not launch two auto-flusher: *)
not ch.auto_flushing &&
(* Do not launch the auto-flusher if operations are queued: *)
Lwt_sequence.is_empty wrapper.queued) then begin
ch.auto_flushing <- true;
ignore (auto_flush ch)
end
| Closed | Invalid ->
(* Do not change channel state if the channel has been closed *)
if not (Lwt_sequence.is_empty wrapper.queued) then
wakeup_later (Lwt_sequence.take_l wrapper.queued) ()
| Idle | Waiting_for_busy ->
(* We must never unlock an unlocked channel *)
assert false
(* Wrap primitives into atomic io operations: *)
let primitive f wrapper = match wrapper.state with
| Idle ->
wrapper.state <- Busy_primitive;
try_lwt
f wrapper.channel
finally
unlock wrapper;
return ()
| Busy_primitive | Busy_atomic _ | Waiting_for_busy ->
let (res, w) = task () in
let node = Lwt_sequence.add_r w wrapper.queued in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
lwt () = res in
begin match wrapper.state with
| Closed ->
(* The channel has been closed while we were waiting *)
unlock wrapper;
raise_lwt (closed_channel wrapper.channel)
| Idle | Waiting_for_busy ->
wrapper.state <- Busy_primitive;
try_lwt
f wrapper.channel
finally
unlock wrapper;
return ()
| Invalid ->
raise_lwt (invalid_channel wrapper.channel)
| Busy_primitive | Busy_atomic _ ->
assert false
end
| Closed ->
raise_lwt (closed_channel wrapper.channel)
| Invalid ->
raise_lwt (invalid_channel wrapper.channel)
(* Wrap a sequence of io operations into an atomic operation: *)
let atomic f wrapper = match wrapper.state with
| Idle ->
let tmp_wrapper = { state = Idle;
channel = wrapper.channel;
queued = Lwt_sequence.create () } in
wrapper.state <- Busy_atomic tmp_wrapper;
try_lwt
f tmp_wrapper
finally
(* The temporary wrapper is no more valid: *)
tmp_wrapper.state <- Invalid;
unlock wrapper;
return ()
| Busy_primitive | Busy_atomic _ | Waiting_for_busy ->
let (res, w) = task () in
let node = Lwt_sequence.add_r w wrapper.queued in
Lwt.on_cancel res (fun _ -> Lwt_sequence.remove node);
lwt () = res in
begin match wrapper.state with
| Closed ->
(* The channel has been closed while we were waiting *)
unlock wrapper;
raise_lwt (closed_channel wrapper.channel)
| Idle | Waiting_for_busy ->
let tmp_wrapper = { state = Idle;
channel = wrapper.channel;
queued = Lwt_sequence.create () } in
wrapper.state <- Busy_atomic tmp_wrapper;
try_lwt
f tmp_wrapper
finally
tmp_wrapper.state <- Invalid;
unlock wrapper;
return ()
| Invalid ->
raise_lwt (invalid_channel wrapper.channel)
| Busy_primitive | Busy_atomic _ ->
assert false
end
| Closed ->
raise_lwt (closed_channel wrapper.channel)
| Invalid ->
raise_lwt (invalid_channel wrapper.channel)
let rec abort wrapper = match wrapper.state with
| Busy_atomic tmp_wrapper ->
(* Close the depest opened wrapper: *)
abort tmp_wrapper
| Closed ->
(* Double close, just returns the same thing as before *)
Lazy.force wrapper.channel.close
| Invalid ->
raise_lwt (invalid_channel wrapper.channel)
| Idle | Busy_primitive | Waiting_for_busy ->
wrapper.state <- Closed;
(* Abort any current real reading/writing operation on the
channel: *)
wakeup_exn wrapper.channel.abort_wakener (closed_channel wrapper.channel);
Lazy.force wrapper.channel.close
#if ocaml_version >= (3, 13)
let close : type mode. mode channel -> unit Lwt.t = fun wrapper ->
#else
let close wrapper =
#endif
let channel = wrapper.channel in
if channel.main != wrapper then
raise_lwt (Failure "Lwt_io.close: cannot close a channel obtained via Lwt_io.atomic")
else
match channel.mode with
| Input ->
(* Just close it now: *)
abort wrapper
| Output ->
try_lwt
(* Performs all pending actions, flush the buffer, then
close it: *)
primitive (fun channel -> safe_flush_total channel >> abort wrapper) wrapper
with _ ->
abort wrapper
let flush_all () =
let wrappers = Outputs.fold (fun x l -> x :: l) outputs [] in
Lwt_list.iter_p
(fun wrapper ->
try_lwt
primitive safe_flush_total wrapper
with _ ->
return ())
wrappers
let () =
(* Flush all opened ouput channels on exit: *)
Lwt_main.at_exit flush_all
let no_seek pos cmd =
raise_lwt (Failure "Lwt_io.seek: seek not supported on this channel")
#if ocaml_version < (3, 13)
external unsafe_output : 'a channel -> output channel = "%identity"
#endif
#if ocaml_version >= (3, 13)
let make :
type m.
?buffer_size : int ->
?close : (unit -> unit Lwt.t) ->
?seek : (int64 -> Unix.seek_command -> int64 Lwt.t) ->
mode : m mode ->
(Lwt_bytes.t -> int -> int -> int Lwt.t) ->
m channel = fun ?buffer_size ?(close=return) ?(seek=no_seek) ~mode perform_io ->
#else
let make ?buffer_size ?(close=return) ?(seek=no_seek) ~mode perform_io =
#endif
let size =
match buffer_size with
| None ->
!default_buffer_size
| Some size ->
check_buffer_size "Lwt_io.make" size;
size
in
let buffer = Lwt_bytes.create size and abort_waiter, abort_wakener = Lwt.wait () in
let rec ch = {
buffer = buffer;
length = size;
ptr = 0;
max = (match mode with
| Input -> 0
| Output -> size);
close = lazy(try_lwt close ());
abort_waiter = abort_waiter;
abort_wakener = abort_wakener;
main = wrapper;
auto_flushing = false;
mode = mode;
offset = 0L;
typ = Type_normal(perform_io, fun pos cmd -> try seek pos cmd with e -> raise_lwt e);
} and wrapper = {
state = Idle;
channel = ch;
queued = Lwt_sequence.create ();
} in
#if ocaml_version < (3, 13)
if mode = Output then Outputs.add outputs (unsafe_output wrapper);
#else
(match mode with
| Input -> ()
| Output -> Outputs.add outputs wrapper);
#endif
wrapper
let of_bytes ~mode bytes =
let length = Lwt_bytes.length bytes in
let abort_waiter, abort_wakener = Lwt.wait () in
let rec ch = {
buffer = bytes;
length = length;
ptr = 0;
max = length;
close = lazy(return ());
abort_waiter = abort_waiter;
abort_wakener = abort_wakener;
main = wrapper;
(* Auto flush is set to [true] to prevent writing functions from
trying to launch the auto-fllushed. *)
auto_flushing = true;
mode = mode;
offset = 0L;
typ = Type_bytes;
} and wrapper = {
state = Idle;
channel = ch;
queued = Lwt_sequence.create ();
} in
wrapper
let of_string ~mode str = of_bytes ~mode (Lwt_bytes.of_string str)
#if ocaml_version >= (3, 13)
let of_fd : type m. ?buffer_size : int -> ?close : (unit -> unit Lwt.t) -> mode : m mode -> Lwt_unix.file_descr -> m channel = fun ?buffer_size ?close ~mode fd ->
#else
let of_fd ?buffer_size ?close ~mode fd =
#endif
let perform_io = match mode with
| Input -> Lwt_bytes.read fd
| Output -> Lwt_bytes.write fd
in
make
?buffer_size
~close:(match close with
| Some f -> f
| None -> (fun () -> Lwt_unix.close fd))
~seek:(fun pos cmd -> Lwt_unix.LargeFile.lseek fd pos cmd)
~mode
perform_io
#if ocaml_version >= (3, 13)
let of_unix_fd : type m. ?buffer_size : int -> ?close : (unit -> unit Lwt.t) -> mode : m mode -> Unix.file_descr -> m channel = fun ?buffer_size ?close ~mode fd ->
#else
let of_unix_fd ?buffer_size ?close ~mode fd =
#endif
of_fd ?buffer_size ?close ~mode (Lwt_unix.of_unix_file_descr fd)
#if ocaml_version >= (3, 13)
let buffered : type m. m channel -> int = fun ch ->
#else
let buffered ch =
#endif
match ch.channel.mode with
| Input -> ch.channel.max - ch.channel.ptr
| Output -> ch.channel.ptr
let buffer_size ch = ch.channel.length
#if ocaml_version >= (3, 13)
let resize_buffer : type m. m channel -> int -> unit Lwt.t = fun wrapper len ->
#else
let resize_buffer wrapper len =
#endif
if len < min_buffer_size then invalid_arg "Lwt_io.resize_buffer";
match wrapper.channel.typ with
| Type_bytes ->
raise_lwt (Failure "Lwt_io.resize_buffer: cannot resize the buffer of a channel created with Lwt_io.of_string")
| Type_normal _ ->
#if ocaml_version >= (3, 13)
let f : type m. m _channel -> unit Lwt.t = fun ch ->
#else
let f ch =
#endif
match ch.mode with
| Input ->
let unread_count = ch.max - ch.ptr in
(* Fail if we want to decrease the buffer size and there is
too much unread data in the buffer: *)
if len < unread_count then
raise_lwt (Failure "Lwt_io.resize_buffer: cannot decrease buffer size")
else begin
let buffer = Lwt_bytes.create len in
Lwt_bytes.unsafe_blit ch.buffer ch.ptr buffer 0 unread_count;
ch.buffer <- buffer;
ch.length <- len;
ch.ptr <- 0;
ch.max <- unread_count;
return ()
end
| Output ->
(* If we decrease the buffer size, flush the buffer until
the number of buffered bytes fits into the new buffer: *)
let rec loop () =
if ch.ptr > len then
lwt _ = flush_partial ch in
loop ()
else
return ()
in
lwt () = loop () in
let buffer = Lwt_bytes.create len in
Lwt_bytes.unsafe_blit ch.buffer 0 buffer 0 ch.ptr;
ch.buffer <- buffer;
ch.length <- len;
ch.max <- len;
return ()
in
primitive f wrapper
(* +-----------------------------------------------------------------+
| Byte-order |
+-----------------------------------------------------------------+ *)
module ByteOrder =
struct
module type S = sig
val pos16_0 : int
val pos16_1 : int
val pos32_0 : int
val pos32_1 : int
val pos32_2 : int
val pos32_3 : int
val pos64_0 : int
val pos64_1 : int
val pos64_2 : int
val pos64_3 : int
val pos64_4 : int
val pos64_5 : int
val pos64_6 : int
val pos64_7 : int
end
module LE =
struct
let pos16_0 = 0
let pos16_1 = 1
let pos32_0 = 0
let pos32_1 = 1
let pos32_2 = 2
let pos32_3 = 3
let pos64_0 = 0
let pos64_1 = 1
let pos64_2 = 2
let pos64_3 = 3
let pos64_4 = 4
let pos64_5 = 5
let pos64_6 = 6
let pos64_7 = 7
end
module BE =
struct
let pos16_0 = 1
let pos16_1 = 0
let pos32_0 = 3
let pos32_1 = 2
let pos32_2 = 1
let pos32_3 = 0
let pos64_0 = 7
let pos64_1 = 6
let pos64_2 = 5
let pos64_3 = 4
let pos64_4 = 3
let pos64_5 = 2
let pos64_6 = 1
let pos64_7 = 0
end
end
module Primitives =
struct
(* This module contains all primitives operations. The operates
without protection regarding locking, they are wrapped after into
safe operations. *)
(* +---------------------------------------------------------------+
| Reading |
+---------------------------------------------------------------+ *)
let rec read_char ic =
let ptr = ic.ptr in
if ptr = ic.max then
refill ic >>= function
| 0 -> raise_lwt End_of_file
| _ -> read_char ic
else begin
ic.ptr <- ptr + 1;
return (Lwt_bytes.unsafe_get ic.buffer ptr)
end
let read_char_opt ic =
try_lwt
read_char ic >|= fun ch -> Some ch
with End_of_file ->
return None
let read_line ic =
let buf = Buffer.create 128 in
let rec loop cr_read =
try_bind (fun _ -> read_char ic)
(function
| '\n' ->
return(Buffer.contents buf)
| '\r' ->
if cr_read then Buffer.add_char buf '\r';
loop true
| ch ->
if cr_read then Buffer.add_char buf '\r';
Buffer.add_char buf ch;
loop false)
(function
| End_of_file ->
if cr_read then Buffer.add_char buf '\r';
return(Buffer.contents buf)
| exn ->
raise_lwt exn)
in
read_char ic >>= function
| '\r' -> loop true
| '\n' -> return ""
| ch -> Buffer.add_char buf ch; loop false
let read_line_opt ic =
try_lwt
read_line ic >|= fun ch -> Some ch
with End_of_file ->
return None
let unsafe_read_into ic str ofs len =
let avail = ic.max - ic.ptr in
if avail > 0 then begin
let len = min len avail in
Lwt_bytes.unsafe_blit_bytes_string ic.buffer ic.ptr str ofs len;
ic.ptr <- ic.ptr + len;
return len
end else begin
refill ic >>= fun n ->
let len = min len n in
Lwt_bytes.unsafe_blit_bytes_string ic.buffer 0 str ofs len;
ic.ptr <- len;
ic.max <- n;
return len
end
let read_into ic str ofs len =
if ofs < 0 || len < 0 || ofs + len > String.length str then
raise_lwt (Invalid_argument (Printf.sprintf
"Lwt_io.read_into(ofs=%d,len=%d,str_len=%d)"
ofs len (String.length str)))
else begin
if len = 0 then
return 0
else
unsafe_read_into ic str ofs len
end
let rec unsafe_read_into_exactly ic str ofs len =
unsafe_read_into ic str ofs len >>= function
| 0 ->
raise_lwt End_of_file
| n ->
let len = len - n in
if len = 0 then
return ()
else
unsafe_read_into_exactly ic str (ofs + n) len
let read_into_exactly ic str ofs len =
if ofs < 0 || len < 0 || ofs + len > String.length str then
raise_lwt (Invalid_argument (Printf.sprintf
"Lwt_io.read_into_exactly(ofs=%d,len=%d,str_len=%d)"
ofs len (String.length str)))
else begin
if len = 0 then
return ()
else
unsafe_read_into_exactly ic str ofs len
end
let rev_concat len l =
let buf = String.create len in
let _ =
List.fold_left
(fun ofs str ->
let len = String.length str in
let ofs = ofs - len in
String.unsafe_blit str 0 buf ofs len;
ofs)
len l
in
buf
let rec read_all ic total_len acc =
let len = ic.max - ic.ptr in
let str = String.create len in
Lwt_bytes.unsafe_blit_bytes_string ic.buffer ic.ptr str 0 len;
ic.ptr <- ic.max;
refill ic >>= function
| 0 ->
return (rev_concat (len + total_len) (str :: acc))
| n ->
read_all ic (len + total_len) (str :: acc)
let read count ic =
match count with
| None ->
read_all ic 0 []
| Some len ->
let str = String.create len in
lwt real_len = unsafe_read_into ic str 0 len in
if real_len < len then
return (String.sub str 0 real_len)
else
return str
let read_value ic =
let header = String.create 20 in
lwt () = unsafe_read_into_exactly ic header 0 20 in
let bsize = Marshal.data_size header 0 in
let buffer = String.create (20 + bsize) in
String.unsafe_blit header 0 buffer 0 20;
lwt () = unsafe_read_into_exactly ic buffer 20 bsize in
return (Marshal.from_string buffer 0)
(* +---------------------------------------------------------------+
| Writing |
+---------------------------------------------------------------+ *)
let flush = flush_total
let rec write_char oc ch =
let ptr = oc.ptr in
if ptr < oc.length then begin
oc.ptr <- ptr + 1;
Lwt_bytes.unsafe_set oc.buffer ptr ch;
return ()
end else
lwt _ = flush_partial oc in
write_char oc ch
let rec unsafe_write_from oc str ofs len =
let avail = oc.length - oc.ptr in
if avail >= len then begin
Lwt_bytes.unsafe_blit_string_bytes str ofs oc.buffer oc.ptr len;
oc.ptr <- oc.ptr + len;
return 0
end else begin
Lwt_bytes.unsafe_blit_string_bytes str ofs oc.buffer oc.ptr avail;
oc.ptr <- oc.length;
lwt _ = flush_partial oc in
let len = len - avail in
if oc.ptr = 0 then begin
if len = 0 then
return 0
else
(* Everything has been written, try to write more: *)
unsafe_write_from oc str (ofs + avail) len
end else
(* Not everything has been written, just what is
remaining: *)
return len
end
let write_from oc str ofs len =
if ofs < 0 || len < 0 || ofs + len > String.length str then
raise_lwt (Invalid_argument (Printf.sprintf
"Lwt_io.write_from(ofs=%d,len=%d,str_len=%d)"
ofs len (String.length str)))
else begin
if len = 0 then
return 0
else
unsafe_write_from oc str ofs len >>= fun remaining -> return (len - remaining)
end
let rec unsafe_write_from_exactly oc str ofs len =
unsafe_write_from oc str ofs len >>= function
| 0 ->
return ()
| n ->
unsafe_write_from_exactly oc str (ofs + len - n) n
let write_from_exactly oc str ofs len =
if ofs < 0 || len < 0 || ofs + len > String.length str then
raise_lwt (Invalid_argument (Printf.sprintf
"Lwt_io.write_from_exactly(ofs=%d,len=%d,str_len=%d)"
ofs len (String.length str)))
else begin
if len = 0 then
return ()
else
unsafe_write_from_exactly oc str ofs len
end
let write oc str =
unsafe_write_from_exactly oc str 0 (String.length str)
let write_line oc str =
lwt () = unsafe_write_from_exactly oc str 0 (String.length str) in
write_char oc '\n'
let write_value oc ?(flags=[]) x =
write oc (Marshal.to_string x flags)
(* +---------------------------------------------------------------+
| Low-level access |
+---------------------------------------------------------------+ *)
let rec read_block_unsafe ic size f =
if ic.max - ic.ptr < size then
refill ic >>= function
| 0 ->
raise_lwt End_of_file
| _ ->
read_block_unsafe ic size f
else begin
let ptr = ic.ptr in
ic.ptr <- ptr + size;
f ic.buffer ptr
end
let rec write_block_unsafe oc size f =
if oc.max - oc.ptr < size then
lwt _ = flush_partial oc in
write_block_unsafe oc size f
else begin
let ptr = oc.ptr in
oc.ptr <- ptr + size;
f oc.buffer ptr
end
#if ocaml_version >= (3, 13)
let block : type m. m _channel -> int -> (Lwt_bytes.t -> int -> 'a Lwt.t) -> 'a Lwt.t = fun ch size f ->
#else
let block ch size f =
#endif
if size < 0 || size > min_buffer_size then
raise_lwt (Invalid_argument(Printf.sprintf "Lwt_io.block(size=%d)" size))
else
if ch.max - ch.ptr >= size then begin
let ptr = ch.ptr in
ch.ptr <- ptr + size;
f ch.buffer ptr
end else
match ch.mode with
| Input ->
read_block_unsafe ch size f
| Output ->
write_block_unsafe ch size f
let perform token da ch =
if !token then begin
if da.da_max <> ch.max || da.da_ptr < ch.ptr || da.da_ptr > ch.max then
raise_lwt (Invalid_argument "Lwt_io.direct_access.perform")
else begin
ch.ptr <- da.da_ptr;
lwt count = perform_io ch in
da.da_ptr <- ch.ptr;
da.da_max <- ch.max;
return count
end
end else
raise_lwt (Failure "Lwt_io.direct_access.perform: this function can not be called outside Lwt_io.direct_access")
let direct_access ch f =
let token = ref true in
let rec da = {
da_ptr = ch.ptr;
da_max = ch.max;
da_buffer = ch.buffer;
da_perform = (fun _ -> perform token da ch);
} in
lwt x = f da in
token := false;
if da.da_max <> ch.max || da.da_ptr < ch.ptr || da.da_ptr > ch.max then
raise_lwt (Failure "Lwt_io.direct_access: invalid result of [f]")
else begin
ch.ptr <- da.da_ptr;
return x
end
module MakeNumberIO(ByteOrder : ByteOrder.S) =
struct
open ByteOrder
(* +-------------------------------------------------------------+
| Reading numbers |
+-------------------------------------------------------------+ *)
let get buffer ptr = Char.code (Lwt_bytes.unsafe_get buffer ptr)
let read_int ic =
read_block_unsafe ic 4
(fun buffer ptr ->
let v0 = get buffer (ptr + pos32_0)
and v1 = get buffer (ptr + pos32_1)
and v2 = get buffer (ptr + pos32_2)
and v3 = get buffer (ptr + pos32_3) in
let v = v0 lor (v1 lsl 8) lor (v2 lsl 16) lor (v3 lsl 24) in
if v3 land 0x80 = 0 then
return v
else
return (v - (1 lsl 32)))
let read_int16 ic =
read_block_unsafe ic 2
(fun buffer ptr ->
let v0 = get buffer (ptr + pos16_0)
and v1 = get buffer (ptr + pos16_1) in
let v = v0 lor (v1 lsl 8) in
if v1 land 0x80 = 0 then
return v
else
return (v - (1 lsl 16)))
let read_int32 ic =
read_block_unsafe ic 4
(fun buffer ptr ->
let v0 = get buffer (ptr + pos32_0)
and v1 = get buffer (ptr + pos32_1)
and v2 = get buffer (ptr + pos32_2)
and v3 = get buffer (ptr + pos32_3) in
return (Int32.logor
(Int32.logor
(Int32.of_int v0)
(Int32.shift_left (Int32.of_int v1) 8))
(Int32.logor
(Int32.shift_left (Int32.of_int v2) 16)
(Int32.shift_left (Int32.of_int v3) 24))))
let read_int64 ic =
read_block_unsafe ic 8
(fun buffer ptr ->
let v0 = get buffer (ptr + pos64_0)
and v1 = get buffer (ptr + pos64_1)
and v2 = get buffer (ptr + pos64_2)
and v3 = get buffer (ptr + pos64_3)
and v4 = get buffer (ptr + pos64_4)
and v5 = get buffer (ptr + pos64_5)
and v6 = get buffer (ptr + pos64_6)
and v7 = get buffer (ptr + pos64_7) in
return (Int64.logor
(Int64.logor
(Int64.logor
(Int64.of_int v0)
(Int64.shift_left (Int64.of_int v1) 8))
(Int64.logor
(Int64.shift_left (Int64.of_int v2) 16)
(Int64.shift_left (Int64.of_int v3) 24)))
(Int64.logor
(Int64.logor
(Int64.shift_left (Int64.of_int v4) 32)
(Int64.shift_left (Int64.of_int v5) 40))
(Int64.logor
(Int64.shift_left (Int64.of_int v6) 48)
(Int64.shift_left (Int64.of_int v7) 56)))))
let read_float32 ic = read_int32 ic >>= fun x -> return (Int32.float_of_bits x)
let read_float64 ic = read_int64 ic >>= fun x -> return (Int64.float_of_bits x)
(* +-------------------------------------------------------------+
| Writing numbers |
+-------------------------------------------------------------+ *)
let set buffer ptr x = Lwt_bytes.unsafe_set buffer ptr (Char.unsafe_chr x)
let write_int oc v =
write_block_unsafe oc 4
(fun buffer ptr ->
set buffer (ptr + pos32_0) v;
set buffer (ptr + pos32_1) (v lsr 8);
set buffer (ptr + pos32_2) (v lsr 16);
set buffer (ptr + pos32_3) (v asr 24);
return ())
let write_int16 oc v =
write_block_unsafe oc 2
(fun buffer ptr ->
set buffer (ptr + pos16_0) v;
set buffer (ptr + pos16_1) (v lsr 8);
return ())
let write_int32 oc v =
write_block_unsafe oc 4
(fun buffer ptr ->
set buffer (ptr + pos32_0) (Int32.to_int v);
set buffer (ptr + pos32_1) (Int32.to_int (Int32.shift_right v 8));
set buffer (ptr + pos32_2) (Int32.to_int (Int32.shift_right v 16));
set buffer (ptr + pos32_3) (Int32.to_int (Int32.shift_right v 24));
return ())
let write_int64 oc v =
write_block_unsafe oc 8
(fun buffer ptr ->
set buffer (ptr + pos64_0) (Int64.to_int v);
set buffer (ptr + pos64_1) (Int64.to_int (Int64.shift_right v 8));
set buffer (ptr + pos64_2) (Int64.to_int (Int64.shift_right v 16));
set buffer (ptr + pos64_3) (Int64.to_int (Int64.shift_right v 24));
set buffer (ptr + pos64_4) (Int64.to_int (Int64.shift_right v 32));
set buffer (ptr + pos64_5) (Int64.to_int (Int64.shift_right v 40));
set buffer (ptr + pos64_6) (Int64.to_int (Int64.shift_right v 48));
set buffer (ptr + pos64_7) (Int64.to_int (Int64.shift_right v 56));
return ())
let write_float32 oc v = write_int32 oc (Int32.bits_of_float v)
let write_float64 oc v = write_int64 oc (Int64.bits_of_float v)
end
(* +---------------------------------------------------------------+
| Random access |
+---------------------------------------------------------------+ *)
let do_seek seek pos =
lwt offset = seek pos Unix.SEEK_SET in
if offset <> pos then
raise_lwt (Failure "Lwt_io.set_position: seek failed")
else
return ()
#if ocaml_version >= (3, 13)
let set_position : type m. m _channel -> int64 -> unit Lwt.t = fun ch pos -> match ch.typ, ch.mode with
#else
let set_position ch pos = match ch.typ, ch.mode with
#endif
| Type_normal(perform_io, seek), Output ->
lwt () = flush_total ch in
lwt () = do_seek seek pos in
ch.offset <- pos;
return ()
| Type_normal(perform_io, seek), Input ->
let current = Int64.sub ch.offset (Int64.of_int (ch.max - ch.ptr)) in
if pos >= current && pos <= ch.offset then begin
ch.ptr <- ch.max - (Int64.to_int (Int64.sub ch.offset pos));
return ()
end else begin
lwt () = do_seek seek pos in
ch.offset <- pos;
ch.ptr <- 0;
ch.max <- 0;
return ()
end
| Type_bytes, _ ->
if pos < 0L || pos > Int64.of_int ch.length then
raise_lwt (Failure "Lwt_io.set_position: out of bounds")
else begin
ch.ptr <- Int64.to_int pos;
return ()
end
let length ch = match ch.typ with
| Type_normal(perform_io, seek) ->
lwt len = seek 0L Unix.SEEK_END in
lwt () = do_seek seek ch.offset in
return len
| Type_bytes ->
return (Int64.of_int ch.length)
end
(* +-----------------------------------------------------------------+
| Primitive operations |
+-----------------------------------------------------------------+ *)
let read_char wrapper =
let channel = wrapper.channel in
let ptr = channel.ptr in
(* Speed-up in case a character is available in the buffer. It
increases performances by 10x. *)
if wrapper.state = Idle && ptr < channel.max then begin
channel.ptr <- ptr + 1;
return (Lwt_bytes.unsafe_get channel.buffer ptr)
end else
primitive Primitives.read_char wrapper
let read_char_opt wrapper =
let channel = wrapper.channel in
let ptr = channel.ptr in
if wrapper.state = Idle && ptr < channel.max then begin
channel.ptr <- ptr + 1;
return (Some(Lwt_bytes.unsafe_get channel.buffer ptr))
end else
primitive Primitives.read_char_opt wrapper
let read_line ic = primitive Primitives.read_line ic
let read_line_opt ic = primitive Primitives.read_line_opt ic
let read ?count ic = primitive (fun ic -> Primitives.read count ic) ic
let read_into ic str ofs len = primitive (fun ic -> Primitives.read_into ic str ofs len) ic
let read_into_exactly ic str ofs len = primitive (fun ic -> Primitives.read_into_exactly ic str ofs len) ic
let read_value ic = primitive Primitives.read_value ic
let flush oc = primitive Primitives.flush oc
let write_char wrapper x =
let channel = wrapper.channel in
let ptr = channel.ptr in
if wrapper.state = Idle && ptr < channel.max then begin
channel.ptr <- ptr + 1;
Lwt_bytes.unsafe_set channel.buffer ptr x;
(* Fast launching of the auto flusher: *)
if not channel.auto_flushing then begin
channel.auto_flushing <- true;
ignore (auto_flush channel);
return ()
end else
return ()
end else
primitive (fun oc -> Primitives.write_char oc x) wrapper
let write oc str = primitive (fun oc -> Primitives.write oc str) oc
let write_line oc x = primitive (fun oc -> Primitives.write_line oc x) oc
let write_from oc str ofs len = primitive (fun oc -> Primitives.write_from oc str ofs len) oc
let write_from_exactly oc str ofs len = primitive (fun oc -> Primitives.write_from_exactly oc str ofs len) oc
let write_value oc ?flags x = primitive (fun oc -> Primitives.write_value oc ?flags x) oc
let block ch size f = primitive (fun ch -> Primitives.block ch size f) ch
let direct_access ch f = primitive (fun ch -> Primitives.direct_access ch f) ch
let set_position ch pos = primitive (fun ch -> Primitives.set_position ch pos) ch
let length ch = primitive Primitives.length ch
module type NumberIO = sig
val read_int : input_channel -> int Lwt.t
val read_int16 : input_channel -> int Lwt.t
val read_int32 : input_channel -> int32 Lwt.t
val read_int64 : input_channel -> int64 Lwt.t
val read_float32 : input_channel -> float Lwt.t
val read_float64 : input_channel -> float Lwt.t
val write_int : output_channel -> int -> unit Lwt.t
val write_int16 : output_channel -> int -> unit Lwt.t
val write_int32 : output_channel -> int32 -> unit Lwt.t
val write_int64 : output_channel -> int64 -> unit Lwt.t
val write_float32 : output_channel -> float -> unit Lwt.t
val write_float64 : output_channel -> float -> unit Lwt.t
end
module MakeNumberIO(ByteOrder : ByteOrder.S) =
struct
module Primitives = Primitives.MakeNumberIO(ByteOrder)
let read_int ic = primitive Primitives.read_int ic
let read_int16 ic = primitive Primitives.read_int16 ic
let read_int32 ic = primitive Primitives.read_int32 ic
let read_int64 ic = primitive Primitives.read_int64 ic
let read_float32 ic = primitive Primitives.read_float32 ic
let read_float64 ic = primitive Primitives.read_float64 ic
let write_int oc x = primitive (fun oc -> Primitives.write_int oc x) oc
let write_int16 oc x = primitive (fun oc -> Primitives.write_int16 oc x) oc
let write_int32 oc x = primitive (fun oc -> Primitives.write_int32 oc x) oc
let write_int64 oc x = primitive (fun oc -> Primitives.write_int64 oc x) oc
let write_float32 oc x = primitive (fun oc -> Primitives.write_float32 oc x) oc
let write_float64 oc x = primitive (fun oc -> Primitives.write_float64 oc x) oc
end
module LE = MakeNumberIO(ByteOrder.LE)
module BE = MakeNumberIO(ByteOrder.BE)
type byte_order = Lwt_sys.byte_order = Little_endian | Big_endian
let system_byte_order = Lwt_sys.byte_order
include (val (match system_byte_order with
| Little_endian -> (module LE : NumberIO)
| Big_endian -> (module BE : NumberIO)) : NumberIO)
(* +-----------------------------------------------------------------+
| Other |
+-----------------------------------------------------------------+ *)
let read_chars ic = Lwt_stream.from (fun _ -> read_char_opt ic)
let write_chars oc chars = Lwt_stream.iter_s (fun char -> write_char oc char) chars
let read_lines ic = Lwt_stream.from (fun _ -> read_line_opt ic)
let write_lines oc lines = Lwt_stream.iter_s (fun line -> write_line oc line) lines
let zero =
make
~mode:input
~buffer_size:min_buffer_size
(fun str ofs len -> Lwt_bytes.fill str ofs len '\x00'; return len)
let null =
make
~mode:output
~buffer_size:min_buffer_size
(fun str ofs len -> return len)
(* Do not close standard ios on close, otherwise uncaught exceptions
will not be printed *)
let stdin = of_fd ~mode:input Lwt_unix.stdin
let stdout = of_fd ~mode:output Lwt_unix.stdout
let stderr = of_fd ~mode:output Lwt_unix.stderr
let fprint oc txt = write oc txt
let fprintl oc txt = write_line oc txt
let fprintf oc fmt = Printf.ksprintf (fun txt -> write oc txt) fmt
let fprintlf oc fmt = Printf.ksprintf (fun txt -> write_line oc txt) fmt
let print txt = write stdout txt
let printl txt = write_line stdout txt
let printf fmt = Printf.ksprintf print fmt
let printlf fmt = Printf.ksprintf printl fmt
let eprint txt = write stderr txt
let eprintl txt = write_line stderr txt
let eprintf fmt = Printf.ksprintf eprint fmt
let eprintlf fmt = Printf.ksprintf eprintl fmt
let pipe ?buffer_size _ =
let fd_r, fd_w = Lwt_unix.pipe () in
(of_fd ?buffer_size ~mode:input fd_r, of_fd ?buffer_size ~mode:output fd_w)
type file_name = string
#if ocaml_version >= (3, 13)
let open_file : type m. ?buffer_size : int -> ?flags : Unix.open_flag list -> ?perm : Unix.file_perm -> mode : m mode -> file_name -> m channel Lwt.t = fun ?buffer_size ?flags ?perm ~mode filename ->
#else
let open_file ?buffer_size ?flags ?perm ~mode filename =
#endif
let flags = match flags, mode with
| Some l, _ ->
l
| None, Input ->
[Unix.O_RDONLY; Unix.O_NONBLOCK]
| None, Output ->
[Unix.O_WRONLY; Unix.O_CREAT; Unix.O_TRUNC; Unix.O_NONBLOCK]
and perm = match perm, mode with
| Some p, _ ->
p
| None, Input ->
0
| None, Output ->
0o666
in
lwt fd = Lwt_unix.openfile filename flags perm in
return (of_fd ?buffer_size ~mode fd)
let with_file ?buffer_size ?flags ?perm ~mode filename f =
lwt ic = open_file ?buffer_size ?flags ?perm ~mode filename in
try_lwt
f ic
finally
close ic
let file_length filename = with_file ~mode:input filename length
let open_connection ?buffer_size sockaddr =
let fd = Lwt_unix.socket (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0 in
let close = lazy begin
try_lwt
Lwt_unix.shutdown fd Unix.SHUTDOWN_ALL;
return ()
with Unix.Unix_error(Unix.ENOTCONN, _, _) ->
(* This may happen if the server closed the connection before us *)
return ()
finally
Lwt_unix.close fd
end in
try_lwt
lwt () = Lwt_unix.connect fd sockaddr in
(try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ());
return (make ?buffer_size
~close:(fun _ -> Lazy.force close)
~mode:input (Lwt_bytes.read fd),
make ?buffer_size
~close:(fun _ -> Lazy.force close)
~mode:output (Lwt_bytes.write fd))
with exn ->
lwt () = Lwt_unix.close fd in
raise_lwt exn
let with_connection ?buffer_size sockaddr f =
lwt ic, oc = open_connection ?buffer_size sockaddr in
try_lwt
f (ic, oc)
finally
close ic <&> close oc
type server = {
shutdown : unit Lazy.t;
}
let shutdown_server server = Lazy.force server.shutdown
let establish_server ?buffer_size ?(backlog=5) sockaddr f =
let sock = Lwt_unix.socket (Unix.domain_of_sockaddr sockaddr) Unix.SOCK_STREAM 0 in
Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true;
Lwt_unix.bind sock sockaddr;
Lwt_unix.listen sock backlog;
let abort_waiter, abort_wakener = wait () in
let abort_waiter = abort_waiter >> return `Shutdown in
let rec loop () =
pick [Lwt_unix.accept sock >|= (fun x -> `Accept x); abort_waiter] >>= function
| `Accept(fd, addr) ->
(try Lwt_unix.set_close_on_exec fd with Invalid_argument _ -> ());
let close = lazy begin
Lwt_unix.shutdown fd Unix.SHUTDOWN_ALL;
Lwt_unix.close fd
end in
f (of_fd ?buffer_size ~mode:input ~close:(fun () -> Lazy.force close) fd,
of_fd ?buffer_size ~mode:output ~close:(fun () -> Lazy.force close) fd);
loop ()
| `Shutdown ->
lwt () = Lwt_unix.close sock in
match sockaddr with
| Unix.ADDR_UNIX path when path <> "" && path.[0] <> '\x00' ->
Unix.unlink path;
return ()
| _ ->
return ()
in
ignore (loop ());
{ shutdown = lazy(wakeup abort_wakener `Shutdown) }
let ignore_close ch =
ignore (close ch)
let make_stream f lazy_ic =
let lazy_ic =
lazy(lwt ic = Lazy.force lazy_ic in
Gc.finalise ignore_close ic;
return ic)
in
Lwt_stream.from (fun _ ->
lwt ic = Lazy.force lazy_ic in
lwt x = f ic in
if x = None then
lwt () = close ic in
return x
else
return x)
let lines_of_file filename =
make_stream read_line_opt (lazy(open_file ~mode:input filename))
let lines_to_file filename lines =
with_file ~mode:output filename (fun oc -> write_lines oc lines)
let chars_of_file filename =
make_stream read_char_opt (lazy(open_file ~mode:input filename))
let chars_to_file filename chars =
with_file ~mode:output filename (fun oc -> write_chars oc chars)
let hexdump_stream oc stream = write_lines oc (Lwt_stream.hexdump stream)
let hexdump oc buf = hexdump_stream oc (Lwt_stream.of_string buf)
let set_default_buffer_size size =
check_buffer_size "set_default_buffer_size" size;
default_buffer_size := size
let default_buffer_size _ = !default_buffer_size