1#![allow(
2 unused,
3 reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23 FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{debug, instrument, warn};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35pub const CHANNEL_MUX_PORT: u16 = 10000;
37
38pub const CHANNEL_MAGIC: u64 = 0x4859_4452_4f5f_4348;
40
41#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
46pub struct ChannelMagic {
47 pub magic: u64,
48}
49
50pub const CHANNEL_PROTOCOL_VERSION: u64 = 1;
52
53#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
58pub struct ChannelProtocolVersion {
59 pub version: u64,
60}
61
62#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
69pub struct ChannelHandshake {
70 pub channel_name: String,
72 pub sender_id: Option<String>,
76}
77
78type MuxConnection = (
80 Option<String>,
81 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
82);
83
84pub struct ChannelMux {
91 channels: std::sync::Mutex<HashMap<String, tokio::sync::mpsc::UnboundedSender<MuxConnection>>>,
93}
94
95impl Default for ChannelMux {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl ChannelMux {
102 pub fn new() -> Self {
103 Self {
104 channels: std::sync::Mutex::new(HashMap::new()),
105 }
106 }
107
108 pub fn register(
109 &self,
110 channel_name: String,
111 ) -> tokio::sync::mpsc::UnboundedReceiver<MuxConnection> {
112 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
113 let mut channels = self.channels.lock().unwrap();
114 channels.insert(channel_name, tx);
115 rx
116 }
117
118 pub async fn run_accept_loop(self: Arc<Self>, listener: TcpListener) {
119 loop {
120 let (stream, peer) = match listener.accept().await {
121 Ok(v) => v,
122 Err(e) => {
123 warn!(name: "accept_error", error = %e);
124 continue;
125 }
126 };
127 debug!(name: "mux_accepting", ?peer);
128
129 let mux = self.clone();
130 tokio::spawn(async move {
131 let (rx, _tx) = stream.into_split();
132 let mut source = FramedRead::new(rx, LengthDelimitedCodec::new());
133
134 let magic_frame = match source.next().await {
135 Some(Ok(frame)) => frame,
136 _ => {
137 warn!(name: "magic_failed", ?peer, "no magic frame");
138 return;
139 }
140 };
141
142 let magic: ChannelMagic = match bincode::deserialize(&magic_frame) {
143 Ok(m) => m,
144 Err(e) => {
145 warn!(name: "magic_deserialize_failed", ?peer, error = %e);
146 return;
147 }
148 };
149
150 if magic.magic != CHANNEL_MAGIC {
151 warn!(name: "bad_magic", ?peer, magic = magic.magic, expected = CHANNEL_MAGIC);
152 return;
153 }
154
155 let version_frame = match source.next().await {
156 Some(Ok(frame)) => frame,
157 _ => {
158 warn!(name: "version_failed", ?peer, "no version frame");
159 return;
160 }
161 };
162
163 let version: ChannelProtocolVersion = match bincode::deserialize(&version_frame) {
164 Ok(v) => v,
165 Err(e) => {
166 warn!(name: "version_deserialize_failed", ?peer, error = %e);
167 return;
168 }
169 };
170
171 if version.version != CHANNEL_PROTOCOL_VERSION {
172 warn!(name: "version_mismatch", ?peer, version = version.version, expected = CHANNEL_PROTOCOL_VERSION);
173 return;
174 }
175
176 let handshake_frame = match source.next().await {
177 Some(Ok(frame)) => frame,
178 _ => {
179 warn!(name: "handshake_failed", ?peer, "no handshake frame");
180 return;
181 }
182 };
183
184 let handshake: ChannelHandshake = match bincode::deserialize(&handshake_frame) {
185 Ok(h) => h,
186 Err(e) => {
187 warn!(name: "handshake_deserialize_failed", ?peer, error = %e);
188 return;
189 }
190 };
191
192 debug!(name: "handshake_received", ?peer, ?handshake);
193
194 let channels = mux.channels.lock().unwrap();
195 if let Some(tx_chan) = channels.get(&handshake.channel_name) {
196 let _ = tx_chan.send((handshake.sender_id, source));
197 } else {
198 warn!(
199 name: "unknown_channel",
200 channel_name = %handshake.channel_name,
201 ?peer,
202 "no registered consumer for channel"
203 );
204 }
205 });
206 }
207 }
208}
209
210pub fn get_or_init_channel_mux() -> Arc<ChannelMux> {
215 use std::sync::OnceLock;
216 static MUX: OnceLock<Arc<ChannelMux>> = OnceLock::new();
217
218 MUX.get_or_init(|| {
219 let mux = Arc::new(ChannelMux::new());
220 let mux_clone = mux.clone();
221
222 tokio::spawn(async move {
225 let bind_addr = format!("0.0.0.0:{}", CHANNEL_MUX_PORT);
226 debug!(name: "mux_listening", %bind_addr);
227 let listener = TcpListener::bind(&bind_addr)
228 .await
229 .expect("failed to bind channel mux listener");
230 mux_clone.run_accept_loop(listener).await;
231 });
232
233 mux
234 })
235 .clone()
236}
237
238pub async fn send_handshake(
241 sink: &mut FramedWrite<TcpStream, LengthDelimitedCodec>,
242 channel_name: &str,
243 sender_id: Option<&str>,
244) -> Result<(), std::io::Error> {
245 let magic = ChannelMagic {
246 magic: CHANNEL_MAGIC,
247 };
248 sink.send(bytes::Bytes::from(bincode::serialize(&magic).unwrap()))
249 .await?;
250
251 let version = ChannelProtocolVersion {
252 version: CHANNEL_PROTOCOL_VERSION,
253 };
254 sink.send(bytes::Bytes::from(bincode::serialize(&version).unwrap()))
255 .await?;
256
257 let handshake = ChannelHandshake {
258 channel_name: channel_name.to_owned(),
259 sender_id: sender_id.map(|s| s.to_owned()),
260 };
261 sink.send(bytes::Bytes::from(bincode::serialize(&handshake).unwrap()))
262 .await?;
263 Ok(())
264}
265
266pub fn deploy_containerized_o2o(target: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
267 (
268 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
269 async move {
270 let channel_name = channel_name;
271 let target = format!("{}:{}", target, self::CHANNEL_MUX_PORT);
272 debug!(name: "connecting", %target, %channel_name);
273
274 let stream = TcpStream::connect(&target).await?;
275 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
276
277 self::send_handshake(&mut sink, channel_name, None).await?;
278
279 Result::<_, std::io::Error>::Ok(sink)
280 }
281 )))
282 .splice_untyped_ctx(&()),
283 q!(LazySource::new(move || Box::pin(async move {
284 let channel_name = channel_name;
285 let mux = self::get_or_init_channel_mux();
286 let mut rx = mux.register(channel_name.to_owned());
287
288 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
289 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
290 })?;
291
292 debug!(name: "o2o_channel_connected", %channel_name);
293
294 Result::<_, std::io::Error>::Ok(source)
295 })))
296 .splice_untyped_ctx(&()),
297 )
298}
299
300pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
301 (
302 q!(sinktools::demux_map_lazy::<_, _, _, _>(
303 move |key: &TaglessMemberId| {
304 let key = key.clone();
305 let channel_name = channel_name.to_owned();
306
307 LazySink::<_, _, _, bytes::Bytes>::new(move || {
308 Box::pin(async move {
309 let target =
310 format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
311 debug!(name: "connecting", %target, channel_name = %channel_name);
312
313 let stream = TcpStream::connect(&target).await?;
314 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
315
316 self::send_handshake(&mut sink, &channel_name, None).await?;
317
318 Result::<_, std::io::Error>::Ok(sink)
319 })
320 })
321 }
322 ))
323 .splice_untyped_ctx(&()),
324 q!(LazySource::new(move || Box::pin(async move {
325 let channel_name = channel_name;
326 let mux = self::get_or_init_channel_mux();
327 let mut rx = mux.register(channel_name.to_owned());
328
329 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
330 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
331 })?;
332
333 debug!(name: "o2m_channel_connected", %channel_name);
334
335 Result::<_, std::io::Error>::Ok(source)
336 })))
337 .splice_untyped_ctx(&()),
338 )
339}
340
341pub fn deploy_containerized_m2o(target_host: &str, channel_name: &str) -> (syn::Expr, syn::Expr) {
342 (
343 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
344 Box::pin(async move {
345 let channel_name = channel_name;
346 let target = format!("{}:{}", target_host, self::CHANNEL_MUX_PORT);
347 debug!(name: "connecting", %target, %channel_name);
348
349 let stream = TcpStream::connect(&target).await?;
350 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
351
352 let container_name = std::env::var("CONTAINER_NAME").unwrap();
353 self::send_handshake(&mut sink, channel_name, Some(&container_name)).await?;
354
355 Result::<_, std::io::Error>::Ok(sink)
356 })
357 }))
358 .splice_untyped_ctx(&()),
359 q!(LazySource::new(move || Box::pin(async move {
360 let channel_name = channel_name;
361 let mux = self::get_or_init_channel_mux();
362 let mut rx = mux.register(channel_name.to_owned());
363
364 Result::<_, std::io::Error>::Ok(
365 futures::stream::unfold(rx, |mut rx| {
366 Box::pin(async move {
367 let (sender_id, source) = rx.recv().await?;
368 let from = sender_id.expect("m2o sender must provide container name");
369
370 debug!(name: "m2o_channel_connected", %from);
371
372 Some((
373 source.map(move |v| {
374 v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
375 }),
376 rx,
377 ))
378 })
379 })
380 .flatten_unordered(None),
381 )
382 })))
383 .splice_untyped_ctx(&()),
384 )
385}
386
387pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
388 (
389 q!(sinktools::demux_map_lazy::<_, _, _, _>(
390 move |key: &TaglessMemberId| {
391 let key = key.clone();
392 let channel_name = channel_name.to_owned();
393
394 LazySink::<_, _, _, bytes::Bytes>::new(move || {
395 Box::pin(async move {
396 let target =
397 format!("{}:{}", key.get_container_name(), self::CHANNEL_MUX_PORT);
398 debug!(name: "connecting", %target, channel_name = %channel_name);
399
400 let stream = TcpStream::connect(&target).await?;
401 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
402
403 let container_name = std::env::var("CONTAINER_NAME").unwrap();
404 self::send_handshake(&mut sink, &channel_name, Some(&container_name))
405 .await?;
406
407 Result::<_, std::io::Error>::Ok(sink)
408 })
409 })
410 }
411 ))
412 .splice_untyped_ctx(&()),
413 q!(LazySource::new(move || Box::pin(async move {
414 let channel_name = channel_name;
415 let mux = self::get_or_init_channel_mux();
416 let mut rx = mux.register(channel_name.to_owned());
417
418 Result::<_, std::io::Error>::Ok(
419 futures::stream::unfold(rx, |mut rx| {
420 Box::pin(async move {
421 let (sender_id, source) = rx.recv().await?;
422 let from = sender_id.expect("m2m sender must provide container name");
423
424 debug!(name: "m2m_channel_connected", %from);
425
426 Some((
427 source.map(move |v| {
428 v.map(|v| (TaglessMemberId::from_container_name(from.clone()), v))
429 }),
430 rx,
431 ))
432 })
433 })
434 .flatten_unordered(None),
435 )
436 })))
437 .splice_untyped_ctx(&()),
438 )
439}
440
441pub struct SocketIdent {
442 pub socket_ident: syn::Ident,
443}
444
445impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
446 type O = TcpListener;
447
448 fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
449 where
450 Self: Sized,
451 {
452 let ident = self.socket_ident;
453
454 (
455 QuoteTokens {
456 prelude: None,
457 expr: Some(quote::quote! { #ident }),
458 },
459 (),
460 )
461 }
462}
463
464pub fn deploy_containerized_external_sink_source_ident(socket_ident: syn::Ident) -> syn::Expr {
465 let socket_ident = SocketIdent { socket_ident };
466
467 q!(LazySinkSource::<
468 _,
469 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
470 FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
471 bytes::Bytes,
472 std::io::Error,
473 >::new(async move {
474 let (stream, peer) = socket_ident.accept().await?;
475 debug!(name: "external accepting", ?peer);
476 let (rx, tx) = stream.into_split();
477
478 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
479 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
480
481 Result::<_, std::io::Error>::Ok((fr, fw))
482 },))
483 .splice_untyped_ctx(&())
484}
485
486pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
487 q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
491 "INVALID CONTAINER NAME cluster_ids"
492 )]))
493 .as_slice())
494}
495
496#[cfg(feature = "docker_runtime")]
497pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
498 q!(TaglessMemberId::from_container_name(
499 std::env::var("CONTAINER_NAME").unwrap()
500 ))
501}
502
503#[cfg(feature = "docker_runtime")]
504pub fn cluster_membership_stream<'a>(
505 location_id: &LocationId,
506) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
507{
508 let key = location_id.key();
509
510 q!(Box::new(self::docker_membership_stream(
511 std::env::var("DEPLOYMENT_INSTANCE").unwrap(),
512 key
513 ))
514 as Box<
515 dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
516 >)
517}
518
519#[cfg(feature = "docker_runtime")]
520#[instrument(skip_all, fields(%deployment_instance, %location_key))]
524fn docker_membership_stream(
525 deployment_instance: String,
526 location_key: LocationKey,
527) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
528 use std::collections::HashSet;
529 use std::sync::{Arc, Mutex};
530
531 use bollard::Docker;
532 use bollard::query_parameters::{EventsOptions, ListContainersOptions};
533 use tokio::sync::mpsc;
534
535 let docker = Docker::connect_with_local_defaults()
536 .unwrap()
537 .with_timeout(Duration::from_secs(1));
538
539 let (event_tx, event_rx) = mpsc::unbounded_channel::<(String, MembershipEvent)>();
540
541 let events_docker = docker.clone();
543 let events_deployment_instance = deployment_instance.clone();
544 tokio::spawn(async move {
545 let mut filters = HashMap::new();
546 filters.insert("type".to_owned(), vec!["container".to_owned()]);
547 filters.insert(
548 "event".to_owned(),
549 vec!["start".to_owned(), "die".to_owned()],
550 );
551 let event_options = Some(EventsOptions {
552 filters: Some(filters),
553 ..Default::default()
554 });
555
556 let mut events = events_docker.events(event_options);
557 while let Some(event) = events.next().await {
558 if let Some((name, membership_event)) = event.ok().and_then(|e| {
559 let name = e
560 .actor
561 .as_ref()
562 .and_then(|a| a.attributes.as_ref())
563 .and_then(|attrs| attrs.get("name"))
564 .map(|s| &**s)?;
565
566 if name.contains(format!("{events_deployment_instance}-{location_key}").as_str()) {
567 match e.action.as_deref() {
568 Some("start") => Some((name.to_owned(), MembershipEvent::Joined)),
569 Some("die") => Some((name.to_owned(), MembershipEvent::Left)),
570 _ => None,
571 }
572 } else {
573 None
574 }
575 }) && event_tx.send((name, membership_event)).is_err()
576 {
577 break;
578 }
579 }
580 });
581
582 let seen_joined = Arc::new(Mutex::new(HashSet::<String>::new()));
584 let seen_joined_snapshot = seen_joined.clone();
585 let seen_joined_events = seen_joined;
586
587 let snapshot_stream = futures::stream::once(async move {
589 let mut filters = HashMap::new();
590 filters.insert(
591 "name".to_owned(),
592 vec![format!("{deployment_instance}-{location_key}")],
593 );
594 let options = Some(ListContainersOptions {
595 filters: Some(filters),
596 ..Default::default()
597 });
598
599 docker
600 .list_containers(options)
601 .await
602 .unwrap_or_default()
603 .iter()
604 .filter_map(|c| c.names.as_deref())
605 .filter_map(|names| names.first())
606 .map(|name| name.trim_start_matches('/'))
607 .filter(|&name| seen_joined_snapshot.lock().unwrap().insert(name.to_owned()))
608 .map(|name| (name.to_owned(), MembershipEvent::Joined))
609 .collect::<Vec<_>>()
610 })
611 .flat_map(futures::stream::iter);
612
613 let events_stream = tokio_stream::StreamExt::filter_map(
615 tokio_stream::wrappers::UnboundedReceiverStream::new(event_rx),
616 move |(name, event)| {
617 let mut seen = seen_joined_events.lock().unwrap();
618 match event {
619 MembershipEvent::Joined => {
620 if seen.insert(name.to_owned()) {
621 Some((name, MembershipEvent::Joined))
622 } else {
623 None
624 }
625 }
626 MembershipEvent::Left => seen.take(&name).map(|name| (name, MembershipEvent::Left)),
627 }
628 },
629 );
630
631 Box::pin(
633 snapshot_stream
634 .chain(events_stream)
635 .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
636 .inspect(|(member_id, event)| debug!(name: "membership_event", ?member_id, ?event)),
637 )
638}