(* PG'OCaml is a set of OCaml bindings for the PostgreSQL database.
 *
 * PG'OCaml - type safe interface to PostgreSQL.
 * Copyright (C) 2005-2009 Richard Jones and other authors.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 *)

open PGOCaml_aux
open CalendarLib
open Printf

module type THREAD = sig
  type 'a t
  val return : 'a -> 'a t
  val (>>=) : 'a t -> ('a -> 'b t) -> 'b t
  val fail : exn -> 'a t
  val catch : (unit -> 'a t) -> (exn -> 'a t) -> 'a t

  type in_channel
  type out_channel
  val open_connection : Unix.sockaddr -> (in_channel * out_channel) t
  val output_char : out_channel -> char -> unit t
  val output_binary_int : out_channel -> int -> unit t
  val output_string : out_channel -> string -> unit t
  val flush : out_channel -> unit t
  val input_char : in_channel -> char t
  val input_binary_int : in_channel -> int t
  val really_input : in_channel -> Bytes.t -> int -> int -> unit t
  val close_in : in_channel -> unit t
end

module type PGOCAML_GENERIC =
sig

type 'a t				(** Database handle. *)

type 'a monad

type isolation = [ `Serializable | `Repeatable_read | `Read_committed | `Read_uncommitted ]

type access = [ `Read_write | `Read_only ]

exception Error of string
(** For library errors. *)

exception PostgreSQL_Error of string * (char * string) list
(** For errors generated by the PostgreSQL database back-end.  The
  * first argument is a printable error message.  The second argument
  * is the complete set of error fields returned from the back-end.
  * See [http://www.postgresql.org/docs/8.1/static/protocol-error-fields.html]
  *)

(** {6 Connection management} *)

type connection_desc = {
  user: string;
  port: int;
  password: string;
  host: [ `Hostname of string | `Unix_domain_socket_dir of string];
  database: string
}

val describe_connection : ?host:string -> ?port:int -> ?user:string -> ?password:string -> ?database:string -> ?unix_domain_socket_dir:string -> unit -> connection_desc
(** Produce the actual, concrete connection parameters based on the values and
  * availability of the various configuration variables.
  *)

val connection_desc_to_string : connection_desc -> string
(** Produce a human-readable textual representation of a concrete connection
  * descriptor (the password is NOT included in the output of this function)
  * for logging and error reporting purposes.
  *)

val connect : ?host:string -> ?port:int -> ?user:string -> ?password:string -> ?database:string -> ?unix_domain_socket_dir:string -> ?desc:connection_desc -> unit -> 'a t monad
(** Connect to the database.  The normal [$PGDATABASE], etc. environment
  * variables are available.
  *)

val close : 'a t -> unit monad
(** Close the database handle.  You must call this after you have
  * finished with the handle, or else you will get leaked file
  * descriptors.
  *)

val ping : 'a t -> unit monad
(** Ping the database.  If the database is not available, some sort of
  * exception will be thrown.
  *)

val alive : 'a t -> bool monad
(** This function is a wrapper of [ping] that returns a boolean instead of
  * raising an exception.
  *)

(** {6 Transactions} *)

val begin_work : ?isolation:isolation -> ?access:access -> ?deferrable:bool -> 'a t -> unit monad
(** Start a transaction. *)

val commit : 'a t -> unit monad
(** Perform a COMMIT operation on the database. *)

val rollback : 'a t -> unit monad
(** Perform a ROLLBACK operation on the database. *)

val transact :
  'a t ->
  ?isolation:isolation ->
  ?access:access ->
  ?deferrable:bool ->
  ('a t -> 'b monad) ->
  'b monad
(** [transact db ?isolation ?access ?deferrable f] wraps your
  * function [f] inside a transactional block.
  * First it calls [begin_work] with [isolation], [access] and [deferrable],
  * then calls [f] and do [rollback] if [f] raises
  * an exception, [commit] otherwise.
  *)

(** {6 Serial column} *)

val serial : 'a t -> string -> int64 monad
(** This is a shorthand for [SELECT CURRVAL(serial)].  For a table
  * called [table] with serial column [id] you would typically
  * call this as [serial dbh "table_id_seq"] after the previous INSERT
  * operation to get the serial number of the inserted row.
  *)

val serial4 : 'a t -> string -> int32 monad
(** As {!serial} but assumes that the column is a SERIAL or
  * SERIAL4 type.
  *)

val serial8 : 'a t -> string -> int64 monad
(** Same as {!serial}.
  *)

(** {6 Miscellaneous} *)

val max_message_length : int ref
(** Maximum message length accepted from the back-end.  The default
  * is [Sys.max_string_length], which means that we will try to read as
  * much data from the back-end as we can, and this may cause us to
  * run out of memory (particularly on 64 bit machines), causing a
  * possible denial of service.  You may want to set this to a smaller
  * size to avoid this happening.
  *)

val verbose : int ref
(** Verbosity.  0 means don't print anything.  1 means print short
  * error messages as returned from the back-end.  2 means print all
  * messages as returned from the back-end.  Messages are printed on [stderr].
  * Default verbosity level is 1.
  *)

val set_private_data : 'a t -> 'a -> unit
(** Attach some private data to the database handle.
  *
  * NB. The pa_pgsql camlp4 extension uses this for its own purposes, which
  * means that in most programs you will not be able to attach private data
  * to the database handle.
  *)

val private_data : 'a t -> 'a
(** Retrieve some private data previously attached to the database handle.
  * If no data has been attached, raises [Not_found].
  *
  * NB. The pa_pgsql camlp4 extension uses this for its own purposes, which
  * means that in most programs you will not be able to attach private data
  * to the database handle.
  *)

val uuid : 'a t -> string

type pa_pg_data = (string, bool) Hashtbl.t
(** When using pa_pgsql, database handles have type
  * [PGOCaml.pa_pg_data PGOCaml.t]
  *)

(** {6 Low level query interface - DO NOT USE DIRECTLY} *)

type oid = int32 [@@deriving show]

type param = string option (* None is NULL. *)
type result = string option (* None is NULL. *)
type row = result list (* One row is a list of fields. *)

val prepare : 'a t -> query:string -> ?name:string -> ?types:oid list -> unit -> unit monad
(** [prepare conn ~query ?name ?types ()] prepares the statement [query]
  * and optionally names it [name] and sets the parameter types to [types].
  * If no name is given, then the "unnamed" statement is overwritten.  If
  * no types are given, then the PostgreSQL engine infers types.
  * Synchronously checks for errors.
  *)

val execute_rev : 'a t -> ?name:string -> ?portal:string -> params:param list -> unit -> row list monad
val execute : 'a t -> ?name:string -> ?portal:string -> params:param list -> unit -> row list monad
(** [execute conn ?name ~params ()] executes the named or unnamed
  * statement [name], with the given parameters [params],
  * returning the result rows (if any).
  *
  * There are several steps involved at the protocol layer:
  * (1) a "portal" is created from the statement, binding the
  * parameters in the statement (Bind).
  * (2) the portal is executed (Execute).
  * (3) we synchronise the connection (Sync).
  *
  * The optional [?portal] parameter may be used to name the portal
  * created in step (1) above (otherwise the unnamed portal is used).
  * This is only important if you want to call {!PGOCaml.describe_portal}
  * to find out the result types.
  *)

val cursor : 'a t -> ?name:string -> ?portal:string -> params:param list -> (row -> unit monad) -> unit monad

val close_statement : 'a t -> ?name:string -> unit -> unit monad
(** [close_statement conn ?name ()] closes a prepared statement and frees
  * up any resources.
  *)

val close_portal : 'a t -> ?portal:string -> unit -> unit monad
(** [close_portal conn ?portal ()] closes a portal and frees up any resources.
  *)

val inject : 'a t -> ?name:string -> string -> row list monad
(** [inject conn ?name query] executes the statement [query]
  * and optionally names it [name] and gives the result.
  *)

val alter : 'a t -> ?name:string -> string -> unit monad
(** [alter conn ?name query] executes the statement [query]
  * and optionally names it [name]. Same as inject but ignoring the result.
  *)

type result_description = {
  name : string;			(** Field name. *)
  table : oid option;			(** OID of table. *)
  column : int option;			(** Column number of field in table. *)
  field_type : oid;			(** The type of the field. *)
  length : int;				(** Length of the field. *)
  modifier : int32;			(** Type modifier. *)
}[@@deriving show]
type row_description = result_description list [@@deriving show]

type param_description = {
  param_type : oid;			(** The type of the parameter. *)
}
type params_description = param_description list

val describe_statement : 'a t -> ?name:string -> unit -> (params_description * row_description option) monad
(** [describe_statement conn ?name ()] describes the named or unnamed
  * statement's parameter types and result types.
  *)

val describe_portal : 'a t -> ?portal:string -> unit -> row_description option monad
(** [describe_portal conn ?portal ()] describes the named or unnamed
  * portal's result types.
  *)

(** {6 Low level type conversion functions - DO NOT USE DIRECTLY} *)

val name_of_type : oid -> string
(** Returns the OCaml equivalent type name to the PostgreSQL type [oid].
  * For instance, [name_of_type (Int32.of_int 23)] returns ["int32"] because
  * the OID for PostgreSQL's internal [int4] type is [23].  As another
  * example, [name_of_type (Int32.of_int 25)] returns ["string"].
  *)

type inet = Unix.inet_addr * int
type timestamptz = Calendar.t * Time_Zone.t
type int16 = int
type bytea = string (* XXX *)
type point = float * float
type hstore = (string * string option) list
type numeric = string
type uuid = string
type jsonb = string

type bool_array = bool option list
type int16_array = int16 option list
type int32_array = int32 option list
type int64_array = int64 option list
type string_array = string option list
type float_array = float option list
type timestamp_array = Calendar.t option list
type uuid_array = string option list

(** The following conversion functions are used by pa_pgsql to convert
  * values in and out of the database.
  *)

val string_of_oid : oid -> string
val string_of_bool : bool -> string
val string_of_int : int -> string
val string_of_int16 : int16 -> string
val string_of_int32 : int32 -> string
val string_of_int64 : int64 -> string
val string_of_float : float -> string
val string_of_point : point -> string
val string_of_hstore : hstore -> string
val string_of_numeric : numeric -> string
val string_of_uuid : uuid -> string
val string_of_jsonb : jsonb -> string
val string_of_inet : inet -> string
val string_of_timestamp : Calendar.t -> string
val string_of_timestamptz : timestamptz -> string
val string_of_date : Date.t -> string
val string_of_time : Time.t -> string
val string_of_interval : Calendar.Period.t -> string
val string_of_bytea : bytea -> string
val string_of_string : string -> string
val string_of_unit : unit -> string

val string_of_bool_array : bool_array -> string
val string_of_int16_array : int16_array -> string
val string_of_int32_array : int32_array -> string
val string_of_int64_array : int64_array -> string
val string_of_string_array : string_array -> string
val string_of_bytea_array : string_array -> string
val string_of_float_array : float_array -> string
val string_of_timestamp_array : timestamp_array -> string
val string_of_arbitrary_array : ('a -> string) -> 'a option list -> string
val string_of_uuid_array : string_array -> string

val comment_src_loc : unit -> bool

val find_custom_typconvs
  :  ?typnam:string
  -> ?lookin:string
  -> ?colnam:string
  -> ?argnam:string
  -> unit
  -> ((string * string) option, string) Rresult.result

val oid_of_string : string -> oid
val bool_of_string : string -> bool
val int_of_string : string -> int
val int16_of_string : string -> int16
val int32_of_string : string -> int32
val int64_of_string : string -> int64
val float_of_string : string -> float
val point_of_string : string -> point
val hstore_of_string : string -> hstore
val numeric_of_string : string -> numeric
val uuid_of_string : string -> uuid
val jsonb_of_string : string -> jsonb
val inet_of_string : string -> inet
val timestamp_of_string : string -> Calendar.t
val timestamptz_of_string : string -> timestamptz
val date_of_string : string -> Date.t
val time_of_string : string -> Time.t
val interval_of_string : string -> Calendar.Period.t
val bytea_of_string : string -> bytea
val unit_of_string : string -> unit

val bool_array_of_string : string -> bool_array
val int16_array_of_string : string -> int16_array
val int32_array_of_string : string -> int32_array
val int64_array_of_string : string -> int64_array
val string_array_of_string : string -> string_array
val float_array_of_string : string -> float_array
val timestamp_array_of_string : string -> timestamp_array
val arbitrary_array_of_string : (string -> 'a) -> string -> 'a option list

val bind : 'a monad -> ('a -> 'b monad) -> 'b monad
val return : 'a -> 'a monad
end


module Make (Thread : THREAD) = struct

open Thread

type connection_desc = {
  user: string;
  port: int;
  password: string;
  host: [ `Hostname of string | `Unix_domain_socket_dir of string];
  database: string
}

type 'a t = {
  ichan : in_channel;			(* In_channel wrapping socket. *)
  chan : out_channel;			(* Out_channel wrapping socket. *)
  mutable private_data : 'a option;
  uuid : string;			(* UUID for this connection. *)
}

type 'a monad = 'a Thread.t

type isolation = [ `Serializable | `Repeatable_read | `Read_committed | `Read_uncommitted ]

type access = [ `Read_write | `Read_only ]

exception Error of string

exception PostgreSQL_Error of string * (char * string) list

(* If true, emit a lot of debugging information about the protocol on stderr.*)
let debug_protocol = false

(*----- Code to generate messages for the back-end. -----*)

let new_message typ =
  let buf = Buffer.create 128 in
  buf, Some typ

(* StartUpMessage and SSLRequest are special messages which don't
 * have a type byte field.
 *)
let new_start_message () =
  let buf = Buffer.create 128 in
  buf, None

let add_byte (buf, _) i =
  (* Deliberately throw an exception if i isn't [0..255]. *)
  Buffer.add_char buf (Char.chr i)

let add_char (buf, _) c =
  Buffer.add_char buf c

let add_int16 (buf, _) i =
  if i < 0 || i > 65_535 then
    raise (Error "PGOCaml: int16 is outside range [0..65535].");
  Buffer.add_char buf (Char.unsafe_chr ((i lsr 8) land 0xff));
  Buffer.add_char buf (Char.unsafe_chr (i land 0xff))

let add_int32 (buf, _) i =
  let base = Int32.to_int i in
  let big = Int32.to_int (Int32.shift_right_logical i 24) in
  Buffer.add_char buf (Char.unsafe_chr (big land 0xff));
  Buffer.add_char buf (Char.unsafe_chr ((base lsr 16) land 0xff));
  Buffer.add_char buf (Char.unsafe_chr ((base lsr 8) land 0xff));
  Buffer.add_char buf (Char.unsafe_chr (base land 0xff))

let add_string_no_trailing_nil (buf, _) str =
  (* Check the string doesn't contain '\0' characters. *)
  if String.contains str '\000' then
    raise (Error (sprintf "PGOCaml: string contains ASCII NIL character: %S" str));
  if String.length str > 0x3fff_ffff then
    raise (Error "PGOCaml: string is too long.");
  Buffer.add_string buf str

let add_string msg str =
  add_string_no_trailing_nil msg str;
  add_byte msg 0

let send_message { chan; _ } (buf, typ) =
  (* Get the length in bytes. *)
  let len = 4 + Buffer.length buf in

  (* If the length is longer than a 31 bit integer, then the message is
   * too long to send.  This limits messages to 1 GB, which should be
   * enough for anyone :-)
   *)
  if Int64.of_int len >= 0x4000_0000L then
    raise (Error "PGOCaml: message is larger than 1 GB");

  if debug_protocol then
    eprintf "> %s%d %S\n%!"
      (match typ with
       | None -> ""
       | Some c -> sprintf "%c " c)
      len (Buffer.contents buf);

  (* Write the type byte? *)
  (match typ with
   | None -> Thread.return ()
   | Some c -> output_char chan c
  ) >>= fun () ->

  (* Write the length field. *)
  output_binary_int chan len >>= fun () ->

  (* Write the buffer. *)
  output_string chan (Buffer.contents buf)

(* Max message length accepted from back-end. *)
let max_message_length = ref Sys.max_string_length

(* Receive a single result message.  Parse out the message type,
 * message length, and binary message content.
 *)
let receive_raw_message { ichan; chan; _ } =
  (* Flush output buffer. *)
  flush chan >>= fun () ->

  input_char ichan >>= fun typ ->
  input_binary_int ichan >>= fun len ->

  (* Discount the length word itself. *)
  let len = len - 4 in

  (* If the message is too long, give up now. *)
  if len > !max_message_length then (
    (* Skip the message so we stay in synch with the stream. *)
    let bufsize = 65_536 in
    let buf = Bytes.create bufsize in
    let rec loop n =
      if n > 0 then begin
        let m = min n bufsize in
        really_input ichan buf 0 m >>= fun () ->
        loop (n - m)
      end else
        return ()
    in
    loop len >>= fun () ->

    fail (Error
	     "PGOCaml: back-end message is longer than max_message_length")
  ) else (

    (* Read the binary message content. *)
    let msg = Bytes.create len in
    really_input ichan msg 0 len >>= fun () ->
    return (typ, Bytes.to_string msg)
  )

(* Parse a back-end message. *)
type msg_t =
  | AuthenticationOk
  | AuthenticationKerberosV5
  | AuthenticationCleartextPassword
  | AuthenticationCryptPassword of string
  | AuthenticationMD5Password of string
  | AuthenticationSCMCredential
  | BackendKeyData of int32 * int32
  | BindComplete
  | CloseComplete
  | CommandComplete of string
  | DataRow of (int * string) list
  | EmptyQueryResponse
  | ErrorResponse of (char * string) list
  | NoData
  | NoticeResponse of (char * string) list
  | NotificationResponse
  | ParameterDescription of int32 list
  | ParameterStatus of string * string
  | ParseComplete
  | ReadyForQuery of char
  | RowDescription of (string * int32 * int * int32 * int * int32 * int) list
  | UnknownMessage of char * string

let string_of_msg_t = function
  | AuthenticationOk -> "AuthenticationOk"
  | AuthenticationKerberosV5 -> "AuthenticationKerberosV5"
  | AuthenticationCleartextPassword -> "AuthenticationCleartextPassword"
  | AuthenticationCryptPassword str ->
      sprintf "AuthenticationCleartextPassword %S" str
  | AuthenticationMD5Password str ->
      sprintf "AuthenticationMD5Password %S" str
  | AuthenticationSCMCredential -> "AuthenticationMD5Password"
  | BackendKeyData (i1, i2) ->
      sprintf "BackendKeyData %ld, %ld" i1 i2
  | BindComplete -> "BindComplete"
  | CloseComplete -> "CloseComplete"
  | CommandComplete str ->
      sprintf "CommandComplete %S" str
  | DataRow fields ->
      sprintf "DataRow [%s]"
	(String.concat "; "
	   (List.map (fun (len, bytes) -> sprintf "%d, %S" len bytes) fields))
  | EmptyQueryResponse -> "EmptyQueryResponse"
  | ErrorResponse strs ->
      sprintf "ErrorResponse [%s]"
	(String.concat "; "
	   (List.map (fun (k, v) -> sprintf "%c, %S" k v) strs))
  | NoData -> "NoData"
  | NoticeResponse strs ->
      sprintf "NoticeResponse [%s]"
	(String.concat "; "
	   (List.map (fun (k, v) -> sprintf "%c, %S" k v) strs))
  | NotificationResponse -> "NotificationResponse"
  | ParameterDescription fields ->
      sprintf "ParameterDescription [%s]"
	(String.concat "; "
	   (List.map (fun oid -> sprintf "%ld" oid) fields))
  | ParameterStatus (s1, s2) ->
      sprintf "ParameterStatus %S, %S" s1 s2
  | ParseComplete -> "ParseComplete"
  | ReadyForQuery c ->
      sprintf "ReadyForQuery %s"
	(match c with
	 | 'I' -> "Idle"
	 | 'T' -> "inTransaction"
	 | 'E' -> "Error"
	 | c -> sprintf "unknown(%c)" c)
  | RowDescription fields ->
      sprintf "RowDescription [%s]"
	(String.concat "; "
	   (List.map (fun (name, table, col, oid, len, modifier, format) ->
			sprintf "%s %ld %d %ld %d %ld %d"
			  name table col oid len modifier format) fields))
  | UnknownMessage (typ, msg) ->
      sprintf "UnknownMessage %c, %S" typ msg

let parse_backend_message (typ, msg) =
  let pos = ref 0 in
  let len = String.length msg in

  (* Functions to grab the next object from the string 'msg'. *)
  let get_char where =
    if !pos < len then (
      let r = msg.[!pos] in
      incr pos;
      r
    ) else
      raise (Error ("PGOCaml: parse_backend_message: " ^ where ^
		    ": short message"))
  in
  let get_byte where = Char.code (get_char where) in
  let get_int16 () =
    let r0 = get_byte "get_int16" in
    let r1 = get_byte "get_int16" in
    (r0 lsr 8) + r1
  in
  let get_int32 () =
    let r0 = get_byte "get_int32" in
    let r1 = get_byte "get_int32" in
    let r2 = get_byte "get_int32" in
    let r3 = get_byte "get_int32" in
    let r = Int32.of_int r0 in
    let r = Int32.shift_left r 8 in
    let r = Int32.logor r (Int32.of_int r1) in
    let r = Int32.shift_left r 8 in
    let r = Int32.logor r (Int32.of_int r2) in
    let r = Int32.shift_left r 8 in
    let r = Int32.logor r (Int32.of_int r3) in
    r
  in
  (*let get_int64 () =
    let r0 = get_byte "get_int64" in
    let r1 = get_byte "get_int64" in
    let r2 = get_byte "get_int64" in
    let r3 = get_byte "get_int64" in
    let r4 = get_byte "get_int64" in
    let r5 = get_byte "get_int64" in
    let r6 = get_byte "get_int64" in
    let r7 = get_byte "get_int64" in
    let r = Int64.of_int r0 in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r1) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r2) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r3) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r4) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r5) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r6) in
    let r = Int64.shift_left r 8 in
    let r = Int64.logor r (Int64.of_int r7) in
    r
  in*)
  let get_string () =
    let buf = Buffer.create 16 in
    let rec loop () =
      let c = get_char "get_string" in
      if c <> '\000' then (
	Buffer.add_char buf c;
	loop ()
      ) else
	Buffer.contents buf
    in
    loop ()
  in
  let get_n_bytes n = String.init n (fun _ -> get_char "get_n_bytes") in
  let get_char () = get_char "get_char" in
  (*let get_byte () = get_byte "get_byte" in*)

  let msg =
    match typ with
    | 'R' ->
	let t = get_int32 () in
	(match t with
	 | 0l -> AuthenticationOk
	 | 2l -> AuthenticationKerberosV5
	 | 3l -> AuthenticationCleartextPassword
	 | 4l ->
	     let salt = String.init 2 (fun _ -> get_char ()) in
	     AuthenticationCryptPassword salt
	 | 5l ->
	     let salt = String.init 4 (fun _ -> get_char ()) in
	     AuthenticationMD5Password salt
	 | 6l -> AuthenticationSCMCredential
	 | _ -> UnknownMessage (typ, msg)
	);

    | 'E' ->
	let strs = ref [] in
	let rec loop () =
	  let field_type = get_char () in
	  if field_type = '\000' then List.rev !strs (* end of list *)
	  else (
	    strs := (field_type, get_string ()) :: !strs;
	    loop ()
	  )
	in
	ErrorResponse (loop ())

    | 'N' ->
	let strs = ref [] in
	let rec loop () =
	  let field_type = get_char () in
	  if field_type = '\000' then List.rev !strs (* end of list *)
	  else (
	    strs := (field_type, get_string ()) :: !strs;
	    loop ()
	  )
	in
	NoticeResponse (loop ())

    | 'A' ->
        NotificationResponse

    | 'Z' ->
	let c = get_char () in
	ReadyForQuery c

    | 'K' ->
	let pid = get_int32 () in
	let key = get_int32 () in
	BackendKeyData (pid, key)

    | 'S' ->
	let param = get_string () in
	let value = get_string () in
	ParameterStatus (param, value)

    | '1' -> ParseComplete

    | '2' -> BindComplete

    | '3' -> CloseComplete

    | 'C' ->
	let str = get_string () in
	CommandComplete str

    | 'D' ->
	let nr_fields = get_int16 () in
	let fields = ref [] in
	for _ = 0 to nr_fields-1 do
	  let len = get_int32 () in
	  let field =
	    if len < 0l then (-1, "")
	    else (
	      if len >= 0x4000_0000l then
		raise (Error "PGOCaml: result field is too long");
	      let len = Int32.to_int len in
	      if len > Sys.max_string_length then
		raise (Error "PGOCaml: result field is too wide for string");
	      let bytes = get_n_bytes len in
	      len, bytes
	    ) in
	  fields := field :: !fields
	done;
	DataRow (List.rev !fields)

    | 'I' -> EmptyQueryResponse

    | 'n' -> NoData

    | 'T' ->
	let nr_fields = get_int16 () in
	let fields = ref [] in
	for _ = 0 to nr_fields-1 do
	  let name = get_string () in
	  let table = get_int32 () in
	  let column = get_int16 () in
	  let oid = get_int32 () in
	  let length = get_int16 () in
	  let modifier = get_int32 () in
	  let format = get_int16 () in
	  fields := (name, table, column, oid, length, modifier, format)
	    :: !fields
	done;
	RowDescription (List.rev !fields)

    | 't' ->
	let nr_fields = get_int16 () in
	let fields = ref [] in
	for _ = 0 to nr_fields - 1 do
	  let oid = get_int32 () in
	  fields := oid :: !fields
	done;
	ParameterDescription (List.rev !fields)

    | _ -> UnknownMessage (typ, msg) in

  if debug_protocol then eprintf "< %s\n%!" (string_of_msg_t msg);

  msg

let rec receive_message conn =
  receive_raw_message conn >>= fun msg ->
  match parse_backend_message msg with
  | ParameterStatus _
  | NoticeResponse _
  | NotificationResponse ->
    (* Skip asynchronous messages *)
    receive_message conn
  | msg ->
    return msg

(* Send a message and expect a single result. *)
let send_recv conn msg =
  send_message conn msg >>= fun () ->
  receive_message conn

let verbose = ref 1

type severity = ERROR | FATAL | PANIC | WARNING | NOTICE | DEBUG | INFO | LOG

let get_severity fields =
  let field =
    try List.assoc 'V' fields (* introduced with PostgreSQL 9.6 *)
    with Not_found -> List.assoc 'S' fields
  in
  match field with
  | "ERROR" -> ERROR
  | "FATAL" -> FATAL
  | "PANIC" -> PANIC
  | "WARNING" -> WARNING
  | "NOTICE" -> NOTICE
  | "DEBUG" -> DEBUG
  | "INFO" -> INFO
  | "LOG" -> LOG
  | _ -> raise Not_found

let show_severity = function
  | ERROR -> "ERROR"
  | FATAL -> "FATAL"
  | PANIC -> "PANIC"
  | WARNING -> "WARNING"
  | NOTICE -> "NOTICE"
  | DEBUG -> "DEBUG"
  | INFO -> "INFO"
  | LOG -> "LOG"

(* Print an ErrorResponse on stderr. *)
let print_ErrorResponse fields =
  if !verbose >= 1 then (
    try
      let severity = try Some (get_severity fields) with Not_found -> None in
      let severity_string = match severity with
        | Some s -> show_severity s
        | None -> "UNKNOWN"
      in
      let code = List.assoc 'C' fields in
      let message = List.assoc 'M' fields in
      if !verbose = 1 then
	match severity with
	| Some ERROR | Some FATAL | Some PANIC ->
	    eprintf "%s: %s: %s\n%!" severity_string code message
	| _ -> ()
      else
	eprintf "%s: %s: %s\n%!" severity_string code message
    with
      Not_found ->
	eprintf
	  "WARNING: 'Always present' field is missing in error message\n%!"
  );
  if !verbose >= 2 then (
    List.iter (
      fun (field_type, field) ->
	if field_type <> 'S' && field_type <> 'C' && field_type <> 'M' then
	  eprintf "%c: %s\n%!" field_type field
    ) fields
  )

let sync_msg conn =
  let msg = new_message 'S' in
  send_message conn msg

(* Handle an ErrorResponse anywhere, by printing and raising an exception. *)
let pg_error ?conn fields =
  print_ErrorResponse fields;
  let str =
    try
      let severity_string =
        try show_severity @@ get_severity fields
        with Not_found -> "UNKNOWN"
      in
      let code = List.assoc 'C' fields in
      let message = List.assoc 'M' fields in
      sprintf "%s: %s: %s" severity_string code message
    with
      Not_found ->
	"WARNING: 'Always present' field is missing in error message" in

  (* If conn parameter was given, then resynch - read messages until we
   * see ReadyForQuery.
   *)
  (match conn with
   | None -> return ()
   | Some conn ->
       let rec loop () =
	 receive_message conn >>= fun msg ->
	 match msg with ReadyForQuery _ -> return () | _ -> loop ()
       in
         loop ()
  ) >>= fun () ->

  fail (PostgreSQL_Error (str, fields))

(*----- Profiling. -----*)

type 'a retexn = Ret of 'a | Exn of exn

(* profile_op :
 *   string -> string -> string list -> (unit -> 'a Thread.t) -> 'a Thread.t
 *)
let profile_op uuid op detail f =
  let chan =
    try
      let filename = Sys.getenv "PGPROFILING" in
      let flags = [ Open_wronly; Open_append; Open_creat ] in
      let chan = open_out_gen flags 0o644 filename in
      Some chan
    with
    | Not_found
    | Sys_error _ -> None in
  match chan with
  | None -> f ()			(* No profiling - just run it. *)
  | Some chan ->			(* Profiling. *)
      let start_time = Unix.gettimeofday () in
      catch
        (fun () -> f () >>= fun x -> return (Ret x))
        (fun exn -> return (Exn exn)) >>= fun ret ->
      let end_time = Unix.gettimeofday () in

      let elapsed_time_ms = int_of_float (1000. *. (end_time -. start_time)) in
      let row = [
	"1";				(* Version number. *)
	uuid;
	op;
	string_of_int elapsed_time_ms;
	match ret with
	| Ret _ -> "ok"
	| Exn exn -> Printexc.to_string exn
      ] @ detail in

      (* Lock the output channel while we write the row, to prevent
       * corruption from multiple writers.
       *)
      let fd = Unix.descr_of_out_channel chan in
      Unix.lockf fd Unix.F_LOCK 0;
      Csv.output_all (Csv.to_channel chan) [row];
      close_out chan;

      (* Return result or re-raise the exception. *)
      match ret with
      | Ret r -> return r
      | Exn exn -> fail exn

(*----- Connection. -----*)

let pgsql_socket dir port =
  let sockaddr = sprintf "%s/.s.PGSQL.%d" dir port in
  Unix.ADDR_UNIX sockaddr

let describe_connection ?host ?port ?user ?password ?database
    ?(unix_domain_socket_dir)
    () =
  (* Get the username. *)
  let user =
    match user with
    | Some user -> user
    | None ->
      try Sys.getenv "PGUSER"
      with Not_found ->
        try
          let pw = Unix.getpwuid (Unix.geteuid ()) in
          pw.Unix.pw_name
        with
          Not_found -> PGOCaml_config.default_user
  in

  (* Get the password. *)
  let password =
    match password with
    | Some password -> password
    | None ->
      try Sys.getenv "PGPASSWORD"
      with Not_found -> PGOCaml_config.default_password in

  (* Get the database name. *)
  let database =
    match database with
    | Some database -> database
    | None ->
      try Sys.getenv "PGDATABASE"
      with Not_found -> user in

  (* Get the hostname or Unix domain socket directory. *)
  let host =
    let host_or_socket s =
      if String.length s > 0 && s.[0] = '/'
      then `Unix_domain_socket_dir s
      else `Hostname s
    in
    match (host, unix_domain_socket_dir) with
    | (Some _), (Some _) ->
      raise (Failure "describe_connection: it's invalid to specify both a HOST and a unix domain socket directory")
    | (Some s), None ->
      host_or_socket s
    | None, (Some s) ->
      `Unix_domain_socket_dir s
    | None, None ->
      try
        host_or_socket (Sys.getenv "PGHOST")
      with
        Not_found -> (* fall back on Unix domain socket. *)
          `Unix_domain_socket_dir PGOCaml_config.default_unix_domain_socket_dir
  in

  (* Get the port number. *)
  let port =
    match port with
    | Some port -> port
    | None ->
	    try int_of_string (Sys.getenv "PGPORT")
	    with Not_found | Failure _ -> PGOCaml_config.default_port
  in
  { user; host; port; database; password }

(** We need to convert keys to a human-readable format for error reporting.
  *)
let connection_desc_to_string key =
  Printf.sprintf
    "host=%s, port=%s, user=%s, password=%s, database=%s"
    (match key.host with `Unix_domain_socket_dir _ -> "unix" | `Hostname s -> s)
    (string_of_int key.port)
    key.user
    "*****" (* we don't want to be dumping passwords into error logs *)
    key.database

let connect ?host ?port ?user ?password ?database ?unix_domain_socket_dir ?desc
    () =
  let { user; host; port; database; password } =
    match desc with
    | None -> describe_connection ?host ?port ?user ?password ?database ?unix_domain_socket_dir ()
    | Some desc -> desc
  in
  (* Make the socket address. *)
  let sockaddrs =
    match host with
    | `Hostname hostname ->
       let addrs = Unix.getaddrinfo hostname (sprintf "%d" port) [Unix.AI_SOCKTYPE(Unix.SOCK_STREAM)] in
       if addrs = [] then
	 raise (Error ("PGOCaml: unknown host: " ^ hostname))
       else
	 List.map (fun {Unix.ai_addr = sockaddr; _} -> sockaddr) addrs
    | `Unix_domain_socket_dir udsd -> (* Unix domain socket. *)
      [pgsql_socket udsd port] in

  (* Create a universally unique identifier for this connection.  This
   * is mainly for debugging and profiling.
   *)
  let uuid =
    (*
     * On Windows, the result of Unix.getpid is largely meaningless (it's not unique)
     * and, more importantly, Unix.getppid is not implemented.
     *)
    let ppid =
      try
        Unix.getppid ()
      with Invalid_argument _ -> 0
    in
    sprintf "%s %d %d %g %s %g"
      (Unix.gethostname ())
      (Unix.getpid ())
      ppid
      (Unix.gettimeofday ())
      Sys.executable_name
      ((Unix.times ()).Unix.tms_utime) in
  let uuid = Digest.to_hex (Digest.string uuid) in

  let sock_channels =
    let rec create_sock_channels sockaddrs =
      match sockaddrs with
	[] ->
	  raise (Error ("PGOCaml: Could not connect to database"))
      | sockaddr :: sockaddrs ->
	 catch
	    (fun () ->
	      open_connection sockaddr)
	    (function
        | Unix.Unix_error _ -> create_sock_channels sockaddrs
        | exn -> raise exn)
    in
    create_sock_channels sockaddrs in

  let do_connect () =
    sock_channels >>= fun (ichan, chan) ->
    catch (fun () ->
    (* Create the connection structure. *)
    let conn = { ichan = ichan;
		 chan = chan;
		 private_data = None;
		 uuid = uuid } in

    (* Send the StartUpMessage.  NB. At present we do not support SSL. *)
    let msg = new_start_message () in
    add_int32 msg 196608l;
    add_string msg "user"; add_string msg user;
    add_string msg "database"; add_string msg database;
    add_byte msg 0;

    (* Loop around here until the database gives a ReadyForQuery message. *)
    let rec loop msg =
      (match msg with
	| Some msg -> send_recv conn msg
	| None -> receive_message conn) >>= fun msg ->

      match msg with
      | ReadyForQuery _ -> return () (* Finished connecting! *)
      | BackendKeyData _ ->
	  (* XXX We should save this key. *)
	  loop None
      | AuthenticationOk -> loop None
      | AuthenticationKerberosV5 ->
	  fail (Error "PGOCaml: Kerberos authentication not supported")
      | AuthenticationCleartextPassword ->
	  let msg = new_message 'p' in (* PasswordMessage *)
	  add_string msg password;
	  loop (Some msg)
      | AuthenticationCryptPassword _salt ->
	  (* Crypt password not supported because there is no crypt(3) function
	   * in OCaml.
	   *)
	  fail (Error "PGOCaml: crypt password authentication not supported")
      | AuthenticationMD5Password salt ->
	  (*	(* This is a guess at how the salt is used ... *)
		let password = salt ^ password in
		let password = Digest.string password in*)
	  let password = "md5" ^ Digest.to_hex (Digest.string (Digest.to_hex (Digest.string (password ^ user)) ^ salt)) in
	  let msg = new_message 'p' in (* PasswordMessage *)
	  add_string msg password;
	  loop (Some msg)
      | AuthenticationSCMCredential ->
	  fail (Error "PGOCaml: SCM Credential authentication not supported")
      | ErrorResponse err ->
	  pg_error err
      | _ ->
	  (* Silently ignore unknown or unexpected message types. *)
	  loop None
    in
    loop (Some msg) >>= fun () ->

    return conn)
      (fun e -> close_in ichan >>= fun () -> fail e)
  in
  let detail = [
    "user"; user;
    "database"; database;
    "host"; begin match host with `Unix_domain_socket_dir _ -> "unix" | `Hostname s -> s end;
    "port"; string_of_int port;
    "prog"; Sys.executable_name
  ] in
  profile_op uuid "connect" detail do_connect

let close conn =
  let do_close () =
    catch
      (fun () ->
         (* Be nice and send the terminate message. *)
         let msg = new_message 'X' in
         send_message conn msg >>= fun () ->
         flush conn.chan >>= fun () ->
         return None)
      (fun e ->
         return (Some e)) >>= fun e ->
    (* Closes the underlying socket too. *)
    close_in conn.ichan >>= fun () ->
    match e with
    | None   -> return ()
    | Some e -> fail e
  in
  profile_op conn.uuid "close" [] do_close

let set_private_data conn data =
  conn.private_data <- Some data

let private_data { private_data; _ } =
  match private_data with
  | None -> raise Not_found
  | Some private_data -> private_data

let uuid conn = conn.uuid

type pa_pg_data = (string, bool) Hashtbl.t

let ping conn =
  let do_ping () =
    sync_msg conn >>= fun () ->

    (* Wait for ReadyForQuery. *)
    let rec loop () =
      receive_message conn >>= fun msg ->
      match msg with
      | ReadyForQuery _ -> return () (* Finished! *)
      | ErrorResponse err -> pg_error ~conn err (* Error *)
      | _ -> loop ()
    in
    loop ()
  in
  profile_op conn.uuid "ping" [] do_ping

let alive conn =
  catch
    (fun () -> ping conn >>= fun () -> return true)
    (fun _ -> return false)

type oid = int32 [@@deriving show]

type param = string option
type result = string option
type row = result list

let prepare conn ~query ?(name = "") ?(types = []) () =
  let do_prepare () =
    let msg = new_message 'P' in
    add_string msg name;
    add_string msg query;
    add_int16 msg (List.length types);
    List.iter (add_int32 msg) types;
    send_message conn msg >>= fun () ->
    sync_msg conn >>= fun () ->
    let rec loop () =
      receive_message conn >>= fun msg ->
      match msg with
      | ErrorResponse err -> pg_error ~conn err
      | ParseComplete -> loop ()
      | ReadyForQuery _ -> return () (* Finished! *)
      | _ ->
	  fail (Error ("PGOCaml: unknown response from parse: " ^
			  string_of_msg_t msg))
    in
    loop ()
  in
  let details = [ "query"; query; "name"; name ] in
  profile_op conn.uuid "prepare" details do_prepare

let iter_execute conn name portal params proc () =
    (* Bind *)
    let msg = new_message 'B' in
    add_string msg portal;
    add_string msg name;
    add_int16 msg 0; (* Send all parameters as text. *)
    add_int16 msg (List.length params);
    List.iter (
      fun param ->
	match param with
	| None -> add_int32 msg 0xffff_ffffl (* NULL *)
	| Some str ->
	    add_int32 msg (Int32.of_int (String.length str));
	    add_string_no_trailing_nil msg str
    ) params;
    add_int16 msg 0; (* Send back all results as text. *)
    send_message conn msg >>= fun () ->

    (* Execute *)
    let msg = new_message 'E' in
    add_string msg portal;
    add_int32 msg 0l; (* no limit on rows *)
    send_message conn msg >>= fun () ->

    (* Sync *)
    sync_msg conn >>= fun () ->

    (* Process the message(s) received from the database until we read
     * ReadyForQuery.  In the process we may get some rows back from
     * the database, no data, or an error.
     *)
    let rec loop () =
      (* NB: receive_message flushes the output connection. *)
      receive_message conn >>= fun msg ->
      match msg with
      | ReadyForQuery _ -> return () (* Finished! *)
      | ErrorResponse err -> pg_error ~conn err (* Error *)
      | BindComplete -> loop ()
      | CommandComplete _ -> loop ()
      | EmptyQueryResponse -> loop ()
      | DataRow fields ->
	  let fields = List.map (
	    function
	    | (i, _) when i < 0 -> None (* NULL *)
	    | (0, _) -> Some ""
	    | (_, bytes) -> Some bytes
	  ) fields in
	  proc fields >>= loop
      | NoData -> loop ()
      | _ ->
	  fail
	    (Error ("PGOCaml: unknown response message: " ^
		      string_of_msg_t msg))
    in
    loop ()

let do_execute conn name portal params rev () =
    let rows = ref [] in
    iter_execute conn name portal params
        (fun fields -> return (rows := fields :: !rows)) () >>= fun () ->
    (* Return the result rows. *)
    return (if rev then List.rev !rows else !rows)

let execute_rev conn ?(name = "") ?(portal = "") ~params () =
  let do_execute = do_execute conn name portal params false in
  let details = [ "name"; name; "portal"; portal ] in
  profile_op conn.uuid "execute" details do_execute

let execute conn ?(name = "") ?(portal = "") ~params () =
  let do_execute = do_execute conn name portal params true in
  let details = [ "name"; name; "portal"; portal ] in
  profile_op conn.uuid "execute" details do_execute

let cursor conn ?(name = "") ?(portal = "") ~params proc =
  let do_execute = iter_execute conn name portal params proc in
  let details = [ "name"; name; "portal"; portal ] in
  profile_op conn.uuid "cursor" details do_execute

let begin_work ?isolation ?access ?deferrable conn =
  let isolation_str = match isolation with
    | None -> ""
    | Some x -> " isolation level " ^ (match x with
      | `Serializable -> "serializable"
      | `Repeatable_read -> "repeatable read"
      | `Read_committed -> "read committed"
      | `Read_uncommitted -> "read uncommitted") in
  let access_str = match access with
    | None -> ""
    | Some x -> match x with
      | `Read_write -> " read write"
      | `Read_only -> " read only" in
  let deferrable_str = match deferrable with
    | None -> ""
    | Some x -> (match x with true -> "" | false -> " not") ^ " deferrable" in
  let query = "begin work" ^ isolation_str ^ access_str ^ deferrable_str in
  prepare conn ~query () >>= fun () ->
  execute conn ~params:[] () >>= fun _ ->
  return ()

let commit conn =
  let query = "commit" in
  prepare conn ~query () >>= fun () ->
  execute conn ~params:[] () >>= fun _ ->
  return ()

let rollback conn =
  let query = "rollback" in
  prepare conn ~query () >>= fun () ->
  execute conn ~params:[] () >>= fun _ ->
  return ()

let transact conn ?isolation ?access ?deferrable f =
  begin_work ?isolation ?access ?deferrable conn >>= fun () ->
  catch
    (fun () ->
       f conn >>= fun r ->
       commit conn >>= fun () ->
       return r
    )
    (fun e ->
       rollback conn >>= fun () ->
       fail e
    )

let serial conn name =
  let query = "select currval ($1)" in
  prepare conn ~query () >>= fun () ->
  execute conn ~params:[Some name] () >>= fun rows ->
  let row = List.hd rows in
  let result = List.hd row in
  (* NB. According to the manual, the return type of currval is
   * always a bigint, whether or not the column is serial or bigserial.
   *)
  return (Int64.of_string (Option.get result))

let serial4 conn name =
  serial conn name >>= fun s -> return (Int64.to_int32 s)

let serial8 = serial

let close_statement conn ?(name = "") () =
  let msg = new_message 'C' in
  add_char msg 'S';
  add_string msg name;
  send_message conn msg >>= fun () ->
  sync_msg conn >>= fun () ->
  let rec loop () =
    receive_message conn >>= fun msg ->
    match msg with
    | ErrorResponse err -> pg_error ~conn err
    | CloseComplete -> loop ()
    | ReadyForQuery _ -> return () (* Finished! *)
    | _ ->
	fail (Error ("PGOCaml: unknown response from close: " ^
			string_of_msg_t msg))
  in
  loop ()

let close_portal conn ?(portal = "") () =
  let msg = new_message 'C' in
  add_char msg 'P';
  add_string msg portal;
  send_message conn msg >>= fun () ->
  sync_msg conn >>= fun () ->
  let rec loop () =
    receive_message conn >>= fun msg ->
    match msg with
    | ErrorResponse err -> pg_error ~conn err
    | CloseComplete -> loop ()
    | ReadyForQuery _ -> return () (* Finished! *)
    | _ ->
	fail (Error ("PGOCaml: unknown response from close: " ^
			string_of_msg_t msg))
  in
  loop ()

let inject db ?name query =
  prepare db ~query ?name () >>= fun () ->
  execute db ?name ~params:[] () >>= fun ret ->
  close_statement db ?name () >>= fun () ->
  return ret

let alter db ?name query = inject db ?name query >>= fun _ -> return ()

type result_description = {
  name : string;
  table : oid option;
  column : int option;
  field_type : oid;
  length : int;
  modifier : int32;
}[@@deriving show]
type row_description = result_description list [@@deriving show]

type param_description = {
  param_type : oid;
}
type params_description = param_description list

let expect_rfq conn ret =
  receive_message conn >>= fun msg ->
  match msg with
    | ReadyForQuery _ -> return ret
    | msg -> fail @@
      Error ("PGOCaml: unknown response from describe: " ^ string_of_msg_t msg)

let describe_statement conn ?(name = "") () =
  let msg = new_message 'D' in
  add_char msg 'S';
  add_string msg name;
  send_message conn msg >>= fun () ->
  sync_msg conn >>= fun () ->
  receive_message conn >>= fun msg ->
  ( match msg with
    | ErrorResponse err -> pg_error ~conn err
    | ParameterDescription params ->
	let params = List.map (
	  fun oid ->
	    { param_type = oid }
	) params in
	return params
    | _ ->
	fail (Error ("PGOCaml: unknown response from describe: " ^
			string_of_msg_t msg))) >>= fun params ->
  receive_message conn >>= fun msg ->
  ( match msg with
  | ErrorResponse err -> pg_error ~conn err
  | NoData -> return (params, None)
  | RowDescription fields ->
      let fields = List.map (
	fun (name, table, column, oid, length, modifier, _) ->
	  {
	    name = name;
	    table = if table = 0l then None else Some table;
	    column = if column = 0 then None else Some column;
	    field_type = oid;
	    length = length;
	    modifier = modifier;
	  }
      ) fields in
      return (params, Some fields)
  | _ ->
      fail (Error ("PGOCaml: unknown response from describe: " ^
		      string_of_msg_t msg))) >>= expect_rfq conn

let describe_portal conn ?(portal = "") () =
  let msg = new_message 'D' in
  add_char msg 'P';
  add_string msg portal;
  send_message conn msg >>= fun () ->
  sync_msg conn >>= fun () ->
  receive_message conn >>= fun msg ->
  ( match msg with
  | ErrorResponse err -> pg_error ~conn err
  | NoData -> return None
  | RowDescription fields ->
      let fields = List.map (
	fun (name, table, column, oid, length, modifier, _) ->
	  {
	    name = name;
	    table = if table = 0l then None else Some table;
	    column = if column = 0 then None else Some column;
	    field_type = oid;
	    length = length;
	    modifier = modifier;
	  }
      ) fields in
      return (Some fields)
  | _ ->
      fail (Error ("PGOCaml: unknown response from describe: " ^
		      string_of_msg_t msg))) >>= expect_rfq conn

(*----- Type conversion. -----*)

(* For certain types, more information is available by looking
 * at the modifier field as well as just the OID.  For example,
 * for NUMERIC the modifier tells us the precision.
 * However we don't always have the modifier field available -
 * in particular for parameters.
 *)
let name_of_type = function
  | 16_l -> "bool"           (* BOOLEAN *)
  | 17_l -> "bytea"          (* BYTEA *)
  | 20_l -> "int64"          (* INT8 *)
  | 21_l -> "int16"          (* INT2 *)
  | 23_l -> "int32"          (* INT4 *)
  | 25_l -> "string"         (* TEXT *)
  | 114_l -> "string"        (* JSON *)
  | 119_l -> "string_array"  (* JSON[] *)
  | 600_l -> "point"         (* POINT *)
  | 700_l
  | 701_l -> "float"	     (* FLOAT4, FLOAT8 *)
  | 869_l -> "inet"          (* INET *)
  | 1000_l -> "bool_array"   (* BOOLEAN[] *)
  | 1005_l -> "int16_array"  (* INT2[] *)
  | 1001_l -> "bytea_array"   (* BYTEA[] *)
  | 1007_l -> "int32_array"  (* INT4[] *)
  | 1009_l -> "string_array" (* TEXT[] *)
  | 1014_l -> "string_array"  (* CHAR[] *)
  | 1015_l -> "string_array"  (* VARCHAR[] *)
  | 1016_l -> "int64_array"  (* INT8[] *)
  | 1021_l
  | 1022_l -> "float_array"  (* FLOAT4[], FLOAT8[] *)
  | 1042_l -> "string"       (* CHAR(n) - treat as string *)
  | 1043_l -> "string"       (* VARCHAR(n) - treat as string *)
  | 1082_l -> "date"         (* DATE *)
  | 1083_l -> "time"         (* TIME *)
  | 1114_l -> "timestamp"    (* TIMESTAMP *)
  | 1115_l -> "timestamp_array" (* TIMESTAMP[] *)
  | 1184_l -> "timestamptz"  (* TIMESTAMP WITH TIME ZONE *)
  | 1186_l -> "interval"     (* INTERVAL *)
  | 2278_l -> "unit"         (* VOID *)
  | 1700_l -> "string"       (* NUMERIC *)
  | 2950_l -> "uuid"         (* UUID *)
  | 2951_l -> "uuid_array"   (* UUID[] *)
  | 3802_l -> "string"       (* JSONB *)
  | 3807_l -> "string_array" (* JSONB[] *)
  | i ->
      (* For unknown types, look at <postgresql/catalog/pg_type.h>. *)
      raise (Error ("PGOCaml: unknown type for OID " ^ Int32.to_string i))

type inet = Unix.inet_addr * int
type timestamptz = Calendar.t * Time_Zone.t
type int16 = int
type bytea = string
type point = float * float
type hstore = (string * string option) list
type numeric = string
type uuid = string
type jsonb = string

type bool_array = bool option list
type int16_array = int16 option list
type int32_array = int32 option list
type int64_array = int64 option list
type string_array = string option list
type float_array = float option list
type timestamp_array = Calendar.t option list
type uuid_array = uuid option list

let string_of_hstore hstore =
  let string_of_quoted str = "\"" ^ str ^ "\"" in
  let string_of_mapping (key, value) =
    let key_str = string_of_quoted key in
    let value_str = match value with
      | Some v -> string_of_quoted v
      | None -> "NULL"
    in key_str ^ "=>" ^ value_str
  in String.join ", " (List.map string_of_mapping hstore)

let string_of_numeric (x : string) = x

let string_of_uuid (x : string) = x

let string_of_jsonb (x : string) = x

let string_of_inet (addr, mask) =
  let hostmask =
    if Unix.domain_of_sockaddr (Unix.ADDR_INET(addr, 1)) = Unix.PF_INET6
    then 128
    else 32
  in
    let addr = Unix.string_of_inet_addr addr
    in
      if mask = hostmask
      then addr
      else if mask >= 0 && mask < hostmask
           then addr ^ "/" ^ string_of_int mask
           else failwith "string_of_inet"

let string_of_oid = Int32.to_string
let string_of_bool = function
  | true -> "t"
  | false -> "f"
let string_of_int = Stdlib.string_of_int
let string_of_int16 = Stdlib.string_of_int
let string_of_int32 = Int32.to_string
let string_of_int64 = Int64.to_string
let string_of_float = string_of_float
let string_of_point (x, y) = "(" ^ (string_of_float x) ^ "," ^ (string_of_float y) ^ ")"
let string_of_timestamp = Printer.Calendar.to_string
let string_of_timestamptz (cal, tz) =
  Printer.Calendar.to_string cal ^
    match tz with
    | Time_Zone.UTC -> "+00"
    | Time_Zone.Local ->
	let gap = Time_Zone.gap Time_Zone.UTC Time_Zone.Local in
	if gap >= 0 then sprintf "+%02d" gap
	else sprintf "-%02d" (-gap)
    | Time_Zone.UTC_Plus gap ->
	if gap >= 0 then sprintf "+%02d" gap
	else sprintf "-%02d" (-gap)
let string_of_date = Printer.Date.to_string
let string_of_time = Printer.Time.to_string
let string_of_interval p =
  let y, m, d, s = Calendar.Period.ymds p in
  sprintf "%d years %d mons %d days %d seconds" y m d s
let string_of_unit () = ""

(* NB. It is the responsibility of the caller of this function to
 * properly escape array elements.
 *)
let string_of_any_array xs =
  let buf = Buffer.create 128 in
  Buffer.add_char buf '{';
  let adder i x =
    if i > 0 then Buffer.add_char buf ',';
    match x with
      | Some x ->
        Buffer.add_char buf '"';
        Buffer.add_string buf x;
        Buffer.add_char buf '"'
      | None ->
        Buffer.add_string buf "NULL" in
  List.iteri adder xs;
  Buffer.add_char buf '}';
  Buffer.contents buf

let option_map f = function
  | Some x -> Some (f x)
  | None -> None

let escape_string str =
  let buf = Buffer.create 128 in
  for i = 0 to String.length str - 1 do
    match str.[i] with
      | '"' | '\\' as x -> Buffer.add_char buf '\\'; Buffer.add_char buf x
      | x -> Buffer.add_char buf x
  done;
  Buffer.contents buf

let string_of_bool_array a = string_of_any_array (List.map (option_map string_of_bool) a)
let string_of_int16_array a = string_of_any_array (List.map (option_map Stdlib.string_of_int) a)
let string_of_int32_array a = string_of_any_array (List.map (option_map Int32.to_string) a)
let string_of_int64_array a = string_of_any_array (List.map (option_map Int64.to_string) a)
let string_of_string_array a = string_of_any_array (List.map (option_map escape_string) a)
let string_of_float_array a = string_of_any_array (List.map (option_map string_of_float) a)
let string_of_timestamp_array a = string_of_any_array (List.map (option_map string_of_timestamp) a)
let string_of_arbitrary_array f a = string_of_any_array (List.map (option_map f) a)
let string_of_uuid_array a = string_of_any_array (List.map (option_map escape_string) a)

let comment_src_loc () =
  match Sys.getenv_opt "PGCOMMENT_SRC_LOC" with
  | Some x ->
    begin match x with
      | "yes" | "1" | "on" -> true
      | "no" | "0" | "off" -> false
      | _ -> failwith (Printf.sprintf "Unrecognized option for 'PGCOMMENT_SRC_LOC': %s" x)
    end
  | None -> PGOCaml_config.default_comment_src_loc

open Sexplib.Std

type custom_rule_payload =
  { serialize: string
  ; deserialize: string
  }
  [@@deriving sexp]

type custom_rule_spec =
  | Typnam of string
  | Colnam of string
  | Argnam of string
  [@@deriving sexp]

type rule_logic =
  | And of rule_logic list
  | Or of rule_logic list
  | Rule of custom_rule_spec
  | True
  | False
  [@@deriving sexp]

let rec eval_rule_spec ?typnam ?colnam ?argnam logic =
  match logic with
  | True -> true
  | False -> false
  | Rule (Typnam s) -> Option.(map ((=) s) typnam |> default false)
  | Rule (Colnam s) -> Option.(map ((=) s) colnam |> default false)
  | Rule (Argnam s) ->
    Option.(map ((=) s) argnam |> default false)
  | And logics ->
    let[@warning "-8"] (hd :: tl) = List.map (eval_rule_spec ?typnam ?colnam ?argnam) logics in
    List.fold_left (fun acc x -> acc && x) hd tl
  | Or logics ->
    let[@warning "-8"] (hd :: tl) = List.map (eval_rule_spec ?typnam ?colnam ?argnam) logics in
    List.fold_left (fun acc x -> acc || x) hd tl

type custom_rule = rule_logic * custom_rule_payload [@@deriving sexp]

type custom_rules_conf = custom_rule list [@@deriving sexp]

let find_custom_typconvs =
  let open Rresult in
  let loadconvs fname =
    try
      Ok (Sexplib.Sexp.load_sexp_conv_exn fname custom_rules_conf_of_sexp)
    with
    | exn ->
      let cwd = Unix.getcwd () in
      Error (
        Printf.sprintf
          "Error parsing custom typeconvs file in %s: %s"
          cwd
          (Printexc.to_string exn))
  in
  let default_custom_converters =
    match Sys.getenv_opt "PGCUSTOM_CONVERTERS_CONFIG" with
      | None -> Ok []
      | Some f -> loadconvs f
  in
  fun ?typnam ?lookin ?colnam ?argnam () ->
    begin match lookin with
      | Some x ->
        (*let _ = failwith ("Kill me now " ^ x) in*)
        let convs = loadconvs x in
        begin match convs with
        | Ok x -> Ok x
        | Error e -> Error e
        end
        (*>>= fun blep ->
        Error (Printf.sprintf "Got %d converters" (List.length @@ blep))*)
      | None -> default_custom_converters
    end
    >>= fun custom_converters ->
    begin try
      Ok(
        List.filter
          (fun (logic, _) -> eval_rule_spec ?typnam ?colnam ?argnam logic)
          custom_converters)
    with
      | Failure e ->
        failwith e
    end
    >>= fun res ->
    match res with
    | _ :: _ :: _ -> Error "Converter collision"
    | [] -> Ok None
    | [_rulespec, v] -> Ok (Some (v.serialize, v.deserialize))

let string_of_bytea b =
  let `Hex b_hex = Hex.of_string b in  "\\x" ^ b_hex

let string_of_bytea_array a =
  string_of_any_array (List.map (option_map string_of_bytea) a)

let string_of_string (x : string) = x
let oid_of_string = Int32.of_string
let bool_of_string = function
  | "true" | "t" -> true
  | "false" | "f" -> false
  | str ->
      raise (Error ("PGOCaml: not a boolean: " ^ str))
let int_of_string = Stdlib.int_of_string
let int16_of_string = Stdlib.int_of_string
let int32_of_string = Int32.of_string
let int64_of_string = Int64.of_string
let float_of_string = float_of_string

let hstore_of_string str =
  let expect target stream =
    if List.exists (fun c -> c <> Stream.next stream) target
    then raise (Error ("PGOCaml: unexpected input in hstore_of_string")) in
  let parse_quoted stream =
    let rec loop accum stream = match Stream.next stream with
      | '"'  -> String.implode (List.rev accum)
      | '\\' -> loop (Stream.next stream :: accum) stream
      | x    -> loop (x :: accum) stream in
    expect ['"'] stream;
    loop [] stream in
  let parse_value stream = match Stream.peek stream with
    | Some 'N' -> (expect ['N'; 'U'; 'L'; 'L'] stream; None)
    | _        -> Some (parse_quoted stream) in
  let parse_mapping stream =
    let key = parse_quoted stream in
    expect ['='; '>'] stream;
    let value = parse_value stream in
    (key, value) in
  let parse_main stream =
    let rec loop accum stream =
      let mapping = parse_mapping stream in
      match Stream.peek stream with
        | Some _ -> (expect [','; ' '] stream; loop (mapping :: accum) stream)
        | None   -> mapping :: accum in
    match Stream.peek stream with
      | Some _ -> loop [] stream
      | None   -> [] in
  parse_main (Stream.of_string str)

let numeric_of_string (x : string) = x

let uuid_of_string (x : string) = x

let jsonb_of_string (x : string) = x

let inet_of_string =
  let re =
    let open Re in
    [ group (
        [ rep (compl [set ":./"])
        ; group (set ":.")
        ; rep1 (compl [char '/']) ]
        |> seq
      )
    ; opt (seq [char '/'; group (rep1 any)]) ]
    |> seq
    |> compile in
  fun str ->
    let subs = Re.exec re str in
    let addr = Unix.inet_addr_of_string (Re.Group.get subs 1) in
    let mask = try (Re.Group.get subs 3) with Not_found -> "" in (* optional match *)
    if mask = ""
    then (addr, (if (Re.Group.get subs 2) = "." then 32 else 128))
    else (addr, int_of_string mask)

let point_of_string =
  let point_re =
    let open Re in
    let space p =
      let space = rep (set " \t") in
      seq [ space ; p ; space ] in
    let sign = opt (set "+-") in
    let num = seq [ sign ; rep1 digit ; opt (char '.') ; rep digit
                  ; opt (seq [ set "Ee"; set "+-"; rep1 digit ]) ] in
    let nan = seq [ set "Nn"; char 'a'; set "Nn" ] in
    let inf = seq [ sign ; set "Ii" ; str "nfinity" ] in
    let float_pat = Re.alt [num ; nan ; inf ] in
    [ char '(' ; space (group float_pat) ; char ','
    ; space (group float_pat) ; char ')' ]
    |> seq
    |> compile in
  fun str ->
    try
      let subs = Re.exec point_re str in
      (float_of_string (Re.Group.get subs 1), float_of_string (Re.Group.get subs 2))
    with
    | _ -> failwith "point_of_string"

let date_of_string = Printer.Date.from_string

let time_of_string str =
  (* Remove trailing ".microsecs" if present. *)
  let n = String.length str in
  let str =
    if n > 8 && str.[8] = '.' then
      String.sub str 0 8
    else str in
  Printer.Time.from_string str

let timestamp_of_string str =
  (* Remove trailing ".microsecs" if present. *)
  let n = String.length str in
  let str =
    if n > 19 && str.[19] = '.' then
      String.sub str 0 19
    else str in
  Printer.Calendar.from_string str

let timestamptz_of_string str =
  (* Split into datetime+timestamp. *)
  let n = String.length str in
  let cal, tz =
    if n >= 3 && (str.[n-3] = '+' || str.[n-3] = '-') then
      String.sub str 0 (n-3), Some (String.sub str (n-3) 3)
    else
      str, None in
  let cal = timestamp_of_string cal in
  let tz = match tz with
    | None -> Time_Zone.Local (* best guess? *)
    | Some tz ->
	let sgn = match tz.[0] with '+' -> 1 | '-' -> -1 | _ -> assert false in
	let mag = int_of_string (String.sub tz 1 2) in
	Time_Zone.UTC_Plus (sgn * mag) in
  cal, tz

let re_interval =
  let open Re in
  let time_period unit_name =
    [ group (rep1 digit) ; space ; str unit_name ; opt (char 's') ]
    |> seq |> opt in
  let digit2 = [digit ; digit ] |> seq |> group in
  let time =
    seq [digit2 ; char ':' ; digit2 ; opt (seq [char ':' ; digit2]) ] in
  [ opt (time_period "year")
  ; rep space
  ; opt (time_period "mon")
  ; rep space
  ; opt (time_period "day")
  ; rep space
  ; opt time ]
  |> seq
  |> compile

let interval_of_string =
  let int_opt subs i =
    try int_of_string (Re.Group.get subs i) with
    | Not_found -> 0 in
  fun str ->
    try
      let sub = Re.exec re_interval str in
      Calendar.Period.make
        (int_opt sub 1) (* year *)
        (int_opt sub 2) (* month *)
        (int_opt sub 3) (* day *)
        (int_opt sub 4) (* hour *)
        (int_opt sub 5) (* min *)
        (int_opt sub 6) (* sec *)
    with
      Not_found -> failwith ("interval_of_string: bad interval: " ^ str)

let unit_of_string _ = ()

(* NB. This function also takes care of unescaping returned elements. *)
let any_array_of_string str =
  let n = String.length str in
  assert (str.[0] = '{');
  assert (str.[n-1] = '}');
  let str = String.sub str 1 (n-2) in
  let buf = Buffer.create 128 in
  let add_field accum =
    let x = Buffer.contents buf in
    Buffer.clear buf;
    let field =
      if x = "NULL"
      then
        None
      else
        let n = String.length x in
        if n >= 2 && x.[0] = '"'
        then Some (String.sub x 1 (n-2))
        else Some x in
    field :: accum in
  let loop (accum, quoted, escaped) = function
    | '\\' when not escaped -> (accum, quoted, true)
    | '"' when not escaped -> Buffer.add_char buf '"'; (accum, not quoted, false)
    | ',' when not escaped && not quoted -> (add_field accum, false, false)
    | x -> Buffer.add_char buf x; (accum, quoted, false) in
  let (accum, _, _) = String.fold_left loop ([], false, false) str in
  let accum = if Buffer.length buf = 0 then accum else add_field accum in
  List.rev accum

let bool_array_of_string str = List.map (option_map bool_of_string) (any_array_of_string str)
let int16_array_of_string str = List.map (option_map Stdlib.int_of_string) (any_array_of_string str)
let int32_array_of_string str = List.map (option_map Int32.of_string) (any_array_of_string str)
let int64_array_of_string str = List.map (option_map Int64.of_string) (any_array_of_string str)
let string_array_of_string str = any_array_of_string str
let float_array_of_string str = List.map (option_map float_of_string) (any_array_of_string str)
let timestamp_array_of_string str = List.map (option_map timestamp_of_string) (any_array_of_string str)
let arbitrary_array_of_string f str = List.map (option_map f) (any_array_of_string str)

let is_first_oct_digit c = c >= '0' && c <= '3'
let is_oct_digit c = c >= '0' && c <= '7'
let oct_val c = Char.code c - 0x30

let is_hex_digit = function '0'..'9' | 'a'..'f' | 'A'..'F' -> true | _ -> false

let hex_val c =
  let offset = match c with
    | '0'..'9' -> 0x30
    | 'a'..'f' -> 0x57
    | 'A'..'F' -> 0x37
    | _	       -> failwith "hex_val"
  in Char.code c - offset

(* Deserialiser for the new 'hex' format introduced in PostgreSQL 9.0. *)
let bytea_of_string_hex str =
  let len = String.length str in
  let buf = Buffer.create ((len-2)/2) in
  let i = ref 3 in
  while !i < len do
    let hi_nibble = str.[!i-1] in
    let lo_nibble = str.[!i] in
    i := !i+2;
    if is_hex_digit hi_nibble && is_hex_digit lo_nibble
    then begin
      let byte = ((hex_val hi_nibble) lsl 4) + (hex_val lo_nibble) in
      Buffer.add_char buf (Char.chr byte)
    end
  done;
  Buffer.contents buf

(* Deserialiser for the old 'escape' format used in PostgreSQL < 9.0. *)
let bytea_of_string_escape str =
  let len = String.length str in
  let buf = Buffer.create len in
  let i = ref 0 in
  while !i < len do
    let c = str.[!i] in
    if c = '\\' then (
      incr i;
      if !i < len && str.[!i] = '\\' then (
	Buffer.add_char buf '\\';
	incr i
      ) else if !i+2 < len &&
	is_first_oct_digit str.[!i] &&
	is_oct_digit str.[!i+1] &&
	is_oct_digit str.[!i+2] then (
	  let byte = oct_val str.[!i] in
	  incr i;
	  let byte = (byte lsl 3) + oct_val str.[!i] in
	  incr i;
	  let byte = (byte lsl 3) + oct_val str.[!i] in
	  incr i;
	  Buffer.add_char buf (Char.chr byte)
	)
    ) else (
      incr i;
      Buffer.add_char buf c
    )
  done;
  Buffer.contents buf

(* PostgreSQL 9.0 introduced the new 'hex' format for binary data.
   We must therefore check whether the data begins with a magic sequence
   that identifies this new format and if so call the appropriate parser;
   if it doesn't, then we invoke the parser for the old 'escape' format.
*)
let bytea_of_string str =
	if String.starts_with str "\\x"
	then bytea_of_string_hex str
	else bytea_of_string_escape str

let bind = (>>=)
let return = Thread.return

end
