1#![allow(
2 unused,
3 reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::time::Duration;
8
9use futures::{SinkExt, Stream, StreamExt};
10use sinktools::lazy::{LazySink, LazySource};
11use sinktools::lazy_sink_source::LazySinkSource;
12use stageleft::{QuotedWithContext, q};
13use tokio::net::TcpStream;
14use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
15use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
16use tracing::{Instrument, debug, instrument, span, trace, trace_span};
17
18pub use super::deploy_runtime_containerized::{
19 CHANNEL_MAGIC, CHANNEL_MUX_PORT, CHANNEL_PROTOCOL_VERSION, ChannelHandshake, ChannelMagic,
20 ChannelMux, ChannelProtocolVersion, SocketIdent, cluster_ids, get_or_init_channel_mux,
21 send_handshake,
22};
23use crate::location::dynamic::LocationId;
24use crate::location::member_id::TaglessMemberId;
25use crate::location::{LocationKey, MembershipEvent};
26
27pub fn deploy_containerized_o2o(
28 target_task_family: &str,
29 channel_name: &str,
30) -> (syn::Expr, syn::Expr) {
31 (
32 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
33 async move {
34 let channel_name = channel_name;
35 let target_task_family = target_task_family;
36 let task_id = self::resolve_task_family_to_task_id(target_task_family).await;
37 let ip = self::resolve_task_ip(&task_id).await;
38 let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
39 debug!(name: "connecting", %target, %target_task_family, %task_id, %channel_name);
40
41 let stream = TcpStream::connect(&target).await?;
42 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
43
44 self::send_handshake(&mut sink, channel_name, None).await?;
45
46 Result::<_, std::io::Error>::Ok(sink)
47 }
48 )))
49 .splice_untyped_ctx(&()),
50 q!(LazySource::new(move || Box::pin(async move {
51 let channel_name = channel_name;
52 let mux = self::get_or_init_channel_mux();
53 let mut rx = mux.register(channel_name.to_owned());
54
55 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
56 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
57 })?;
58
59 debug!(name: "o2o_channel_connected", %channel_name);
60
61 Result::<_, std::io::Error>::Ok(source)
62 })))
63 .splice_untyped_ctx(&()),
64 )
65}
66
67pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
68 (
69 q!(sinktools::demux_map_lazy::<_, _, _, _>(
70 move |key: &TaglessMemberId| {
71 let key = key.clone();
72 let channel_name = channel_name.to_owned();
73
74 LazySink::<_, _, _, bytes::Bytes>::new(move || {
75 Box::pin(async move {
76 let task_id = key.get_container_name();
77 let ip = self::resolve_task_ip(task_id).await;
78 let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
79 debug!(name: "connecting", %target, %task_id, channel_name = %channel_name);
80
81 let stream = TcpStream::connect(&target).await?;
82 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
83
84 self::send_handshake(&mut sink, &channel_name, None).await?;
85
86 Result::<_, std::io::Error>::Ok(sink)
87 })
88 })
89 }
90 ))
91 .splice_untyped_ctx(&()),
92 q!(LazySource::new(move || Box::pin(async move {
93 let channel_name = channel_name;
94 let mux = self::get_or_init_channel_mux();
95 let mut rx = mux.register(channel_name.to_owned());
96
97 let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
98 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
99 })?;
100
101 debug!(name: "o2m_channel_connected", %channel_name);
102
103 Result::<_, std::io::Error>::Ok(source)
104 })))
105 .splice_untyped_ctx(&()),
106 )
107}
108
109pub fn deploy_containerized_m2o(
110 target_task_family: &str,
111 channel_name: &str,
112) -> (syn::Expr, syn::Expr) {
113 (
114 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
115 Box::pin(async move {
116 let channel_name = channel_name;
117 let target_task_family = target_task_family;
118 let target_task_id = self::resolve_task_family_to_task_id(target_task_family).await;
119 let ip = self::resolve_task_ip(&target_task_id).await;
120 let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
121 debug!(name: "connecting", %target, %target_task_family, %target_task_id, %channel_name);
122
123 let stream = TcpStream::connect(&target).await?;
124 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
125
126 let self_task_id = self::get_self_task_id();
127 self::send_handshake(&mut sink, channel_name, Some(&self_task_id)).await?;
128
129 Result::<_, std::io::Error>::Ok(sink)
130 })
131 }))
132 .splice_untyped_ctx(&()),
133 q!(LazySource::new(move || Box::pin(async move {
134 let channel_name = channel_name;
135 let mux = self::get_or_init_channel_mux();
136 let mut rx = mux.register(channel_name.to_owned());
137
138 Result::<_, std::io::Error>::Ok(
139 futures::stream::unfold(rx, |mut rx| {
140 Box::pin(async move {
141 let (sender_id, source) = rx.recv().await?;
142 let from_task_id = sender_id
143 .expect("m2o sender must provide task ID");
144
145 debug!(name: "m2o_channel_connected", %from_task_id);
146
147 Some((
148 source.map(move |v| {
149 v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
150 }),
151 rx,
152 ))
153 })
154 })
155 .flatten_unordered(None),
156 )
157 })))
158 .splice_untyped_ctx(&()),
159 )
160}
161
162pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
163 (
164 q!(sinktools::demux_map_lazy::<_, _, _, _>(
165 move |key: &TaglessMemberId| {
166 let key = key.clone();
167 let channel_name = channel_name.to_owned();
168
169 LazySink::<_, _, _, bytes::Bytes>::new(move || {
170 Box::pin(async move {
171 let task_id = key.get_container_name();
172 let ip = self::resolve_task_ip(task_id).await;
173 let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
174 debug!(name: "connecting", %target, %task_id, channel_name = %channel_name);
175
176 let stream = TcpStream::connect(&target).await?;
177 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
178
179 let self_task_id = self::get_self_task_id();
180 self::send_handshake(&mut sink, &channel_name, Some(&self_task_id)).await?;
181
182 Result::<_, std::io::Error>::Ok(sink)
183 })
184 })
185 }
186 ))
187 .splice_untyped_ctx(&()),
188 q!(LazySource::new(move || Box::pin(async move {
189 let channel_name = channel_name;
190 let mux = self::get_or_init_channel_mux();
191 let mut rx = mux.register(channel_name.to_owned());
192
193 Result::<_, std::io::Error>::Ok(
194 futures::stream::unfold(rx, |mut rx| {
195 Box::pin(async move {
196 let (sender_id, source) = rx.recv().await?;
197 let from_task_id = sender_id.expect("m2m sender must provide task ID");
198
199 debug!(name: "m2m_channel_connected", %from_task_id);
200
201 Some((
202 source.map(move |v| {
203 v.map(|v| {
204 (
205 TaglessMemberId::from_container_name(from_task_id.clone()),
206 v,
207 )
208 })
209 }),
210 rx,
211 ))
212 })
213 })
214 .flatten_unordered(None),
215 )
216 })))
217 .splice_untyped_ctx(&()),
218 )
219}
220
221pub fn deploy_containerized_external_sink_source_ident(
222 bind_addr: String,
223 socket_ident: syn::Ident,
224) -> syn::Expr {
225 let socket_ident = SocketIdent { socket_ident };
226
227 q!(LazySinkSource::<
228 _,
229 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
230 FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
231 bytes::Bytes,
232 std::io::Error,
233 >::new(async move {
234 let span = span!(tracing::Level::TRACE, "lazy_sink_source");
235 let guard = span.enter();
236 let bind_addr = bind_addr;
237 trace!(name: "attempting to accept from external", %bind_addr);
238 std::mem::drop(guard);
239 let (stream, peer) = socket_ident.accept().instrument(span.clone()).await?;
240 let guard = span.enter();
241
242 debug!(name: "external accepting", ?peer);
243 let (rx, tx) = stream.into_split();
244
245 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
246 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
247
248 Result::<_, std::io::Error>::Ok((fr, fw))
249 },))
250 .splice_untyped_ctx(&())
251}
252
253pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
254 q!(TaglessMemberId::from_container_name(
255 self::get_self_task_id()
256 ))
257}
258
259pub fn cluster_membership_stream<'a>(
260 location_id: &LocationId,
261) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
262{
263 let location_key = location_id.key();
264
265 q!(Box::new(self::ecs_membership_stream(
266 std::env::var("CLUSTER_NAME").unwrap(),
267 location_key
268 ))
269 as Box<
270 dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
271 >)
272}
273
274#[instrument(skip_all, fields(%cluster_name, %location_key))]
275fn ecs_membership_stream(
276 cluster_name: String,
277 location_key: LocationKey,
278) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
279 use std::collections::HashSet;
280
281 use futures::stream::{StreamExt, once};
282
283 trace!(name: "ecs_membership_stream_created", %cluster_name, %location_key);
284
285 let ecs_poller_span = trace_span!("ecs_poller");
286
287 let task_definition_arn_parser =
290 regex::Regex::new(r#"arn:aws:ecs:(?<region>.*):(?<account_id>.*):task-definition\/(?<container_id>hy-(?<type>[^-]+)-loc(?<location_idx>[0-9]+)v(?<location_version>[0-9]+)(?:-(?<instance_id>.*))?):.*"#).unwrap();
291
292 let poll_stream = futures::stream::unfold(
293 (HashSet::<String>::new(), cluster_name, location_key),
294 move |(known_tasks, cluster_name, location_key)| {
295 let task_definition_arn_parser = task_definition_arn_parser.clone();
296
297 async move {
298 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
299 let ecs_client = aws_sdk_ecs::Client::new(&config);
300
301 let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
302 Ok(tasks) => tasks,
303 Err(e) => {
304 trace!(name: "list_tasks_error", error = %e);
305 tokio::time::sleep(Duration::from_secs(2)).await;
306 return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
307 }
308 };
309
310 let task_arns: Vec<String> =
311 tasks.task_arns().iter().map(|s| s.to_string()).collect();
312
313 let mut events = Vec::new();
314 let mut current_tasks = HashSet::<String>::new();
315
316 if !task_arns.is_empty() {
317 let task_details = match ecs_client
318 .describe_tasks()
319 .cluster(&cluster_name)
320 .set_tasks(Some(task_arns.clone()))
321 .send()
322 .await
323 {
324 Ok(details) => details,
325 Err(e) => {
326 trace!(name: "describe_tasks_error", error = %e);
327 tokio::time::sleep(Duration::from_secs(2)).await;
328 return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
329 }
330 };
331
332 for task in task_details.tasks() {
333 let Some(last_status) = task.last_status() else {
334 continue;
335 };
336
337 if last_status != "RUNNING" {
338 continue;
339 }
340
341 let Some(task_def_arn) = task.task_definition_arn() else {
342 continue;
343 };
344
345 let Some(captures) = task_definition_arn_parser.captures(task_def_arn)
346 else {
347 continue;
348 };
349
350 let Some(location_idx) = captures.name("location_idx") else {
351 continue;
352 };
353 let Some(location_version) = captures.name("location_version") else {
354 continue;
355 };
356 let location_key_str =
358 format!("loc{}v{}", location_idx.as_str(), location_version.as_str());
359 let task_location_key: LocationKey = match location_key_str.parse() {
360 Ok(key) => key,
361 Err(_) => {
362 continue;
363 }
364 };
365
366 if task_location_key != location_key {
368 continue;
369 }
370
371 let Some(task_arn) = task.task_arn() else {
374 continue;
375 };
376 let Some(task_id) = task_arn.rsplit('/').next() else {
377 continue;
378 };
379
380 current_tasks.insert(task_id.to_owned());
382 if !known_tasks.contains(task_id) {
383 trace!(name: "task_joined", %task_id);
384 events.push((task_id.to_owned(), MembershipEvent::Joined));
385 }
386 }
387 }
388
389 #[expect(
390 clippy::disallowed_methods,
391 reason = "nondeterministic iteration order, container events are not deterministically ordered"
392 )]
393 for task_id in known_tasks.iter() {
394 if !current_tasks.contains(task_id) {
395 trace!(name: "task_left", %task_id);
396 events.push((task_id.to_owned(), MembershipEvent::Left));
397 }
398 }
399
400 tokio::time::sleep(Duration::from_secs(2)).await;
401
402 Some((events, (current_tasks, cluster_name, location_key)))
403 }
404 .instrument(ecs_poller_span.clone())
405 },
406 )
407 .flat_map(futures::stream::iter);
408
409 Box::pin(
410 poll_stream
411 .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
412 .inspect(|(member_id, event)| trace!(name: "membership_event", ?member_id, ?event)),
413 )
414}
415
416async fn resolve_task_ip(task_id: &str) -> String {
418 let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
419
420 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
421 let ecs_client = aws_sdk_ecs::Client::new(&config);
422
423 loop {
424 let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
425 Ok(t) => t,
426 Err(e) => {
427 trace!(name: "resolve_ip_list_error", %task_id, error = %e);
428 tokio::time::sleep(Duration::from_secs(1)).await;
429 continue;
430 }
431 };
432
433 let task_arns: Vec<_> = tasks.task_arns().to_vec();
434 if task_arns.is_empty() {
435 trace!(name: "resolve_ip_no_tasks", %task_id);
436 tokio::time::sleep(Duration::from_secs(1)).await;
437 continue;
438 }
439
440 let task_details = match ecs_client
441 .describe_tasks()
442 .cluster(&cluster_name)
443 .set_tasks(Some(task_arns))
444 .send()
445 .await
446 {
447 Ok(d) => d,
448 Err(e) => {
449 trace!(name: "resolve_ip_describe_error", %task_id, error = %e);
450 tokio::time::sleep(Duration::from_secs(1)).await;
451 continue;
452 }
453 };
454
455 for task in task_details.tasks() {
457 let Some(task_arn) = task.task_arn() else {
458 continue;
459 };
460 let current_task_id = task_arn.rsplit('/').next().unwrap_or_default();
461
462 if current_task_id == task_id
463 && let Some(ip) = task
464 .attachments()
465 .iter()
466 .flat_map(|a| a.details())
467 .find(|d| d.name() == Some("privateIPv4Address"))
468 .and_then(|d| d.value())
469 {
470 trace!(name: "resolved_ip", %task_id, %ip);
471 return ip.to_owned();
472 }
473 }
474
475 trace!(name: "resolve_ip_not_found", %task_id);
476 tokio::time::sleep(Duration::from_secs(1)).await;
477 }
478}
479
480async fn resolve_task_family_to_task_id(task_family: &str) -> String {
483 let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
484
485 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
486 let ecs_client = aws_sdk_ecs::Client::new(&config);
487
488 loop {
489 let tasks = match ecs_client
490 .list_tasks()
491 .cluster(&cluster_name)
492 .family(task_family)
493 .send()
494 .await
495 {
496 Ok(t) => t,
497 Err(e) => {
498 trace!(name: "resolve_family_list_error", %task_family, error = %e);
499 tokio::time::sleep(Duration::from_secs(1)).await;
500 continue;
501 }
502 };
503
504 let Some(task_arn) = tasks.task_arns().first() else {
505 trace!(name: "resolve_family_no_task", %task_family);
506 tokio::time::sleep(Duration::from_secs(1)).await;
507 continue;
508 };
509
510 let task_id = task_arn.rsplit('/').next().unwrap_or_default();
512 if !task_id.is_empty() {
513 trace!(name: "resolved_task_id", %task_family, %task_id);
514 return task_id.to_owned();
515 }
516
517 trace!(name: "resolve_family_invalid_arn", %task_family, %task_arn);
518 tokio::time::sleep(Duration::from_secs(1)).await;
519 }
520}
521
522fn get_self_task_id() -> String {
523 let metadata_uri = std::env::var("ECS_CONTAINER_METADATA_URI_V4")
524 .expect("ECS_CONTAINER_METADATA_URI_V4 not set - are we running in ECS?");
525 metadata_uri
526 .rsplit('/')
527 .next()
528 .expect("Invalid ECS metadata URI format")
529 .to_owned()
530}