aws_lc_rs/
ptr.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0 OR ISC
3
4use crate::aws_lc::{
5    BN_free, CMAC_CTX_free, ECDSA_SIG_free, EC_GROUP_free, EC_KEY_free, EC_POINT_free,
6    EVP_AEAD_CTX_free, EVP_CIPHER_CTX_free, EVP_PKEY_CTX_free, EVP_PKEY_free, OPENSSL_free,
7    RSA_free, BIGNUM, CMAC_CTX, ECDSA_SIG, EC_GROUP, EC_KEY, EC_POINT, EVP_AEAD_CTX,
8    EVP_CIPHER_CTX, EVP_PKEY, EVP_PKEY_CTX, RSA,
9};
10use std::marker::PhantomData;
11
12pub(crate) type LcPtr<T> = ManagedPointer<*mut T>;
13pub(crate) type DetachableLcPtr<T> = DetachablePointer<*mut T>;
14
15#[derive(Debug)]
16pub(crate) struct ManagedPointer<P: Pointer> {
17    pointer: P,
18}
19
20impl<P: Pointer> ManagedPointer<P> {
21    #[inline]
22    pub fn new<T: IntoPointer<P>>(value: T) -> Result<Self, ()> {
23        if let Some(pointer) = value.into_pointer() {
24            Ok(Self { pointer })
25        } else {
26            Err(())
27        }
28    }
29
30    pub unsafe fn as_slice(&self, len: usize) -> &[P::T] {
31        core::slice::from_raw_parts(self.pointer.as_const_ptr(), len)
32    }
33}
34
35impl<P: Pointer> Drop for ManagedPointer<P> {
36    #[inline]
37    fn drop(&mut self) {
38        self.pointer.free();
39    }
40}
41
42impl<'a, P: Pointer> From<&'a ManagedPointer<P>> for ConstPointer<'a, P::T> {
43    fn from(ptr: &'a ManagedPointer<P>) -> ConstPointer<'a, P::T> {
44        ConstPointer {
45            ptr: ptr.pointer.as_const_ptr(),
46            _lifetime: PhantomData,
47        }
48    }
49}
50
51impl<P: Pointer> ManagedPointer<P> {
52    #[inline]
53    pub fn as_const(&self) -> ConstPointer<'_, P::T> {
54        self.into()
55    }
56
57    #[inline]
58    pub fn as_const_ptr(&self) -> *const P::T {
59        self.pointer.as_const_ptr()
60    }
61
62    pub fn project_const_lifetime<'a, C>(
63        &'a self,
64        f: unsafe fn(&'a Self) -> *const C,
65    ) -> Result<ConstPointer<'a, C>, ()> {
66        let ptr = unsafe { f(self) };
67        if ptr.is_null() {
68            return Err(());
69        }
70        Ok(ConstPointer {
71            ptr,
72            _lifetime: PhantomData,
73        })
74    }
75
76    #[inline]
77    pub fn as_mut_ptr(&mut self) -> *mut P::T {
78        self.pointer.as_mut_ptr()
79    }
80}
81
82impl<P: Pointer> DetachablePointer<P> {
83    #[inline]
84    pub fn as_mut_ptr(&mut self) -> *mut P::T {
85        self.pointer.as_mut().unwrap().as_mut_ptr()
86    }
87}
88
89#[derive(Debug)]
90#[allow(clippy::module_name_repetitions)]
91pub(crate) struct DetachablePointer<P: Pointer> {
92    pointer: Option<P>,
93}
94
95impl<P: Pointer> DetachablePointer<P> {
96    #[inline]
97    pub fn new<T: IntoPointer<P>>(value: T) -> Result<Self, ()> {
98        if let Some(pointer) = value.into_pointer() {
99            Ok(Self {
100                pointer: Some(pointer),
101            })
102        } else {
103            Err(())
104        }
105    }
106
107    #[inline]
108    pub fn detach(mut self) -> P {
109        self.pointer.take().unwrap()
110    }
111}
112
113impl<P: Pointer> From<DetachablePointer<P>> for ManagedPointer<P> {
114    #[inline]
115    fn from(mut dptr: DetachablePointer<P>) -> Self {
116        match dptr.pointer.take() {
117            Some(pointer) => ManagedPointer { pointer },
118            None => {
119                // Safety: pointer is only None when DetachableLcPtr is detached or dropped
120                unreachable!()
121            }
122        }
123    }
124}
125
126impl<P: Pointer> Drop for DetachablePointer<P> {
127    #[inline]
128    fn drop(&mut self) {
129        if let Some(mut pointer) = self.pointer.take() {
130            pointer.free();
131        }
132    }
133}
134
135#[derive(Debug)]
136pub(crate) struct ConstPointer<'a, T> {
137    ptr: *const T,
138    _lifetime: PhantomData<&'a T>,
139}
140
141impl<T> ConstPointer<'static, T> {
142    pub unsafe fn new_static(ptr: *const T) -> Result<Self, ()> {
143        if ptr.is_null() {
144            return Err(());
145        }
146        Ok(ConstPointer {
147            ptr,
148            _lifetime: PhantomData,
149        })
150    }
151}
152
153impl<T> ConstPointer<'_, T> {
154    pub fn project_const_lifetime<'a, C>(
155        &'a self,
156        f: unsafe fn(&'a Self) -> *const C,
157    ) -> Result<ConstPointer<'a, C>, ()> {
158        let ptr = unsafe { f(self) };
159        if ptr.is_null() {
160            return Err(());
161        }
162        Ok(ConstPointer {
163            ptr,
164            _lifetime: PhantomData,
165        })
166    }
167
168    pub fn as_const_ptr(&self) -> *const T {
169        self.ptr
170    }
171}
172
173pub(crate) trait Pointer {
174    type T;
175
176    fn free(&mut self);
177    fn as_const_ptr(&self) -> *const Self::T;
178    fn as_mut_ptr(&mut self) -> *mut Self::T;
179}
180
181pub(crate) trait IntoPointer<P> {
182    fn into_pointer(self) -> Option<P>;
183}
184
185impl<T> IntoPointer<*mut T> for *mut T {
186    #[inline]
187    fn into_pointer(self) -> Option<*mut T> {
188        if self.is_null() {
189            None
190        } else {
191            Some(self)
192        }
193    }
194}
195
196macro_rules! create_pointer {
197    ($ty:ty, $free:path) => {
198        impl Pointer for *mut $ty {
199            type T = $ty;
200
201            #[inline]
202            fn free(&mut self) {
203                unsafe {
204                    let ptr = *self;
205                    $free(ptr.cast());
206                }
207            }
208
209            #[inline]
210            fn as_const_ptr(&self) -> *const Self::T {
211                self.cast()
212            }
213
214            #[inline]
215            fn as_mut_ptr(&mut self) -> *mut Self::T {
216                *self
217            }
218        }
219    };
220}
221
222// `OPENSSL_free` and the other `XXX_free` functions perform a zeroization of the memory when it's
223// freed. This is different than functions of the same name in OpenSSL which generally do not zero
224// memory.
225create_pointer!(u8, OPENSSL_free);
226create_pointer!(EC_GROUP, EC_GROUP_free);
227create_pointer!(EC_POINT, EC_POINT_free);
228create_pointer!(EC_KEY, EC_KEY_free);
229create_pointer!(ECDSA_SIG, ECDSA_SIG_free);
230create_pointer!(BIGNUM, BN_free);
231create_pointer!(EVP_PKEY, EVP_PKEY_free);
232create_pointer!(EVP_PKEY_CTX, EVP_PKEY_CTX_free);
233create_pointer!(RSA, RSA_free);
234create_pointer!(EVP_AEAD_CTX, EVP_AEAD_CTX_free);
235create_pointer!(EVP_CIPHER_CTX, EVP_CIPHER_CTX_free);
236create_pointer!(CMAC_CTX, CMAC_CTX_free);
237
238#[cfg(test)]
239mod tests {
240    use crate::aws_lc::BIGNUM;
241    use crate::ptr::{DetachablePointer, ManagedPointer};
242
243    #[test]
244    fn test_debug() {
245        let num = 100u64;
246        let detachable_ptr: DetachablePointer<*mut BIGNUM> =
247            DetachablePointer::try_from(num).unwrap();
248        let debug = format!("{detachable_ptr:?}");
249        assert!(debug.contains("DetachablePointer { pointer: Some("));
250
251        let lc_ptr = ManagedPointer::new(detachable_ptr.detach()).unwrap();
252        let debug = format!("{lc_ptr:?}");
253        assert!(debug.contains("ManagedPointer { pointer:"));
254    }
255}