1use proc_macro2::Span;
2use quote::{quote, ToTokens};
3use syn::*;
4
5use diplomat_core::ast;
6
7mod enum_convert;
8mod transparent_convert;
9
10fn cfgs_to_stream(attrs: &[Attribute]) -> proc_macro2::TokenStream {
11 attrs
12 .iter()
13 .fold(quote!(), |prev, attr| quote!(#prev #attr))
14}
15
16fn gen_params_at_boundary(param: &ast::Param, expanded_params: &mut Vec<FnArg>) {
17 match ¶m.ty {
18 ast::TypeName::StrReference(
19 ..,
20 ast::StringEncoding::UnvalidatedUtf8
21 | ast::StringEncoding::UnvalidatedUtf16
22 | ast::StringEncoding::Utf8,
23 )
24 | ast::TypeName::PrimitiveSlice(..)
25 | ast::TypeName::StrSlice(..) => {
26 let data_type = if let ast::TypeName::PrimitiveSlice(.., prim) = ¶m.ty {
27 ast::TypeName::Primitive(*prim).to_syn().to_token_stream()
28 } else if let ast::TypeName::StrReference(
29 _,
30 ast::StringEncoding::UnvalidatedUtf8 | ast::StringEncoding::Utf8,
31 ) = ¶m.ty
32 {
33 quote! { u8 }
34 } else if let ast::TypeName::StrReference(_, ast::StringEncoding::UnvalidatedUtf16) =
35 ¶m.ty
36 {
37 quote! { u16 }
38 } else if let ast::TypeName::StrSlice(ast::StringEncoding::Utf8) = ¶m.ty {
39 quote! { &str }
41 } else if let ast::TypeName::StrSlice(ast::StringEncoding::UnvalidatedUtf8) = ¶m.ty
42 {
43 quote! { &[u8] }
45 } else if let ast::TypeName::StrSlice(ast::StringEncoding::UnvalidatedUtf16) = ¶m.ty
46 {
47 quote! { &[u16] }
49 } else {
50 unreachable!()
51 };
52 expanded_params.push(FnArg::Typed(PatType {
53 attrs: vec![],
54 pat: Box::new(Pat::Ident(PatIdent {
55 attrs: vec![],
56 by_ref: None,
57 mutability: None,
58 ident: Ident::new(&format!("{}_diplomat_data", param.name), Span::call_site()),
59 subpat: None,
60 })),
61 colon_token: syn::token::Colon(Span::call_site()),
62 ty: Box::new(
63 parse2({
64 if let ast::TypeName::PrimitiveSlice(
65 Some((_, ast::Mutability::Mutable)) | None,
66 _,
67 )
68 | ast::TypeName::StrReference(None, ..) = ¶m.ty
69 {
70 quote! { *mut #data_type }
71 } else {
72 quote! { *const #data_type }
73 }
74 })
75 .unwrap(),
76 ),
77 }));
78
79 expanded_params.push(FnArg::Typed(PatType {
80 attrs: vec![],
81 pat: Box::new(Pat::Ident(PatIdent {
82 attrs: vec![],
83 by_ref: None,
84 mutability: None,
85 ident: Ident::new(&format!("{}_diplomat_len", param.name), Span::call_site()),
86 subpat: None,
87 })),
88 colon_token: syn::token::Colon(Span::call_site()),
89 ty: Box::new(
90 parse2(quote! {
91 usize
92 })
93 .unwrap(),
94 ),
95 }));
96 }
97 o => {
98 expanded_params.push(FnArg::Typed(PatType {
99 attrs: vec![],
100 pat: Box::new(Pat::Ident(PatIdent {
101 attrs: vec![],
102 by_ref: None,
103 mutability: None,
104 ident: Ident::new(param.name.as_str(), Span::call_site()),
105 subpat: None,
106 })),
107 colon_token: syn::token::Colon(Span::call_site()),
108 ty: Box::new(o.to_syn()),
109 }));
110 }
111 }
112}
113
114fn gen_params_invocation(param: &ast::Param, expanded_params: &mut Vec<Expr>) {
115 match ¶m.ty {
116 ast::TypeName::StrReference(..)
117 | ast::TypeName::PrimitiveSlice(..)
118 | ast::TypeName::StrSlice(..) => {
119 let data_ident =
120 Ident::new(&format!("{}_diplomat_data", param.name), Span::call_site());
121 let len_ident = Ident::new(&format!("{}_diplomat_len", param.name), Span::call_site());
122
123 let tokens = if let ast::TypeName::PrimitiveSlice(lm, _) = ¶m.ty {
124 match lm {
125 Some((_, ast::Mutability::Mutable)) => quote! {
126 if #len_ident == 0 {
127 &mut []
128 } else {
129 unsafe { core::slice::from_raw_parts_mut(#data_ident, #len_ident) }
130 }
131 },
132 Some((_, ast::Mutability::Immutable)) => quote! {
133 if #len_ident == 0 {
134 &[]
135 } else {
136 unsafe { core::slice::from_raw_parts(#data_ident, #len_ident) }
137 }
138 },
139 None => quote! {
140 if #len_ident == 0 {
141 Default::default()
142 } else {
143 unsafe { alloc::boxed::Box::from_raw(core::ptr::slice_from_raw_parts_mut(#data_ident, #len_ident)) }
144 }
145 },
146 }
147 } else if let ast::TypeName::StrReference(Some(_), encoding) = ¶m.ty {
148 let encode = match encoding {
149 ast::StringEncoding::Utf8 => quote! {
150 unsafe { core::str::from_utf8_unchecked(core::slice::from_raw_parts(#data_ident, #len_ident)) }
152 },
153 _ => quote! {
154 unsafe { core::slice::from_raw_parts(#data_ident, #len_ident) }
155 },
156 };
157 quote! {
158 if #len_ident == 0 {
159 Default::default()
160 } else {
161 #encode
162 }
163 }
164 } else if let ast::TypeName::StrReference(None, encoding) = ¶m.ty {
165 let encode = match encoding {
166 ast::StringEncoding::Utf8 => quote! {
167 unsafe { core::str::from_boxed_utf8_unchecked(alloc::boxed::Box::from_raw(core::ptr::slice_from_raw_parts_mut(#data_ident, #len_ident))) }
168 },
169 _ => quote! {
170 unsafe { alloc::boxed::Box::from_raw(core::ptr::slice_from_raw_parts_mut(#data_ident, #len_ident)) }
171 },
172 };
173 quote! {
174 if #len_ident == 0 {
175 Default::default()
176 } else {
177 #encode
178 }
179 }
180 } else if let ast::TypeName::StrSlice(_) = ¶m.ty {
181 quote! {
182 if #len_ident == 0 {
183 &[]
184 } else {
185 unsafe { core::slice::from_raw_parts(#data_ident, #len_ident) }
186 }
187 }
188 } else {
189 unreachable!();
190 };
191 expanded_params.push(parse2(tokens).unwrap());
192 }
193 ast::TypeName::Result(_, _, _) => {
194 let param = ¶m.name;
195 expanded_params.push(parse2(quote!(#param.into())).unwrap());
196 }
197 _ => {
198 expanded_params.push(Expr::Path(ExprPath {
199 attrs: vec![],
200 qself: None,
201 path: Ident::new(param.name.as_str(), Span::call_site()).into(),
202 }));
203 }
204 }
205}
206
207fn gen_custom_type_method(strct: &ast::CustomType, m: &ast::Method) -> Item {
208 let self_ident = Ident::new(strct.name().as_str(), Span::call_site());
209 let method_ident = Ident::new(m.name.as_str(), Span::call_site());
210 let extern_ident = Ident::new(m.full_path_name.as_str(), Span::call_site());
211
212 let mut all_params = vec![];
213 m.params.iter().for_each(|p| {
214 gen_params_at_boundary(p, &mut all_params);
215 });
216
217 let mut all_params_invocation = vec![];
218 m.params.iter().for_each(|p| {
219 gen_params_invocation(p, &mut all_params_invocation);
220 });
221
222 let this_ident = Pat::Ident(PatIdent {
223 attrs: vec![],
224 by_ref: None,
225 mutability: None,
226 ident: Ident::new("this", Span::call_site()),
227 subpat: None,
228 });
229
230 if let Some(self_param) = &m.self_param {
231 all_params.insert(
232 0,
233 FnArg::Typed(PatType {
234 attrs: vec![],
235 pat: Box::new(this_ident.clone()),
236 colon_token: syn::token::Colon(Span::call_site()),
237 ty: Box::new(self_param.to_typename().to_syn()),
238 }),
239 );
240 }
241
242 let lifetimes = {
243 let lifetime_env = &m.lifetime_env;
244 if lifetime_env.is_empty() {
245 quote! {}
246 } else {
247 quote! { <#lifetime_env> }
248 }
249 };
250
251 let method_invocation = if m.self_param.is_some() {
252 quote! { #this_ident.#method_ident }
253 } else {
254 quote! { #self_ident::#method_ident }
255 };
256
257 let (return_tokens, maybe_into) = if let Some(return_type) = &m.return_type {
258 if let ast::TypeName::Result(ok, err, true) = return_type {
259 let ok = ok.to_syn();
260 let err = err.to_syn();
261 (
262 quote! { -> diplomat_runtime::DiplomatResult<#ok, #err> },
263 quote! { .into() },
264 )
265 } else if let ast::TypeName::Ordering = return_type {
266 let return_type_syn = return_type.to_syn();
267 (quote! { -> #return_type_syn }, quote! { as i8 })
268 } else if let ast::TypeName::Option(ty) = return_type {
269 match ty.as_ref() {
270 ast::TypeName::Box(..) | ast::TypeName::Reference(..) => {
272 let return_type_syn = return_type.to_syn();
273 (quote! { -> #return_type_syn }, quote! {})
274 }
275 _ => {
277 let ty = ty.to_syn();
278 (
279 quote! { -> diplomat_runtime::DiplomatResult<#ty, ()> },
280 quote! { .ok_or(()).into() },
281 )
282 }
283 }
284 } else {
285 let return_type_syn = return_type.to_syn();
286 (quote! { -> #return_type_syn }, quote! {})
287 }
288 } else {
289 (quote! {}, quote! {})
290 };
291
292 let writeable_flushes = m
293 .params
294 .iter()
295 .filter(|p| p.is_writeable())
296 .map(|p| {
297 let p = &p.name;
298 quote! { #p.flush(); }
299 })
300 .collect::<Vec<_>>();
301
302 let cfg = cfgs_to_stream(&m.attrs.cfg);
303
304 if writeable_flushes.is_empty() {
305 Item::Fn(syn::parse_quote! {
306 #[no_mangle]
307 #cfg
308 extern "C" fn #extern_ident#lifetimes(#(#all_params),*) #return_tokens {
309 #method_invocation(#(#all_params_invocation),*) #maybe_into
310 }
311 })
312 } else {
313 Item::Fn(syn::parse_quote! {
314 #[no_mangle]
315 #cfg
316 extern "C" fn #extern_ident#lifetimes(#(#all_params),*) #return_tokens {
317 let ret = #method_invocation(#(#all_params_invocation),*);
318 #(#writeable_flushes)*
319 ret #maybe_into
320 }
321 })
322 }
323}
324
325struct AttributeInfo {
326 repr: bool,
327 opaque: bool,
328 is_out: bool,
329}
330
331impl AttributeInfo {
332 fn extract(attrs: &mut Vec<Attribute>) -> Self {
333 let mut repr = false;
334 let mut opaque = false;
335 let mut is_out = false;
336 attrs.retain(|attr| {
337 let ident = &attr.path().segments.iter().next().unwrap().ident;
338 if ident == "repr" {
339 repr = true;
340 return true;
342 } else if ident == "diplomat" {
343 if attr.path().segments.len() == 2 {
344 let seg = &attr.path().segments.iter().nth(1).unwrap().ident;
345 if seg == "opaque" {
346 opaque = true;
347 return false;
348 } else if seg == "out" {
349 is_out = true;
350 return false;
351 } else if seg == "rust_link"
352 || seg == "out"
353 || seg == "attr"
354 || seg == "skip_if_ast"
355 || seg == "abi_rename"
356 {
357 return false;
360 } else if seg == "enum_convert" || seg == "transparent_convert" {
361 return true;
364 } else {
365 panic!("Only #[diplomat::opaque] and #[diplomat::rust_link] are supported")
366 }
367 } else {
368 panic!("#[diplomat::foo] attrs have a single-segment path name")
369 }
370 }
371 true
372 });
373
374 Self {
375 repr,
376 opaque,
377 is_out,
378 }
379 }
380}
381
382fn gen_bridge(mut input: ItemMod) -> ItemMod {
383 let module = ast::Module::from_syn(&input, true);
384 let _attrs = AttributeInfo::extract(&mut input.attrs);
386 let (brace, mut new_contents) = input.content.unwrap();
387
388 new_contents.push(parse2(quote! { use diplomat_runtime::*; }).unwrap());
389
390 new_contents.iter_mut().for_each(|c| match c {
391 Item::Struct(s) => {
392 let info = AttributeInfo::extract(&mut s.attrs);
393
394 if !info.opaque {
398 let copy = if !info.is_out {
399 quote!(#[derive(Clone, Copy)])
401 } else {
402 quote!()
403 };
404
405 let repr = if !info.repr {
406 quote!(#[repr(C)])
407 } else {
408 quote!()
409 };
410
411 *s = syn::parse_quote! {
412 #repr
413 #copy
414 #s
415 }
416 }
417 }
418
419 Item::Enum(e) => {
420 let info = AttributeInfo::extract(&mut e.attrs);
421 if info.opaque {
422 panic!("#[diplomat::opaque] not allowed on enums")
423 }
424 for v in &mut e.variants {
425 let info = AttributeInfo::extract(&mut v.attrs);
426 if info.opaque {
427 panic!("#[diplomat::opaque] not allowed on enum variants");
428 }
429 }
430 *e = syn::parse_quote! {
431 #[repr(C)]
432 #[derive(Clone, Copy)]
433 #e
434 };
435 }
436
437 Item::Impl(i) => {
438 for item in &mut i.items {
439 if let syn::ImplItem::Fn(ref mut m) = *item {
440 let info = AttributeInfo::extract(&mut m.attrs);
441 if info.opaque {
442 panic!("#[diplomat::opaque] not allowed on methods")
443 }
444 }
445 }
446 }
447 _ => (),
448 });
449
450 for custom_type in module.declared_types.values() {
451 custom_type.methods().iter().for_each(|m| {
452 new_contents.push(gen_custom_type_method(custom_type, m));
453 });
454
455 let destroy_ident = Ident::new(custom_type.dtor_name().as_str(), Span::call_site());
456
457 let type_ident = custom_type.name().to_syn();
458
459 let (lifetime_defs, lifetimes) = if let Some(lifetime_env) = custom_type.lifetimes() {
460 (
461 quote! { <#lifetime_env> },
462 lifetime_env.lifetimes_to_tokens(),
463 )
464 } else {
465 (quote! {}, quote! {})
466 };
467
468 let cfg = cfgs_to_stream(&custom_type.attrs().cfg);
469
470 new_contents.push(Item::Fn(syn::parse_quote! {
473 #[no_mangle]
474 #cfg
475 extern "C" fn #destroy_ident#lifetime_defs(this: Box<#type_ident#lifetimes>) {}
476 }));
477 }
478
479 ItemMod {
480 attrs: input.attrs,
481 vis: input.vis,
482 mod_token: input.mod_token,
483 ident: input.ident,
484 content: Some((brace, new_contents)),
485 semi: input.semi,
486 unsafety: None,
487 }
488}
489
490#[proc_macro_attribute]
492pub fn bridge(
493 _attr: proc_macro::TokenStream,
494 input: proc_macro::TokenStream,
495) -> proc_macro::TokenStream {
496 let expanded = gen_bridge(parse_macro_input!(input));
497 proc_macro::TokenStream::from(expanded.to_token_stream())
498}
499
500#[proc_macro_attribute]
509pub fn enum_convert(
510 attr: proc_macro::TokenStream,
511 input: proc_macro::TokenStream,
512) -> proc_macro::TokenStream {
513 let input_cached: proc_macro2::TokenStream = input.clone().into();
518 let expanded =
519 enum_convert::gen_enum_convert(parse_macro_input!(attr), parse_macro_input!(input));
520
521 let full = quote! {
522 #expanded
523 #input_cached
524 };
525 proc_macro::TokenStream::from(full.to_token_stream())
526}
527
528#[proc_macro_attribute]
534pub fn transparent_convert(
535 _attr: proc_macro::TokenStream,
536 input: proc_macro::TokenStream,
537) -> proc_macro::TokenStream {
538 let input_cached: proc_macro2::TokenStream = input.clone().into();
543 let expanded = transparent_convert::gen_transparent_convert(parse_macro_input!(input));
544
545 let full = quote! {
546 #expanded
547 #input_cached
548 };
549 proc_macro::TokenStream::from(full.to_token_stream())
550}
551
552#[cfg(test)]
553mod tests {
554 use std::fs::File;
555 use std::io::{Read, Write};
556 use std::process::Command;
557
558 use quote::ToTokens;
559 use syn::parse_quote;
560 use tempfile::tempdir;
561
562 use super::gen_bridge;
563
564 fn rustfmt_code(code: &str) -> String {
565 let dir = tempdir().unwrap();
566 let file_path = dir.path().join("temp.rs");
567 let mut file = File::create(file_path.clone()).unwrap();
568
569 writeln!(file, "{code}").unwrap();
570 drop(file);
571
572 Command::new("rustfmt")
573 .arg(file_path.to_str().unwrap())
574 .spawn()
575 .unwrap()
576 .wait()
577 .unwrap();
578
579 let mut file = File::open(file_path).unwrap();
580 let mut data = String::new();
581 file.read_to_string(&mut data).unwrap();
582 drop(file);
583 dir.close().unwrap();
584 data
585 }
586
587 #[test]
588 fn method_taking_str() {
589 insta::assert_snapshot!(rustfmt_code(
590 &gen_bridge(parse_quote! {
591 mod ffi {
592 struct Foo {}
593
594 impl Foo {
595 pub fn from_str(s: &DiplomatStr) {
596 unimplemented!()
597 }
598 }
599 }
600 })
601 .to_token_stream()
602 .to_string()
603 ));
604 }
605
606 #[test]
607 fn method_taking_slice() {
608 insta::assert_snapshot!(rustfmt_code(
609 &gen_bridge(parse_quote! {
610 mod ffi {
611 struct Foo {}
612
613 impl Foo {
614 pub fn from_slice(s: &[f64]) {
615 unimplemented!()
616 }
617 }
618 }
619 })
620 .to_token_stream()
621 .to_string()
622 ));
623 }
624
625 #[test]
626 fn method_taking_mutable_slice() {
627 insta::assert_snapshot!(rustfmt_code(
628 &gen_bridge(parse_quote! {
629 mod ffi {
630 struct Foo {}
631
632 impl Foo {
633 pub fn fill_slice(s: &mut [f64]) {
634 unimplemented!()
635 }
636 }
637 }
638 })
639 .to_token_stream()
640 .to_string()
641 ));
642 }
643
644 #[test]
645 fn method_taking_owned_slice() {
646 insta::assert_snapshot!(rustfmt_code(
647 &gen_bridge(parse_quote! {
648 mod ffi {
649 struct Foo {}
650
651 impl Foo {
652 pub fn fill_slice(s: Box<[u16]>) {
653 unimplemented!()
654 }
655 }
656 }
657 })
658 .to_token_stream()
659 .to_string()
660 ));
661 }
662
663 #[test]
664 fn method_taking_owned_str() {
665 insta::assert_snapshot!(rustfmt_code(
666 &gen_bridge(parse_quote! {
667 mod ffi {
668 struct Foo {}
669
670 impl Foo {
671 pub fn something_with_str(s: Box<str>) {
672 unimplemented!()
673 }
674 }
675 }
676 })
677 .to_token_stream()
678 .to_string()
679 ));
680 }
681
682 #[test]
683 fn mod_with_enum() {
684 insta::assert_snapshot!(rustfmt_code(
685 &gen_bridge(parse_quote! {
686 mod ffi {
687 enum Abc {
688 A,
689 B = 123,
690 }
691
692 impl Abc {
693 pub fn do_something(&self) {
694 unimplemented!()
695 }
696 }
697 }
698 })
699 .to_token_stream()
700 .to_string()
701 ));
702 }
703
704 #[test]
705 fn mod_with_writeable_result() {
706 insta::assert_snapshot!(rustfmt_code(
707 &gen_bridge(parse_quote! {
708 mod ffi {
709 struct Foo {}
710
711 impl Foo {
712 pub fn to_string(&self, to: &mut DiplomatWriteable) -> Result<(), ()> {
713 unimplemented!()
714 }
715 }
716 }
717 })
718 .to_token_stream()
719 .to_string()
720 ));
721 }
722
723 #[test]
724 fn mod_with_rust_result() {
725 insta::assert_snapshot!(rustfmt_code(
726 &gen_bridge(parse_quote! {
727 mod ffi {
728 struct Foo {}
729
730 impl Foo {
731 pub fn bar(&self) -> Result<(), ()> {
732 unimplemented!()
733 }
734 }
735 }
736 })
737 .to_token_stream()
738 .to_string()
739 ));
740 }
741
742 #[test]
743 fn multilevel_borrows() {
744 insta::assert_snapshot!(rustfmt_code(
745 &gen_bridge(parse_quote! {
746 mod ffi {
747 #[diplomat::opaque]
748 struct Foo<'a>(&'a str);
749
750 #[diplomat::opaque]
751 struct Bar<'b, 'a: 'b>(&'b Foo<'a>);
752
753 struct Baz<'x, 'y> {
754 foo: &'y Foo<'x>,
755 }
756
757 impl<'a> Foo<'a> {
758 pub fn new(x: &'a str) -> Box<Foo<'a>> {
759 unimplemented!()
760 }
761
762 pub fn get_bar<'b>(&'b self) -> Box<Bar<'b, 'a>> {
763 unimplemented!()
764 }
765
766 pub fn get_baz<'b>(&'b self) -> Baz<'b, 'a> {
767 Bax { foo: self }
768 }
769 }
770 }
771 })
772 .to_token_stream()
773 .to_string()
774 ));
775 }
776
777 #[test]
778 fn self_params() {
779 insta::assert_snapshot!(rustfmt_code(
780 &gen_bridge(parse_quote! {
781 mod ffi {
782 #[diplomat::opaque]
783 struct RefList<'a> {
784 data: &'a i32,
785 next: Option<Box<Self>>,
786 }
787
788 impl<'b> RefList<'b> {
789 pub fn extend(&mut self, other: &Self) -> Self {
790 unimplemented!()
791 }
792 }
793 }
794 })
795 .to_token_stream()
796 .to_string()
797 ));
798 }
799
800 #[test]
801 fn cfged_method() {
802 insta::assert_snapshot!(rustfmt_code(
803 &gen_bridge(parse_quote! {
804 mod ffi {
805 struct Foo {}
806
807 impl Foo {
808 #[cfg(feature = "foo")]
809 pub fn bar(s: u8) {
810 unimplemented!()
811 }
812 }
813 }
814 })
815 .to_token_stream()
816 .to_string()
817 ));
818
819 insta::assert_snapshot!(rustfmt_code(
820 &gen_bridge(parse_quote! {
821 mod ffi {
822 struct Foo {}
823
824 #[cfg(feature = "bar")]
825 impl Foo {
826 #[cfg(feature = "foo")]
827 pub fn bar(s: u8) {
828 unimplemented!()
829 }
830 }
831 }
832 })
833 .to_token_stream()
834 .to_string()
835 ));
836 }
837
838 #[test]
839 fn cfgd_struct() {
840 insta::assert_snapshot!(rustfmt_code(
841 &gen_bridge(parse_quote! {
842 mod ffi {
843 #[diplomat::opaque]
844 #[cfg(feature = "foo")]
845 struct Foo {}
846 #[cfg(feature = "foo")]
847 impl Foo {
848 pub fn bar(s: u8) {
849 unimplemented!()
850 }
851 }
852 }
853 })
854 .to_token_stream()
855 .to_string()
856 ));
857 }
858}