1use std::collections::{BTreeMap, BTreeSet};
8
9use serde::{Deserialize, Serialize};
10
11use crate::air::{AirBundle, AirFunction, AirModule};
12use crate::id::make_id;
13use crate::ids::{FunctionId, ModuleId, ProgramId, ValueId};
14
15#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
19pub struct LinkTable {
20 pub function_resolutions: BTreeMap<FunctionId, FunctionId>,
22
23 pub global_resolutions: BTreeMap<ValueId, ValueId>,
25
26 pub conflicts: Vec<LinkConflict>,
28}
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub struct LinkConflict {
33 pub name: String,
35
36 pub defining_modules: Vec<ModuleId>,
38}
39
40#[derive(Debug, Clone, Default, PartialEq)]
42pub struct ProgramDiff {
43 pub added_modules: Vec<ModuleId>,
45
46 pub removed_modules: Vec<ModuleId>,
48
49 pub changed_modules: Vec<ModuleId>,
51
52 pub unchanged_modules: Vec<ModuleId>,
54
55 pub added_functions: BTreeSet<FunctionId>,
57
58 pub removed_functions: BTreeSet<FunctionId>,
60}
61
62#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
64pub enum SplitStrategy {
65 BySourceFile,
67
68 ByFunction,
70
71 Monolithic,
73
74 #[default]
76 Auto,
77}
78
79#[derive(Debug, Clone)]
81pub struct AirProgram {
82 pub id: ProgramId,
84
85 pub modules: Vec<AirModule>,
87
88 pub link_table: LinkTable,
90}
91
92impl AirProgram {
93 pub fn link(bundles: Vec<AirBundle>) -> Self {
98 let modules: Vec<AirModule> = bundles.into_iter().map(|b| b.module).collect();
99
100 let mut module_ids: Vec<u128> = modules.iter().map(|m| m.id.raw()).collect();
102 module_ids.sort_unstable();
103 let id_bytes: Vec<u8> = module_ids.iter().flat_map(|id| id.to_le_bytes()).collect();
104 let id = ProgramId::new(make_id("program", &id_bytes));
105
106 let link_table = Self::resolve_symbols(&modules);
107
108 Self {
109 id,
110 modules,
111 link_table,
112 }
113 }
114
115 #[must_use]
121 pub fn merged_view(&self) -> AirModule {
122 let mut merged = AirModule::new(crate::ids::ModuleId::new(self.id.raw()));
123 merged.name = Some("merged".to_string());
124
125 let mut defined_functions: BTreeSet<String> = BTreeSet::new();
127 for module in &self.modules {
128 for func in &module.functions {
129 if !func.is_declaration {
130 defined_functions.insert(func.name.clone());
131 }
132 }
133 }
134
135 for module in &self.modules {
137 for func in &module.functions {
138 if func.is_declaration && defined_functions.contains(&func.name) {
139 continue;
141 }
142 merged.add_function(func.clone());
143 }
144 }
145
146 if !self.link_table.function_resolutions.is_empty() {
148 for func in &mut merged.functions {
149 for block in &mut func.blocks {
150 for inst in &mut block.instructions {
151 if let crate::air::Operation::CallDirect { callee } = &mut inst.op {
152 if let Some(def_id) = self.link_table.function_resolutions.get(callee) {
153 *callee = *def_id;
154 }
155 }
156 }
157 }
158 }
159 }
160
161 let mut defined_globals: BTreeSet<String> = BTreeSet::new();
163 for module in &self.modules {
164 for global in &module.globals {
165 if global.init.is_some() {
166 defined_globals.insert(global.name.clone());
167 }
168 }
169 }
170 for module in &self.modules {
171 for global in &module.globals {
172 if global.init.is_none() && defined_globals.contains(&global.name) {
173 continue;
174 }
175 merged.globals.push(global.clone());
176 }
177 }
178
179 let mut seen_paths: BTreeSet<String> = BTreeSet::new();
181 for module in &self.modules {
182 for sf in &module.source_files {
183 if seen_paths.insert(sf.path.clone()) {
184 merged.source_files.push(sf.clone());
185 }
186 }
187 }
188
189 for module in &self.modules {
191 merged.type_hierarchy.extend(module.type_hierarchy.clone());
192 }
193
194 for module in &self.modules {
196 for (tid, ty) in &module.types {
197 merged.types.entry(*tid).or_insert_with(|| ty.clone());
198 }
199 }
200
201 for module in &self.modules {
203 for (vid, c) in &module.constants {
204 merged.constants.entry(*vid).or_insert_with(|| c.clone());
205 }
206 }
207
208 merged.target_pointer_width = self
210 .modules
211 .iter()
212 .map(|m| m.target_pointer_width)
213 .max()
214 .unwrap_or(8);
215
216 merged.rebuild_function_index();
217 merged
218 }
219
220 fn resolve_symbols(modules: &[AirModule]) -> LinkTable {
222 let mut table = LinkTable::default();
223
224 let mut func_index: BTreeMap<String, Vec<(ModuleId, FunctionId, bool)>> = BTreeMap::new();
226 for module in modules {
227 for func in &module.functions {
228 func_index.entry(func.name.clone()).or_default().push((
229 module.id,
230 func.id,
231 func.is_declaration,
232 ));
233 }
234 }
235
236 for (name, entries) in &func_index {
238 let definitions: Vec<_> = entries.iter().filter(|e| !e.2).collect();
239 let declarations: Vec<_> = entries.iter().filter(|e| e.2).collect();
240
241 if definitions.len() > 1 {
242 table.conflicts.push(LinkConflict {
243 name: name.clone(),
244 defining_modules: definitions.iter().map(|d| d.0).collect(),
245 });
246 }
247
248 if let Some(def) = definitions.first() {
249 for decl in &declarations {
250 table.function_resolutions.insert(decl.1, def.1);
251 }
252 }
253 }
254
255 let mut global_index: BTreeMap<String, Vec<(ModuleId, ValueId, bool)>> = BTreeMap::new();
257 for module in modules {
258 for global in &module.globals {
259 global_index.entry(global.name.clone()).or_default().push((
260 module.id,
261 global.id,
262 global.init.is_some(),
263 ));
264 }
265 }
266
267 for entries in global_index.values() {
268 let definitions: Vec<_> = entries.iter().filter(|e| e.2).collect();
269 let declarations: Vec<_> = entries.iter().filter(|e| !e.2).collect();
270
271 if let Some(def) = definitions.first() {
272 for decl in &declarations {
273 table.global_resolutions.insert(decl.1, def.1);
274 }
275 }
276 }
277
278 table
279 }
280
281 pub fn diff(previous: &[ModuleId], current: &[ModuleId]) -> ProgramDiff {
286 let prev_set: BTreeSet<ModuleId> = previous.iter().copied().collect();
287 let curr_set: BTreeSet<ModuleId> = current.iter().copied().collect();
288
289 ProgramDiff {
290 added_modules: curr_set.difference(&prev_set).copied().collect(),
291 removed_modules: prev_set.difference(&curr_set).copied().collect(),
292 changed_modules: Vec::new(), unchanged_modules: curr_set.intersection(&prev_set).copied().collect(),
294 added_functions: BTreeSet::new(),
295 removed_functions: BTreeSet::new(),
296 }
297 }
298}
299
300pub fn split_module(module: AirModule, strategy: SplitStrategy) -> Vec<AirModule> {
307 match strategy {
308 SplitStrategy::Monolithic => vec![module],
309 SplitStrategy::ByFunction => split_by_function(module),
310 SplitStrategy::BySourceFile => split_by_source_file(module),
311 SplitStrategy::Auto => {
312 let has_spans = module.functions.iter().any(|f| f.span.is_some());
314 if has_spans {
315 split_by_source_file(module)
316 } else {
317 split_by_function(module)
318 }
319 }
320 }
321}
322
323fn split_by_function(module: AirModule) -> Vec<AirModule> {
324 if module.functions.is_empty() {
325 return vec![module];
326 }
327
328 let mut result = Vec::new();
329 for func in &module.functions {
330 let mod_id = ModuleId::new(make_id("module-split-fn", func.name.as_bytes()));
331 let mut sub = AirModule::new(mod_id);
332 sub.name = Some(func.name.clone());
333 sub.target_pointer_width = module.target_pointer_width;
334 sub.types.clone_from(&module.types);
336 sub.constants.clone_from(&module.constants);
337 sub.add_function(func.clone());
338 result.push(sub);
339 }
340
341 if !module.globals.is_empty() {
343 if let Some(first) = result.first_mut() {
344 first.globals.clone_from(&module.globals);
345 }
346 }
347
348 result
349}
350
351fn split_by_source_file(module: AirModule) -> Vec<AirModule> {
352 let file_paths: BTreeMap<crate::ids::FileId, String> = module
354 .source_files
355 .iter()
356 .map(|sf| (sf.id, sf.path.clone()))
357 .collect();
358
359 let mut groups: BTreeMap<String, Vec<AirFunction>> = BTreeMap::new();
361 for func in &module.functions {
362 let key = func
363 .span
364 .as_ref()
365 .and_then(|s| file_paths.get(&s.file_id))
366 .cloned()
367 .unwrap_or_else(|| "<unknown>".to_string());
368 groups.entry(key).or_default().push(func.clone());
369 }
370
371 if groups.len() <= 1 {
372 return vec![module];
374 }
375
376 let mut result = Vec::new();
377 for (path, functions) in &groups {
378 let mod_id = ModuleId::new(make_id("module-split-src", path.as_bytes()));
379 let mut sub = AirModule::new(mod_id);
380 sub.name = Some(path.clone());
381 sub.target_pointer_width = module.target_pointer_width;
382 sub.types.clone_from(&module.types);
383 sub.constants.clone_from(&module.constants);
384 for f in functions {
385 sub.add_function(f.clone());
386 }
387 result.push(sub);
388 }
389
390 if !module.globals.is_empty() {
392 if let Some(first) = result.first_mut() {
393 first.globals.clone_from(&module.globals);
394 }
395 }
396
397 result
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::air::{AirBundle, AirFunction, AirGlobal, AirModule, Constant};
404 use crate::id::make_id;
405 use crate::ids::{FunctionId, ModuleId, ObjId, ValueId};
406
407 fn make_function(name: &str, is_declaration: bool) -> AirFunction {
408 let id = FunctionId::new(make_id("fn", name.as_bytes()));
409 let mut f = AirFunction::new(id, name.to_string());
410 f.is_declaration = is_declaration;
411 f
412 }
413
414 fn make_module(name: &str, functions: Vec<AirFunction>) -> AirModule {
415 let id = ModuleId::new(make_id("module", name.as_bytes()));
416 let mut m = AirModule::new(id);
417 m.name = Some(name.to_string());
418 for f in functions {
419 m.add_function(f);
420 }
421 m
422 }
423
424 fn make_bundle(name: &str, functions: Vec<AirFunction>) -> AirBundle {
425 AirBundle::new("test".to_string(), make_module(name, functions))
426 }
427
428 fn make_global(name: &str, has_init: bool) -> AirGlobal {
429 let id = ValueId::new(make_id("val", name.as_bytes()));
430 let obj = ObjId::new(make_id("obj", name.as_bytes()));
431 let mut g = AirGlobal::new(id, obj, name.to_string());
432 if has_init {
433 g.init = Some(Constant::Int { value: 0, bits: 32 });
434 }
435 g
436 }
437
438 #[test]
439 fn link_empty_program() {
440 let program = AirProgram::link(vec![]);
441 assert!(program.modules.is_empty());
442 assert!(program.link_table.function_resolutions.is_empty());
443 }
444
445 #[test]
446 fn link_single_module() {
447 let bundle = make_bundle(
448 "main",
449 vec![
450 make_function("main", false),
451 make_function("printf", true), ],
453 );
454 let program = AirProgram::link(vec![bundle]);
455
456 assert_eq!(program.modules.len(), 1);
457 assert!(program.link_table.function_resolutions.is_empty());
459 }
460
461 #[test]
462 fn link_resolves_extern_functions() {
463 let main_bundle = make_bundle(
464 "main",
465 vec![
466 make_function("main", false),
467 make_function("helper", true), ],
469 );
470 let lib_bundle = make_bundle(
471 "lib",
472 vec![
473 make_function("helper", false), ],
475 );
476
477 let program = AirProgram::link(vec![main_bundle, lib_bundle]);
478
479 assert_eq!(program.modules.len(), 2);
480 assert_eq!(program.link_table.function_resolutions.len(), 1);
482
483 let decl_id = FunctionId::new(make_id("fn", b"helper"));
484 assert!(
485 program
486 .link_table
487 .function_resolutions
488 .contains_key(&decl_id)
489 );
490 }
491
492 #[test]
493 fn link_unresolved_extern_stays_unresolved() {
494 let bundle = make_bundle(
495 "main",
496 vec![
497 make_function("main", false),
498 make_function("printf", true), ],
500 );
501 let program = AirProgram::link(vec![bundle]);
502
503 assert!(program.link_table.function_resolutions.is_empty());
505 }
506
507 #[test]
508 fn link_detects_conflicting_definitions() {
509 let a = make_bundle("a", vec![make_function("foo", false)]);
510 let b = make_bundle("b", vec![make_function("foo", false)]);
511
512 let program = AirProgram::link(vec![a, b]);
513
514 assert_eq!(program.link_table.conflicts.len(), 1);
515 assert_eq!(program.link_table.conflicts[0].name, "foo");
516 assert_eq!(program.link_table.conflicts[0].defining_modules.len(), 2);
517 }
518
519 #[test]
520 fn merged_view_contains_all_definitions() {
521 let main_bundle = make_bundle(
522 "main",
523 vec![
524 make_function("main", false),
525 make_function("helper", true), ],
527 );
528 let lib_bundle = make_bundle(
529 "lib",
530 vec![
531 make_function("helper", false), ],
533 );
534
535 let program = AirProgram::link(vec![main_bundle, lib_bundle]);
536 let merged = program.merged_view();
537
538 assert_eq!(merged.functions.len(), 2);
540 let names: BTreeSet<_> = merged.functions.iter().map(|f| f.name.as_str()).collect();
541 assert!(names.contains("main"));
542 assert!(names.contains("helper"));
543 let helper = merged.function_by_name("helper").unwrap();
545 assert!(!helper.is_declaration);
546 }
547
548 #[test]
549 fn merged_view_keeps_unresolved_declarations() {
550 let bundle = make_bundle(
551 "main",
552 vec![
553 make_function("main", false),
554 make_function("printf", true), ],
556 );
557
558 let program = AirProgram::link(vec![bundle]);
559 let merged = program.merged_view();
560
561 assert_eq!(merged.functions.len(), 2);
563 let printf = merged.function_by_name("printf").unwrap();
564 assert!(printf.is_declaration);
565 }
566
567 #[test]
568 fn merged_view_resolves_globals() {
569 let mut main_mod = make_module("main", vec![make_function("main", false)]);
570 let decl_global = make_global("config", false); main_mod.globals.push(decl_global);
572
573 let mut lib_mod = make_module("lib", vec![]);
574 let def_global = make_global("config", true); lib_mod.globals.push(def_global);
576
577 let program = AirProgram::link(vec![
578 AirBundle::new("test".to_string(), main_mod),
579 AirBundle::new("test".to_string(), lib_mod),
580 ]);
581 let merged = program.merged_view();
582
583 let config_globals: Vec<_> = merged
585 .globals
586 .iter()
587 .filter(|g| g.name == "config")
588 .collect();
589 assert_eq!(config_globals.len(), 1);
590 assert!(config_globals[0].init.is_some());
591 }
592
593 #[test]
594 fn program_id_is_deterministic() {
595 let bundles1 = vec![
596 make_bundle("a", vec![make_function("fa", false)]),
597 make_bundle("b", vec![make_function("fb", false)]),
598 ];
599 let bundles2 = vec![
600 make_bundle("b", vec![make_function("fb", false)]),
601 make_bundle("a", vec![make_function("fa", false)]),
602 ];
603
604 let p1 = AirProgram::link(bundles1);
605 let p2 = AirProgram::link(bundles2);
606
607 assert_eq!(p1.id, p2.id);
609 }
610
611 #[test]
612 fn diff_identifies_added_modules() {
613 let m1 = ModuleId::new(1);
614 let m2 = ModuleId::new(2);
615 let m3 = ModuleId::new(3);
616
617 let diff = AirProgram::diff(&[m1, m2], &[m1, m2, m3]);
618
619 assert_eq!(diff.added_modules, vec![m3]);
620 assert!(diff.removed_modules.is_empty());
621 assert_eq!(diff.unchanged_modules.len(), 2);
622 }
623
624 #[test]
625 fn diff_identifies_removed_modules() {
626 let m1 = ModuleId::new(1);
627 let m2 = ModuleId::new(2);
628
629 let diff = AirProgram::diff(&[m1, m2], &[m1]);
630
631 assert!(diff.added_modules.is_empty());
632 assert_eq!(diff.removed_modules, vec![m2]);
633 assert_eq!(diff.unchanged_modules.len(), 1);
634 }
635
636 #[test]
637 fn split_monolithic_returns_original() {
638 let module = make_module(
639 "test",
640 vec![make_function("a", false), make_function("b", false)],
641 );
642 let result = split_module(module, SplitStrategy::Monolithic);
643 assert_eq!(result.len(), 1);
644 assert_eq!(result[0].functions.len(), 2);
645 }
646
647 #[test]
648 fn split_by_function_creates_one_module_per_function() {
649 let module = make_module(
650 "test",
651 vec![
652 make_function("a", false),
653 make_function("b", false),
654 make_function("c", false),
655 ],
656 );
657 let result = split_module(module, SplitStrategy::ByFunction);
658 assert_eq!(result.len(), 3);
659 assert_eq!(result[0].functions.len(), 1);
660 assert_eq!(result[0].functions[0].name, "a");
661 }
662
663 #[test]
664 fn split_auto_uses_by_function_when_no_spans() {
665 let module = make_module(
666 "test",
667 vec![make_function("a", false), make_function("b", false)],
668 );
669 let result = split_module(module, SplitStrategy::Auto);
670 assert_eq!(result.len(), 2);
672 }
673
674 #[test]
675 fn split_empty_module_returns_single() {
676 let module = make_module("test", vec![]);
677 let result = split_module(module, SplitStrategy::ByFunction);
678 assert_eq!(result.len(), 1);
679 }
680
681 #[test]
682 fn split_preserves_globals_in_first_module() {
683 let mut module = make_module(
684 "test",
685 vec![make_function("a", false), make_function("b", false)],
686 );
687 module.globals.push(make_global("g", true));
688
689 let result = split_module(module, SplitStrategy::ByFunction);
690 assert_eq!(result.len(), 2);
691 assert_eq!(result[0].globals.len(), 1);
692 assert!(result[1].globals.is_empty());
693 }
694}