open OUnit

let assert_equal_a = 
  let cmp (a1,c1) (a2,c2) = (Cstrs.equal a1 a2) && (Cstrs.equal c1 c2) in
  let printer (asts,cstrs) = 
    let ass = Cstrs.to_string asts in
    let css = Cstrs.to_string cstrs in
      "\nasts={" ^ ass ^ " }\ncstrs={" ^ css ^  "}\n" in
    assert_equal ~cmp:cmp ~printer:printer
      
let empty = Cstrs.empty 

let test_t1_t1 _ = 
  begin
    assert_equal_a (empty, empty) (Subtype.test SigDB.empty (Cstrs.Subtype(Type.var 1, Type.var 1)) empty empty);
    assert_equal_a (empty, empty) (Subtype.test SigDB.empty (Cstrs.Subtype(Type.base "Integer", Type.base "Integer")) empty empty);
  end

let test_top _ = 
  begin
    assert_equal_a (empty,empty) (Subtype.test SigDB.empty (Cstrs.Subtype(Type.var 1, Type.top)) empty empty);
    assert_equal_a (empty,empty) (Subtype.test SigDB.empty (Cstrs.Subtype(Type.base "TCPSocket", Type.top)) empty empty);
  end

let test_top_fail _ = 
  begin
    assert_raises
      (Subtype.IncompatibleSignatures(Cstrs.Subtype(Type.top, Type.var 12)))
      (fun _ -> (Subtype.test SigDB.empty (Cstrs.Subtype(Type.top, Type.var 12)) empty empty));
  end

let test_not_in_db _ =
  let ast1 = Cstrs.add (Cstrs.Subtype(Type.var 1, Type.base "Integer")) empty in
  begin
    assert_equal_a (ast1,empty) (Subtype.test SigDB.empty (Cstrs.Subtype(Type.var 1, Type.base "Integer")) ast1 empty);
  end

let test_no_recursive_subtype1 _ = 
  let s1 = (Sig.create [("foo", (Sig.Entry.create [] (Type.var 10)))]) in
  let s2 = (Sig.create [("foo", (Sig.Entry.create [] (Type.var 20)))]) in
  let db = (SigDB.of_list [(Type.var 1, s1); (Type.var 2, s2)]) in
    begin
      assert_equal_a
	((Cstrs.of_list [(Cstrs.Subtype(Type.var 1, Type.var 2))]),
	 (Cstrs.of_list [(Cstrs.Subtype(Type.var 10, Type.var 20))]))
	(Subtype.test db (Cstrs.Subtype(Type.var 1, Type.var 2)) empty empty);
    end
      
let test_no_recursive_subtype2 _ =
  let s1 = (Sig.create [("foo", (Sig.Entry.create [Type.base "Numeric"] (Type.var 10)));
			("bar", (Sig.Entry.create [Type.base "String"] (Type.var 20)))]) in
  let s2 = (Sig.create [("foo", (Sig.Entry.create [Type.base "Integer"] (Type.var 30)))]) in
  let db = (SigDB.of_list [(Type.var 1, s1); (Type.var 2, s2)]) in
  let asts = (Cstrs.of_list [(Cstrs.Subtype(Type.base "Integer", Type.base "Numeric"))]) in
    begin
      assert_equal_a
	((Cstrs.of_list [(Cstrs.Subtype(Type.var 1, Type.var 2));
			 (Cstrs.Subtype(Type.base "Integer", Type.base "Numeric"))]),
	 (Cstrs.of_list [(Cstrs.Subtype(Type.var 10, Type.var 30))]))
	(Subtype.test db (Cstrs.Subtype(Type.var 1, Type.var 2)) asts empty);
    end

let suite = "Test Subtype" >:::
	      [
		"t1 <= t1" >:: test_t1_t1;
		"t1 <= top" >:: test_top;
		"top <!= t1" >:: test_top_fail;
		"var <= var" >:: test_not_in_db;
		"non recursive subtype 1" >:: test_no_recursive_subtype1;
		"non recursive subtype 2" >:: test_no_recursive_subtype2;
	      ]


