Rust 实战 - 使用套接字联网API (一)

虽然标准库已经封装好了 TcpListenerTcpStream 等基础api,但作为Rust 的爱好者,我们可以去一探究竟。本文假设你已经对 Rust 和 Linux 操作系统有了一定了解。

在 Linux 上 Rust 默认会链接的系统的 libc 以及一些其他的库,这就意味着,你可以直接使用libc中的函数。比如,你可以使用 gethostname 获取你电脑的 "hostname":

use std::os::raw::c_char;
use std::ffi::CStr;

extern {
    pub fn gethostname(name: *mut c_char, len: usize) -> i32;
}

fn main() {
    let len = 255;
    let mut buf = Vec::<u8>::with_capacity(len);
    let ptr = buf.as_mut_ptr() as *mut c_char;

    unsafe {
        gethostname(ptr, len);
        println!("{:?}", CStr::from_ptr(ptr));
    }
}

解释一下上面的代码。

extren 表示“外部块(External blocks)”,用来申明外部非 Rust 库中的符号。我们需要使用 Rust 以外的函数,比如 libc ,就需要在 extren 中将需要用到的函数定义出来,然后就可以像使用本地函数一样使用外部函数,编译器会负责帮我们转换,是不是很方便呢。但是,调用一个外部函数是unsafe的,编译器不能提供足够的保证,所以要放到unsafe块中。

如果外部函数有可变参数,可以这么申明:

extern {
    fn foo(x: i32, ...);
}

不过 Rust 中的函数目前还不支持可变参数。

实际上,这里应该是 extern "C" { .. },因为默认值就是"C",我们就可以将其省略。还有一些其他的可选值,因为这里不会用到,暂且不讨论,你可以去这儿这儿查看。

再来说说类型。“gethostname” 函数在 C 头文件中的原型是:

int gethostname(char *name, size_t len);

在 Linux 64位平台上,C中的int对应于Rust中的intsize_t对应Rust中的usize,但C中的char与Rust中的char是完全不同的,C中的char始终是i8或者u8,而 Rust 中的char是一个unicode标量值。你也可以去标准库查看。对于指针,Rust 中的裸指针 与C中的指针几乎是一样的,Rust的*mut对应C的普通指针,*const 对应C的const指针。因此我们将类型一一对应,函数的参数名称不要求一致。

pub fn gethostname(name: *mut i8, len: usize) -> i32;

但是,我们后面会使用CStr::from_ptr()将C中的字符串转换为 Rust 本地字符串,这个函数的定义是:

pub unsafe fn from_pt<'a>(ptr: *const c_char) -> &'a CStr

为了“好看”一点,我就写成了c_char,但是,c_char只是i8的别名,你写成i8也没有问题的。

type c_char = i8;

你可以看这里

不过,如果你要是考虑跨平台的话,可能需要吧 i32 换成 std::os::raw::c_int,并不是所有平台上C中的int都对应Rust中的i32。不过,如果你没有一一对应类型,一定程度上是可行的,如果没有发生越界的话。比如像这样:

use std::os::raw::c_char;
use std::ffi::CStr;

extern {
    pub fn gethostname(name: *mut c_char, len: u16) -> u16;
}

fn main() {
    let len = 255;
    let mut buf = Vec::<u8>::with_capacity(len);
    let ptr = buf.as_mut_ptr() as *mut c_char;

    unsafe {
        gethostname(ptr, len as u16);
        println!("{:?}", CStr::from_ptr(ptr));
    }
}

我把 size_tint 都对应成了 u16,这段代码是可以编译通过,并正确输出你的hostname的,但我建议,你最好是将类型一一对应上,以减少一些不必要的麻烦。当然,你把那个 *mut c_char 换成 *mut i32,也没问题,反正都是个指针,你可以试试:

use std::os::raw::c_char;
use std::ffi::CStr;

extern {
    pub fn gethostname(name: *mut i32, len: u16) -> u16;
}

fn main() {
    let len = 255;
    let mut buf = Vec::<u8>::with_capacity(len);
    let ptr = buf.as_mut_ptr() as *mut i32;

    unsafe {
        gethostname(ptr, len as u16);
        println!("{:?}", CStr::from_ptr(ptr as *const i8));
    }
}

你还可以把 Vec::<u8>换成Vec::<i32> 看看结果。

int gethostname(char *name, size_t len) 这个函数,是接收一个char数组和数组长度,也可以说成接收缓冲区和接收缓冲区的最大长度。我是创建了一个容量为255的Vec<u8>,将其可变指针转换为裸指针。你也可以创建可以长度为255的u8数组,也没有问题:

let len = 255;
    let mut buf = [0u8; 255];
    let ptr = buf.as_mut_ptr() as *mut i32;

    unsafe {
        gethostname(ptr, len as u16);
        println!("{:?}", CStr::from_ptr(ptr as *const i8));
    }

为什么这样可以,因为Rust的Slice和Vec的底层内存布局,跟C是一样的。(注意,Rust中Slice与Array的关系,就像&str与str的关系)。我们可以看看Vec和Slice在源码中的定义:

pub struct Vec<T> {
    buf: RawVec<T>,
    len: usize,
}

pub struct RawVec<T, A: Alloc = Global> {
    ptr: Unique<T>,
    cap: usize,
    a: A,
}

pub struct Unique<T: ?Sized> {
    pointer: *const T,
    _marker: PhantomData<T>,
}

struct FatPtr<T> {
    data: *const T,
    len: usize,
}

Vec是一个结构体,里面包含buflen两个字段,len用来表示Vec的长度,buf又指向另一个结构体RawVec,其中有三个字段,第三个字段a是一个Tarit,不占内存。cap用来表示Vec的容量,ptr指向另一个结构体Unique,其中的pointer字段就是一个裸指针了,_marker是给编译器看的一个标记,也不占内存,暂时不讨论这个,你可以去看文档。Slice的结构更简单,就一个裸指针和长度。

虽然RawVecUnique在标准库外部是不可见的,但我们还是能用一定的“手段”取出里面值,那就是定义一个内存布局跟Vec一样的结构体,“强行”转换。

#[derive(Debug)]
struct MyVec<T> {
    ptr: *mut T,
    cap: usize,
    len: usize
}

我定义了一个叫做MyVec的结构体,忽略了Vec中两个不占用内存的字段,他们的内存布局是相同的,在64位平台上都是24(ptr占8个,另外两个usize个8个)个字节。你可以试试:

#[derive(Debug)]
struct MyVec<T> {
    ptr: *mut T,
    cap: usize,
    len: usize
}

println!("{:?}", std::mem::size_of::<Vec<u8>>());
println!("{:?}", std::mem::size_of::<MyVec<u8>>());

我先创建一个Vec<u8>,拿到Vec<u8>的裸指针*const Vec<u8>,再将*const Vec<u8>转换为*const MyVec<u8>,之后,解引用,就能得到MyVec<u8>了。不过,解引裸指针是unsafe的,要谨慎!!! 你还可以看看标准库中讲述pointer的文档。

fn main() {
    let vec = Vec::<u8>::with_capacity(255);

    println!("vec ptr: {:?}", vec.as_ptr());

    #[derive(Debug)]
    struct MyVec<T> {
        ptr: *mut T,
        cap: usize,
        len: usize
    }

    let ptr: *const Vec<u8> = &vec;

    let my_vec_ptr: *const MyVec<u8> = ptr as _;

    unsafe {
        println!("{:?}", *my_vec_ptr);
    }
}

然后编译运行,是否可以看到类似下面的输出呢:

vec ptr: 0x557933de6b40
MyVec { ptr: 0x557933de6b40, cap: 255, len: 0 }

你可以看到,我们调用vec.as_ptr()得到的就是Vec内部的那个裸指针。

对于std::mem::size_of 相等的两个类型,你也可以使用std::mem::transmute 这个函数转换,跟上面的通过裸指针间接转换,几乎是等效的,只是会多加一个验证,如果两个类型size_of不相等的话,是无法通过编译的。这个函数是unsafe的。

你还可以继续尝试,比如把Vec<u8>转换为长度为3(或者更小更大)的usize数组,像是这样:

fn main() {
    let vec = Vec::<u8>::with_capacity(255);

    println!("vec ptr: {:?}", vec.as_ptr());

    let ptr: *const Vec<u8> = &vec;

    unsafe {
        let aaa_ptr: *const [usize; 2] = ptr as _;
        println!("{:?}", (*aaa_ptr)[0] as *const u8);
    }
}

不过,由于Rust中Vec的扩容机制,这段代码是存在一定问题的:

fn main() {
    let len = 255;
    let mut buf = Vec::<u8>::with_capacity(len);
    let ptr = buf.as_mut_ptr() as *mut c_char;

    unsafe {
        gethostname(ptr, len);
        println!("{:?}", CStr::from_ptr(ptr));
    }

    println!("{:?}", buf);
}

虽然获取到了正确的主机名,但是之后你打印buf会发现,buf是空的,这个问题留给你去探究。

你已经看到,Rust已经变得“不安全”,这又不小心又引入了另一个话题--《 Meet Safe and Unsafe》。不过,还是尽快回归正题,等之后有机会再说这个话题。

说起套接字API,主要包括TCP、UDP、SCTP相关的函数,I/O复用函数和高级I/O函数。其中大部分函数Rust标准里是没有的,如果标准库不能满足你的需求,你可以直接调用libc中的函数。实际上,标准库中,网络这一块,也基本是对libc中相关函数的封装。

先从TCP开始。TCP套接字编程主要会涉及到socketconnectbindlistenacceptclosegetsocknamegetpeername等函数。先来看看这些函数的定义:

// socket 函数用来指定期望的通信协议类型,并返回套接字描述符
int socket(int family, int type, int protocol); // 成功返回监听描述符。用来设置监听,出错为-1
// family是表示socket使用的协议类型,对于TCP,通常设置为 `AF_INET` 或`AF_INET6`,表示`IPv4`和`IPv6`
// type是创建的套接字类型,TCP是字节流套接字,所以这里设置为`SOCK_STREAM`,可选的值还有
// `SOCK_DGRAM`用于UDP,`SOCK_SEQPACKET`用于SCTP
// protocol协议的标识,可以设置为0,让系统选择默认值。可选的值有`IPPROTO_TCP`、`IPPROTO_UDP`、`IPPROTO_SCTP`

// connect 函数被客户端用来联立与TCP服务器的连接
int connect(int sockfd, const struct sockaddr *servaddr, socklen_t addrlen); // 成功返回0 ,出错为-1
// sockfd 是由 socket 函数返回的套接字描述符,第二和第三个参数分别指向一个指向套接字地址结构的指针和该指针的长度

// bind 函数把一个本地协议地址赋予一个套接字。
int bind(int sockfd, const struct sockaddr *myaddr,  socklen_t addrlen); // 成功返回0 ,出错为-1
// 第二个和第三个参数分别是指向特点于协议的地址结构的指针和指针的长度

// listen 函数把一个未连接的套接字转换成一个被动套接字,指示内核应接受指向该套接字的连接请求。
int listen(int sockfd, int backlog); // 成功返回0 ,出错为-1
// 第二个参数指定内核该为相应套接字排队的最大连接个数。

// accept 函数由TCP服务器调用,用于从已完成连接的队列头返回下一个已完成的连接。
int accept(int sockfd, struct sockaddr *cliaddr, socklen_t *addrlen); // 成功返回非负描述符,错误返回-1
// 第二个和第三个参数用来返回客户端的协议地址和该地址的大小

// close 用来关闭套接字,并终止TCP连接
int close(int sockfd); // 成功返回0 ,出错为-1

// getsockname 和 getpeername 函数返回与某个套接字关联的本地协议地址和外地协议地址
int getsockname(int sockfd,struct sockaddr *localaddr,socklen_t *addrlen); // 成功返回0 ,出错为-1
int getpeername(int sockfd,struct sockaddr *peeraddr,socklen_t *addelen); // 成功返回0 ,出错为-1

还有一对常见的函数,readwrite 用于读写数据。另外还有三对高级I/O函数,recv/sendreadv/writevrecvmsg/sendmsg等需要的时候再加。

ssize_t read(int fd, void *buf, size_t count);
ssize_t write(int fd, const void *buf, size_t count);

除了函数外,还有几个常量和sockaddr这个结构体。常量我们需要在Rust这边定义出来,只定义出需要的:

const AF_INET: i32 = 2;
const AF_INET6: i32 = 10;
const SOCK_STREAM: i32 = 1;
const IPPROTO_TCP: i32 = 6;

除了sockaddr外,还有几个与之相关的结构体,他们在C中的定是:

struct sockaddr
{
    unsigned short    int sa_family; // 地址族
    unsigned char     sa_data[14];  // 包含套接字中的目标地址和端口信息
};

struct sockaddr_in
{
    sa_family_t       sin_family;
    uint16_t          sin_port;
    struct in_addr    sin addr;
    char              sin_zero[8];
};

struct in_addr
{
    In_addr_t  s_addr;
};

struct sockaddr_in6
{
    sa_family_t       sin_family;
    in_port_t         sin6_port;
    uint32_t          sin6_flowinfo;
    struct in6_addr   sin6_addr; 
    uint32_t          sin6_scope_id;
};

struct in6_addr
{
    uint8_t           s6_addr[16]
};

struct sockaddr_storage {
    sa_family_t       ss_family;     // address family

    // all this is padding, implementation specific, ignore it:
    char              __ss_pad1[_SS_PAD1SIZE];
    int64_t           __ss_align;
    char              __ss_pad2[_SS_PAD2SIZE];
};

然后,需要在Rust中定义出布局相同的结构体:

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr {
    pub sa_family: u16,
    pub sa_data: [c_char; 14],
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr_in {
    pub sin_family: u16,
    pub sin_port: u16,
    pub sin_addr: in_addr,
    pub sin_zero: [u8; 8],
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct in_addr {
    pub s_addr: u32,
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr_in6 {
    pub sin6_family: u16,
    pub sin6_port: u16,
    pub sin6_flowinfo: u32,
    pub sin6_addr: in6_addr,
    pub sin6_scope_id: u32,
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct in6_addr {
    pub s6_addr: [u8; 16],
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr_storage {
    pub ss_family: u16,
    _unused: [u8; 126]
}

你需要在结构体前面加一个#[repr(C)]标签,以确保结构体的内存布局跟C一致,因为,Rust结构体的内存对齐规则,可能跟C是不一样的。#[derive(Debug, Clone, Copy)] 不是必须的。对于最后一个结构体sockaddr_storage,我也很迷,我不知道在Rust中如何定义出来,但是我知道它占128个字节,然后我就定义一个长度为126的u8数组,凑够128位。

接下来,继续把那几个函数定义出来:

extern {
    pub fn socket(fanily: i32, ty: i32, protocol: i32) -> i32;
    pub fn connect(sockfd: i32, servaddr: *const sockaddr, addrlen: u32) -> i32;
    pub fn bind(sockfd: i32, myaddr: *const sockaddr, addrlen: u32) -> i32;
    pub fn listen(sockfd: i32, backlog: i32);
    pub fn accept(sockfd: i32, cliaddr: *mut sockaddr, addrlen: u32) -> i32;
    pub fn close(sockfd: i32) -> i32;
    pub fn getsockname(sockfd: i32, localaddr: *mut sockaddr, addrlen: *mut u32) -> i32;
    pub fn getpeername(sockfd: i32, peeraddr: *mut sockaddr, addrlen: *mut u32) -> i32;
    pub fn read(fd: i32, buf: *mut std::ffi::c_void, count: usize) -> isize;
    pub fn write(fd: i32, buf: *const std::ffi::c_void, count: usize) -> isize;
}

对于readwrite 里的参数buf类型void, 可以使用标准库提供的 std::ffi::c_void,也可以是*mut u8/*const u8,像是下面这样:

pub fn read(fd: i32, buf: *mut u8, count: usize) -> isize;
pub fn write(fd: i32, buf: *const u8, count: usize) -> isize;

或者,既然void本身是个“动态类型”,也可以传个其他类型的指针进去的,之后你可以试试,不过可能会有点危险。

看看目前的代码:

use std::os::raw::c_char;
use std::ffi::c_void;

pub const AF_INET: i32 = 2;
pub const AF_INET6: i32 = 10;
pub const SOCK_STREAM: i32 = 1;
pub const IPPRPTO_TCP: i32 = 6;

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr {
    pub sa_family: u16,
    pub sa_data: [c_char; 14],
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr_in {
    pub sin_family: u16,
    pub sin_port: u16,
    pub sin_addr: in_addr,
    pub sin_zero: [u8; 8],
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct in_addr {
    pub s_addr: u32,
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct sockaddr_in6 {
    pub sin6_family: u16,
    pub sin6_port: u16,
    pub sin6_flowinfo: u32,
    pub sin6_addr: in6_addr,
    pub sin6_scope_id: u32,
}

#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct in6_addr {
    pub s6_addr: [u8; 16],
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct sockaddr_storage {
    pub ss_family: u16,
    _unused: [u8; 126]
}

extern {
    pub fn socket(fanily: i32, ty: i32, protocol: i32) -> i32;
    pub fn connect(sockfd: i32, servaddr: *const sockaddr, addrlen: u32) -> i32;
    pub fn bind(sockfd: i32, myaddr: *const sockaddr, addrlen: u32) -> i32;
    pub fn listen(sockfd: i32, backlog: i32);
    pub fn accept(sockfd: i32, cliaddr: *mut sockaddr, addrlen: *mut u32) -> i32;
    pub fn close(sockfd: i32) -> i32;
    pub fn getsockname(sockfd: i32, localaddr: *mut sockaddr, addrlen: *mut u32) -> i32;
    pub fn getpeername(sockfd: i32, peeraddr: *mut sockaddr, addrlen: *mut u32) -> i32;
    pub fn read(fd: i32, buf: *mut std::ffi::c_void, count: usize) -> isize;
    pub fn write(fd: i32, buf: *const std::ffi::c_void, count: usize) -> isize;
}

然后,我们可以写一个简单的服务器和客户端程序:服务器监听一个地址,客户端连接服务器,然后向服务器发送“Hello, server!”,服务器回应“Hi,client!”,客户端收到后断开连接。

fn main() {
    use std::io::Error;
    use std::mem;
    use std::thread;
    use std::time::Duration;

    thread::spawn(|| {

        // server
        unsafe {
            let socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
            if socket < 0 {
                panic!("last OS error: {:?}", Error::last_os_error());
            }

            let servaddr = sockaddr_in {
                sin_family: AF_INET as u16,
                sin_port: 8080u16.to_be(),
                sin_addr: in_addr {
                    s_addr: u32::from_be_bytes([127, 0, 0, 1]).to_be()
                },
                sin_zero: mem::zeroed()
            };

            let result = bind(socket, &servaddr as *const sockaddr_in as *const sockaddr, mem::size_of_val(&servaddr) as u32);
            if result < 0 {
                println!("last OS error: {:?}", Error::last_os_error());
                close(socket);
            }

            listen(socket, 128);

            loop {
                let mut cliaddr: sockaddr_storage = mem::zeroed();
                let mut len = mem::size_of_val(&cliaddr) as u32;

                let client_socket = accept(socket, &mut cliaddr as *mut sockaddr_storage as *mut sockaddr, &mut len);
                if client_socket < 0 {
                    println!("last OS error: {:?}", Error::last_os_error());
                    break;
                }

                thread::spawn(move || {
                    loop {
                        let mut buf = [0u8; 64];
                        let n = read(client_socket, &mut buf as *mut _ as *mut c_void, buf.len());
                        if n <= 0 {
                            break;
                        }

                        println!("{:?}", String::from_utf8_lossy(&buf[0..n as usize]));

                        let msg = b"Hi, client!";
                        let n = write(client_socket, msg as *const _ as *const c_void, msg.len());
                        if n <= 0 {
                            break;
                        }
                    }

                    close(client_socket);
                });
            }

            close(socket);
        }

    });

    thread::sleep(Duration::from_secs(1));

    // client
    unsafe {
        let socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
        if socket < 0 {
            panic!("last OS error: {:?}", Error::last_os_error());
        }

        let servaddr = sockaddr_in {
            sin_family: AF_INET as u16,
            sin_port: 8080u16.to_be(),
            sin_addr: in_addr {
                s_addr: u32::from_be_bytes([127, 0, 0, 1]).to_be()
            },
            sin_zero: mem::zeroed()
        };

        let result = connect(socket, &servaddr as *const sockaddr_in as *const sockaddr, mem::size_of_val(&servaddr) as u32);
        if result < 0 {
            println!("last OS error: {:?}", Error::last_os_error());
            close(socket);
        }

        let msg = b"Hello, server!";
        let n = write(socket, msg as *const _ as *const c_void, msg.len());
        if n <= 0 {
            println!("last OS error: {:?}", Error::last_os_error());
            close(socket);
        }

        let mut buf = [0u8; 64];
        let n = read(socket, &mut buf as *mut _ as *mut c_void, buf.len());
        if n <= 0 {
            println!("last OS error: {:?}", Error::last_os_error());
        }

        println!("{:?}", String::from_utf8_lossy(&buf[0..n as usize]));

        close(socket);
    }
}

调用外部函数是unsafe的,我为了简单省事,暂时把代码放到了一个大的unsafe {} 中,之后我们再把他们封装成safe的API。为了方便测试,我把服务器程序放到了一个线程里,然后等待1秒后,再让客户端建立连接。

std::io::Error::last_os_error 这个函数,是用来捕获函数操作失败后,内核反馈给我们的错误。

在调用bindconnect 函数时,先要创建sockaddr_in结构体,端口(sin_port)和IP地址(s_addr) 是网络字节序(big endian),于是我调用了u16u32to_be()方法将其转换为网络字节序。u32::from_be_bytes 函数是将[127u8, 0u8, 0u8, 1u8] 转换为u32整数,由于我们看到的已经是大端了,转换回去会变成小端,于是后面又调用了to_be(),你也可以直接u32::from_le_bytes([127, 0, 0, 1])。然后使用了std::mem::zeroed 函数创建一个[0u8; 8] 数组,你也可以直接[0u8; 8],在这里他们是等效的。接着,我们进行强制类型转换,将&sockaddr_in 转换为*const sockaddr_in类型,又继续转换为*const sockaddr,如果你理解了一开始“gethostname”那个例子话,这里应该很好理解。这里还可以简写成&servaddr as *const _ as *const _,编译器会自动推导类型。

在调用accept函数时,先创建了一个mut sockaddr_storage,同样进行类型转换。之所以用sockaddr_storage 而不是sockaddr_insockaddr_in6是因为sockaddr_storage这个通用结构足够大,能承载sockaddr_insockaddr_in6等任何套接字的地址结构,因此,我们如果把套接bind到一个IPv6地址上的话,这里的代码是不需要修改的。我还是用std::mem::zeroed 函数初始化sockaddr_storage,它的结构我也很迷惑,所以就借助了这个函数,这个函数是unsafe的,使用的时候要小心。你也可以继续尝试这个函数:

let mut a: Vec<u8> = unsafe { std::mem::zeroed() };
a.push(123);
println!("{:?}", a);

readwrite 时,同样要类型转换。

很多时候,类型根本“强不起来”。OK,这一节的内容就先到这里。

相关推荐