Skip to main content

dfir_rs/util/
mod.rs

1//! Helper utilities for the DFIR syntax.
2#![warn(missing_docs)]
3
4pub mod accumulator;
5#[cfg(feature = "dfir_macro")]
6#[cfg_attr(docsrs, doc(cfg(feature = "dfir_macro")))]
7pub mod demux_enum;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29use std::net::SocketAddr;
30use std::num::NonZeroUsize;
31use std::task::{Context, Poll};
32
33use futures::Stream;
34use serde::de::DeserializeOwned;
35use serde::ser::Serialize;
36#[cfg(unix)]
37pub use socket::*;
38
39/// Persit or delete tuples
40pub enum Persistence<T> {
41    /// Persist T values
42    Persist(T),
43    /// Delete all values that exactly match
44    Delete(T),
45}
46
47/// Persit or delete key-value pairs
48pub enum PersistenceKeyed<K, V> {
49    /// Persist key-value pairs
50    Persist(K, V),
51    /// Delete all tuples that have the key K
52    Delete(K),
53}
54
55/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
56pub fn unbounded_channel<T>() -> (
57    tokio::sync::mpsc::UnboundedSender<T>,
58    tokio_stream::wrappers::UnboundedReceiverStream<T>,
59) {
60    let (send, recv) = tokio::sync::mpsc::unbounded_channel();
61    let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
62    (send, recv)
63}
64
65/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
66pub fn unsync_channel<T>(
67    capacity: Option<NonZeroUsize>,
68) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
69    unsync::mpsc::channel(capacity)
70}
71
72/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
73pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
74where
75    S: Stream,
76{
77    let mut stream = Box::pin(stream);
78    std::iter::from_fn(move || {
79        match stream
80            .as_mut()
81            .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
82        {
83            Poll::Ready(opt) => opt,
84            Poll::Pending => None,
85        }
86    })
87}
88
89/// Collects the immediately available items from the `Stream` into a `FromIterator` collection.
90///
91/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
92/// to retain ownership of your stream.
93pub fn collect_ready<C, S>(stream: S) -> C
94where
95    C: FromIterator<S::Item>,
96    S: Stream,
97{
98    assert!(
99        tokio::runtime::Handle::try_current().is_err(),
100        "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
101    );
102    ready_iter(stream).collect()
103}
104
105/// Collects the immediately available items from the `Stream` into a collection (`Default` + `Extend`).
106///
107/// This consumes the stream, use [`futures::StreamExt::by_ref()`] (or just `&mut ...`) if you want
108/// to retain ownership of your stream.
109pub async fn collect_ready_async<C, S>(stream: S) -> C
110where
111    C: Default + Extend<S::Item>,
112    S: Stream,
113{
114    use std::sync::atomic::Ordering;
115
116    // Yield to let any background async tasks send to the stream.
117    tokio::task::yield_now().await;
118
119    let got_any_items = std::sync::atomic::AtomicBool::new(true);
120    let mut unfused_iter =
121        ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
122    let mut out = C::default();
123    while got_any_items.swap(false, Ordering::Relaxed) {
124        out.extend(unfused_iter.by_ref());
125        // Tokio unbounded channel returns items in lenght-128 chunks, so we have to be careful
126        // that everything gets returned. That is why we yield here and loop.
127        tokio::task::yield_now().await;
128    }
129    out
130}
131
132/// Serialize a message to bytes using bincode.
133pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
134where
135    T: Serialize,
136{
137    bytes::Bytes::from(bincode::serialize(&msg).unwrap())
138}
139
140/// Serialize a message from bytes using bincode.
141pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
142where
143    T: DeserializeOwned,
144{
145    bincode::deserialize(msg.as_ref())
146}
147
148/// Resolve the `ipv4` [`SocketAddr`] from an IP or hostname string.
149pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
150    use std::net::ToSocketAddrs;
151    let mut addrs = addr.to_socket_addrs()?;
152    let result = addrs.find(|addr| addr.is_ipv4());
153    match result {
154        Some(addr) => Ok(addr),
155        None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
156    }
157}
158
159/// Returns a length-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
160/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
161#[cfg(not(target_arch = "wasm32"))]
162pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
163    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
164    udp_bytes(socket)
165}
166
167/// Returns a newline-delimited bytes `Sink`, `Stream`, and `SocketAddr` bound to the given address.
168/// The input `addr` may have a port of `0`, the returned `SocketAddr` will have the chosen port.
169#[cfg(not(target_arch = "wasm32"))]
170pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
171    let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
172    udp_lines(socket)
173}
174
175/// Returns a newline-delimited bytes `Sender`, `Receiver`, and `SocketAddr` bound to the given address.
176///
177/// The input `addr` may have a port of `0`, the returned `SocketAddr` will be the address of the newly bound endpoint.
178/// The inbound connections can be used in full duplex mode. When a `(T, SocketAddr)` pair is fed to the `Sender`
179/// returned by this function, the `SocketAddr` will be looked up against the currently existing connections.
180/// If a match is found then the data will be sent on that connection. If no match is found then the data is silently dropped.
181#[cfg(not(target_arch = "wasm32"))]
182pub async fn bind_tcp_bytes(
183    addr: SocketAddr,
184) -> (
185    unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
186    unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
187    SocketAddr,
188) {
189    bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
190        .await
191        .unwrap()
192}
193
194/// This is the same thing as `bind_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
195#[cfg(not(target_arch = "wasm32"))]
196pub async fn bind_tcp_lines(
197    addr: SocketAddr,
198) -> (
199    unsync::mpsc::Sender<(String, SocketAddr)>,
200    unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
201    SocketAddr,
202) {
203    bind_tcp(addr, tokio_util::codec::LinesCodec::new())
204        .await
205        .unwrap()
206}
207
208/// The inverse of [`bind_tcp_bytes`].
209///
210/// `(Bytes, SocketAddr)` pairs fed to the returned `Sender` will initiate new tcp connections to the specified `SocketAddr`.
211/// These connections will be cached and reused, so that there will only be one connection per destination endpoint. When the endpoint sends data back it will be available via the returned `Receiver`
212#[cfg(not(target_arch = "wasm32"))]
213pub fn connect_tcp_bytes() -> (
214    TcpFramedSink<bytes::Bytes>,
215    TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
216) {
217    connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
218}
219
220/// This is the same thing as `connect_tcp_bytes` except instead of using a length-delimited encoding scheme it uses new lines to separate frames.
221#[cfg(not(target_arch = "wasm32"))]
222pub fn connect_tcp_lines() -> (
223    TcpFramedSink<String>,
224    TcpFramedStream<tokio_util::codec::LinesCodec>,
225) {
226    connect_tcp(tokio_util::codec::LinesCodec::new())
227}
228
229/// Sort a slice using a key fn which returns references.
230///
231/// From addendum in
232/// <https://stackoverflow.com/questions/56105305/how-to-sort-a-vec-of-structs-by-a-string-field>
233pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
234where
235    F: for<'a> Fn(&'a T) -> &'a K,
236    K: Ord,
237{
238    slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
239}
240
241/// Converts an iterator into a stream that emits `n` items at a time, yielding between each batch.
242///
243/// This is useful for breaking up a large iterator across several ticks: `source_iter(...)` always
244/// releases all items in the first tick. However using `iter_batches_stream` with `source_stream(...)`
245/// will cause `n` items to be released each tick. (Although more than that may be emitted if there
246/// are loops in the stratum).
247pub fn iter_batches_stream<I>(
248    iter: I,
249    n: usize,
250) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
251where
252    I: IntoIterator + Unpin,
253{
254    let mut count = 0;
255    let mut iter = iter.into_iter();
256    futures::stream::poll_fn(move |ctx| {
257        count += 1;
258        if n < count {
259            count = 0;
260            ctx.waker().wake_by_ref();
261            Poll::Pending
262        } else {
263            Poll::Ready(iter.next())
264        }
265    })
266}
267
268#[cfg(test)]
269mod test {
270    use super::*;
271
272    #[test]
273    pub fn test_collect_ready() {
274        let (send, mut recv) = unbounded_channel::<usize>();
275        for x in 0..1000 {
276            send.send(x).unwrap();
277        }
278        assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
279    }
280
281    #[crate::test]
282    pub async fn test_collect_ready_async() {
283        // Tokio unbounded channel returns items in 128 item long chunks, so we have to be careful that everything gets returned.
284        let (send, mut recv) = unbounded_channel::<usize>();
285        for x in 0..1000 {
286            send.send(x).unwrap();
287        }
288        assert_eq!(
289            1000,
290            collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
291        );
292    }
293}