1use std::fmt;
21use std::future::Future;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24
25use tokio::task;
26
27use crate::{Error, ErrorKind, Result};
28
29pub struct JoinHandle<T>(task::JoinHandle<T>);
38
39impl<T> Unpin for JoinHandle<T> {}
40
41impl<T: Send + 'static> Future for JoinHandle<T> {
42 type Output = crate::Result<T>;
43
44 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45 Pin::new(&mut self.get_mut().0).poll(cx).map(|r| {
46 r.map_err(|e| Error::new(ErrorKind::Unexpected, "spawned task failed").with_source(e))
47 })
48 }
49}
50
51#[derive(Clone)]
58pub struct RuntimeHandle {
59 handle: tokio::runtime::Handle,
60}
61
62impl fmt::Debug for RuntimeHandle {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("RuntimeHandle").finish()
65 }
66}
67
68impl RuntimeHandle {
69 fn from_tokio_handle(handle: tokio::runtime::Handle) -> Self {
70 Self { handle }
71 }
72
73 pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
75 where
76 F: Future + Send + 'static,
77 F::Output: Send + 'static,
78 {
79 JoinHandle(self.handle.spawn(future))
80 }
81
82 pub fn spawn_blocking<F, T>(&self, f: F) -> JoinHandle<T>
84 where
85 F: FnOnce() -> T + Send + 'static,
86 T: Send + 'static,
87 {
88 JoinHandle(self.handle.spawn_blocking(f))
89 }
90}
91
92#[derive(Clone)]
108pub struct Runtime {
109 io: RuntimeHandle,
110 cpu: RuntimeHandle,
111}
112
113impl fmt::Debug for Runtime {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("Runtime").finish()
116 }
117}
118
119impl Runtime {
120 pub fn new(runtime: &tokio::runtime::Runtime) -> Self {
122 let handle = RuntimeHandle::from_tokio_handle(runtime.handle().clone());
123 Self {
124 io: handle.clone(),
125 cpu: handle,
126 }
127 }
128
129 pub fn new_with_split(
131 io_runtime: &tokio::runtime::Runtime,
132 cpu_runtime: &tokio::runtime::Runtime,
133 ) -> Self {
134 Self {
135 io: RuntimeHandle::from_tokio_handle(io_runtime.handle().clone()),
136 cpu: RuntimeHandle::from_tokio_handle(cpu_runtime.handle().clone()),
137 }
138 }
139
140 pub fn current() -> Self {
149 Self::try_current().expect(
150 "Runtime::current() called outside a tokio runtime context. \
151 Call it from within #[tokio::main] / #[tokio::test], or construct \
152 a Runtime explicitly via Runtime::new / Runtime::new_with_split.",
153 )
154 }
155
156 pub fn try_current() -> Result<Self> {
159 let handle = tokio::runtime::Handle::try_current().map_err(|e| {
160 Error::new(
161 ErrorKind::Unexpected,
162 "no tokio runtime in context; call Runtime::try_current() \
163 from within a tokio runtime, or construct a Runtime explicitly \
164 via Runtime::new / Runtime::new_with_split",
165 )
166 .with_source(e)
167 })?;
168 let rh = RuntimeHandle::from_tokio_handle(handle);
169 Ok(Self {
170 io: rh.clone(),
171 cpu: rh,
172 })
173 }
174
175 pub fn io(&self) -> &RuntimeHandle {
177 &self.io
178 }
179
180 pub fn cpu(&self) -> &RuntimeHandle {
182 &self.cpu
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 struct TestRuntime {
193 tokio: tokio::runtime::Runtime,
194 rt: Runtime,
195 }
196
197 impl TestRuntime {
198 fn new() -> Self {
199 let tokio = tokio::runtime::Builder::new_multi_thread()
200 .enable_all()
201 .build()
202 .expect("Failed to build tokio runtime");
203 let rt = Runtime::new(&tokio);
204 Self { tokio, rt }
205 }
206
207 fn block_on<F: Future>(&self, f: F) -> F::Output {
208 self.tokio.block_on(f)
209 }
210 }
211
212 #[test]
213 fn test_runtime_spawn_io() {
214 let h = TestRuntime::new();
215 let handle = h.rt.io().spawn(async { 1 + 1 });
216 assert_eq!(h.block_on(handle).unwrap(), 2);
217 }
218
219 #[test]
220 fn test_runtime_spawn_cpu() {
221 let h = TestRuntime::new();
222 let handle = h.rt.cpu().spawn(async { 3 + 4 });
223 assert_eq!(h.block_on(handle).unwrap(), 7);
224 }
225
226 #[test]
227 fn test_runtime_spawn_blocking() {
228 let h = TestRuntime::new();
229 let handle = h.rt.cpu().spawn_blocking(|| 1 + 1);
230 assert_eq!(h.block_on(handle).unwrap(), 2);
231 }
232
233 #[test]
234 fn test_runtime_new_with_custom_runtime() {
235 let h = TestRuntime::new();
236 let handle = h.rt.io().spawn(async { 42 });
237 assert_eq!(h.block_on(handle).unwrap(), 42);
238 }
239
240 #[test]
241 fn test_runtime_split_uses_separate_handles() {
242 let io_rt = tokio::runtime::Builder::new_multi_thread()
243 .enable_all()
244 .build()
245 .unwrap();
246 let cpu_rt = tokio::runtime::Builder::new_multi_thread()
247 .enable_all()
248 .build()
249 .unwrap();
250 let rt = Runtime::new_with_split(&io_rt, &cpu_rt);
251 let io_result = io_rt.block_on(async { rt.io().spawn(async { "io" }).await.unwrap() });
255 let cpu_result = cpu_rt.block_on(async { rt.cpu().spawn(async { "cpu" }).await.unwrap() });
256 assert_eq!(io_result, "io");
257 assert_eq!(cpu_result, "cpu");
258 }
259
260 #[test]
261 fn test_runtime_clone() {
262 let h = TestRuntime::new();
263 let rt2 = h.rt.clone();
264 let handle = rt2.io().spawn(async { 5 });
265 assert_eq!(h.block_on(handle).unwrap(), 5);
266 }
267
268 #[test]
269 fn test_runtime_debug() {
270 let h = TestRuntime::new();
271 let debug_str = format!("{:?}", h.rt);
272 assert!(debug_str.contains("Runtime"));
273 }
274
275 #[tokio::test(flavor = "multi_thread")]
276 async fn test_try_current_in_runtime() {
277 let rt = Runtime::try_current().expect("should find current runtime");
278 let result = rt.io().spawn(async { 7 }).await.unwrap();
279 assert_eq!(result, 7);
280 }
281
282 #[test]
283 fn test_try_current_outside_runtime() {
284 let err = Runtime::try_current().expect_err("must fail outside runtime");
285 assert_eq!(err.kind(), ErrorKind::Unexpected);
286 }
287
288 #[test]
292 fn test_spawn_after_runtime_drop_errors() {
293 let driver = tokio::runtime::Builder::new_current_thread()
294 .enable_all()
295 .build()
296 .unwrap();
297 let owned = tokio::runtime::Builder::new_multi_thread()
298 .enable_all()
299 .build()
300 .unwrap();
301 let rt = Runtime::new(&owned);
302 drop(owned);
305
306 let handle = rt.io().spawn(async { 1 });
309 let result = driver.block_on(handle);
310 assert!(result.is_err(), "expected error after runtime shutdown");
311 }
312}