Skip to main content

hydro_lang/deploy/
deploy_runtime_containerized_ecs.rs

1#![allow(
2    unused,
3    reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::time::Duration;
8
9use futures::{SinkExt, Stream, StreamExt};
10use sinktools::lazy::{LazySink, LazySource};
11use sinktools::lazy_sink_source::LazySinkSource;
12use stageleft::{QuotedWithContext, q};
13use tokio::net::TcpStream;
14use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
15use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
16use tracing::{Instrument, debug, instrument, span, trace, trace_span};
17
18pub use super::deploy_runtime_containerized::{
19    CHANNEL_MAGIC, CHANNEL_MUX_PORT, CHANNEL_PROTOCOL_VERSION, ChannelHandshake, ChannelMagic,
20    ChannelMux, ChannelProtocolVersion, SocketIdent, cluster_ids, get_or_init_channel_mux,
21    send_handshake,
22};
23use crate::location::dynamic::LocationId;
24use crate::location::member_id::TaglessMemberId;
25use crate::location::{LocationKey, MembershipEvent};
26
27pub fn deploy_containerized_o2o(
28    target_task_family: &str,
29    channel_name: &str,
30) -> (syn::Expr, syn::Expr) {
31    (
32        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
33            async move {
34                let channel_name = channel_name;
35                let target_task_family = target_task_family;
36                let task_id = self::resolve_task_family_to_task_id(target_task_family).await;
37                let ip = self::resolve_task_ip(&task_id).await;
38                let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
39                debug!(name: "connecting", %target, %target_task_family, %task_id, %channel_name);
40
41                let stream = TcpStream::connect(&target).await?;
42                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
43
44                self::send_handshake(&mut sink, channel_name, None).await?;
45
46                Result::<_, std::io::Error>::Ok(sink)
47            }
48        )))
49        .splice_untyped_ctx(&()),
50        q!(LazySource::new(move || Box::pin(async move {
51            let channel_name = channel_name;
52            let mux = self::get_or_init_channel_mux();
53            let mut rx = mux.register(channel_name.to_owned());
54
55            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
56                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
57            })?;
58
59            debug!(name: "o2o_channel_connected", %channel_name);
60
61            Result::<_, std::io::Error>::Ok(source)
62        })))
63        .splice_untyped_ctx(&()),
64    )
65}
66
67pub fn deploy_containerized_o2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
68    (
69        q!(sinktools::demux_map_lazy::<_, _, _, _>(
70            move |key: &TaglessMemberId| {
71                let key = key.clone();
72                let channel_name = channel_name.to_owned();
73
74                LazySink::<_, _, _, bytes::Bytes>::new(move || {
75                    Box::pin(async move {
76                        let task_id = key.get_container_name();
77                        let ip = self::resolve_task_ip(task_id).await;
78                        let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
79                        debug!(name: "connecting", %target, %task_id, channel_name = %channel_name);
80
81                        let stream = TcpStream::connect(&target).await?;
82                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
83
84                        self::send_handshake(&mut sink, &channel_name, None).await?;
85
86                        Result::<_, std::io::Error>::Ok(sink)
87                    })
88                })
89            }
90        ))
91        .splice_untyped_ctx(&()),
92        q!(LazySource::new(move || Box::pin(async move {
93            let channel_name = channel_name;
94            let mux = self::get_or_init_channel_mux();
95            let mut rx = mux.register(channel_name.to_owned());
96
97            let (_sender_id, source) = rx.recv().await.ok_or_else(|| {
98                std::io::Error::new(std::io::ErrorKind::ConnectionReset, "channel mux closed")
99            })?;
100
101            debug!(name: "o2m_channel_connected", %channel_name);
102
103            Result::<_, std::io::Error>::Ok(source)
104        })))
105        .splice_untyped_ctx(&()),
106    )
107}
108
109pub fn deploy_containerized_m2o(
110    target_task_family: &str,
111    channel_name: &str,
112) -> (syn::Expr, syn::Expr) {
113    (
114        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
115            Box::pin(async move {
116                let channel_name = channel_name;
117                let target_task_family = target_task_family;
118                let target_task_id = self::resolve_task_family_to_task_id(target_task_family).await;
119                let ip = self::resolve_task_ip(&target_task_id).await;
120                let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
121                debug!(name: "connecting", %target, %target_task_family, %target_task_id, %channel_name);
122
123                let stream = TcpStream::connect(&target).await?;
124                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
125
126                let self_task_id = self::get_self_task_id();
127                self::send_handshake(&mut sink, channel_name, Some(&self_task_id)).await?;
128
129                Result::<_, std::io::Error>::Ok(sink)
130            })
131        }))
132        .splice_untyped_ctx(&()),
133        q!(LazySource::new(move || Box::pin(async move {
134            let channel_name = channel_name;
135            let mux = self::get_or_init_channel_mux();
136            let mut rx = mux.register(channel_name.to_owned());
137
138            Result::<_, std::io::Error>::Ok(
139                futures::stream::unfold(rx, |mut rx| {
140                    Box::pin(async move {
141                        let (sender_id, source) = rx.recv().await?;
142                        let from_task_id = sender_id
143                            .expect("m2o sender must provide task ID");
144
145                        debug!(name: "m2o_channel_connected", %from_task_id);
146
147                        Some((
148                            source.map(move |v| {
149                                v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
150                            }),
151                            rx,
152                        ))
153                    })
154                })
155                .flatten_unordered(None),
156            )
157        })))
158        .splice_untyped_ctx(&()),
159    )
160}
161
162pub fn deploy_containerized_m2m(channel_name: &str) -> (syn::Expr, syn::Expr) {
163    (
164        q!(sinktools::demux_map_lazy::<_, _, _, _>(
165            move |key: &TaglessMemberId| {
166                let key = key.clone();
167                let channel_name = channel_name.to_owned();
168
169                LazySink::<_, _, _, bytes::Bytes>::new(move || {
170                    Box::pin(async move {
171                        let task_id = key.get_container_name();
172                        let ip = self::resolve_task_ip(task_id).await;
173                        let target = format!("{}:{}", ip, self::CHANNEL_MUX_PORT);
174                        debug!(name: "connecting", %target, %task_id, channel_name = %channel_name);
175
176                        let stream = TcpStream::connect(&target).await?;
177                        let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
178
179                        let self_task_id = self::get_self_task_id();
180                        self::send_handshake(&mut sink, &channel_name, Some(&self_task_id)).await?;
181
182                        Result::<_, std::io::Error>::Ok(sink)
183                    })
184                })
185            }
186        ))
187        .splice_untyped_ctx(&()),
188        q!(LazySource::new(move || Box::pin(async move {
189            let channel_name = channel_name;
190            let mux = self::get_or_init_channel_mux();
191            let mut rx = mux.register(channel_name.to_owned());
192
193            Result::<_, std::io::Error>::Ok(
194                futures::stream::unfold(rx, |mut rx| {
195                    Box::pin(async move {
196                        let (sender_id, source) = rx.recv().await?;
197                        let from_task_id = sender_id.expect("m2m sender must provide task ID");
198
199                        debug!(name: "m2m_channel_connected", %from_task_id);
200
201                        Some((
202                            source.map(move |v| {
203                                v.map(|v| {
204                                    (
205                                        TaglessMemberId::from_container_name(from_task_id.clone()),
206                                        v,
207                                    )
208                                })
209                            }),
210                            rx,
211                        ))
212                    })
213                })
214                .flatten_unordered(None),
215            )
216        })))
217        .splice_untyped_ctx(&()),
218    )
219}
220
221pub fn deploy_containerized_external_sink_source_ident(
222    bind_addr: String,
223    socket_ident: syn::Ident,
224) -> syn::Expr {
225    let socket_ident = SocketIdent { socket_ident };
226
227    q!(LazySinkSource::<
228        _,
229        FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
230        FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
231        bytes::Bytes,
232        std::io::Error,
233    >::new(async move {
234        let span = span!(tracing::Level::TRACE, "lazy_sink_source");
235        let guard = span.enter();
236        let bind_addr = bind_addr;
237        trace!(name: "attempting to accept from external", %bind_addr);
238        std::mem::drop(guard);
239        let (stream, peer) = socket_ident.accept().instrument(span.clone()).await?;
240        let guard = span.enter();
241
242        debug!(name: "external accepting", ?peer);
243        let (rx, tx) = stream.into_split();
244
245        let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
246        let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
247
248        Result::<_, std::io::Error>::Ok((fr, fw))
249    },))
250    .splice_untyped_ctx(&())
251}
252
253pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
254    q!(TaglessMemberId::from_container_name(
255        self::get_self_task_id()
256    ))
257}
258
259pub fn cluster_membership_stream<'a>(
260    location_id: &LocationId,
261) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
262{
263    let location_key = location_id.key();
264
265    q!(Box::new(self::ecs_membership_stream(
266        std::env::var("CLUSTER_NAME").unwrap(),
267        location_key
268    ))
269        as Box<
270            dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
271        >)
272}
273
274#[instrument(skip_all, fields(%cluster_name, %location_key))]
275fn ecs_membership_stream(
276    cluster_name: String,
277    location_key: LocationKey,
278) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
279    use std::collections::HashSet;
280
281    use futures::stream::{StreamExt, once};
282
283    trace!(name: "ecs_membership_stream_created", %cluster_name, %location_key);
284
285    let ecs_poller_span = trace_span!("ecs_poller");
286
287    // Task family format: hy-{name_hint}-loc{idx}v{version}
288    // Example: hy-p1-loc2v1
289    let task_definition_arn_parser =
290        regex::Regex::new(r#"arn:aws:ecs:(?<region>.*):(?<account_id>.*):task-definition\/(?<container_id>hy-(?<type>[^-]+)-loc(?<location_idx>[0-9]+)v(?<location_version>[0-9]+)(?:-(?<instance_id>.*))?):.*"#).unwrap();
291
292    let poll_stream = futures::stream::unfold(
293        (HashSet::<String>::new(), cluster_name, location_key),
294        move |(known_tasks, cluster_name, location_key)| {
295            let task_definition_arn_parser = task_definition_arn_parser.clone();
296
297            async move {
298                let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
299                let ecs_client = aws_sdk_ecs::Client::new(&config);
300
301                let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
302                    Ok(tasks) => tasks,
303                    Err(e) => {
304                        trace!(name: "list_tasks_error", error = %e);
305                        tokio::time::sleep(Duration::from_secs(2)).await;
306                        return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
307                    }
308                };
309
310                let task_arns: Vec<String> =
311                    tasks.task_arns().iter().map(|s| s.to_string()).collect();
312
313                let mut events = Vec::new();
314                let mut current_tasks = HashSet::<String>::new();
315
316                if !task_arns.is_empty() {
317                    let task_details = match ecs_client
318                        .describe_tasks()
319                        .cluster(&cluster_name)
320                        .set_tasks(Some(task_arns.clone()))
321                        .send()
322                        .await
323                    {
324                        Ok(details) => details,
325                        Err(e) => {
326                            trace!(name: "describe_tasks_error", error = %e);
327                            tokio::time::sleep(Duration::from_secs(2)).await;
328                            return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
329                        }
330                    };
331
332                    for task in task_details.tasks() {
333                        let Some(last_status) = task.last_status() else {
334                            continue;
335                        };
336
337                        if last_status != "RUNNING" {
338                            continue;
339                        }
340
341                        let Some(task_def_arn) = task.task_definition_arn() else {
342                            continue;
343                        };
344
345                        let Some(captures) = task_definition_arn_parser.captures(task_def_arn)
346                        else {
347                            continue;
348                        };
349
350                        let Some(location_idx) = captures.name("location_idx") else {
351                            continue;
352                        };
353                        let Some(location_version) = captures.name("location_version") else {
354                            continue;
355                        };
356                        // Reconstruct the location key string and parse it
357                        let location_key_str =
358                            format!("loc{}v{}", location_idx.as_str(), location_version.as_str());
359                        let task_location_key: LocationKey = match location_key_str.parse() {
360                            Ok(key) => key,
361                            Err(_) => {
362                                continue;
363                            }
364                        };
365
366                        // Filter by location_id - only include tasks for this specific cluster
367                        if task_location_key != location_key {
368                            continue;
369                        }
370
371                        // Extract task ID from task ARN (last segment after final /)
372                        // Task ARN format: arn:aws:ecs:region:account:task/cluster-name/task-id
373                        let Some(task_arn) = task.task_arn() else {
374                            continue;
375                        };
376                        let Some(task_id) = task_arn.rsplit('/').next() else {
377                            continue;
378                        };
379
380                        // Use task_id as the member identifier
381                        current_tasks.insert(task_id.to_owned());
382                        if !known_tasks.contains(task_id) {
383                            trace!(name: "task_joined", %task_id);
384                            events.push((task_id.to_owned(), MembershipEvent::Joined));
385                        }
386                    }
387                }
388
389                #[expect(
390                    clippy::disallowed_methods,
391                    reason = "nondeterministic iteration order, container events are not deterministically ordered"
392                )]
393                for task_id in known_tasks.iter() {
394                    if !current_tasks.contains(task_id) {
395                        trace!(name: "task_left", %task_id);
396                        events.push((task_id.to_owned(), MembershipEvent::Left));
397                    }
398                }
399
400                tokio::time::sleep(Duration::from_secs(2)).await;
401
402                Some((events, (current_tasks, cluster_name, location_key)))
403            }
404            .instrument(ecs_poller_span.clone())
405        },
406    )
407    .flat_map(futures::stream::iter);
408
409    Box::pin(
410        poll_stream
411            .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
412            .inspect(|(member_id, event)| trace!(name: "membership_event", ?member_id, ?event)),
413    )
414}
415
416/// Resolve a task ID to its private IP address via ECS API.
417async fn resolve_task_ip(task_id: &str) -> String {
418    let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
419
420    let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
421    let ecs_client = aws_sdk_ecs::Client::new(&config);
422
423    loop {
424        let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
425            Ok(t) => t,
426            Err(e) => {
427                trace!(name: "resolve_ip_list_error", %task_id, error = %e);
428                tokio::time::sleep(Duration::from_secs(1)).await;
429                continue;
430            }
431        };
432
433        let task_arns: Vec<_> = tasks.task_arns().to_vec();
434        if task_arns.is_empty() {
435            trace!(name: "resolve_ip_no_tasks", %task_id);
436            tokio::time::sleep(Duration::from_secs(1)).await;
437            continue;
438        }
439
440        let task_details = match ecs_client
441            .describe_tasks()
442            .cluster(&cluster_name)
443            .set_tasks(Some(task_arns))
444            .send()
445            .await
446        {
447            Ok(d) => d,
448            Err(e) => {
449                trace!(name: "resolve_ip_describe_error", %task_id, error = %e);
450                tokio::time::sleep(Duration::from_secs(1)).await;
451                continue;
452            }
453        };
454
455        // Find the task with matching task ID
456        for task in task_details.tasks() {
457            let Some(task_arn) = task.task_arn() else {
458                continue;
459            };
460            let current_task_id = task_arn.rsplit('/').next().unwrap_or_default();
461
462            if current_task_id == task_id
463                && let Some(ip) = task
464                    .attachments()
465                    .iter()
466                    .flat_map(|a| a.details())
467                    .find(|d| d.name() == Some("privateIPv4Address"))
468                    .and_then(|d| d.value())
469            {
470                trace!(name: "resolved_ip", %task_id, %ip);
471                return ip.to_owned();
472            }
473        }
474
475        trace!(name: "resolve_ip_not_found", %task_id);
476        tokio::time::sleep(Duration::from_secs(1)).await;
477    }
478}
479
480/// Resolve a task family name to its task ID via ECS API.
481/// Used for process-to-process connections where the target is known by task family at compile time.
482async fn resolve_task_family_to_task_id(task_family: &str) -> String {
483    let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
484
485    let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
486    let ecs_client = aws_sdk_ecs::Client::new(&config);
487
488    loop {
489        let tasks = match ecs_client
490            .list_tasks()
491            .cluster(&cluster_name)
492            .family(task_family)
493            .send()
494            .await
495        {
496            Ok(t) => t,
497            Err(e) => {
498                trace!(name: "resolve_family_list_error", %task_family, error = %e);
499                tokio::time::sleep(Duration::from_secs(1)).await;
500                continue;
501            }
502        };
503
504        let Some(task_arn) = tasks.task_arns().first() else {
505            trace!(name: "resolve_family_no_task", %task_family);
506            tokio::time::sleep(Duration::from_secs(1)).await;
507            continue;
508        };
509
510        // Extract task ID from ARN
511        let task_id = task_arn.rsplit('/').next().unwrap_or_default();
512        if !task_id.is_empty() {
513            trace!(name: "resolved_task_id", %task_family, %task_id);
514            return task_id.to_owned();
515        }
516
517        trace!(name: "resolve_family_invalid_arn", %task_family, %task_arn);
518        tokio::time::sleep(Duration::from_secs(1)).await;
519    }
520}
521
522fn get_self_task_id() -> String {
523    let metadata_uri = std::env::var("ECS_CONTAINER_METADATA_URI_V4")
524        .expect("ECS_CONTAINER_METADATA_URI_V4 not set - are we running in ECS?");
525    metadata_uri
526        .rsplit('/')
527        .next()
528        .expect("Invalid ECS metadata URI format")
529        .to_owned()
530}