Skip to main content

saf_core/
program.rs

1//! Multi-module program linking.
2//!
3//! `AirProgram` represents a whole program composed of multiple linked
4//! `AirModule`s. The `LinkTable` resolves cross-module references
5//! (extern declarations matched to definitions).
6
7use std::collections::{BTreeMap, BTreeSet};
8
9use serde::{Deserialize, Serialize};
10
11use crate::air::{AirBundle, AirFunction, AirModule};
12use crate::id::make_id;
13use crate::ids::{FunctionId, ModuleId, ProgramId, ValueId};
14
15/// Cross-module symbol resolution.
16///
17/// Maps extern declarations in one module to their definitions in another.
18#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
19pub struct LinkTable {
20    /// Extern function declaration `FunctionId` -> defining function's `FunctionId`.
21    pub function_resolutions: BTreeMap<FunctionId, FunctionId>,
22
23    /// Extern global `ValueId` -> defining global's `ValueId`.
24    pub global_resolutions: BTreeMap<ValueId, ValueId>,
25
26    /// Functions with conflicting definitions across modules.
27    pub conflicts: Vec<LinkConflict>,
28}
29
30/// A conflicting symbol found during linking.
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub struct LinkConflict {
33    /// The symbol name that has multiple definitions.
34    pub name: String,
35
36    /// Modules that define this symbol.
37    pub defining_modules: Vec<ModuleId>,
38}
39
40/// Tracks what changed between two analysis runs of the same program.
41#[derive(Debug, Clone, Default, PartialEq)]
42pub struct ProgramDiff {
43    /// Modules present in the new run but not the previous.
44    pub added_modules: Vec<ModuleId>,
45
46    /// Modules present in the previous run but not the new.
47    pub removed_modules: Vec<ModuleId>,
48
49    /// Modules whose fingerprint changed between runs.
50    pub changed_modules: Vec<ModuleId>,
51
52    /// Modules whose fingerprint is identical to the previous run.
53    pub unchanged_modules: Vec<ModuleId>,
54
55    /// Functions that appear in added or changed modules but not in previous.
56    pub added_functions: BTreeSet<FunctionId>,
57
58    /// Functions that were in removed or changed modules but not in new.
59    pub removed_functions: BTreeSet<FunctionId>,
60}
61
62/// How to split a single pre-linked input into logical modules.
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
64pub enum SplitStrategy {
65    /// Group functions by their `source_files` metadata.
66    BySourceFile,
67
68    /// Each function becomes its own module.
69    ByFunction,
70
71    /// Keep as a single monolithic module (today's behavior).
72    Monolithic,
73
74    /// `BySourceFile` if source metadata present, else `ByFunction`.
75    #[default]
76    Auto,
77}
78
79/// A whole program composed of multiple linked modules.
80#[derive(Debug, Clone)]
81pub struct AirProgram {
82    /// Deterministic ID derived from the sorted set of module fingerprints.
83    pub id: ProgramId,
84
85    /// Individual compilation units (one per input file).
86    pub modules: Vec<AirModule>,
87
88    /// Cross-module symbol resolution.
89    pub link_table: LinkTable,
90}
91
92impl AirProgram {
93    /// Link multiple bundles into a unified program.
94    ///
95    /// Resolves extern function declarations to their definitions across modules.
96    /// Produces a `LinkTable` mapping declaration IDs to definition IDs.
97    pub fn link(bundles: Vec<AirBundle>) -> Self {
98        let modules: Vec<AirModule> = bundles.into_iter().map(|b| b.module).collect();
99
100        // Compute deterministic ProgramId from sorted module IDs
101        let mut module_ids: Vec<u128> = modules.iter().map(|m| m.id.raw()).collect();
102        module_ids.sort_unstable();
103        let id_bytes: Vec<u8> = module_ids.iter().flat_map(|id| id.to_le_bytes()).collect();
104        let id = ProgramId::new(make_id("program", &id_bytes));
105
106        let link_table = Self::resolve_symbols(&modules);
107
108        Self {
109            id,
110            modules,
111            link_table,
112        }
113    }
114
115    /// Produce a flattened `AirModule` for backward compatibility with
116    /// the existing single-module analysis pipeline.
117    ///
118    /// Merges all functions and globals into one module, rewriting
119    /// extern declarations that have definitions in other modules.
120    #[must_use]
121    pub fn merged_view(&self) -> AirModule {
122        let mut merged = AirModule::new(crate::ids::ModuleId::new(self.id.raw()));
123        merged.name = Some("merged".to_string());
124
125        // Collect all defined function names to skip duplicate declarations
126        let mut defined_functions: BTreeSet<String> = BTreeSet::new();
127        for module in &self.modules {
128            for func in &module.functions {
129                if !func.is_declaration {
130                    defined_functions.insert(func.name.clone());
131                }
132            }
133        }
134
135        // Merge functions: include definitions and unresolved declarations only
136        for module in &self.modules {
137            for func in &module.functions {
138                if func.is_declaration && defined_functions.contains(&func.name) {
139                    // Skip this declaration — a definition exists in another module
140                    continue;
141                }
142                merged.add_function(func.clone());
143            }
144        }
145
146        // Rewrite CallDirect callee IDs: replace declaration IDs with definition IDs
147        if !self.link_table.function_resolutions.is_empty() {
148            for func in &mut merged.functions {
149                for block in &mut func.blocks {
150                    for inst in &mut block.instructions {
151                        if let crate::air::Operation::CallDirect { callee } = &mut inst.op {
152                            if let Some(def_id) = self.link_table.function_resolutions.get(callee) {
153                                *callee = *def_id;
154                            }
155                        }
156                    }
157                }
158            }
159        }
160
161        // Merge globals: include definitions and unresolved declarations only
162        let mut defined_globals: BTreeSet<String> = BTreeSet::new();
163        for module in &self.modules {
164            for global in &module.globals {
165                if global.init.is_some() {
166                    defined_globals.insert(global.name.clone());
167                }
168            }
169        }
170        for module in &self.modules {
171            for global in &module.globals {
172                if global.init.is_none() && defined_globals.contains(&global.name) {
173                    continue;
174                }
175                merged.globals.push(global.clone());
176            }
177        }
178
179        // Merge source files (deduplicated by path)
180        let mut seen_paths: BTreeSet<String> = BTreeSet::new();
181        for module in &self.modules {
182            for sf in &module.source_files {
183                if seen_paths.insert(sf.path.clone()) {
184                    merged.source_files.push(sf.clone());
185                }
186            }
187        }
188
189        // Merge type hierarchies
190        for module in &self.modules {
191            merged.type_hierarchy.extend(module.type_hierarchy.clone());
192        }
193
194        // Merge type tables
195        for module in &self.modules {
196            for (tid, ty) in &module.types {
197                merged.types.entry(*tid).or_insert_with(|| ty.clone());
198            }
199        }
200
201        // Merge constants
202        for module in &self.modules {
203            for (vid, c) in &module.constants {
204                merged.constants.entry(*vid).or_insert_with(|| c.clone());
205            }
206        }
207
208        // Use largest target_pointer_width
209        merged.target_pointer_width = self
210            .modules
211            .iter()
212            .map(|m| m.target_pointer_width)
213            .max()
214            .unwrap_or(8);
215
216        merged.rebuild_function_index();
217        merged
218    }
219
220    /// Resolve extern declarations across modules.
221    fn resolve_symbols(modules: &[AirModule]) -> LinkTable {
222        let mut table = LinkTable::default();
223
224        // Index: function name -> Vec<(ModuleId, FunctionId, is_declaration)>
225        let mut func_index: BTreeMap<String, Vec<(ModuleId, FunctionId, bool)>> = BTreeMap::new();
226        for module in modules {
227            for func in &module.functions {
228                func_index.entry(func.name.clone()).or_default().push((
229                    module.id,
230                    func.id,
231                    func.is_declaration,
232                ));
233            }
234        }
235
236        // For each function name, find definitions and declarations
237        for (name, entries) in &func_index {
238            let definitions: Vec<_> = entries.iter().filter(|e| !e.2).collect();
239            let declarations: Vec<_> = entries.iter().filter(|e| e.2).collect();
240
241            if definitions.len() > 1 {
242                table.conflicts.push(LinkConflict {
243                    name: name.clone(),
244                    defining_modules: definitions.iter().map(|d| d.0).collect(),
245                });
246            }
247
248            if let Some(def) = definitions.first() {
249                for decl in &declarations {
250                    table.function_resolutions.insert(decl.1, def.1);
251                }
252            }
253        }
254
255        // Same for globals: name -> Vec<(ModuleId, ValueId, has_init)>
256        let mut global_index: BTreeMap<String, Vec<(ModuleId, ValueId, bool)>> = BTreeMap::new();
257        for module in modules {
258            for global in &module.globals {
259                global_index.entry(global.name.clone()).or_default().push((
260                    module.id,
261                    global.id,
262                    global.init.is_some(),
263                ));
264            }
265        }
266
267        for entries in global_index.values() {
268            let definitions: Vec<_> = entries.iter().filter(|e| e.2).collect();
269            let declarations: Vec<_> = entries.iter().filter(|e| !e.2).collect();
270
271            if let Some(def) = definitions.first() {
272                for decl in &declarations {
273                    table.global_resolutions.insert(decl.1, def.1);
274                }
275            }
276        }
277
278        table
279    }
280
281    /// Compute what changed between two programs.
282    ///
283    /// Compares modules by `ModuleId`. Modules with the same ID but different
284    /// content should have different `ModuleId`s (since IDs are content-derived).
285    pub fn diff(previous: &[ModuleId], current: &[ModuleId]) -> ProgramDiff {
286        let prev_set: BTreeSet<ModuleId> = previous.iter().copied().collect();
287        let curr_set: BTreeSet<ModuleId> = current.iter().copied().collect();
288
289        ProgramDiff {
290            added_modules: curr_set.difference(&prev_set).copied().collect(),
291            removed_modules: prev_set.difference(&curr_set).copied().collect(),
292            changed_modules: Vec::new(), // fingerprint-based diff is done at a higher level
293            unchanged_modules: curr_set.intersection(&prev_set).copied().collect(),
294            added_functions: BTreeSet::new(),
295            removed_functions: BTreeSet::new(),
296        }
297    }
298}
299
300/// Split a monolithic module into per-source-file modules.
301///
302/// Groups functions by their `Span.file_id` field (which maps to a `SourceFile`
303/// entry). Functions without span information go into an "unknown" module.
304/// If no functions have span info, returns the original module unchanged
305/// in a single-element Vec.
306pub fn split_module(module: AirModule, strategy: SplitStrategy) -> Vec<AirModule> {
307    match strategy {
308        SplitStrategy::Monolithic => vec![module],
309        SplitStrategy::ByFunction => split_by_function(module),
310        SplitStrategy::BySourceFile => split_by_source_file(module),
311        SplitStrategy::Auto => {
312            // Check if any function has span info
313            let has_spans = module.functions.iter().any(|f| f.span.is_some());
314            if has_spans {
315                split_by_source_file(module)
316            } else {
317                split_by_function(module)
318            }
319        }
320    }
321}
322
323fn split_by_function(module: AirModule) -> Vec<AirModule> {
324    if module.functions.is_empty() {
325        return vec![module];
326    }
327
328    let mut result = Vec::new();
329    for func in &module.functions {
330        let mod_id = ModuleId::new(make_id("module-split-fn", func.name.as_bytes()));
331        let mut sub = AirModule::new(mod_id);
332        sub.name = Some(func.name.clone());
333        sub.target_pointer_width = module.target_pointer_width;
334        // Copy types and constants needed by this function
335        sub.types.clone_from(&module.types);
336        sub.constants.clone_from(&module.constants);
337        sub.add_function(func.clone());
338        result.push(sub);
339    }
340
341    // Globals go into the first module
342    if !module.globals.is_empty() {
343        if let Some(first) = result.first_mut() {
344            first.globals.clone_from(&module.globals);
345        }
346    }
347
348    result
349}
350
351fn split_by_source_file(module: AirModule) -> Vec<AirModule> {
352    // Build file_id -> source_file path mapping
353    let file_paths: BTreeMap<crate::ids::FileId, String> = module
354        .source_files
355        .iter()
356        .map(|sf| (sf.id, sf.path.clone()))
357        .collect();
358
359    // Group functions by their source file
360    let mut groups: BTreeMap<String, Vec<AirFunction>> = BTreeMap::new();
361    for func in &module.functions {
362        let key = func
363            .span
364            .as_ref()
365            .and_then(|s| file_paths.get(&s.file_id))
366            .cloned()
367            .unwrap_or_else(|| "<unknown>".to_string());
368        groups.entry(key).or_default().push(func.clone());
369    }
370
371    if groups.len() <= 1 {
372        // All functions from same file (or no span info) — don't split
373        return vec![module];
374    }
375
376    let mut result = Vec::new();
377    for (path, functions) in &groups {
378        let mod_id = ModuleId::new(make_id("module-split-src", path.as_bytes()));
379        let mut sub = AirModule::new(mod_id);
380        sub.name = Some(path.clone());
381        sub.target_pointer_width = module.target_pointer_width;
382        sub.types.clone_from(&module.types);
383        sub.constants.clone_from(&module.constants);
384        for f in functions {
385            sub.add_function(f.clone());
386        }
387        result.push(sub);
388    }
389
390    // Distribute globals: for now, all globals go into the first module
391    if !module.globals.is_empty() {
392        if let Some(first) = result.first_mut() {
393            first.globals.clone_from(&module.globals);
394        }
395    }
396
397    result
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::air::{AirBundle, AirFunction, AirGlobal, AirModule, Constant};
404    use crate::id::make_id;
405    use crate::ids::{FunctionId, ModuleId, ObjId, ValueId};
406
407    fn make_function(name: &str, is_declaration: bool) -> AirFunction {
408        let id = FunctionId::new(make_id("fn", name.as_bytes()));
409        let mut f = AirFunction::new(id, name.to_string());
410        f.is_declaration = is_declaration;
411        f
412    }
413
414    fn make_module(name: &str, functions: Vec<AirFunction>) -> AirModule {
415        let id = ModuleId::new(make_id("module", name.as_bytes()));
416        let mut m = AirModule::new(id);
417        m.name = Some(name.to_string());
418        for f in functions {
419            m.add_function(f);
420        }
421        m
422    }
423
424    fn make_bundle(name: &str, functions: Vec<AirFunction>) -> AirBundle {
425        AirBundle::new("test".to_string(), make_module(name, functions))
426    }
427
428    fn make_global(name: &str, has_init: bool) -> AirGlobal {
429        let id = ValueId::new(make_id("val", name.as_bytes()));
430        let obj = ObjId::new(make_id("obj", name.as_bytes()));
431        let mut g = AirGlobal::new(id, obj, name.to_string());
432        if has_init {
433            g.init = Some(Constant::Int { value: 0, bits: 32 });
434        }
435        g
436    }
437
438    #[test]
439    fn link_empty_program() {
440        let program = AirProgram::link(vec![]);
441        assert!(program.modules.is_empty());
442        assert!(program.link_table.function_resolutions.is_empty());
443    }
444
445    #[test]
446    fn link_single_module() {
447        let bundle = make_bundle(
448            "main",
449            vec![
450                make_function("main", false),
451                make_function("printf", true), // extern
452            ],
453        );
454        let program = AirProgram::link(vec![bundle]);
455
456        assert_eq!(program.modules.len(), 1);
457        // No cross-module resolution possible with single module
458        assert!(program.link_table.function_resolutions.is_empty());
459    }
460
461    #[test]
462    fn link_resolves_extern_functions() {
463        let main_bundle = make_bundle(
464            "main",
465            vec![
466                make_function("main", false),
467                make_function("helper", true), // extern declaration
468            ],
469        );
470        let lib_bundle = make_bundle(
471            "lib",
472            vec![
473                make_function("helper", false), // definition
474            ],
475        );
476
477        let program = AirProgram::link(vec![main_bundle, lib_bundle]);
478
479        assert_eq!(program.modules.len(), 2);
480        // helper declaration in main -> helper definition in lib
481        assert_eq!(program.link_table.function_resolutions.len(), 1);
482
483        let decl_id = FunctionId::new(make_id("fn", b"helper"));
484        assert!(
485            program
486                .link_table
487                .function_resolutions
488                .contains_key(&decl_id)
489        );
490    }
491
492    #[test]
493    fn link_unresolved_extern_stays_unresolved() {
494        let bundle = make_bundle(
495            "main",
496            vec![
497                make_function("main", false),
498                make_function("printf", true), // libc — no definition anywhere
499            ],
500        );
501        let program = AirProgram::link(vec![bundle]);
502
503        // printf has no definition, so no resolution
504        assert!(program.link_table.function_resolutions.is_empty());
505    }
506
507    #[test]
508    fn link_detects_conflicting_definitions() {
509        let a = make_bundle("a", vec![make_function("foo", false)]);
510        let b = make_bundle("b", vec![make_function("foo", false)]);
511
512        let program = AirProgram::link(vec![a, b]);
513
514        assert_eq!(program.link_table.conflicts.len(), 1);
515        assert_eq!(program.link_table.conflicts[0].name, "foo");
516        assert_eq!(program.link_table.conflicts[0].defining_modules.len(), 2);
517    }
518
519    #[test]
520    fn merged_view_contains_all_definitions() {
521        let main_bundle = make_bundle(
522            "main",
523            vec![
524                make_function("main", false),
525                make_function("helper", true), // extern
526            ],
527        );
528        let lib_bundle = make_bundle(
529            "lib",
530            vec![
531                make_function("helper", false), // definition
532            ],
533        );
534
535        let program = AirProgram::link(vec![main_bundle, lib_bundle]);
536        let merged = program.merged_view();
537
538        // Should have: main (def) + helper (def), NOT helper (decl)
539        assert_eq!(merged.functions.len(), 2);
540        let names: BTreeSet<_> = merged.functions.iter().map(|f| f.name.as_str()).collect();
541        assert!(names.contains("main"));
542        assert!(names.contains("helper"));
543        // The helper should be the definition, not the declaration
544        let helper = merged.function_by_name("helper").unwrap();
545        assert!(!helper.is_declaration);
546    }
547
548    #[test]
549    fn merged_view_keeps_unresolved_declarations() {
550        let bundle = make_bundle(
551            "main",
552            vec![
553                make_function("main", false),
554                make_function("printf", true), // extern, no definition
555            ],
556        );
557
558        let program = AirProgram::link(vec![bundle]);
559        let merged = program.merged_view();
560
561        // printf stays as a declaration since no module defines it
562        assert_eq!(merged.functions.len(), 2);
563        let printf = merged.function_by_name("printf").unwrap();
564        assert!(printf.is_declaration);
565    }
566
567    #[test]
568    fn merged_view_resolves_globals() {
569        let mut main_mod = make_module("main", vec![make_function("main", false)]);
570        let decl_global = make_global("config", false); // declaration
571        main_mod.globals.push(decl_global);
572
573        let mut lib_mod = make_module("lib", vec![]);
574        let def_global = make_global("config", true); // definition (has init)
575        lib_mod.globals.push(def_global);
576
577        let program = AirProgram::link(vec![
578            AirBundle::new("test".to_string(), main_mod),
579            AirBundle::new("test".to_string(), lib_mod),
580        ]);
581        let merged = program.merged_view();
582
583        // Only the definition should appear, not the declaration
584        let config_globals: Vec<_> = merged
585            .globals
586            .iter()
587            .filter(|g| g.name == "config")
588            .collect();
589        assert_eq!(config_globals.len(), 1);
590        assert!(config_globals[0].init.is_some());
591    }
592
593    #[test]
594    fn program_id_is_deterministic() {
595        let bundles1 = vec![
596            make_bundle("a", vec![make_function("fa", false)]),
597            make_bundle("b", vec![make_function("fb", false)]),
598        ];
599        let bundles2 = vec![
600            make_bundle("b", vec![make_function("fb", false)]),
601            make_bundle("a", vec![make_function("fa", false)]),
602        ];
603
604        let p1 = AirProgram::link(bundles1);
605        let p2 = AirProgram::link(bundles2);
606
607        // Same modules in different order should produce same ProgramId
608        assert_eq!(p1.id, p2.id);
609    }
610
611    #[test]
612    fn diff_identifies_added_modules() {
613        let m1 = ModuleId::new(1);
614        let m2 = ModuleId::new(2);
615        let m3 = ModuleId::new(3);
616
617        let diff = AirProgram::diff(&[m1, m2], &[m1, m2, m3]);
618
619        assert_eq!(diff.added_modules, vec![m3]);
620        assert!(diff.removed_modules.is_empty());
621        assert_eq!(diff.unchanged_modules.len(), 2);
622    }
623
624    #[test]
625    fn diff_identifies_removed_modules() {
626        let m1 = ModuleId::new(1);
627        let m2 = ModuleId::new(2);
628
629        let diff = AirProgram::diff(&[m1, m2], &[m1]);
630
631        assert!(diff.added_modules.is_empty());
632        assert_eq!(diff.removed_modules, vec![m2]);
633        assert_eq!(diff.unchanged_modules.len(), 1);
634    }
635
636    #[test]
637    fn split_monolithic_returns_original() {
638        let module = make_module(
639            "test",
640            vec![make_function("a", false), make_function("b", false)],
641        );
642        let result = split_module(module, SplitStrategy::Monolithic);
643        assert_eq!(result.len(), 1);
644        assert_eq!(result[0].functions.len(), 2);
645    }
646
647    #[test]
648    fn split_by_function_creates_one_module_per_function() {
649        let module = make_module(
650            "test",
651            vec![
652                make_function("a", false),
653                make_function("b", false),
654                make_function("c", false),
655            ],
656        );
657        let result = split_module(module, SplitStrategy::ByFunction);
658        assert_eq!(result.len(), 3);
659        assert_eq!(result[0].functions.len(), 1);
660        assert_eq!(result[0].functions[0].name, "a");
661    }
662
663    #[test]
664    fn split_auto_uses_by_function_when_no_spans() {
665        let module = make_module(
666            "test",
667            vec![make_function("a", false), make_function("b", false)],
668        );
669        let result = split_module(module, SplitStrategy::Auto);
670        // No spans → ByFunction → 2 modules
671        assert_eq!(result.len(), 2);
672    }
673
674    #[test]
675    fn split_empty_module_returns_single() {
676        let module = make_module("test", vec![]);
677        let result = split_module(module, SplitStrategy::ByFunction);
678        assert_eq!(result.len(), 1);
679    }
680
681    #[test]
682    fn split_preserves_globals_in_first_module() {
683        let mut module = make_module(
684            "test",
685            vec![make_function("a", false), make_function("b", false)],
686        );
687        module.globals.push(make_global("g", true));
688
689        let result = split_module(module, SplitStrategy::ByFunction);
690        assert_eq!(result.len(), 2);
691        assert_eq!(result[0].globals.len(), 1);
692        assert!(result[1].globals.is_empty());
693    }
694}