1use std::future::Future;
21use std::pin::Pin;
22use std::task::{Context, Poll};
23
24use tokio::task;
25
26pub struct JoinHandle<T>(task::JoinHandle<T>);
27
28impl<T> Unpin for JoinHandle<T> {}
29
30impl<T: Send + 'static> Future for JoinHandle<T> {
31 type Output = T;
32
33 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
34 match self.get_mut() {
35 JoinHandle(handle) => Pin::new(handle)
36 .poll(cx)
37 .map(|r| r.expect("tokio spawned task failed")),
38 }
39 }
40}
41
42#[allow(dead_code)]
43pub fn spawn<F>(f: F) -> JoinHandle<F::Output>
44where
45 F: std::future::Future + Send + 'static,
46 F::Output: Send + 'static,
47{
48 JoinHandle(task::spawn(f))
49}
50
51#[allow(dead_code)]
52pub fn spawn_blocking<F, T>(f: F) -> JoinHandle<T>
53where
54 F: FnOnce() -> T + Send + 'static,
55 T: Send + 'static,
56{
57 JoinHandle(task::spawn_blocking(f))
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[tokio::test]
65 async fn test_tokio_spawn() {
66 let handle = spawn(async { 1 + 1 });
67 assert_eq!(handle.await, 2);
68 }
69
70 #[tokio::test]
71 async fn test_tokio_spawn_blocking() {
72 let handle = spawn_blocking(|| 1 + 1);
73 assert_eq!(handle.await, 2);
74 }
75}