(* vim: fileencoding=utf8 ft=ocaml et sw=2 ts=4 sts=4 *)

open Core

let do_echo_prompt () =
  Out_channel.output_string stdout "prompt: ";
  Out_channel.flush stdout;
  match In_channel.input_line Stdio.stdin with
  | None -> Out_channel.output_string stderr "No input!\n"
  | Some line -> Out_channel.output_string stdout (
    String.concat ["Got: "; line; "\n"]);
    Out_channel.flush stdout;
;;

let twoprompt ()=
  Out_channel.output_string stdout "1) ";
  do_echo_prompt ();
  Out_channel.output_string stdout "2) ";
  do_echo_prompt ();
;;

let pp_string_list l = String.concat ~sep:" " (List.map l (sprintf "%S"));;

type color_pair = {normal: string; escaped: string} [@@deriving show];;
type color_scheme = {i: color_pair; o:color_pair} [@@deriving show];;
type cmd_args = {
  hideendl: bool;
  endl: string;
  color: color_scheme option;
  exe: string list;
} [@@deriving show];;

let print_color_pair {normal; escaped} =
  sprintf "{normal=%S; escaped=%S}" normal escaped
;;
let print_color_scheme {i; o} =
  sprintf "{i=%s; o=%s}" (print_color_pair i) (print_color_pair o)
;;
let print_cmd_args {hideendl; endl; color; exe} =
      printf "hide: %B\n" hideendl;
      printf "endl: %S\n" endl;
      (match color with
      | None -> printf "colors: none\n";
      | Some cs -> printf "colors: %s\n" (print_color_scheme cs);
      );
      printf "exe: %s\n" (pp_string_list exe);
;;

let po str = 
  Out_channel.output_string stdout str;
  Out_channel.flush stdout;
;;

let pe str = 
  Out_channel.output_string stderr str;
  Out_channel.flush stderr;
;;

type display_mode = NoFormat | Normal | Escaped;;
type pretty_state = {to_print: Buffer.t; mutable dm: display_mode};;
module CharSet = Set.Make(Char);;

let pp_writer hideendl endl colors =
  let s = {to_print = Buffer.create 80; dm = NoFormat} in
  let add_char = Buffer.add_char s.to_print
  and add_string = Buffer.add_string s.to_print in
  let add_escaped c =
    match c with
    | '\x00' -> add_string "\\0"
    | '\t' -> add_string "\\t"
    | '\n' -> add_string "\\n"
    | '\r' -> add_string "\\r"
    | _ -> bprintf s.to_print "\\x%02x" (Char.to_int c);
  in
  let write_out () =
    if Buffer.length s.to_print > 0 then (
      match s.dm with NoFormat | _ ->
        add_string "\x1b[0m";
      Out_channel.output_buffer stderr s.to_print;
      Out_channel.flush stderr;
      Buffer.clear s.to_print;
      s.dm <- NoFormat;
    )
  and pp_char = match colors with
    | None -> fun c -> (
        if c >= ' ' && c <= '~' then (
          add_char c;
        ) else (
          add_escaped c;
        )
      )
    | Some {normal; escaped} -> fun c -> (
        if c >= ' ' && c <= '~' then (
          add_string "\x1b[";
          add_string normal;
          add_char 'm';
          add_char c;
        ) else (
          bprintf s.to_print "\x1b[%sm" escaped;
          add_escaped c;
        )
    )
  in
  match (String.length endl, hideendl) with
  | (0, _) -> fun (len:int) (buffer:Bytes.t) -> (
    for n = 0 to len - 1 do
      pp_char (Bytes.get buffer n);
    done;
    write_out ())
  | (1, false) -> (let endl = String.get endl 0 in
    fun (len:int) (buffer:Bytes.t) -> (
      for n = 0 to len - 1 do
        let c = Bytes.get buffer n in
        pp_char c;
        if c = endl then add_char '\n'
      done;
      write_out ()))
  | (1, true) -> (let endl = String.get endl 0 in
    fun (len:int) (buffer:Bytes.t) -> (
      for n = 0 to len - 1 do
        let c = Bytes.get buffer n in
        if c = endl then add_char '\n' else pp_char c;
      done;
      write_out ()))
  | (l, false) -> (
    let ring = Ring.create l in
    fun (len:int) (buffer:Bytes.t) -> (
      for n = 0 to len - 1 do
        let c = Bytes.get buffer n in
        pp_char c;
        let _ = Ring.add_char ring c in
        if Ring.compare ring endl = 0 then (
          add_char '\n';
          Ring.clear ring;
        )
      done;
      write_out ()
    )
  )
  | (l, true) -> (
    let ring = Ring.create l
    and endl_chars = String.fold ~init:CharSet.empty ~f:CharSet.add endl
    in
    fun (len:int) (buffer:Bytes.t) -> if len = 0 then (
      (* end of input, dump ring buffer *)
      Sequence.iter ~f:pp_char (Ring.fwd ring);
      Ring.clear ring;
      write_out ();
    ) else (
      for n = 0 to len - 1 do
        let c = Bytes.get buffer n in
        if CharSet.mem endl_chars c then (
          (match Ring.add_char ring c with
            Some buffered_c -> pp_char c | None -> ());
          if Ring.compare ring endl = 0 then (
            add_char '\n';
            Ring.clear ring;
          )
        ) else (
          if Ring.length ring > 0 then (
            Sequence.iter ~f:pp_char (Ring.fwd ring);
            Ring.clear ring
          );
          pp_char c
        )
      done;
      write_out ();
    )
  )
;;

(* Copy data from the reader to the writer, using the provided buffer
   as scratch space *)
let rec copy_blocks bufsize buffer r w cb () =
  Lwt.bind (Lwt_unix.read r buffer 0 bufsize) (fun bytes_read ->
  if bytes_read > 0 then (
    cb bytes_read buffer;
      Lwt.bind
        (Lwt_io.write_from_exactly w buffer 0 bytes_read)
        (copy_blocks bufsize buffer r w cb)
  ) else (
    cb 0 buffer;
    Lwt_io.close w
  ))
;;

let copy_cb fd_in fd_out cb = 
  let bufsize = 16 * 1024 in
  (* let buffer = Bytes.to_string (Bytes.create (16 * 1024)) in *)
  let buffer = Bytes.create bufsize in
  copy_blocks bufsize buffer fd_in (Lwt_io.of_fd Lwt_io.Output fd_out) cb ()
;;

let cb_null (len:int) (buf:Bytes.t) = ();;

let main {hideendl; endl; color; exe} =
  (* pe ((show_cmd_args {hideendl; endl; color; exe})^"\n"); *)
  Lwt_unix.set_default_async_method Lwt_unix.Async_none;
  let fd0 = Lwt_unix.stdin in
  let fd1 = Lwt_unix.stdout in
  let stdin_r, stdin_w = Lwt_unix.pipe_out () in
  let stdout_r, stdout_w = Lwt_unix.pipe_in () in
  printf "pipes: (%d %d) (%d %d)\n%!" 
    (Unix.File_descr.to_int stdin_r)
    (Unix.File_descr.to_int (Lwt_unix.unix_file_descr stdin_w))
    (Unix.File_descr.to_int (Lwt_unix.unix_file_descr stdout_r))
    (Unix.File_descr.to_int stdout_w)
  ;
  let pid = Subprocess.create_process
    ~redirects:[
      Lwt_unix.unix_file_descr fd0, Subprocess.Redirect stdin_r;
      Lwt_unix.unix_file_descr fd1, Subprocess.Redirect stdout_w;
      Lwt_unix.unix_file_descr stdin_w, Subprocess.Close;
      Lwt_unix.unix_file_descr stdout_r, Subprocess.Close;
    ]
    (List.nth_exn exe 0) (List.to_array exe)
  in
  printf "pid: %d\n%!" pid;
  Lwt_unix.set_blocking ~set_flags:true fd0 false;
  Lwt_unix.set_blocking ~set_flags:true fd1 false;
  Unix.close stdin_r;
  Unix.close stdout_w;
  let i_cb = pp_writer hideendl endl (Option.map color (fun cs -> cs.i))
  and o_cb = pp_writer hideendl endl (Option.map color (fun cs -> cs.o))
  in
  Lwt_main.run (Lwt.join [
    copy_cb fd0 stdin_w i_cb;
    copy_cb stdout_r fd1 o_cb;
  ]);
  let _, status = Subprocess.waitpid [] pid in
  exit (Subprocess.wait_estatus status);
;;

let cmd =
  let open Command.Let_syntax in
  let flag_const value = 
    Let_syntax.map ~f:(fun x -> match x with
        | true -> Some value
        | false -> None)
  in
  Command.basic ~summary:"Pretty print I/O to a spawned command."
  [%map_open
    let hide = flag "-H" ~aliases:["--hide-newlines"] no_arg
      ~doc:"Suppress printing codes of the line-terminating sequence"
    and endl = choose_one  ~if_nothing_chosen:(`Default_to "\n")  [
      flag "-n" no_arg  ~doc:"Use \\n as end of line"
        |> flag_const "\n";
      flag "-r" no_arg  ~doc:"Use \\r as end of line"
        |> flag_const "\r";
      flag "-D" no_arg  ~doc:"Use \\r\\n as end of line"
        |> flag_const "\r\n";
      flag "-E" (optional string)  ~doc:"EOL Set end of line sequence";
      flag "-N" no_arg  ~doc:"Don't recognize end of line sequences"
        |> flag_const "";
      ]
    and color = choose_one ~if_nothing_chosen:(`Default_to (
        Some ["32";"36";"31";"35"]
      )) [
       flag "-c" ~aliases:["--nocolor"; "--no-color"] no_arg
        ~doc:"Don't use colors to distinguish input from output"
        |> flag_const None;
       flag "-C" ~aliases:["--colors"; "--colours"] (optional string)
        ~doc:"COLORS A quad of VT color numbers"
        |> map ~f:(fun x -> match x with
          | Some s -> Some (Some (String.split s ' '))
          | None -> None);
      ]
    and exe = flag "--" escape ~doc:"EXE Command to execute"
    and exe2 = anon (sequence ("EXE"%:string))
    in
    fun () ->
      let args = {
        hideendl=hide;
        endl=endl;
        color=(match color with
          | None -> None
          | Some l -> Some {
            i={normal=List.nth_exn l 0; escaped=List.nth_exn l 1};
            o={normal=List.nth_exn l 2; escaped=List.nth_exn l 3};
          });
        exe=(match exe with
          | None -> exe2
          | Some l -> l
        );
      }
      in
      (* print_cmd_args args; *)
      (* pe (show_cmd_args args); *)
      main args;
  ]
;;

let () = Command.run ~version:"ppio-ocaml-0.0" cmd;;