Rust 实战 - 使用套接字联网API(二)

上一节,我们已经实现了一个最小可运行版本。之所以使用Rust而不是C,是因为Rust具备了必要的抽象能力,还能获得跟C差不多的性能。这一节,我们对上一节的代码做必要的封装,顺便还能把unsafe的代码包装成safe的API。

我将上一节的源码放到了这里,你可以去查看。

还记得上一节,我们把使用到的libc中的函数socketbindconnect和结构体sockaddrsockaddr_inin_addr等,在Rust这边定义了出来。实际上,几乎libc中的函数,libc这个crate都帮我们定义好了。你可以去这里查看。编译器和标准库本身也使用了这个crate,我们也使用这个。

首先在Cargo.toml文件的[dependencies]下面加入libc = "0.2":

[dependencies]
libc = "0.2"

接着在main.rs文件上方加入use libc;,也可以use libc as c;。或者你直接简单粗暴use libc::*,并不推荐这样,除非你明确知道你使用的函数来自哪里。并将我们定义的与libc中对用的常量、函数、结构体删除。再添加libc::c::到我们使用那些常量、结构体、函数的地方。如果你是直接use libc::*,除了直接删除那部分代码外,几乎什么都不用做。目前的代码:

use std::ffi::c_void;
use libc as c;

fn main() {
    use std::io::Error;
    use std::mem;
    use std::thread;
    use std::time::Duration;

    thread::spawn(|| {

        // server
        unsafe {
            let socket = c::socket(c::AF_INET, c::SOCK_STREAM, c::IPPROTO_TCP);
            if socket < 0 {
                panic!("last OS error: {:?}", Error::last_os_error());
            }

            let servaddr = c::sockaddr_in {
                sin_family: c::AF_INET as u16,
                sin_port: 8080u16.to_be(),
                sin_addr: c::in_addr {
                    s_addr: u32::from_be_bytes([127, 0, 0, 1]).to_be()
                },
                sin_zero: mem::zeroed()
            };

            let result = c::bind(socket, &servaddr as *const c::sockaddr_in as *const c::sockaddr, mem::size_of_val(&servaddr) as u32);
            if result < 0 {
                println!("last OS error: {:?}", Error::last_os_error());
                c::close(socket);
            }

            c::listen(socket, 128);

            loop {
                let mut cliaddr: c::sockaddr_storage = mem::zeroed();
                let mut len = mem::size_of_val(&cliaddr) as u32;

                let client_socket = c::accept(socket, &mut cliaddr as *mut c::sockaddr_storage as *mut c::sockaddr, &mut len);
                if client_socket < 0 {
                    println!("last OS error: {:?}", Error::last_os_error());
                    break;
                }

                thread::spawn(move || {
                    loop {
                        let mut buf = [0u8; 64];
                        let n = c::read(client_socket, &mut buf as *mut _ as *mut c_void, buf.len());
                        if n <= 0 {
                            break;
                        }

                        println!("{:?}", String::from_utf8_lossy(&buf[0..n as usize]));

                        let msg = b"Hi, client!";
                        let n = c::write(client_socket, msg as *const _ as *const c_void, msg.len());
                        if n <= 0 {
                            break;
                        }
                    }

                    c::close(client_socket);
                });
            }

            c::close(socket);
        }

    });

    thread::sleep(Duration::from_secs(1));

    // client
    unsafe {
        let socket = c::socket(c::AF_INET, c::SOCK_STREAM, c::IPPROTO_TCP);
        if socket < 0 {
            panic!("last OS error: {:?}", Error::last_os_error());
        }

        let servaddr = c::sockaddr_in {
            sin_family: c::AF_INET as u16,
            sin_port: 8080u16.to_be(),
            sin_addr: c::in_addr {
                s_addr: u32::from_be_bytes([127, 0, 0, 1]).to_be()
            },
            sin_zero: mem::zeroed()
        };

        let result = c::connect(socket, &servaddr as *const c::sockaddr_in as *const c::sockaddr, mem::size_of_val(&servaddr) as u32);
        if result < 0 {
            println!("last OS error: {:?}", Error::last_os_error());
            c::close(socket);
        }

        let msg = b"Hello, server!";
        let n = c::write(socket, msg as *const _ as *const c_void, msg.len());
        if n <= 0 {
            println!("last OS error: {:?}", Error::last_os_error());
            c::close(socket);
        }

        let mut buf = [0u8; 64];
        let n = c::read(socket, &mut buf as *mut _ as *mut c_void, buf.len());
        if n <= 0 {
            println!("last OS error: {:?}", Error::last_os_error());
        }

        println!("{:?}", String::from_utf8_lossy(&buf[0..n as usize]));

        c::close(socket);
    }
}

你编译运行,应该能得到与上一节同样的结果。

接下来,我们尝试把上面代码中函数,封装成更具Rust风格的API,除了TCP外,也还要考虑之后把UDP、UNIX域和SCTP也增加进来。同时,我们跟标准库里 net相关的API保持一致的风格。我们暂时不考虑跨平台,只考虑Linux,因此可以大胆的将一些linux独有的API添加进来。

UNIX中一切皆文件,套接字也不例外。字节流套接字上的read和write函数所表现出来的行为,不同于通常的文件I/O。字节流套接字上调用read和write输入或输出字节数可能比请求的要少,这个现象的原因在于内核中用于套接字的缓冲区可能已经达到了极限。不过,这并不是我们正真关心的。我们来看看标准库中 File的实现:

pub struct File(FileDesc);

impl File {
    ...
    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
            self.0.read(buf)
    }

    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
            self.0.write(buf)
    }

    pub fn duplicate(&self) -> io::Result<File> {
            self.0.duplicate().map(File)
    }
    ...
}

File 是一个元组结构体,标准库已经实现了readwrite,以及duplicateduplicate很有用,用于复制出一个新的描述符。我们继续看File中"包裹的FileDesc:

pub struct FileDesc {
    fd: c_int,
}

impl File {
    ...
    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
            let ret = cvt(unsafe {
               libc::read(self.fd,
                       buf.as_mut_ptr() as *mut c_void,
                       cmp::min(buf.len(), max_len()))
            })?;
            Ok(ret as usize)
    }

    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
            let ret = cvt(unsafe {
                    libc::write(self.fd,
                        buf.as_ptr() as *const c_void,
                        cmp::min(buf.len(), max_len()))
            })?;
            Ok(ret as usize)
    }

    pub fn set_cloexec(&self) -> io::Result<()> {
            unsafe {
                    cvt(libc::ioctl(self.fd, libc::FIOCLEX))?;
                    Ok(())
            }
    }

    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
            unsafe {
                    let v = nonblocking as c_int;
                    cvt(libc::ioctl(self.fd, libc::FIONBIO, &v))?;
                    Ok(())
            }
    }
}

这一层应该是到头了,你可以看到,Rust中的File也是直接对libc的封装,不过你不用担心,一开始就提到,Rust 的ABI与C的ABI是兼容的,也就意味着Rust和C互相调用是几乎是零开销的。FileDescreadwrite中的实现,与我们之前对sockfdreadwrite基本是一样的。除了readwrite外,还有两个很有用的方法set_cloexecset_nonblocking

我把“依附于”某个类型的函数叫做方法,与普通函数不同的是,依附于某个类型的函数,必须通过它所依附的类型调用。Rust通过这种方式来实现OOP,但是与某些语言的OOP不同的是,Rust的这种实现是零开销的。也就是,你将一些函数依附到某个类型上,并不会对运行时造成额外的开销,这些都在编译时去处理。

set_cloexec方法会对描述符设置FD_CLOEXEC。我们经常会碰到需要fork子进程的情况,而且子进程很可能会继续exec新的程序。对描述符设置FD_CLOEXEC,就意味着,我们fork子进程时,父子进程中相同的文件描述符指向系统文件表的同一项,但是,我们如果调用exec执行另一个程序,此时会用全新的程序替换子进程的正文。为了较少不必要的麻烦,我们以后要对打开的描述符设置FD_CLOEXEC,除非遇到特殊情况。

set_nonblocking用于将描述符设置为非阻塞模式,如果我们要使用poll、epoll等api的话。

既然标准库已经封装好了FileDesc,我想直接使用的,然而FileDesc在标准库之外是不可见的。如果使用File的话,set_cloexecset_nonblocking 还是要我们再写一次,但是File并不是“我自己”的类型,我没法直接给File附加方法,为此还需要一个额外的Tarit或者用一个“我自己”的类型,去包裹它。挺绕的。那既然这样,我们还是自己来吧。不过我们已经有了参考,可以将标准库里的FileDecs直接复制出来,然后去掉与Linux无关的代码,当然你也可以自由发挥一下。

要注意的是,这段代码中还调用了一个函数cvt,我们把相关代码也复制过来:

use std::io::{self, ErrorKind};

#[doc(hidden)]
pub trait IsMinusOne {
    fn is_minus_one(&self) -> bool;
}

macro_rules! impl_is_minus_one {
    ($($t:ident)*) => ($(impl IsMinusOne for $t {
        fn is_minus_one(&self) -> bool {
            *self == -1
        }
    })*)
}

impl_is_minus_one! { i8 i16 i32 i64 isize }

pub fn cvt<T: IsMinusOne>(t: T) -> io::Result<T> {
    if t.is_minus_one() {
        Err(io::Error::last_os_error())
    } else {
        Ok(t)
    }
}

pub fn cvt_r<T, F>(mut f: F) -> io::Result<T>
    where T: IsMinusOne,
          F: FnMut() -> T
{
    loop {
        match cvt(f()) {
            Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
            other => return other,
        }
    }
}

还记得上一节我们使用过的last_os_error()方法么,这段代码通过宏impl_is_minus_onei32等常见类型实现了IsMinusOne这个Tarit,然后我们就可以使用cvt函数更便捷得调用last_os_error()取得错误。 我将这段代码放到util.rs文件中,并在main.rs文件上方加入pub mod util;

然后再来看FileDesc最终的实现:

use std::mem;
use std::io;
use std::cmp;
use std::os::unix::io::FromRawFd;

use libc as c;

use crate::util::cvt;

#[derive(Debug)]
pub struct FileDesc(c::c_int);

pub fn max_len() -> usize {
    <c::ssize_t>::max_value() as usize
}

impl FileDesc {
    pub fn raw(&self) -> c::c_int {
        self.0
    }

    pub fn into_raw(self) -> c::c_int {
        let fd = self.0;
        mem::forget(self);
        fd
    }

    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
        let ret = cvt(unsafe {
            c::read(
                self.0,
                buf.as_mut_ptr() as *mut c::c_void,
                cmp::min(buf.len(), max_len())
            )
        })?;

        Ok(ret as usize)
    }

    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
        let ret = cvt(unsafe {
            c::write(
                self.0,
                buf.as_ptr() as *const c::c_void,
                cmp::min(buf.len(), max_len())
            )
        })?;

        Ok(ret as usize)
    }

    pub fn get_cloexec(&self) -> io::Result<bool> {
        unsafe {
            Ok((cvt(libc::fcntl(self.0, c::F_GETFD))? & libc::FD_CLOEXEC) != 0)
        }
    }

    pub fn set_cloexec(&self) -> io::Result<()> {
        unsafe {
            cvt(c::ioctl(self.0, c::FIOCLEX))?;
            Ok(())
        }
    }

    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
        unsafe {
            let v = nonblocking as c::c_int;
            cvt(c::ioctl(self.0, c::FIONBIO, &v))?;
            Ok(())
        }
    }

    pub fn duplicate(&self) -> io::Result<FileDesc> {
        cvt(unsafe { c::fcntl(self.0, c::F_DUPFD_CLOEXEC, 0) }).and_then(|fd| {
            let fd = FileDesc(fd);
            Ok(fd)
        })
    }
}

impl FromRawFd for FileDesc {
    unsafe fn from_raw_fd(fd: c::c_int) -> FileDesc {
        FileDesc(fd)
    }
}

impl Drop for FileDesc {
    fn drop(&mut self) {
        let _ = unsafe { c::close(self.0) };
    }
}

我已经将与Linux不相关的代码删除掉了。之所以原有duplicate那么冗长,是因为旧的Linux内核不支持F_DUPFD_CLOEXEC这个设置。fcntl这个函数,用来设置控制文件描述符的选项,我们稍后还会遇到用来设置和获取套接字的getsockoptsetsockopt。还有read_atwrite_at等实现比较复杂的函数,我们用不到,也将他们删除。还有impl<'a> Read for &'a FileDesc ,因为内部使了一个Unstable的API,我也将其去掉了。

我自由发挥了一下,把:

pub struct FileDesc {
    fd: c_int,
}

替换成了:

pub struct FileDesc(c::c_int);

它们是等效的。不知你注意到没有,我把pub fn new(...)函数给去掉了,因为这个函数是unsafe的----如果我们今后将这些代码作为库让别人使用的话,他可能传入了一个不存在的描述符,并由此可能引起程序崩溃----但他们并不一定知道。我们可以通过在这个函数前面加unsafe来告诉使用者这个函数是unsafe的: pub unsafe fn new(...)。不过,Rust的开发者们已经考虑到了这一点,我们用约定俗成的from_raw_fd来代替pub unsafe fn new(...),于是才有了下面这一段:

impl FromRawFd for FileDesc {
    unsafe fn from_raw_fd(fd: c::c_int) -> FileDesc {
        FileDesc(fd)
    }
}

最后,还利用Rust的drop实现了close函数,也就意味着,描述符离开作用域后,会自动close,就不再需要我们手动close了。与之先关的是into_raw方法,意思是把FileDesc转换为“未加工的”或者说是“裸的”描述符,也就是C的描述符。这个方法里面调用了forget,之后变量离开作用域后,就不会调用drop了。当你使用这个方法拿到描述符,使用完请不要忘记手动close或者再次from_raw_fd

pub fn into_raw(self) -> c::c_int {
        let fd = self.0;
        mem::forget(self);
        fd
}

我将这段代码放到了一个新的文件fd.rs中,并在main.rs文件上方加入pub mod fd;

接着,我们还需一个socket类型,将socketbindconnect等函数附加上去。这一步应该简单多了。同时你也会发现,我们已经把unsafe的代码,封装成了safe的代码。

use std::io;
use std::mem;
use std::os::unix::io::{RawFd, AsRawFd, FromRawFd};

use libc as c;

use crate::fd::FileDesc;
use crate::util::cvt;

pub struct Socket(FileDesc);

impl Socket {
    pub fn new(family: c::c_int, ty: c::c_int, protocol: c::c_int) -> io::Result<Socket> {
        unsafe {
            cvt(c::socket(family, ty | c::SOCK_CLOEXEC, protocol))
                .map(|fd| Socket(FileDesc::from_raw_fd(fd)))
        }
    }

    pub fn bind(&self, storage: *const c::sockaddr, len: c::socklen_t) -> io::Result<()> {
        self.setsockopt(c::SOL_SOCKET, c::SO_REUSEADDR, 1)?;

        cvt(unsafe { c::bind(self.0.raw(), storage, len) })?;

        Ok(())
    }

    pub fn listen(&self, backlog: c::c_int) -> io::Result<()> {
        cvt(unsafe { c::listen(self.0.raw(), backlog) })?;
        Ok(())
    }

    pub fn accept(&self, storage: *mut c::sockaddr, len: *mut c::socklen_t) -> io::Result<Socket> {
        let fd = cvt(unsafe { c::accept4(self.0.raw(), storage, len, c::SOCK_CLOEXEC) })?;
        Ok(Socket(unsafe { FileDesc::from_raw_fd(fd) }))
    }

    pub fn connect(&self, storage: *const c::sockaddr, len: c::socklen_t) -> io::Result<()> {
        cvt(unsafe { c::connect(self.0.raw(), storage, len) })?;
        Ok(())
    }

    pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
        self.0.read(buf)
    }

    pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
        self.0.write(buf)
    }

    pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
        self.0.set_nonblocking(nonblocking)
    }

    pub fn get_cloexec(&self) -> io::Result<bool> {
        self.0.get_cloexec()
    }

    pub fn set_cloexec(&self) -> io::Result<()> {
        self.0.set_cloexec()
    }

    pub fn setsockopt<T>(&self, opt: libc::c_int, val: libc::c_int, payload: T) -> io::Result<()> {
        unsafe {
            let payload = &payload as *const T as *const libc::c_void;

            cvt(libc::setsockopt(
                self.0.raw(),
                opt,
                val,
                payload,
                mem::size_of::<T>() as libc::socklen_t
            ))?;

            Ok(())
        }
    }

    pub fn getsockopt<T: Copy>(&self, opt: libc::c_int, val: libc::c_int) -> io::Result<T> {
        unsafe {
            let mut slot: T = mem::zeroed();
            let mut len = mem::size_of::<T>() as libc::socklen_t;

            cvt(libc::getsockopt(
                self.0.raw(),
                opt,
                val,
                &mut slot as *mut T as *mut libc::c_void,
                &mut len
            ))?;

            assert_eq!(len as usize, mem::size_of::<T>());
            Ok(slot)
        }
    }
}

impl FromRawFd for Socket {
    unsafe fn from_raw_fd(fd: RawFd) -> Socket {
        Socket(FileDesc::from_raw_fd(fd))
    }
}

impl AsRawFd for Socket {
    fn as_raw_fd(&self) -> RawFd {
        self.0.raw()
    }
}

我已经将上一节中我们使用到的socket相关的主要的5个函数,外加readwrite,等几个描述符设置的函数,“依附”到了socket上。保存在 socket.rs 文件里。

要说明的是,我在newaccept方法中,通过flags直接为新创建的描述符设置了SOCK_CLOEXEC选项,如果不想一步设置的话,就需要创建出描述符后,再调用set_cloexec方法。bind中,在调用c::bind之前,我给套接字设置了个选项SO_REUSEADDR,意为允许重用本地地址,这里不展开讲,如果你细心的话就会发现,上一节的例子,如果没有正常关闭socket的话,就可能会出现error:98,Address already in use,等一会儿才会好。accept4不是个标准的方法,只有Linux才支持,我们暂时不考虑兼容性。setsockoptgetsockopt方法中涉及到了类型转换,结合前面的例子,这里应该难不倒你了。除了from_raw_fd,我还又给socket实现了又一个约定俗成的方法as_raw_fd

我已经将远吗放到了这里,你可以去查看。你还可以尝试将上一节的例子,修改成我们今天封装的socket。这一节到这里就结束了。

相关推荐