1use std::collections::BTreeMap;
8
9use crate::air::{AirType, StructField};
10use crate::ids::TypeId;
11
12const fn align_up(value: u64, align: u64) -> u64 {
16 (value + align - 1) & !(align - 1)
17}
18
19pub fn abi_alignment(ty: &AirType, types: &BTreeMap<TypeId, AirType>) -> Option<u64> {
40 abi_alignment_with_ptr(ty, types, 8)
41}
42
43pub fn abi_alignment_with_ptr(
49 ty: &AirType,
50 types: &BTreeMap<TypeId, AirType>,
51 ptr_width: u32,
52) -> Option<u64> {
53 match ty {
54 AirType::Pointer | AirType::Reference { .. } => Some(u64::from(ptr_width)),
55 AirType::Integer { bits } => {
56 let bytes = u64::from(*bits).div_ceil(8).clamp(1, 16);
58 Some(bytes.next_power_of_two())
59 }
60 AirType::Float { bits } => match bits {
61 32 => Some(4),
62 _ => Some(8),
64 },
65 AirType::Vector { element, lanes } => {
66 let elem_ty = types.get(element)?;
67 let elem_size = alloc_size_with_ptr(elem_ty, types, ptr_width)?;
68 let total = elem_size * u64::from(*lanes);
69 Some(total.next_power_of_two().min(64)) }
71 AirType::Array { element, .. } => {
72 let elem_ty = types.get(element)?;
73 abi_alignment_with_ptr(elem_ty, types, ptr_width)
74 }
75 AirType::Struct { fields, .. } => {
76 let mut max_align: u64 = 1;
77 for field in fields {
78 let field_ty = types.get(&field.field_type)?;
79 let align = abi_alignment_with_ptr(field_ty, types, ptr_width)?;
80 max_align = max_align.max(align);
81 }
82 Some(max_align)
83 }
84 AirType::Void | AirType::Function { .. } => Some(1),
85 AirType::Opaque => None,
86 }
87}
88
89pub fn alloc_size(ty: &AirType, types: &BTreeMap<TypeId, AirType>) -> Option<u64> {
110 alloc_size_with_ptr(ty, types, 8)
111}
112
113pub fn alloc_size_with_ptr(
120 ty: &AirType,
121 types: &BTreeMap<TypeId, AirType>,
122 ptr_width: u32,
123) -> Option<u64> {
124 match ty {
125 AirType::Pointer | AirType::Reference { .. } => Some(u64::from(ptr_width)),
126 AirType::Integer { bits } => Some(u64::from(*bits).div_ceil(8)),
127 AirType::Float { bits } => {
128 let bytes = u64::from(*bits) / 8;
129 Some(bytes.max(4))
130 }
131 AirType::Vector { element, lanes } => {
132 let elem_ty = types.get(element)?;
133 let elem_size = alloc_size_with_ptr(elem_ty, types, ptr_width)?;
134 Some(elem_size * u64::from(*lanes))
135 }
136 AirType::Array { element, count } => {
137 let n = (*count)?;
138 let elem_ty = types.get(element)?;
139 let elem_size = alloc_size_with_ptr(elem_ty, types, ptr_width)?;
140 Some(n * elem_size)
141 }
142 AirType::Struct { total_size, .. } => Some(*total_size),
143 AirType::Void | AirType::Function { .. } => Some(0),
144 AirType::Opaque => None,
145 }
146}
147
148pub fn compute_struct_layout(
155 fields: &[StructField],
156 types: &BTreeMap<TypeId, AirType>,
157) -> Option<(Vec<u64>, u64)> {
158 compute_struct_layout_with_ptr(fields, types, 8)
159}
160
161pub fn compute_struct_layout_with_ptr(
173 fields: &[StructField],
174 types: &BTreeMap<TypeId, AirType>,
175 ptr_width: u32,
176) -> Option<(Vec<u64>, u64)> {
177 let mut offset: u64 = 0;
178 let mut max_align: u64 = 1;
179 let mut offsets = Vec::with_capacity(fields.len());
180
181 for field in fields {
182 let field_ty = types.get(&field.field_type)?;
183 let align = abi_alignment_with_ptr(field_ty, types, ptr_width)?;
184 let size = alloc_size_with_ptr(field_ty, types, ptr_width)?;
185
186 offset = align_up(offset, align);
187 offsets.push(offset);
188 offset += size;
189 max_align = max_align.max(align);
190 }
191
192 offset = align_up(offset, max_align);
194
195 Some((offsets, offset))
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::ids::TypeId;
202
203 fn make_types() -> BTreeMap<TypeId, AirType> {
205 let mut types = BTreeMap::new();
206 types.insert(TypeId::derive(b"ptr"), AirType::Pointer);
207 types.insert(TypeId::derive(b"i8"), AirType::Integer { bits: 8 });
208 types.insert(TypeId::derive(b"i16"), AirType::Integer { bits: 16 });
209 types.insert(TypeId::derive(b"i32"), AirType::Integer { bits: 32 });
210 types.insert(TypeId::derive(b"i64"), AirType::Integer { bits: 64 });
211 types.insert(TypeId::derive(b"i128"), AirType::Integer { bits: 128 });
212 types.insert(TypeId::derive(b"f32"), AirType::Float { bits: 32 });
213 types.insert(TypeId::derive(b"f64"), AirType::Float { bits: 64 });
214 types.insert(TypeId::derive(b"void"), AirType::Void);
215 types
216 }
217
218 fn field(name: &[u8]) -> StructField {
219 StructField {
220 field_type: TypeId::derive(name),
221 byte_offset: None,
222 byte_size: None,
223 name: None,
224 }
225 }
226
227 #[test]
230 fn alignment_pointer() {
231 let types = make_types();
232 assert_eq!(abi_alignment(&AirType::Pointer, &types), Some(8));
233 }
234
235 #[test]
236 fn alignment_integers() {
237 let types = make_types();
238 assert_eq!(
239 abi_alignment(&AirType::Integer { bits: 1 }, &types),
240 Some(1)
241 );
242 assert_eq!(
243 abi_alignment(&AirType::Integer { bits: 8 }, &types),
244 Some(1)
245 );
246 assert_eq!(
247 abi_alignment(&AirType::Integer { bits: 16 }, &types),
248 Some(2)
249 );
250 assert_eq!(
251 abi_alignment(&AirType::Integer { bits: 32 }, &types),
252 Some(4)
253 );
254 assert_eq!(
255 abi_alignment(&AirType::Integer { bits: 64 }, &types),
256 Some(8)
257 );
258 assert_eq!(
259 abi_alignment(&AirType::Integer { bits: 128 }, &types),
260 Some(16)
261 );
262 }
263
264 #[test]
265 fn alignment_floats() {
266 let types = make_types();
267 assert_eq!(abi_alignment(&AirType::Float { bits: 32 }, &types), Some(4));
268 assert_eq!(abi_alignment(&AirType::Float { bits: 64 }, &types), Some(8));
269 }
270
271 #[test]
272 fn alignment_void() {
273 let types = make_types();
274 assert_eq!(abi_alignment(&AirType::Void, &types), Some(1));
275 }
276
277 #[test]
278 fn alignment_opaque_returns_none() {
279 let types = make_types();
280 assert_eq!(abi_alignment(&AirType::Opaque, &types), None);
281 }
282
283 #[test]
286 fn size_pointer() {
287 let types = make_types();
288 assert_eq!(alloc_size(&AirType::Pointer, &types), Some(8));
289 }
290
291 #[test]
292 fn size_integers() {
293 let types = make_types();
294 assert_eq!(alloc_size(&AirType::Integer { bits: 1 }, &types), Some(1));
295 assert_eq!(alloc_size(&AirType::Integer { bits: 8 }, &types), Some(1));
296 assert_eq!(alloc_size(&AirType::Integer { bits: 16 }, &types), Some(2));
297 assert_eq!(alloc_size(&AirType::Integer { bits: 32 }, &types), Some(4));
298 assert_eq!(alloc_size(&AirType::Integer { bits: 64 }, &types), Some(8));
299 assert_eq!(
300 alloc_size(&AirType::Integer { bits: 128 }, &types),
301 Some(16)
302 );
303 }
304
305 #[test]
306 fn size_void() {
307 let types = make_types();
308 assert_eq!(alloc_size(&AirType::Void, &types), Some(0));
309 }
310
311 #[test]
314 fn layout_simple_i32_ptr() {
315 let types = make_types();
318 let fields = vec![field(b"i32"), field(b"ptr")];
319 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
320 assert_eq!(offsets, vec![0, 8]);
321 assert_eq!(total, 16);
322 }
323
324 #[test]
325 fn layout_three_i8() {
326 let types = make_types();
328 let fields = vec![field(b"i8"), field(b"i8"), field(b"i8")];
329 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
330 assert_eq!(offsets, vec![0, 1, 2]);
331 assert_eq!(total, 3);
332 }
333
334 #[test]
335 fn layout_i8_i32_padding() {
336 let types = make_types();
339 let fields = vec![field(b"i8"), field(b"i32")];
340 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
341 assert_eq!(offsets, vec![0, 4]);
342 assert_eq!(total, 8);
343 }
344
345 #[test]
346 fn layout_i64_i8_tail_padding() {
347 let types = make_types();
350 let fields = vec![field(b"i64"), field(b"i8")];
351 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
352 assert_eq!(offsets, vec![0, 8]);
353 assert_eq!(total, 16);
354 }
355
356 #[test]
357 fn layout_single_ptr() {
358 let types = make_types();
360 let fields = vec![field(b"ptr")];
361 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
362 assert_eq!(offsets, vec![0]);
363 assert_eq!(total, 8);
364 }
365
366 #[test]
367 fn layout_empty_struct() {
368 let types = make_types();
370 let fields = vec![];
371 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
372 assert!(offsets.is_empty());
373 assert_eq!(total, 0);
374 }
375
376 #[test]
377 fn layout_with_opaque_field_returns_none() {
378 let types = make_types();
380 let opaque_field = StructField {
381 field_type: TypeId::derive(b"unknown"),
382 byte_offset: None,
383 byte_size: None,
384 name: None,
385 };
386 let fields = vec![field(b"i32"), opaque_field];
387 assert!(compute_struct_layout(&fields, &types).is_none());
388 }
389
390 #[test]
391 fn layout_nested_array() {
392 let mut types = make_types();
394 let arr_id = TypeId::derive(b"[10 x i32]");
395 types.insert(
396 arr_id,
397 AirType::Array {
398 element: TypeId::derive(b"i32"),
399 count: Some(10),
400 },
401 );
402 let arr_field = StructField {
403 field_type: arr_id,
404 byte_offset: None,
405 byte_size: None,
406 name: None,
407 };
408 let fields = vec![arr_field, field(b"ptr")];
409 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
410 assert_eq!(offsets, vec![0, 40]);
411 assert_eq!(total, 48);
412 }
413
414 #[test]
415 fn layout_i16_i64_i8() {
416 let types = make_types();
419 let fields = vec![field(b"i16"), field(b"i64"), field(b"i8")];
420 let (offsets, total) = compute_struct_layout(&fields, &types).unwrap();
421 assert_eq!(offsets, vec![0, 8, 16]);
422 assert_eq!(total, 24);
423 }
424
425 #[test]
426 fn alignment_reference() {
427 let types = make_types();
428 assert_eq!(
429 abi_alignment(&AirType::Reference { nullable: false }, &types),
430 Some(8)
431 );
432 assert_eq!(
433 abi_alignment(&AirType::Reference { nullable: true }, &types),
434 Some(8)
435 );
436 }
437
438 #[test]
439 fn size_reference() {
440 let types = make_types();
441 assert_eq!(
442 alloc_size(&AirType::Reference { nullable: false }, &types),
443 Some(8)
444 );
445 assert_eq!(
446 alloc_size(&AirType::Reference { nullable: true }, &types),
447 Some(8)
448 );
449 }
450
451 #[test]
452 fn alignment_and_size_vector() {
453 let mut types = make_types();
454 let f32_id = TypeId::derive(b"f32");
455 types.insert(f32_id, AirType::Float { bits: 32 });
456 let vec_ty = AirType::Vector {
458 element: f32_id,
459 lanes: 4,
460 };
461 assert_eq!(alloc_size(&vec_ty, &types), Some(16));
462 assert_eq!(abi_alignment(&vec_ty, &types), Some(16));
463 }
464
465 #[test]
468 fn layout_32bit_pointer() {
469 let types = make_types();
470 assert_eq!(alloc_size_with_ptr(&AirType::Pointer, &types, 4), Some(4));
471 assert_eq!(
472 abi_alignment_with_ptr(&AirType::Pointer, &types, 4),
473 Some(4)
474 );
475 }
476
477 #[test]
478 fn layout_32bit_reference() {
479 let types = make_types();
480 assert_eq!(
481 alloc_size_with_ptr(&AirType::Reference { nullable: false }, &types, 4),
482 Some(4)
483 );
484 assert_eq!(
485 abi_alignment_with_ptr(&AirType::Reference { nullable: true }, &types, 4),
486 Some(4)
487 );
488 }
489
490 #[test]
491 fn layout_32bit_struct_with_pointer() {
492 let types = make_types();
494 let fields = vec![field(b"i32"), field(b"ptr")];
495 let (offsets, total) = compute_struct_layout_with_ptr(&fields, &types, 4).unwrap();
496 assert_eq!(offsets, vec![0, 4]);
497 assert_eq!(total, 8);
498 }
499
500 #[test]
501 fn layout_64bit_unchanged() {
502 let types = make_types();
504 assert_eq!(
505 alloc_size_with_ptr(&AirType::Pointer, &types, 8),
506 alloc_size(&AirType::Pointer, &types)
507 );
508 assert_eq!(
509 abi_alignment_with_ptr(&AirType::Pointer, &types, 8),
510 abi_alignment(&AirType::Pointer, &types)
511 );
512 }
513}