176 lines
4.7 KiB
OCaml
176 lines
4.7 KiB
OCaml
(* Lightweight thread library for Objective Caml
|
|
* http://www.ocsigen.org/lwt
|
|
* Module Lwt_ssl
|
|
* Copyright (C) 2005-2008 Jérôme Vouillon
|
|
* Laboratoire PPS - CNRS Université Paris Diderot
|
|
*
|
|
* This program is free software; you can redistribute it and/or modify
|
|
* it under the terms of the GNU Lesser General Public License as
|
|
* published by the Free Software Foundation, with linking exceptions;
|
|
* either version 2.1 of the License, or (at your option) any later
|
|
* version. See COPYING file for details.
|
|
*
|
|
* This program is distributed in the hope that it will be useful, but
|
|
* WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
* Lesser General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Lesser General Public
|
|
* License along with this program; if not, write to the Free Software
|
|
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
|
|
* 02111-1307, USA.
|
|
*)
|
|
|
|
type t =
|
|
Plain
|
|
| SSL of Ssl.socket
|
|
|
|
type socket = Lwt_unix.file_descr * t
|
|
|
|
let is_ssl s =
|
|
match snd s with
|
|
Plain -> false
|
|
| _ -> true
|
|
|
|
let wrap_call f () =
|
|
try
|
|
f ()
|
|
with
|
|
(Ssl.Connection_error err | Ssl.Accept_error err |
|
|
Ssl.Read_error err | Ssl.Write_error err) as e ->
|
|
match err with
|
|
Ssl.Error_want_read ->
|
|
raise Lwt_unix.Retry_read
|
|
| Ssl.Error_want_write ->
|
|
raise Lwt_unix.Retry_write
|
|
| _ ->
|
|
raise e
|
|
|
|
let repeat_call fd f =
|
|
try
|
|
Lwt_unix.check_descriptor fd;
|
|
Lwt.return (wrap_call f ())
|
|
with
|
|
Lwt_unix.Retry_read ->
|
|
Lwt_unix.register_action Lwt_unix.Read fd (wrap_call f)
|
|
| Lwt_unix.Retry_write ->
|
|
Lwt_unix.register_action Lwt_unix.Write fd (wrap_call f)
|
|
| e ->
|
|
raise_lwt e
|
|
|
|
(****)
|
|
|
|
let plain fd = (fd, Plain)
|
|
|
|
let embed_socket fd context = (fd, SSL(Ssl.embed_socket (Lwt_unix.unix_file_descr fd) context))
|
|
|
|
let ssl_accept fd ctx =
|
|
let socket = Ssl.embed_socket (Lwt_unix.unix_file_descr fd) ctx in
|
|
Lwt.bind
|
|
(repeat_call fd (fun () -> Ssl.accept socket)) (fun () ->
|
|
Lwt.return (fd, SSL socket))
|
|
|
|
let ssl_connect fd ctx =
|
|
let socket = Ssl.embed_socket (Lwt_unix.unix_file_descr fd) ctx in
|
|
Lwt.bind
|
|
(repeat_call fd (fun () -> Ssl.connect socket)) (fun () ->
|
|
Lwt.return (fd, SSL socket))
|
|
|
|
let read (fd, s) buf pos len =
|
|
match s with
|
|
| Plain ->
|
|
Lwt_unix.read fd buf pos len
|
|
| SSL s ->
|
|
if len = 0 then
|
|
Lwt.return 0
|
|
else
|
|
repeat_call fd
|
|
(fun () ->
|
|
try
|
|
Ssl.read s buf pos len
|
|
with Ssl.Read_error Ssl.Error_zero_return ->
|
|
0)
|
|
|
|
let read_bytes (fd, s) buf pos len =
|
|
match s with
|
|
| Plain ->
|
|
Lwt_bytes.read fd buf pos len
|
|
| SSL s ->
|
|
if len = 0 then
|
|
Lwt.return 0
|
|
else
|
|
repeat_call fd
|
|
(fun () ->
|
|
try
|
|
let str = String.create len in
|
|
let n = Ssl.read s str 0 len in
|
|
Lwt_bytes.blit_string_bytes str 0 buf pos len;
|
|
n
|
|
with Ssl.Read_error Ssl.Error_zero_return ->
|
|
0)
|
|
|
|
let write (fd, s) buf pos len =
|
|
match s with
|
|
| Plain ->
|
|
Lwt_unix.write fd buf pos len
|
|
| SSL s ->
|
|
if len = 0 then
|
|
Lwt.return 0
|
|
else
|
|
repeat_call fd
|
|
(fun () ->
|
|
Ssl.write s buf pos len)
|
|
|
|
let write_bytes (fd, s) buf pos len =
|
|
match s with
|
|
| Plain ->
|
|
Lwt_bytes.write fd buf pos len
|
|
| SSL s ->
|
|
if len = 0 then
|
|
Lwt.return 0
|
|
else
|
|
repeat_call fd
|
|
(fun () ->
|
|
let str = String.create len in
|
|
Lwt_bytes.blit_bytes_string buf pos str 0 len;
|
|
Ssl.write s str 0 len)
|
|
|
|
let wait_read (fd, s) =
|
|
match s with
|
|
Plain -> Lwt_unix.wait_read fd
|
|
| SSL _ -> Lwt_unix.yield ()
|
|
|
|
let wait_write (fd, s) =
|
|
match s with
|
|
Plain -> Lwt_unix.wait_write fd
|
|
| SSL _ -> Lwt_unix.yield ()
|
|
|
|
let out_channel_of_descr s =
|
|
Lwt_io.make ~mode:Lwt_io.output (fun buf pos len -> write_bytes s buf pos len)
|
|
|
|
let in_channel_of_descr s =
|
|
Lwt_io.make ~mode:Lwt_io.input (fun buf pos len -> read_bytes s buf pos len)
|
|
|
|
let ssl_shutdown (fd, s) =
|
|
match s with
|
|
Plain -> Lwt.return ()
|
|
| SSL s -> repeat_call fd (fun () -> Ssl.shutdown s)
|
|
|
|
let shutdown (fd, _) cmd = Lwt_unix.shutdown fd cmd
|
|
|
|
let close (fd, _) = Lwt_unix.close fd
|
|
|
|
let abort (fd, _) = Lwt_unix.abort fd
|
|
|
|
let get_fd (fd,socket) =
|
|
match socket with
|
|
| Plain -> Lwt_unix.unix_file_descr fd
|
|
| SSL socket -> (Ssl.file_descr_of_socket socket)
|
|
|
|
let getsockname s =
|
|
Unix.getsockname (get_fd s)
|
|
|
|
let getpeername s =
|
|
Unix.getpeername (get_fd s)
|
|
|