diff --git a/dscheck.opam b/dscheck.opam index c90f6fd..8ec0706 100644 --- a/dscheck.opam +++ b/dscheck.opam @@ -8,8 +8,6 @@ bug-reports: "https://github.com/ocaml-multicore/dscheck/issues" depends: [ "ocaml" {>= "5.0.0"} "dune" {>= "2.9"} - "containers" - "oseq" "odoc" {with-doc} ] build: [ diff --git a/dune-project b/dune-project index c39a362..09ee8d1 100644 --- a/dune-project +++ b/dune-project @@ -3,12 +3,12 @@ (generate_opam_files true) (name dscheck) -(source (github sadiqj/dscheck)) +(source (github ocaml-multicore/dscheck)) (authors "Sadiq Jaffer") (maintainers "Sadiq Jaffer") (package (name dscheck) (synopsis "Traced Atomics") - (depends (ocaml (>= 5.0.0)) dune containers oseq)) + (depends (ocaml (>= 5.0.0)) dune)) diff --git a/src/dune b/src/dune index 82e41aa..eef9f13 100644 --- a/src/dune +++ b/src/dune @@ -1,3 +1,2 @@ (library - (public_name dscheck) - (libraries containers oseq)) + (public_name dscheck)) diff --git a/src/tracedAtomic.ml b/src/tracedAtomic.ml index 0180d64..e7e7e63 100644 --- a/src/tracedAtomic.ml +++ b/src/tracedAtomic.ml @@ -1,7 +1,29 @@ open Effect open Effect.Shallow -type 'a t = 'a Atomic.t * int +module Uid = struct + type t = int list + let compare = List.compare Int.compare + let to_string t = String.concat "," (List.map string_of_int t) + + let rec alpha i = + let head = Char.chr (Char.code 'A' + i mod 26) in + let i = i / 26 in + head :: if i > 0 then alpha i else [] + + let pretty lst = + match List.concat (List.map alpha lst) with + | [] -> "_" + | [x] -> String.make 1 x + | lst -> + let arr = Array.of_list lst in + "(" ^ String.init (Array.length arr) (Array.get arr) ^ ")" +end + +module IdSet = Set.Make (Uid) +module IdMap = Map.Make (Uid) + +type 'a t = 'a Atomic.t * Uid.t type _ Effect.t += | Make : 'a -> 'a t Effect.t @@ -10,42 +32,28 @@ type _ Effect.t += | Exchange : ('a t * 'a) -> 'a Effect.t | CompareAndSwap : ('a t * 'a * 'a) -> bool Effect.t | FetchAndAdd : (int t * int) -> int Effect.t + | Spawn : (unit -> unit) -> unit Effect.t -module IntSet = Set.Make( - struct - let compare = Stdlib.compare - type t = int - end ) +type cas_result = Unknown | Success | Failed -module IntMap = Map.Make( - struct - type t = int - let compare = Int.compare - end - ) - -let _string_of_set s = - IntSet.fold (fun y x -> (string_of_int y) ^ "," ^ x) s "" +type atomic_op = + | Start | Make | Get | Set | Exchange | FetchAndAdd | Spawn + | CompareAndSwap of cas_result ref -type atomic_op = Start | Make | Get | Set | Exchange | CompareAndSwap | FetchAndAdd - -let atomic_op_str x = - match x with - | Start -> "start" - | Make -> "make" - | Get -> "get" - | Set -> "set" - | Exchange -> "exchange" - | CompareAndSwap -> "compare_and_swap" - | FetchAndAdd -> "fetch_and_add" +let atomic_op_equal a b = match a, b with + | CompareAndSwap _, CompareAndSwap _ -> true + | _ -> a = b let tracing = ref false let finished_processes = ref 0 type process_data = { + uid : Uid.t; + mutable domain_generator : int; + mutable atomic_generator : int; mutable next_op: atomic_op; - mutable next_repr: int option; + mutable next_repr: Uid.t option; mutable resume_func : (unit, unit) handler -> unit; mutable finished : bool; } @@ -54,13 +62,12 @@ let every_func = ref (fun () -> ()) let final_func = ref (fun () -> ()) (* Atomics implementation *) -let atomics_counter = ref 1 +let atomics_counter = Atomic.make (-1) let make v = if !tracing then perform (Make v) else begin - let i = !atomics_counter in - atomics_counter := !atomics_counter + 1; - (Atomic.make v, i) + let i = Atomic.fetch_and_add atomics_counter (- 1) in + (Atomic.make v, [i]) end let get r = if !tracing then perform (Get r) else match r with | (v,_) -> Atomic.get v @@ -82,25 +89,30 @@ let incr r = ignore (fetch_and_add r 1) let decr r = ignore (fetch_and_add r (-1)) (* Tracing infrastructure *) -let processes = CCVector.create () +let processes = ref IdMap.empty +let push_process p = + assert (not (IdMap.mem p.uid !processes)) ; + processes := IdMap.add p.uid p !processes + +let get_process id = IdMap.find id !processes let update_process_data process_id f op repr = - let process_rec = CCVector.get processes process_id in + let process_rec = get_process process_id in process_rec.resume_func <- f; process_rec.next_repr <- repr; process_rec.next_op <- op let finish_process process_id = - let process_rec = CCVector.get processes process_id in + let process_rec = get_process process_id in process_rec.finished <- true; finished_processes := !finished_processes + 1 -let handler current_process_id runner = +let handler current_process runner = { retc = (fun _ -> ( - finish_process current_process_id; + finish_process current_process.uid; runner ())); exnc = (fun s -> raise s); effc = @@ -109,139 +121,372 @@ let handler current_process_id runner = | Make v -> Some (fun (k : (a, _) continuation) -> - let i = !atomics_counter in + let i = current_process.atomic_generator :: current_process.uid in + current_process.atomic_generator <- current_process.atomic_generator + 1 ; let m = (Atomic.make v, i) in - atomics_counter := !atomics_counter + 1; - update_process_data current_process_id (fun h -> continue_with k m h) Make (Some i); + update_process_data current_process.uid (fun h -> continue_with k m h) Make (Some i); runner ()) | Get (v,i) -> Some (fun (k : (a, _) continuation) -> - update_process_data current_process_id (fun h -> continue_with k (Atomic.get v) h) Get (Some i); + update_process_data current_process.uid (fun h -> continue_with k (Atomic.get v) h) Get (Some i); runner ()) | Set ((r,i), v) -> Some (fun (k : (a, _) continuation) -> - update_process_data current_process_id (fun h -> continue_with k (Atomic.set r v) h) Set (Some i); + update_process_data current_process.uid (fun h -> continue_with k (Atomic.set r v) h) Set (Some i); runner ()) | Exchange ((a,i), b) -> Some (fun (k : (a, _) continuation) -> - update_process_data current_process_id (fun h -> continue_with k (Atomic.exchange a b) h) Exchange (Some i); + update_process_data current_process.uid (fun h -> continue_with k (Atomic.exchange a b) h) Exchange (Some i); runner ()) | CompareAndSwap ((x,i), s, v) -> Some (fun (k : (a, _) continuation) -> - update_process_data current_process_id (fun h -> - continue_with k (Atomic.compare_and_set x s v) h) CompareAndSwap (Some i); + let res = ref Unknown in + update_process_data current_process.uid (fun h -> + let ok = Atomic.compare_and_set x s v in + res := if ok then Success else Failed ; + continue_with k ok h) (CompareAndSwap res) (Some i); runner ()) | FetchAndAdd ((v,i), x) -> Some (fun (k : (a, _) continuation) -> - update_process_data current_process_id (fun h -> + update_process_data current_process.uid (fun h -> continue_with k (Atomic.fetch_and_add v x) h) FetchAndAdd (Some i); runner ()) + | Spawn f -> + let uid = current_process.domain_generator :: current_process.uid in + current_process.domain_generator <- current_process.domain_generator + 1 ; + Some + (fun (k : (a, _) continuation) -> + update_process_data current_process.uid (fun h -> + let fiber_f h = continue_with (fiber f) () h in + push_process + { uid; domain_generator = 0; atomic_generator = 0; next_op = Start; next_repr = None; resume_func = fiber_f; finished = false } ; + continue_with k () h) Spawn (Some uid); + runner ()) | _ -> None); } let spawn f = - let fiber_f h = - continue_with (fiber f) () h in - CCVector.push processes - { next_op = Start; next_repr = None; resume_func = fiber_f; finished = false } - -let rec last_element l = - match l with - | h :: [] -> h - | [] -> assert(false) - | _ :: tl -> last_element tl + if !tracing then perform (Spawn f) else failwith "spawn outside tracing" + +type proc_rec = { proc_id: Uid.t; op: atomic_op; obj_ptr : Uid.t option } +type state_cell = { + procs : proc_rec IdMap.t; + run : proc_rec; + enabled : IdSet.t; + mutable backtrack : proc_rec list IdMap.t; +} -type proc_rec = { proc_id: int; op: atomic_op; obj_ptr : int option } -type state_cell = { procs: proc_rec list; run_proc: int; run_op: atomic_op; run_ptr: int option; enabled : IntSet.t; mutable backtrack : IntSet.t } +module Step = struct + type t = proc_rec + let compare a b = Uid.compare a.proc_id b.proc_id + + let atomic_op_str op arg = + let arg = match op, arg with + | _, None -> "?" + | Spawn, Some domain_id -> Uid.pretty domain_id + | _, Some ptr -> Uid.to_string ptr + in + match op with + | Start -> "start" + | Spawn -> "spawn() = " ^ arg + | Make -> "make() = " ^ arg + | Get -> "get(" ^ arg ^ ")" + | Set -> "set(" ^ arg ^ ")" + | Exchange -> "exchange(" ^ arg ^ ")" + | FetchAndAdd -> "fetch_and_add(" ^ arg ^ ")" + | CompareAndSwap cas -> + begin match !cas with + | Unknown -> "compare_and_swap(" ^ arg ^ ")" + | Success -> "compare_and_swap(" ^ arg ^ ")" + | Failed -> "compare_and_no_swap(" ^ arg ^ ")" + end + + let to_string t = + Printf.sprintf "%s %s" + (Uid.pretty t.proc_id) + (atomic_op_str t.op t.obj_ptr) +end + +module T = Trie.Make (Step) let num_runs = ref 0 (* we stash the current state in case a check fails and we need to log it *) let schedule_for_checks = ref [] -let do_run init_func init_schedule = - init_func (); (*set up run *) - tracing := true; +let max_depth = ref 0 +let depth = ref 0 +let incr_depth () = + depth := !depth + 1 ; + if !depth > !max_depth then max_depth := !depth + +let group_by fn = function + | [] -> [] + | first :: rest -> + let rec go previous previouses = function + | [] -> [previous::previouses] + | x::xs when fn x previous -> go x (previous::previouses) xs + | x::xs -> (previous::previouses) :: go x [] xs + in + List.rev (go first [] rest) + +let print_header () = + Format.printf "%5s %6s %5s/%-5s %s@." + "run" "todo" "depth" "max" "latest trace" + +let print_stats trie = + Format.printf "%#4dk %#6d %#5d/%-#5d " + (!num_runs / 1000) + (T.nb_todos trie) + (Option.value (T.min_depth trie) ~default:0) + !max_depth ; + List.iter + (function + | steps when List.compare_length_with steps 3 <= 0 -> + List.iter (fun step -> Format.printf "%s" (Uid.pretty step.proc_id)) steps + | (step :: _) as steps -> + Format.printf "(%s%i)" (Uid.pretty step.proc_id) (List.length steps) + | _ -> assert false) + (group_by (fun a b -> a.proc_id = b.proc_id) !schedule_for_checks) ; + Format.printf "%!" + +let print_trace () = + List.iter + (fun step -> Format.printf " %s@." (Step.to_string step)) + (List.rev !schedule_for_checks) + + +let setup_run func init_schedule trie = + processes := IdMap.empty ; + finished_processes := 0; + num_runs := !num_runs + 1; + if !num_runs mod 1000 == 0 then begin + Format.printf "%c[2K\r%!" (Char.chr 0x1b) ; + print_stats trie + end ; schedule_for_checks := init_schedule; - (* cache the number of processes in case it's expensive*) - let num_processes = CCVector.length processes in - (* current number of ops we are through the current run *) + depth := 0 ; + tracing := true; + let uid = [] in + let fiber_f h = continue_with (fiber func) () h in + push_process + { uid; domain_generator = 0; atomic_generator = 0; next_op = Start; next_repr = None; resume_func = fiber_f; finished = false } ; + tracing := false + +let do_run init_schedule = let rec run_trace s () = tracing := false; !every_func (); tracing := true; match s with - | [] -> if !finished_processes == num_processes then begin + | [] -> if !finished_processes == IdMap.cardinal !processes then begin tracing := false; !final_func (); tracing := true end - | (process_id_to_run, next_op, next_ptr) :: schedule -> begin - if !finished_processes == num_processes then - (* this should never happen *) - failwith("no enabled processes") - else - begin - let process_to_run = CCVector.get processes process_id_to_run in - assert(process_to_run.next_op = next_op); - assert(process_to_run.next_repr = next_ptr); - process_to_run.resume_func (handler process_id_to_run (run_trace schedule)) - end - end + | { proc_id = process_id_to_run ; op = next_op ; obj_ptr = next_ptr } :: schedule -> + incr_depth () ; + let process_to_run = get_process process_id_to_run in + assert(not process_to_run.finished); + assert(atomic_op_equal process_to_run.next_op next_op); + assert(process_to_run.next_repr = next_ptr); + process_to_run.resume_func (handler process_to_run (run_trace schedule)) in tracing := true; - run_trace init_schedule (); - finished_processes := 0; + run_trace (List.rev init_schedule) (); tracing := false; - num_runs := !num_runs + 1; - if !num_runs mod 1000 == 0 then - Printf.printf "run: %d\n" !num_runs; - let procs = CCVector.mapi (fun i p -> { proc_id = i; op = p.next_op; obj_ptr = p.next_repr }) processes |> CCVector.to_list in - let current_enabled = CCVector.to_seq processes - |> OSeq.zip_index - |> Seq.filter (fun (_,proc) -> not proc.finished) - |> Seq.map (fun (id,_) -> id) - |> IntSet.of_seq in - CCVector.clear processes; - atomics_counter := 1; - match last_element init_schedule with - | (run_proc, run_op, run_ptr) -> - { procs; enabled = current_enabled; run_proc; run_op; run_ptr; backtrack = IntSet.empty } - -let rec explore func state clock last_access = - let s = last_element state in - List.iter (fun proc -> - let j = proc.proc_id in - let i = Option.bind proc.obj_ptr (fun ptr -> IntMap.find_opt ptr last_access) |> Option.value ~default:0 in - if i != 0 then begin - let pre_s = List.nth state (i-1) in - if IntSet.mem j pre_s.enabled then - pre_s.backtrack <- IntSet.add j pre_s.backtrack - else - pre_s.backtrack <- IntSet.union pre_s.backtrack pre_s.enabled + let procs = + IdMap.mapi (fun i p -> { proc_id = i; op = p.next_op; obj_ptr = p.next_repr }) !processes + in + let current_enabled = + IdMap.to_seq !processes + |> Seq.filter (fun (_, p) -> not p.finished) + |> Seq.map (fun (id, _) -> id) + |> IdSet.of_seq + in + let last_run = List.hd init_schedule in + { procs; enabled = current_enabled; run = last_run; backtrack = IdMap.empty } + +type category = + | Ignore + | Read + | Write + | Read_write + +let categorize = function + | Spawn | Start | Make -> Ignore + | Get -> Read + | Set -> Write + | Exchange | FetchAndAdd -> Read_write + | CompareAndSwap res -> + begin match !res with + | Unknown -> failwith "CAS unknown outcome" (* should not happen *) + | Success -> Read_write + | Failed -> Read_write end - ) s.procs; - if IntSet.cardinal s.enabled > 0 then begin - let p = IntSet.min_elt s.enabled in - let dones = ref IntSet.empty in - s.backtrack <- IntSet.singleton p; - while IntSet.(cardinal (diff s.backtrack !dones)) > 0 do - let j = IntSet.min_elt (IntSet.diff s.backtrack !dones) in - dones := IntSet.add j !dones; - let j_proc = List.nth s.procs j in - let schedule = (List.map (fun s -> (s.run_proc, s.run_op, s.run_ptr)) state) @ [(j, j_proc.op, j_proc.obj_ptr)] in - let statedash = state @ [do_run func schedule] in - let state_time = (List.length statedash)-1 in - let new_last_access = match j_proc.obj_ptr with Some(ptr) -> IntMap.add ptr state_time last_access | None -> last_access in - let new_clock = IntMap.add j state_time clock in - explore func statedash new_clock new_last_access - done + +let rec list_findi predicate lst i = match lst with + | [] -> None + | x::_ when predicate i x -> Some (i, x) + | _::xs -> list_findi predicate xs (i + 1) + +let list_findi predicate lst = list_findi predicate lst 0 + +let mark_backtrack proc time state (last_read, last_write) = + let j = proc.proc_id in + let find ptr map = match IdMap.find_opt ptr map with + | None -> None + | Some lst -> + List.find_opt (fun (_, proc_id) -> proc_id <> j) lst + in + let find_loc proc = + match categorize proc.op, proc.obj_ptr with + | Ignore, _ -> None + | Read, Some ptr -> find ptr last_write + | (Write | Read_write), Some ptr -> max (find ptr last_read) (find ptr last_write) + | _ -> assert false + in + let rec find_replay_trace ~pre_s ~lower ~upper proc_id = + let replay_steps = List.filteri (fun k s -> k >= lower && k <= time - upper && s.run.proc_id = proc_id) state in + let replay_steps = List.rev_map (fun s -> s.run) replay_steps in + let causal p = match find_loc p with + | None -> true + | Some (k, _) -> k <= upper in + if List.for_all causal replay_steps + then if IdSet.mem proc_id pre_s.enabled + then Some replay_steps + else let is_parent k s = k > lower && k < time - upper && s.run.op = Spawn && s.run.obj_ptr = Some proc_id in + match list_findi is_parent state with + | None -> None + | Some (parent_i, spawn) -> + assert (parent_i > lower) ; + if pre_s.run.proc_id = spawn.run.proc_id + then None + else begin match find_replay_trace ~pre_s ~lower:parent_i ~upper spawn.run.proc_id with + | None -> None + | Some spawn_steps -> Some (spawn_steps @ replay_steps) + end + else None + in + match find_loc proc with + | None -> () + | Some (i, _) -> + assert (i < time) ; + let pre_s = List.nth state (time - i + 1) in + match find_replay_trace ~pre_s ~lower:0 ~upper:i proc.proc_id with + | None -> () + | Some [] -> failwith "empty replay steps" + | Some ((first_step :: _) as replay_steps) -> + let j = first_step.proc_id in + if match IdMap.find_opt j pre_s.backtrack with + | None -> true + | Some lst -> + assert (List.length lst <= List.length replay_steps) ; + false + then pre_s.backtrack <- IdMap.add j replay_steps pre_s.backtrack + +let round_robin previous = + let urgent = + IdMap.fold + (fun proc_id step acc -> + assert (proc_id = step.proc_id) ; + if IdSet.mem proc_id previous.enabled + then match step.op with + | Start | Spawn | Make -> proc_id :: acc + | _ -> acc + else acc) + previous.procs + [] + in + match urgent with + | [] -> + let id = previous.run.proc_id in + begin match IdSet.find_first_opt (fun j -> j > id) previous.enabled with + | None -> IdSet.min_elt_opt previous.enabled + | found -> found + end + | [single] -> Some single + | first :: _ -> Some first (* which one is best? *) + +let rec explore func time state trie current_schedule (last_read, last_write) = + let s = List.hd state in + assert (IdMap.is_empty s.backtrack) ; + let todo_next = + match T.next trie with + | None -> + begin match round_robin s with + | None -> + assert (IdSet.is_empty s.enabled) ; + None + | Some p -> + let init_step = IdMap.find p s.procs in + Some (init_step, T.todo) + end + | Some (step, new_trie) -> + Some (step, new_trie) + in + match todo_next with + | None -> false, T.ok 1 + | Some (new_step, new_trie) -> + begin try + let new_schedule = new_step :: current_schedule in + let schedule = + schedule_for_checks := new_schedule; + [new_step] + in + let step = do_run schedule in + let new_state = + { step with run = new_step ; backtrack = IdMap.empty } + :: state + in + let new_time = time + 1 in + if T.nb_oks new_trie = 0 && not (T.has_error new_trie) + then mark_backtrack step.run new_time new_state (last_read, last_write); + let add ptr map = + IdMap.update + ptr + (function None -> Some [new_time, step.run.proc_id] + | Some steps -> Some ((new_time, step.run.proc_id) :: steps)) + map + in + let new_last_access = + match categorize step.run.op, step.run.obj_ptr with + | Ignore, _ -> last_read, last_write + | Read, Some ptr -> add ptr last_read, last_write + | Write, Some ptr -> last_read, add ptr last_write + | Read_write, Some ptr -> add ptr last_read, add ptr last_write + | _ -> assert false + in + let has_error, child = + explore func new_time new_state new_trie new_schedule new_last_access + in + let trie = + IdMap.fold + (fun _ backstep trie -> + T.insert_todos ~skip_edge:new_step backstep trie) + s.backtrack + trie + in + let trie = T.add new_step child trie in + let trie = T.simplify trie in + has_error, trie + with exn -> + let msg = Printexc.to_string exn in + let trie = T.add new_step (T.error msg) trie in + print_header () ; + print_stats trie ; + + Format.printf "@.@.ERROR: %s@." msg ; + Printexc.print_backtrace stdout ; + print_trace () ; + + true, trie end let every f = @@ -254,22 +499,40 @@ let check f = let tracing_at_start = !tracing in tracing := false; if not (f ()) then begin - Printf.printf "Found assertion violation at run %d:\n" !num_runs; - List.iter (fun s -> - begin match s with - | (last_run_proc, last_run_op, last_run_ptr) -> begin - let last_run_ptr = Option.map string_of_int last_run_ptr |> Option.value ~default:"" in - Printf.printf "Process %d: %s %s\n" last_run_proc (atomic_op_str last_run_op) last_run_ptr - end; - end; - ) !schedule_for_checks; - assert(false) + Format.printf "Check: found assertion violation!@." ; + failwith "invalid check" end; tracing := tracing_at_start +let error_count = ref 0 + +let explore_one func trie = + let empty_schedule = [{ proc_id = [] ; op = Start ; obj_ptr = None }] in + setup_run func empty_schedule trie ; + let empty_state = do_run empty_schedule :: [] in + let empty_last_access = IdMap.empty, IdMap.empty in + explore func 1 empty_state trie empty_schedule empty_last_access + +let rec explore_all ~count ~errors func trie = + if !error_count >= errors + || !num_runs >= count + || T.nb_todos trie = 0 + then trie (* graphviz_output ?graphviz trie *) + else begin + let has_error, trie = explore_one func trie in + if has_error then error_count := !error_count + 1 ; + explore_all ~count ~errors func trie + end -let trace func = - let empty_state = do_run func [(0, Start, None)] :: [] in - let empty_clock = IntMap.empty in - let empty_last_access = IntMap.empty in - explore func empty_state empty_clock empty_last_access +let graphviz_output ?graphviz trie = + match graphviz with + | None -> () + | Some filename -> T.graphviz ~filename trie + +let trace ?(count = max_int) ?(errors = 1) ?graphviz func = + print_header () ; + num_runs := 0 ; + let trie = explore_all ~count ~errors func T.todo in + graphviz_output ?graphviz trie ; + Format.printf "@.Found %#i errors after %#i runs.@." !error_count !num_runs ; + () diff --git a/src/tracedAtomic.mli b/src/tracedAtomic.mli index 3f6ee63..c4a1c2e 100644 --- a/src/tracedAtomic.mli +++ b/src/tracedAtomic.mli @@ -46,8 +46,13 @@ val incr : int t -> unit val decr : int t -> unit (** [decr r] atomically decrements the value of [r] by [1]. *) -val trace : (unit -> unit) -> unit -(** start the simulation trace *) +val trace : ?count:int -> ?errors:int -> ?graphviz:string -> (unit -> unit) -> unit +(** [trace fn] starts the simulation trace on the function [fn]. + + - [?count] is the max number of traces to test (defaults to infinity) + - [?errors] is the max number of errors to find (defaults to 1) + - [?graphviz] is a filename to use for outputting a dot file (defaults to no output) + *) val spawn : (unit -> unit) -> unit (** spawn [f] as a new 'thread' *) @@ -59,4 +64,4 @@ val final : (unit -> unit) -> unit (** run [f] after all processes complete *) val every : (unit -> unit) -> unit -(** run [f] between every possible interleaving *) \ No newline at end of file +(** run [f] between every possible interleaving *) diff --git a/src/trie.ml b/src/trie.ml new file mode 100644 index 0000000..24bfb34 --- /dev/null +++ b/src/trie.ml @@ -0,0 +1,331 @@ +module type EDGE = sig + type t + val compare : t -> t -> int + val to_string : t -> string +end + +module Make (Edge : EDGE) = struct + + module M = Map.Make (Edge) + + type edge = Edge.t + + type t = + { error : bool + ; is_todo : bool + ; todo : int option + ; nb : int + ; nb_ok : int + ; s : s + } + + and s = + | Ok + | Skip + | Todo + | Error of string + | Branch of t M.t + + let min_depth t = t.todo + let nb_todos t = t.nb + let nb_oks t = t.nb_ok + let has_error t = t.error + + let min_todo a b = match a, b with + | None, t | t, None -> t + | Some a, Some b -> Some (min a b) + + let incr_todo = function None -> None | Some t -> Some (t + 1) + + let ok nb_ok = { error = false ; is_todo = false ; todo = None ; nb = 0 ; nb_ok ; s = Ok } + let skip = { (ok 0) with s = Skip } + let error msg = { (ok 0) with error = true ; s = Error msg } + let todo = { error = false ; is_todo = true ; todo = Some 0 ; nb = 1 ; nb_ok = 0 ; s = Todo } + + let branch children = + let error, is_todo, todo, nb, nb_ok = + M.fold + (fun _ child (acc_error, acc_is_todo, acc_todo, acc_nb, acc_nb_ok) -> + child.error || acc_error, + (child.s = Skip || child.is_todo) && acc_is_todo, + min_todo child.todo acc_todo, + acc_nb + child.nb, + acc_nb_ok + child.nb_ok) + children + (false, true, None, 0, 0) + in + { error ; is_todo ; todo = incr_todo todo ; nb ; nb_ok ; s = Branch children } + + let simplify t = + match t.s with + | _ when t.error || t.todo <> None -> t + | Branch children -> + let ok_children, has_skip = ref 0, ref false in + M.iter + (fun _ c -> match c.s with + | Ok -> incr ok_children + | Skip -> has_skip := true + | Branch _ -> + assert (c.todo = None) ; + assert (c.error = false) ; + assert (c.nb_ok > 0) + | _ -> assert false) + children ; + if !ok_children = 1 && !has_skip + then t + else ok t.nb_ok + | _ -> t + + let add edge child t = + match t.s with + | Ok | Skip | Error _ -> assert false + | Todo -> branch (M.singleton edge child) + | Branch m -> branch (M.add edge child m) + + let todolist ~skip_edge edges = + List.fold_right + (fun edge t -> branch (M.add skip_edge skip (M.singleton edge t))) + edges + todo + + let insert_todos ~skip_edge todos t = + match t.s, todos with + | Todo, [] + | Skip, _ + | (Ok | Error _), _ + | Branch _, [] -> + t + | Todo, todos -> todolist ~skip_edge todos + | Branch m, edge::todos -> + if M.mem edge m + then t + else begin + let t = todolist ~skip_edge todos in + let m = M.add edge t m in + assert (M.mem skip_edge m) ; + branch m + end + + let next t = match t.todo, t.s with + | None, _ -> None + | Some _, Todo -> None + | Some todo_distance, Branch m -> + let candidates = + M.fold + (fun edge child acc -> + match child.todo with + | Some distance (* when distance = todo_distance - 1 *) -> + let weight = 1.0 /. float (100 * child.nb + distance - todo_distance + 1) in + (weight, edge, child) :: acc + | _ -> acc) + m + [] + in + let candidates = + if not t.error + then candidates + else match List.filter (fun (_, _, child) -> child.error || child.is_todo) candidates with + | [] -> candidates + | cs -> cs + in + begin match candidates with + | [(_, edge, child)] -> Some (edge, child) + | _ -> + let total = List.fold_left (fun acc (w, _, _) -> acc +. w) 0.0 candidates in + let selection = Random.float total in + let rec go selection = function + | [] -> failwith "Trie: candidate not found" + | [(_, edge, child)] -> Some (edge, child) + | (weight, edge, child) :: _ when selection <= weight -> Some (edge, child) + | (weight, _, _) :: rest -> go (selection -. weight) rest + in + go selection candidates + end + | _ -> assert false + + + let graphviz ~filename t = + let h = open_out filename in + let gen = ref 0 in + let fresh () = + let u = !gen in + gen := u + 1 ; + u + in + let refcounts = Hashtbl.create 16 in + let refcount uid = + match Hashtbl.find_opt refcounts uid with + | None -> 0 + | Some count -> count + in + let incr_refcount uid = + let rc = refcount uid + 1 in + assert (rc >= 1) ; + Hashtbl.replace refcounts uid rc + in + let decr_refcount uid = + let rc = refcount uid - 1 in + assert (rc >= 1) ; + Hashtbl.replace refcounts uid rc + in + let cache = Hashtbl.create 16 in + let id_magic = Char.chr 0 in + let rec go ~print ~rc t = + let s = t.s in + let label = + match s with + | Ok -> "Ok" + | Skip -> "Skip" + | Todo -> "TODO" + | Error msg -> "ERROR: " ^ msg + | Branch _ -> "" + in + let has_error = + if t.error + then "style=filled fillcolor=black fontcolor=white" + else if t.todo = None + then "style=filled fillcolor=green" + else if t.is_todo + then "style=filled fillcolor=grey" + else "" + in + let style = if label = "" then "shape=point" else Printf.sprintf "label=%S" label in + let buf = Buffer.create 16 in + Printf.bprintf buf "%c [%s %s] ;\n" id_magic style has_error ; + let ids_children = ref [] in + begin match s with + | _ when not t.error -> () + | Branch children -> + M.iter + (fun edge child -> + if child.s = Skip + then () + else begin + let id_edge = edge_to_string ~print ~rc edge child in + let arrow_weight = + if t.error && child.error + then "[weight=1000 penwidth=4] " + else "" + in + Printf.bprintf buf "%c -> e%i %s ;\n" id_magic id_edge arrow_weight ; + ids_children := id_edge :: !ids_children ; + end + ) + children + | _ -> () + end ; + let self = Buffer.contents buf in + match Hashtbl.find cache self with + | (uid, printed) -> + if not print + then (if rc then (incr_refcount uid ; List.iter decr_refcount !ids_children)) + else if not printed + then begin + Hashtbl.replace cache self (uid, true) ; + let self = String.concat ("n" ^ string_of_int uid) (String.split_on_char id_magic self) in + Printf.fprintf h "%s" self ; + end ; + uid + | exception Not_found -> + let uid = fresh () in + Hashtbl.add cache self (uid, print) ; + if not print + then (if rc then incr_refcount uid) + else begin + let self = String.concat ("n" ^ string_of_int uid) (String.split_on_char id_magic self) in + Printf.fprintf h "%s" self ; + end ; + uid + + and skip_forward acc t = + if refcount (go ~print:false ~rc:false t) > 1 + then List.rev acc, t + else + let skip = match t.s with + | Branch children -> + M.fold + (fun edge child acc -> if child.s = Skip then acc else (edge, child) :: acc) + children + [] + | _ -> [] + in + match skip with + | [edge, single] -> skip_forward (edge :: acc) single + | _ -> (List.rev acc, t) + + and edge_to_string ~print ~rc edge child = + let edges, child = + if print + then skip_forward [edge] child + else [edge], child + in + let buf = Buffer.create 16 in + let label = String.concat "\\l" (List.map Edge.to_string edges) in + let has_error = + if child.error + then "style=filled fillcolor=black fontcolor=white" + else if child.todo = None + then "style=filled fillcolor=lightgreen" + else if child.is_todo + then "style=filled fillcolor=grey" + else "" + in + let arrow_weight = + if child.error + then "[weight=1000 penwidth=4] " + else "" + in + Printf.bprintf buf "%c [shape=rectangle label=\"%s\\l\" %s] ;\n" id_magic label has_error ; + let id_child = + if not child.error + then begin + if child.nb > 1 then begin + Printf.bprintf buf "%c_todo [label=%S fillcolor=grey style=filled];\n" + id_magic + (Printf.sprintf "TODO %#i" child.nb) ; + Printf.bprintf buf "%c -> %c_todo;\n" id_magic id_magic ; + end ; + if child.nb_ok > 0 then begin + Printf.bprintf buf "%c_ok [label=%S fillcolor=lightgreen style=filled];\n" + id_magic + (Printf.sprintf "OK %#i" child.nb_ok) ; + Printf.bprintf buf "%c -> %c_ok;\n" id_magic id_magic ; + end ; + None + end + else begin + let id_child = go ~print ~rc child in + Printf.bprintf buf "%c -> n%i %s;\n" id_magic id_child arrow_weight ; + Some id_child + end + in + let self = Buffer.contents buf in + match Hashtbl.find cache self with + | (uid, printed) -> + if not print + then (if rc then ( incr_refcount uid ; Option.iter decr_refcount id_child ) ) + else if not printed + then begin + Hashtbl.replace cache self (uid, true) ; + let self = String.concat ("e" ^ string_of_int uid) (String.split_on_char id_magic self) in + Printf.fprintf h "%s" self ; + end ; + uid + | exception Not_found -> + let uid = fresh () in + Hashtbl.add cache self (uid, print) ; + if not print + then (if rc then incr_refcount uid) + else begin + let self = String.concat ("e" ^ string_of_int uid) (String.split_on_char id_magic self) in + Printf.fprintf h "%s" self ; + end ; + uid + in + Printf.fprintf h "digraph dscheck {\n" ; + Printf.fprintf h "node [fontname=%S] ;\n" "Courier New" ; + let _ : int = go ~print:false ~rc:true t in + let _ : int = go ~print:true ~rc:false t in + Printf.fprintf h "}\n" ; + close_out h +end diff --git a/tests/dune b/tests/dune index 2c7a92f..899b715 100644 --- a/tests/dune +++ b/tests/dune @@ -2,3 +2,13 @@ (names test_list test_naive_counter) (modules test_list test_naive_counter) (libraries dscheck)) + +(executable + (name test_simple) + (modules test_simple) + (libraries dscheck)) + +(executable + (name test_spinlock) + (modules test_spinlock) + (libraries dscheck)) diff --git a/tests/test_list.ml b/tests/test_list.ml index 38bb645..adc77f3 100644 --- a/tests/test_list.ml +++ b/tests/test_list.ml @@ -4,17 +4,19 @@ module Atomic = Dscheck.TracedAtomic type conc_list = { value: int; next: conc_list option } -let rec add_node list_head n = +let rec add_node ~bug list_head n = (* try to add a new node to head *) let old_head = Atomic.get list_head in let new_node = { value = n ; next = (Some old_head) } in (* introduce bug *) - if Atomic.get list_head = old_head then begin + if bug && Atomic.get list_head = old_head then begin Atomic.set list_head new_node; true end + else if Atomic.compare_and_set list_head old_head new_node + then true else - add_node list_head n + add_node ~bug list_head n let check_node list_head n = let rec check_from_node node = @@ -28,14 +30,14 @@ let check_node list_head n = (* try to find the node *) check_from_node (Atomic.get list_head) -let add_and_check list_head n () = - assert(add_node list_head n); +let add_and_check ~bug list_head n () = + assert(add_node ~bug list_head n); assert(check_node list_head n) -let create_test upto () = +let create_test ~buggy upto () = let list_head = Atomic.make { value = 0 ; next = None } in for x = 1 to upto do - Atomic.spawn (add_and_check list_head x); + Atomic.spawn (add_and_check ~bug:(x = buggy) list_head x); done; Atomic.final (fun () -> for x = 1 to upto do @@ -44,4 +46,4 @@ let create_test upto () = ) let () = - Atomic.trace (create_test 8) + Atomic.trace (create_test ~buggy:2 4) diff --git a/tests/test_simple.ml b/tests/test_simple.ml new file mode 100644 index 0000000..c2179aa --- /dev/null +++ b/tests/test_simple.ml @@ -0,0 +1,18 @@ +module Atomic = Dscheck.TracedAtomic + +(* Example from ocaml-parrafuzz presentation at the OCaml Workshop 2021: + https://www.youtube.com/watch?v=GZsUoSaIpIs *) + +let test i = + let x = Atomic.make i in + let y = Atomic.make 0 in + Atomic.spawn (fun () -> + Atomic.spawn (fun () -> if Atomic.get x = 10 then Atomic.set y 2)) ; + Atomic.set x 0 ; + Atomic.set y 1 ; + Atomic.final (fun () -> Atomic.check (fun () -> Atomic.get y <> 2)) + +let () = + Atomic.trace (fun () -> test 0) ; + Printf.printf "\n-----------------------------\n\n%!" ; + Atomic.trace (fun () -> test 10) diff --git a/tests/test_spinlock.ml b/tests/test_spinlock.ml new file mode 100644 index 0000000..c703d84 --- /dev/null +++ b/tests/test_spinlock.ml @@ -0,0 +1,10 @@ +module Atomic = Dscheck.TracedAtomic + +let test () = + let lock = Atomic.make true in + let result = Atomic.make 0 in + Atomic.spawn (fun () -> while Atomic.get lock do Atomic.incr result done) ; + Atomic.spawn (fun () -> Atomic.set lock false) ; + Atomic.final (fun () -> Atomic.check (fun () -> Atomic.get result < 10)) + +let () = Atomic.trace test