xref: /drstd/src/std/sync/barrier.rs (revision 86982c5e9b2eaa583327251616ee822c36288824)
1 #[cfg(test)]
2 mod tests;
3 
4 use crate::std::fmt;
5 use crate::std::sync::{Condvar, Mutex};
6 
7 /// A barrier enables multiple threads to synchronize the beginning
8 /// of some computation.
9 ///
10 /// # Examples
11 ///
12 /// ```
13 /// use std::sync::{Arc, Barrier};
14 /// use std::thread;
15 ///
16 /// let n = 10;
17 /// let mut handles = Vec::with_capacity(n);
18 /// let barrier = Arc::new(Barrier::new(n));
19 /// for _ in 0..n {
20 ///     let c = Arc::clone(&barrier);
21 ///     // The same messages will be printed together.
22 ///     // You will NOT see any interleaving.
23 ///     handles.push(thread::spawn(move|| {
24 ///         println!("before wait");
25 ///         c.wait();
26 ///         println!("after wait");
27 ///     }));
28 /// }
29 /// // Wait for other threads to finish.
30 /// for handle in handles {
31 ///     handle.join().unwrap();
32 /// }
33 /// ```
34 pub struct Barrier {
35     lock: Mutex<BarrierState>,
36     cvar: Condvar,
37     num_threads: usize,
38 }
39 
40 // The inner state of a double barrier
41 struct BarrierState {
42     count: usize,
43     generation_id: usize,
44 }
45 
46 /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
47 /// in the [`Barrier`] have rendezvoused.
48 ///
49 /// # Examples
50 ///
51 /// ```
52 /// use std::sync::Barrier;
53 ///
54 /// let barrier = Barrier::new(1);
55 /// let barrier_wait_result = barrier.wait();
56 /// ```
57 pub struct BarrierWaitResult(bool);
58 
59 impl fmt::Debug for Barrier {
60     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61         f.debug_struct("Barrier").finish_non_exhaustive()
62     }
63 }
64 
65 impl Barrier {
66     /// Creates a new barrier that can block a given number of threads.
67     ///
68     /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
69     /// up all threads at once when the `n`th thread calls [`wait()`].
70     ///
71     /// [`wait()`]: Barrier::wait
72     ///
73     /// # Examples
74     ///
75     /// ```
76     /// use std::sync::Barrier;
77     ///
78     /// let barrier = Barrier::new(10);
79     /// ```
80     #[must_use]
81     pub fn new(n: usize) -> Barrier {
82         Barrier {
83             lock: Mutex::new(BarrierState {
84                 count: 0,
85                 generation_id: 0,
86             }),
87             cvar: Condvar::new(),
88             num_threads: n,
89         }
90     }
91 
92     /// Blocks the current thread until all threads have rendezvoused here.
93     ///
94     /// Barriers are re-usable after all threads have rendezvoused once, and can
95     /// be used continuously.
96     ///
97     /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
98     /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
99     /// from this function, and all other threads will receive a result that
100     /// will return `false` from [`BarrierWaitResult::is_leader()`].
101     ///
102     /// # Examples
103     ///
104     /// ```
105     /// use std::sync::{Arc, Barrier};
106     /// use std::thread;
107     ///
108     /// let n = 10;
109     /// let mut handles = Vec::with_capacity(n);
110     /// let barrier = Arc::new(Barrier::new(n));
111     /// for _ in 0..n {
112     ///     let c = Arc::clone(&barrier);
113     ///     // The same messages will be printed together.
114     ///     // You will NOT see any interleaving.
115     ///     handles.push(thread::spawn(move|| {
116     ///         println!("before wait");
117     ///         c.wait();
118     ///         println!("after wait");
119     ///     }));
120     /// }
121     /// // Wait for other threads to finish.
122     /// for handle in handles {
123     ///     handle.join().unwrap();
124     /// }
125     /// ```
126     pub fn wait(&self) -> BarrierWaitResult {
127         let mut lock = self.lock.lock().unwrap();
128         let local_gen = lock.generation_id;
129         lock.count += 1;
130         if lock.count < self.num_threads {
131             let _guard = self
132                 .cvar
133                 .wait_while(lock, |state| local_gen == state.generation_id)
134                 .unwrap();
135             BarrierWaitResult(false)
136         } else {
137             lock.count = 0;
138             lock.generation_id = lock.generation_id.wrapping_add(1);
139             self.cvar.notify_all();
140             BarrierWaitResult(true)
141         }
142     }
143 }
144 
145 impl fmt::Debug for BarrierWaitResult {
146     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147         f.debug_struct("BarrierWaitResult")
148             .field("is_leader", &self.is_leader())
149             .finish()
150     }
151 }
152 
153 impl BarrierWaitResult {
154     /// Returns `true` if this thread is the "leader thread" for the call to
155     /// [`Barrier::wait()`].
156     ///
157     /// Only one thread will have `true` returned from their result, all other
158     /// threads will have `false` returned.
159     ///
160     /// # Examples
161     ///
162     /// ```
163     /// use std::sync::Barrier;
164     ///
165     /// let barrier = Barrier::new(1);
166     /// let barrier_wait_result = barrier.wait();
167     /// println!("{:?}", barrier_wait_result.is_leader());
168     /// ```
169     #[must_use]
170     pub fn is_leader(&self) -> bool {
171         self.0
172     }
173 }
174