xref: /relibc/src/platform/redox/socket.rs (revision e17c6049c6bb1cea5cd9baec3faff50e5b8e4f8b)
1 use alloc::vec::Vec;
2 use core::{cmp, mem, ptr, slice, str};
3 use syscall::{self, flag::*, Result};
4 
5 use super::{
6     super::{errno, types::*, Pal, PalSocket},
7     e, Sys,
8 };
9 use crate::header::{
10     arpa_inet::inet_aton,
11     netinet_in::{in_port_t, sockaddr_in, in_addr},
12     string::strnlen,
13     sys_socket::{constants::*, sa_family_t, sockaddr, socklen_t},
14     sys_time::timeval,
15     sys_un::sockaddr_un,
16 };
17 
18 macro_rules! bind_or_connect {
19     (bind $path:expr) => {
20         concat!("/", $path)
21     };
22     (connect $path:expr) => {
23         $path
24     };
25     ($mode:ident into, $socket:expr, $address:expr, $address_len:expr) => {{
26         let fd = bind_or_connect!($mode copy, $socket, $address, $address_len);
27 
28         let result = syscall::dup2(fd, $socket as usize, &[]);
29         let _ = syscall::close(fd);
30         if (e(result) as c_int) < 0 {
31             return -1;
32         }
33         0
34     }};
35     ($mode:ident copy, $socket:expr, $address:expr, $address_len:expr) => {{
36         if ($address_len as usize) < mem::size_of::<sa_family_t>() {
37             errno = syscall::EINVAL;
38             return -1;
39         }
40 
41         let path = match (*$address).sa_family as c_int {
42             AF_INET => {
43                 if ($address_len as usize) != mem::size_of::<sockaddr_in>() {
44                     errno = syscall::EINVAL;
45                     return -1;
46                 }
47                 let data = &*($address as *const sockaddr_in);
48                 let addr = slice::from_raw_parts(
49                     &data.sin_addr.s_addr as *const _ as *const u8,
50                     mem::size_of_val(&data.sin_addr.s_addr),
51                 );
52                 let port = in_port_t::from_be(data.sin_port);
53                 let path = format!(
54                     bind_or_connect!($mode "{}.{}.{}.{}:{}"),
55                     addr[0],
56                     addr[1],
57                     addr[2],
58                     addr[3],
59                     port
60                 );
61 
62                 path
63             },
64             AF_UNIX => {
65                 let data = &*($address as *const sockaddr_un);
66 
67                 // NOTE: It's UB to access data in given address that exceeds
68                 // the given address length.
69 
70                 let maxlen = cmp::min(
71                     // Max path length of the full-sized struct
72                     data.sun_path.len(),
73                     // Length inferred from given addrlen
74                     $address_len as usize - data.path_offset()
75                 );
76                 let len = cmp::min(
77                     // The maximum length of the address
78                     maxlen,
79                     // The first NUL byte, if any
80                     strnlen(&data.sun_path as *const _, maxlen as size_t),
81                 );
82 
83                 let addr = slice::from_raw_parts(
84                     &data.sun_path as *const _ as *const u8,
85                     len,
86                 );
87                 let path = format!(
88                     "{}",
89                     str::from_utf8(addr).unwrap()
90                 );
91                 trace!("path: {:?}", path);
92 
93                 path
94             },
95             _ => {
96                 errno = syscall::EAFNOSUPPORT;
97                 return -1;
98             },
99         };
100 
101         // Duplicate the socket, and then duplicate the copy back to the original fd
102         let fd = e(syscall::dup($socket as usize, path.as_bytes()));
103         if (fd as c_int) < 0 {
104             return -1;
105         }
106         fd
107     }};
108 }
109 
110 unsafe fn inner_af_unix(buf: &[u8], address: *mut sockaddr, address_len: *mut socklen_t) {
111     let data = &mut *(address as *mut sockaddr_un);
112 
113     data.sun_family = AF_UNIX as c_ushort;
114 
115     let path = slice::from_raw_parts_mut(
116         &mut data.sun_path as *mut _ as *mut u8,
117         data.sun_path.len(),
118     );
119 
120     let len = cmp::min(path.len(), buf.len());
121     path[..len].copy_from_slice(&buf[..len]);
122 
123     *address_len = len as socklen_t;
124 }
125 
126 unsafe fn inner_af_inet(
127     local: bool,
128     buf: &[u8],
129     address: *mut sockaddr,
130     address_len: *mut socklen_t,
131 ) {
132     let mut parts = buf.split(|c| *c == b'/');
133     if local {
134         // Skip the remote part
135         parts.next();
136     }
137     let mut unparsed_addr = Vec::from(parts.next().expect("missing address"));
138 
139     let sep = memchr::memchr(b':', &unparsed_addr).expect("missing port");
140     let (raw_addr, rest) = unparsed_addr.split_at_mut(sep);
141     let (colon, raw_port) = rest.split_at_mut(1);
142     let port = str::from_utf8(raw_port).expect("non-utf8 port").parse().expect("invalid port");
143 
144     // Make address be followed by a NUL-byte
145     colon[0] = b'\0';
146 
147     trace!("address: {:?}, port: {:?}", str::from_utf8(&raw_addr), port);
148 
149     let mut addr = in_addr::default();
150     assert_eq!(inet_aton(raw_addr.as_ptr() as *mut i8, &mut addr), 1, "inet_aton might be broken, failed to parse netstack address");
151 
152     let ret = sockaddr_in {
153         sin_family: AF_INET as sa_family_t,
154         sin_port: port,
155         sin_addr: addr,
156 
157         ..sockaddr_in::default()
158     };
159     let len = cmp::min(*address_len as usize, mem::size_of_val(&ret));
160 
161     ptr::copy_nonoverlapping(&ret as *const _ as *const u8, address as *mut u8, len);
162     *address_len = len as socklen_t;
163 }
164 
165 unsafe fn inner_get_name(
166     local: bool,
167     socket: c_int,
168     address: *mut sockaddr,
169     address_len: *mut socklen_t,
170 ) -> Result<usize> {
171     // Format: [udp|tcp:]remote/local, chan:path
172     let mut buf = [0; 256];
173     let len = syscall::fpath(socket as usize, &mut buf)?;
174     let buf = &buf[..len];
175 
176     if buf.starts_with(b"tcp:") || buf.starts_with(b"udp:") {
177         inner_af_inet(local, &buf[4..], address, address_len);
178     } else if buf.starts_with(b"chan:") {
179         inner_af_unix(&buf[5..], address, address_len);
180     } else {
181         // Socket doesn't belong to any scheme
182         panic!(
183             "socket {:?} doesn't match either tcp, udp or chan schemes",
184             str::from_utf8(buf)
185         );
186     }
187 
188     Ok(0)
189 }
190 
191 fn socket_kind(mut kind: c_int) -> (c_int, usize) {
192     let mut flags = O_RDWR;
193     if kind & SOCK_NONBLOCK == SOCK_NONBLOCK {
194         kind &= !SOCK_NONBLOCK;
195         flags |= O_NONBLOCK;
196     }
197     if kind & SOCK_CLOEXEC == SOCK_CLOEXEC {
198         kind &= !SOCK_CLOEXEC;
199         flags |= O_CLOEXEC;
200     }
201     (kind, flags)
202 }
203 
204 impl PalSocket for Sys {
205     unsafe fn accept(socket: c_int, address: *mut sockaddr, address_len: *mut socklen_t) -> c_int {
206         let stream = e(syscall::dup(socket as usize, b"listen")) as c_int;
207         if stream < 0 {
208             return -1;
209         }
210         if address != ptr::null_mut()
211             && address_len != ptr::null_mut()
212             && Self::getpeername(stream, address, address_len) < 0
213         {
214             return -1;
215         }
216         stream
217     }
218 
219     unsafe fn bind(socket: c_int, address: *const sockaddr, address_len: socklen_t) -> c_int {
220         bind_or_connect!(bind into, socket, address, address_len)
221     }
222 
223     unsafe fn connect(socket: c_int, address: *const sockaddr, address_len: socklen_t) -> c_int {
224         bind_or_connect!(connect into, socket, address, address_len)
225     }
226 
227     unsafe fn getpeername(
228         socket: c_int,
229         address: *mut sockaddr,
230         address_len: *mut socklen_t,
231     ) -> c_int {
232         e(inner_get_name(false, socket, address, address_len)) as c_int
233     }
234 
235     unsafe fn getsockname(
236         socket: c_int,
237         address: *mut sockaddr,
238         address_len: *mut socklen_t,
239     ) -> c_int {
240         e(inner_get_name(true, socket, address, address_len)) as c_int
241     }
242 
243     fn getsockopt(
244         socket: c_int,
245         level: c_int,
246         option_name: c_int,
247         option_value: *mut c_void,
248         option_len: *mut socklen_t,
249     ) -> c_int {
250         match level {
251             SOL_SOCKET => match option_name {
252                 SO_ERROR => {
253                     if option_value.is_null() {
254                         return e(Err(syscall::Error::new(syscall::EFAULT))) as c_int;
255                     }
256 
257                     if (option_len as usize) < mem::size_of::<c_int>() {
258                         return e(Err(syscall::Error::new(syscall::EINVAL))) as c_int;
259                     }
260 
261                     let error = unsafe { &mut *(option_value as *mut c_int) };
262                     //TODO: Socket nonblock connection error
263                     *error = 0;
264 
265                     return 0;
266                 }
267                 _ => (),
268             },
269             _ => (),
270         }
271 
272         eprintln!(
273             "getsockopt({}, {}, {}, {:p}, {:p})",
274             socket, level, option_name, option_value, option_len
275         );
276         e(Err(syscall::Error::new(syscall::ENOSYS))) as c_int
277     }
278 
279     fn listen(socket: c_int, backlog: c_int) -> c_int {
280         // Redox has no need to listen
281         0
282     }
283 
284     unsafe fn recvfrom(
285         socket: c_int,
286         buf: *mut c_void,
287         len: size_t,
288         flags: c_int,
289         address: *mut sockaddr,
290         address_len: *mut socklen_t,
291     ) -> ssize_t {
292         if flags != 0 {
293             errno = syscall::EOPNOTSUPP;
294             return -1;
295         }
296         if address == ptr::null_mut() || address_len == ptr::null_mut() {
297             Self::read(socket, slice::from_raw_parts_mut(buf as *mut u8, len))
298         } else {
299             let fd = e(syscall::dup(socket as usize, b"listen"));
300             if fd == !0 {
301                 return -1;
302             }
303             if Self::getpeername(fd as c_int, address, address_len) < 0 {
304                 let _ = syscall::close(fd);
305                 return -1;
306             }
307 
308             let ret = Self::read(fd as c_int, slice::from_raw_parts_mut(buf as *mut u8, len));
309             let _ = syscall::close(fd);
310             ret
311         }
312     }
313 
314     unsafe fn sendto(
315         socket: c_int,
316         buf: *const c_void,
317         len: size_t,
318         flags: c_int,
319         dest_addr: *const sockaddr,
320         dest_len: socklen_t,
321     ) -> ssize_t {
322         if flags != 0 {
323             errno = syscall::EOPNOTSUPP;
324             return -1;
325         }
326         if dest_addr == ptr::null() || dest_len == 0 {
327             Self::write(socket, slice::from_raw_parts(buf as *const u8, len))
328         } else {
329             let fd = bind_or_connect!(connect copy, socket, dest_addr, dest_len);
330             let ret = Self::write(fd as c_int, slice::from_raw_parts(buf as *const u8, len));
331             let _ = syscall::close(fd);
332             ret
333         }
334     }
335 
336     fn setsockopt(
337         socket: c_int,
338         level: c_int,
339         option_name: c_int,
340         option_value: *const c_void,
341         option_len: socklen_t,
342     ) -> c_int {
343         let set_timeout = |timeout_name: &[u8]| -> c_int {
344             if option_value.is_null() {
345                 return e(Err(syscall::Error::new(syscall::EFAULT))) as c_int;
346             }
347 
348             if (option_len as usize) < mem::size_of::<timeval>() {
349                 return e(Err(syscall::Error::new(syscall::EINVAL))) as c_int;
350             }
351 
352             let timeval = unsafe { &*(option_value as *const timeval) };
353 
354             let fd = e(syscall::dup(socket as usize, timeout_name));
355             if fd == !0 {
356                 return -1;
357             }
358 
359             let timespec = syscall::TimeSpec {
360                 tv_sec: timeval.tv_sec,
361                 tv_nsec: timeval.tv_usec * 1000,
362             };
363 
364             let ret = Self::write(fd as c_int, &timespec);
365 
366             let _ = syscall::close(fd);
367 
368             if ret >= 0 {
369                 0
370             } else {
371                 -1
372             }
373         };
374 
375         match level {
376             SOL_SOCKET => match option_name {
377                 SO_RCVTIMEO => return set_timeout(b"read_timeout"),
378                 SO_SNDTIMEO => return set_timeout(b"write_timeout"),
379                 _ => (),
380             },
381             _ => (),
382         }
383 
384         eprintln!(
385             "setsockopt({}, {}, {}, {:p}, {}) - unknown option",
386             socket, level, option_name, option_value, option_len
387         );
388         0
389     }
390 
391     fn shutdown(socket: c_int, how: c_int) -> c_int {
392         eprintln!("shutdown({}, {})", socket, how);
393         e(Err(syscall::Error::new(syscall::ENOSYS))) as c_int
394     }
395 
396     unsafe fn socket(domain: c_int, kind: c_int, protocol: c_int) -> c_int {
397         if domain != AF_INET && domain != AF_UNIX {
398             errno = syscall::EAFNOSUPPORT;
399             return -1;
400         }
401         // if protocol != 0 {
402         //     errno = syscall::EPROTONOSUPPORT;
403         //     return -1;
404         // }
405 
406         let (kind, flags) = socket_kind(kind);
407 
408         // The tcp: and udp: schemes allow using no path,
409         // and later specifying one using `dup`.
410         match (domain, kind) {
411             (AF_INET, SOCK_STREAM) => e(syscall::open("tcp:", flags)) as c_int,
412             (AF_INET, SOCK_DGRAM) => e(syscall::open("udp:", flags)) as c_int,
413             (AF_UNIX, SOCK_STREAM) => e(syscall::open("chan:", flags | O_CREAT)) as c_int,
414             _ => {
415                 errno = syscall::EPROTONOSUPPORT;
416                 -1
417             },
418         }
419     }
420 
421     fn socketpair(domain: c_int, kind: c_int, protocol: c_int, sv: &mut [c_int; 2]) -> c_int {
422         let (kind, flags) = socket_kind(kind);
423 
424         match (domain, kind) {
425             (AF_UNIX, SOCK_STREAM) => {
426                 let listener = e(syscall::open("chan:", flags | O_CREAT));
427                 if listener == !0 {
428                     return -1;
429                 }
430 
431                 // For now, chan: lets connects be instant, and instead blocks
432                 // on any I/O performed. So we don't need to mark this as
433                 // nonblocking.
434 
435                 let fd0 = e(syscall::dup(listener, b"connect"));
436                 if fd0 == !0 {
437                     let _ = syscall::close(listener);
438                     return -1;
439                 }
440 
441                 let fd1 = e(syscall::dup(listener, b"listen"));
442                 if fd1 == !0 {
443                     let _ = syscall::close(fd0);
444                     let _ = syscall::close(listener);
445                     return -1;
446                 }
447 
448                 sv[0] = fd0 as c_int;
449                 sv[1] = fd1 as c_int;
450                 0
451             },
452             _ => unsafe {
453                 eprintln!(
454                     "socketpair({}, {}, {}, {:p})",
455                     domain,
456                     kind,
457                     protocol,
458                     sv.as_mut_ptr()
459                 );
460                 errno = syscall::EPROTONOSUPPORT;
461                 -1
462             },
463         }
464     }
465 }
466