hop-2012/server/thirdparty/lwt-2.3.2/src/ssl/lwt_ssl.ml

176 lines
4.7 KiB
OCaml

(* Lightweight thread library for Objective Caml
* http://www.ocsigen.org/lwt
* Module Lwt_ssl
* Copyright (C) 2005-2008 Jérôme Vouillon
* Laboratoire PPS - CNRS Université Paris Diderot
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation, with linking exceptions;
* either version 2.1 of the License, or (at your option) any later
* version. See COPYING file for details.
*
* This program is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
* 02111-1307, USA.
*)
type t =
Plain
| SSL of Ssl.socket
type socket = Lwt_unix.file_descr * t
let is_ssl s =
match snd s with
Plain -> false
| _ -> true
let wrap_call f () =
try
f ()
with
(Ssl.Connection_error err | Ssl.Accept_error err |
Ssl.Read_error err | Ssl.Write_error err) as e ->
match err with
Ssl.Error_want_read ->
raise Lwt_unix.Retry_read
| Ssl.Error_want_write ->
raise Lwt_unix.Retry_write
| _ ->
raise e
let repeat_call fd f =
try
Lwt_unix.check_descriptor fd;
Lwt.return (wrap_call f ())
with
Lwt_unix.Retry_read ->
Lwt_unix.register_action Lwt_unix.Read fd (wrap_call f)
| Lwt_unix.Retry_write ->
Lwt_unix.register_action Lwt_unix.Write fd (wrap_call f)
| e ->
raise_lwt e
(****)
let plain fd = (fd, Plain)
let embed_socket fd context = (fd, SSL(Ssl.embed_socket (Lwt_unix.unix_file_descr fd) context))
let ssl_accept fd ctx =
let socket = Ssl.embed_socket (Lwt_unix.unix_file_descr fd) ctx in
Lwt.bind
(repeat_call fd (fun () -> Ssl.accept socket)) (fun () ->
Lwt.return (fd, SSL socket))
let ssl_connect fd ctx =
let socket = Ssl.embed_socket (Lwt_unix.unix_file_descr fd) ctx in
Lwt.bind
(repeat_call fd (fun () -> Ssl.connect socket)) (fun () ->
Lwt.return (fd, SSL socket))
let read (fd, s) buf pos len =
match s with
| Plain ->
Lwt_unix.read fd buf pos len
| SSL s ->
if len = 0 then
Lwt.return 0
else
repeat_call fd
(fun () ->
try
Ssl.read s buf pos len
with Ssl.Read_error Ssl.Error_zero_return ->
0)
let read_bytes (fd, s) buf pos len =
match s with
| Plain ->
Lwt_bytes.read fd buf pos len
| SSL s ->
if len = 0 then
Lwt.return 0
else
repeat_call fd
(fun () ->
try
let str = String.create len in
let n = Ssl.read s str 0 len in
Lwt_bytes.blit_string_bytes str 0 buf pos len;
n
with Ssl.Read_error Ssl.Error_zero_return ->
0)
let write (fd, s) buf pos len =
match s with
| Plain ->
Lwt_unix.write fd buf pos len
| SSL s ->
if len = 0 then
Lwt.return 0
else
repeat_call fd
(fun () ->
Ssl.write s buf pos len)
let write_bytes (fd, s) buf pos len =
match s with
| Plain ->
Lwt_bytes.write fd buf pos len
| SSL s ->
if len = 0 then
Lwt.return 0
else
repeat_call fd
(fun () ->
let str = String.create len in
Lwt_bytes.blit_bytes_string buf pos str 0 len;
Ssl.write s str 0 len)
let wait_read (fd, s) =
match s with
Plain -> Lwt_unix.wait_read fd
| SSL _ -> Lwt_unix.yield ()
let wait_write (fd, s) =
match s with
Plain -> Lwt_unix.wait_write fd
| SSL _ -> Lwt_unix.yield ()
let out_channel_of_descr s =
Lwt_io.make ~mode:Lwt_io.output (fun buf pos len -> write_bytes s buf pos len)
let in_channel_of_descr s =
Lwt_io.make ~mode:Lwt_io.input (fun buf pos len -> read_bytes s buf pos len)
let ssl_shutdown (fd, s) =
match s with
Plain -> Lwt.return ()
| SSL s -> repeat_call fd (fun () -> Ssl.shutdown s)
let shutdown (fd, _) cmd = Lwt_unix.shutdown fd cmd
let close (fd, _) = Lwt_unix.close fd
let abort (fd, _) = Lwt_unix.abort fd
let get_fd (fd,socket) =
match socket with
| Plain -> Lwt_unix.unix_file_descr fd
| SSL socket -> (Ssl.file_descr_of_socket socket)
let getsockname s =
Unix.getsockname (get_fd s)
let getpeername s =
Unix.getpeername (get_fd s)