open ExtList
open Ast

module Env = struct
  let split4 lis = List.fold_right (fun (a,b,c,d) (aa,bb,cc,dd) -> (a::aa, b::bb, c::cc, d::dd)) lis ([],[],[],[])
  
  type t = SigDB.db * TypeEnv.env * Cstrs.t_set * (Cstrs.type_constraint, (string*int)) PMap.t

  let db (d,_,_,_) = d
  let env (_,e,_,_) = e
  let constraints (_,_,c,_) = c
  let lines (_,_,_,l) = l

  let empty = (SigDB.empty, TypeEnv.empty, Cstrs.empty, PMap.empty)

  let equal (d1,e1,c1,_) (d2,e2,c2,_) = ((SigDB.equal d1 d2) & (TypeEnv.equal e1 e2) & (Cstrs.equal c1 c2))

  let to_string ((d,e,c,l):t) = SigDB.to_string d ^ "\n" ^ TypeEnv.to_string e ^ "\n" ^ Cstrs.to_string c

  type add_request = 
    | AddEnv of (string * Type.rb_type)
    | AddDB of (Type.rb_type * Sig.sig_struct)
    | AddCstr of Cstrs.type_constraint
    | AddCstrSet of Cstrs.t_set

  let add req (d,e,c,l) = match req with
    | AddEnv (n,t) -> (d, (TypeEnv.add n t e), c, l)
    | AddDB (t,s) -> ((SigDB.add t s d), e, c, l)
    | AddCstr nc -> (d,e,(Cstrs.add nc c), l)
    | AddCstrSet cs -> (d,e,(Cstrs.merge [c;cs]), l)
	
  type remove_request = 
    | RmEnv of string
  
  let remove req (d,e,c,l) = match req with
    | RmEnv name -> (d, (TypeEnv.remove name e), c, l)

  (**
     monomorphic map.
  *)
  let map0 f (d,e,c,l) = 
    (SigDB.maps f d),
    (TypeEnv.map f e),
    (Cstrs.maps f c),
    l

  let merge vs =
    let merge_map m1 m2 = PMap.foldi PMap.add m1 m2 in
    let dbs,envs,cs,ls = split4 vs in
      (SigDB.merge dbs), (TypeEnv.merge envs), (Cstrs.merge cs), (List.fold_right merge_map ls PMap.empty)

  let mgu vals = 
    let _,es,_,_ = split4 vals in
      TypeEnv.mgu es
end

(**
  list method paramater variables in body expr.
*)
let list_param_names np expr =
  let add p ps = 
    if (List.mem p ps) 
    then ps
    else p::ps in
  let rec list_params expr ps = 
    match expr with
      | VarExpr(n,lv) -> (add n ps)
      | ArrayExpr(es) -> (List.fold_right list_params es ps)
      | HashExpr(ees) ->
	  (List.fold_right (fun (e1,e2) ps -> (list_params e1 (list_params e2 ps))) ees ps)
      | BlockExpr(es) -> (List.fold_right list_params es ps)
      | AssignExpr(_,_,e) -> (list_params e ps)
      | IfExpr(e1,e2,e3) -> (List.fold_right list_params [e1;e2;e3] ps)
      | CallExpr(e1,_,es) -> (List.fold_right list_params (e1::es) ps)
      | FCallExpr(_,es) -> (List.fold_right list_params es ps)
      | ReturnExpr(e)-> (list_params e ps)
      | NewlineExpr(_,e) -> (list_params e ps)
      | _ -> ps
  in
    if np = 0 then
      []
    else
      let filter_vars (id,name) = (2 <= id) & (id < 2+np) in
	(List.map snd (List.filter filter_vars (List.sort ~cmp:compare (list_params expr []))))
     
let rec fresh_types = function
  | 0 -> []
  | n -> (Type.fresh_var())::(fresh_types (n-1))

let ti expr =
  let rec skip_newline = function
    | NewlineExpr (_,e) -> (skip_newline e)
    | e -> e
  in
  let rec ti0 l expr = 
    match expr with
      | NewlineExpr (l,expr) -> (ti0 l expr)
      | IntExpr _ -> (Env.empty, Type.base "Integer")
      | StringExpr _ -> (Env.empty, Type.base "String")
      | FloatExpr _ -> (Env.empty, Type.base "Float")
      | NilExpr -> (Env.empty, Type.base "NilClass")
      | VarExpr ((_,name),_) -> 
	  let t = Type.fresh_var() in
	    (Env.add (Env.AddEnv (name,t)) Env.empty), t
      | AssignExpr ((_,name), _, expr) -> (ti0 l expr)
      | BlockExpr exprs -> 
	  if List.length exprs = 0 then
	    (Env.empty, Type.base "NilClass")
	  else
	    let fold_block (e1,(v1,t1)) (v0,t0) = 
	      match e1 with 
		| AssignExpr ((_,name), _, _) ->
		    let (v0,t0) = 
		      if TypeEnv.mem name (Env.env v0) then
			let s = 
			  let ss = Subst.Src.of_list [(TypeEnv.find name (Env.env v0)), t1] in
			    Subst.create ss in
			  ((Env.remove (Env.RmEnv name) (Env.map0 s v0)), s t0)
		      else
			(v0,t0) in
		    let s = (Subst.create (Env.mgu [v0;v1])) in
		    let vv = (Env.merge (List.map (Env.map0 s) [v0;v1])) in
		      (vv,s t0)
		| _ ->
		    let s = (Subst.create (Env.mgu [v0;v1])) in
		    let vv = Env.merge (List.map (Env.map0 s) [v0;v1]) in
		      (vv,s t0)
	    in
	    let pairs = List.map (fun e -> (skip_newline e, ti0 l e)) exprs in
	    let last_pair = List.last pairs in
	    let other_pairs = List.take (List.length pairs - 1) pairs in
	      List.fold_right fold_block other_pairs (snd last_pair)
      | ClassExpr((_,name), body) ->
	  let env = Env.empty in
	  let env = Env.add (Env.AddEnv (name, Type.base ("_"^name))) env in
	    (env, Type.base "NilClass")
      | FCallExpr (name, params) ->
	  let dummy_name = "*dummy*" in
	  let call_expr =
	    let dummy_recv = VarExpr ((4649, dummy_name), LocalVar) in
	      (CallExpr (dummy_recv, name, params)) in
	  let (vc,tc) = (ti0 l call_expr) 
	  in
	  let s = Subst.create (Subst.Src.of_list [(TypeEnv.find dummy_name (Env.env vc)), (Type.var 0)]) in
	  let vc = Env.map0 s vc in
	  let vc = Env.remove (Env.RmEnv dummy_name) vc in
	    (vc, tc)
      | CallExpr (recv, (_,name), params) ->
	  let (valr,tr) = (ti0 l recv) in
	  let (vals,ts) = (List.split (List.map (ti0 l) params)) 
	  in
	  let envs = (Env.env valr)::(List.map Env.env vals) in
	  let s = (Subst.create (TypeEnv.mgu envs)) 
	  in
	  let envs = List.map (TypeEnv.map s) envs in
	  let ts = List.map s ts in
	  let tr = s tr in
	  let valr = Env.map0 s valr in
	  let vals = List.map (Env.map0 s) vals
	  in
	  let dom = List.map (fun _ -> Type.fresh_var()) params in
	  let rg = Type.fresh_var() in
	  let recv_type = Type.fresh_var() in
	  let recv_struct = 
	    let m = Sig.Entry.create dom rg in
	      (Sig.create [name,(Sig.Entry.map s m)]) in
	  let constraints = 
	    let constraints_of_params =
	      List.fold_right2 (fun t1 t2 c -> Cstrs.add (Cstrs.Subtype (t2,t1)) c) dom ts Cstrs.empty in
	    let constraints_of_recever = Cstrs.Subtype (tr,recv_type) in
	      Cstrs.add constraints_of_recever constraints_of_params
	  in
	  let vals = valr::vals in
	  let vall = Env.merge vals in
	  let vall = Env.add (Env.AddDB (recv_type,recv_struct)) vall in
	  let vall = Env.add (Env.AddCstrSet constraints) vall in
	    (vall, rg)
      | DefExpr((_,name), np, body) ->
	  let (vb,tb) = (ti0 l body) in
	  let param_names = list_param_names np body in
	  let dom_type = (List.map (fun name -> (TypeEnv.find name (Env.env vb))) param_names) in
	  let str_def =
	    let entry =	(Sig.Entry.create dom_type tb) in
	      (Sig.create [(name,entry)]) in
	  let vdef = (Env.add (Env.AddDB ((Type.var 0), str_def)) vb) in
	  let vdef = List.fold_right (fun name v -> (Env.remove (Env.RmEnv name) v)) param_names vdef in
	    (vdef, (Type.base "NilClass"))
      | _ -> failwith "Type inference failed [unsupported expression type]"
  in
    (ti0 ("*",0) expr)
      
let is_solved (db,env,cstrs,l) = 
  let is_solved_c = function
    | Cstrs.Subtype(c1,c2) -> not ((SigDB.mem c1 db) && (SigDB.mem c2 db))
  in
    Cstrs.for_all is_solved_c cstrs

let solve (db, env, cstrs, l) =
  let rec solve0 (asts,cstrs) = 
    let _,cstrs = Cstrs.remove_transitives (fun t -> not (SigDB.mem t db)) cstrs in
      if is_solved (db,env,cstrs,l)
      then
	(asts, cstrs)
      else
	begin
	  (solve0 
	     (Cstrs.fold
		(fun cc (asts,cstrs) -> (Subtype.test db cc asts cstrs))
		cstrs
		(asts, Cstrs.empty)));
	end
  in
    begin
      solve0 (Cstrs.empty, cstrs);
    end

