一个简单的序列化协议

2021/11/06

我的一个朋友去年拿到了网易游戏的 offer, 但是入职之前网易要求完成一个小作业, 他来寻求我的帮助 …

起因

这个序列化协议来自于网易 2021 的新人培训, 要求新人按要求完成一个简单的序列化协议. 我觉得这个作业还蛮有趣的, 于是在这里记录一下. 题目要求实现一个 ProtoParser. 可以通过 buildDesc() 读取描述文本构建协议描述. 通过 dumps() 根据协议描述将 Python Object 转化为二进制序列, 通过 loads() 将二进制序列转为 Python Object.

class ProtoParser():
    ## 读取 proto 描述文件
    def buildDesc(self, filename):
        pass

    ## 序列化
    def dumps(self, obj):
        pass

    ## 反序列化
    def loads(self, binstr):
        pass

协议描述文本定义

协议描述文本表示大致如下

{
    变量类型 变量名;
    ...
    变量类型 变量名;
}

协议描述文本符合以下规则:

  1. 以 “{” 开始, 以 “}” 结束.
  2. 由若干个字段组成, 每个字段依次由变量类型, 变量名构成, 可表示为 “变量类型 变量名;";
  3. 基本变量类型和变量名之间, 至少存在 1 个空格, 其他情况下变量类型和变量名之间可存在零或多个空格;
  4. 每个字段以分号 “;” 结尾, 分号前可以存在零或多个空格;
  5. 字段和字段之间可能存在换行符, 制表符, 空格;
  6. 变量名的命名规则符合 C 语言规范;

如:

{
    string name;
    int32  id; uint16 level;
}

表示依次有 name, id, level 三个字段, “;” 为字段结束标记.

变量类型定义

  1. 变量类型分为基本类型和组合类型;
  2. 基本类型和组合类型也可以有其对应的数组类型;

基本类型

类型标识 含义 字节数
int8 8 位有符号整数 1
uint8 8 位无符号整数 1
int16 16 位有符号整数 2
uint16 16 位无符号整数 2
int32 32 位有符号整数 4
uint32 32 位无符号整数 4
float 单精度浮点 4
double 双精度浮点 8
bool 布尔 (1 真 0 假) 1
string 字符串 2 + L

组合类型

组合类型是若干个任意类型字段的组合体, 协议描述定义本身也是一个组合类型. 组合类型类似于 C 语言中的结构体, 以 “{” 开始, 以 “}” 结束, 中间可包括一个或者多个字段.

例如:

{
   {
       int32 id;
       string name;
   } pet;
   string name;
}

pet 字段就是一个组合类型, 它里面包含了 id 和 name 字段.

组合类型可嵌套, 即组合类型中的字段也可以是一个组合类型.

例如:

{
    {
       string name;
       int32  id;
       {
           string name;
           int32  id;
           {
               int32  id;
               string name;
           } skill;
       } pet;
       uint16 level;
    } player;
}

组合类型 player 中除了包含 name, id, level 基本类型定义的字段, 还包含了 pet 组合类型字段, 而 pet 组合类型也包含 skill 组合类型字段.

数组

可以用数组标识 “[]” 为基本/组合类型定义其数组. 数组分两种:

  1. 变长数组, 表示为 T[], T 表示变量类型;
  2. 定长数组, 表示为 T[N], T 表示变量类型, N 为正整数常量;

变长数组意味着其长度并非在描述文本中事先确定, 而是通过序列化的数据表达 (16 位无符号整数), 其长度不会超过 65535; 一些例子:

int16[5] x;    // x 为 5 个 int16 元素构成的数组
string[] strs; // strs 为元素个数不确定的字符串数组
{
    bool   flag;
    string name;
    int32  id;
} []AoS;       // AoS 为若干个 struct 构成的数组

思路

根据给定的代码框架, 可以比较容易的知道我们需要做哪些工作.

  1. 解析 proto 文件, 将各节点类型信息记录下来, 生成一种树形结构, 在序列化的时候使用.
  2. 在序列化的过程中, 遍历 obj, 根据上一步中存储的节点信息, 序列化 obj 的每一个字段.
  3. 反序列化的过程中, 根据第一步的树形结构, 恢复出 obj.

1. 解析 proto 文件

根据题目, 在协议的定义中, 变量类型分为基本类型和组合类型. 他们的区别仅仅在于组合类型的定义以 ‘{’ 开头, 以 ‘}’ 结束. 最顶层的协议定义实际上也是一个组合类型, 只不过没有对应的字段名字而已. 因此可以用递归下降的方式解析 proto 文件. 最后, 在一个组合类型中, 各种类型的信息可以用一个有序字典 OrderedDict 来存储 (Python3 的字典默认有序, Python2 需要导入 OrderedDict).

## 解析组合类型.
## content 为 proto 文件的内容, 类型为 str.
## cursor 为解析文本时, 字符的位置.
def parse_composite_ty(content, cursor):
    ## 组合类型的信息都存到这个有序字典中.
    ## 字典的 KEY 为字段名称 (在序列化的时候可以通过字段名称快速找到对应的类型信息)
    ## 字典的 VALUE 为字段的类型信息, 比如: 序列化的方式等.
    ty = OrderedDict()

    ## 消耗掉一个 '{'.
    cursor = eat_lbrace(content, cursor)
    ## 只要没有遇到 '}' 就一直向后解析.
    while content[cursor] != '}':
        field_ty = None

        if content[cursor] == '{':
            ## 如果又遇到了 '{', 说明这个类型的声明依旧是一个组合类型,
            ## 递归的解析下去.
            field_ty, cursor = parse_composite_ty(content, cursor)
        else:
            ## 如果没有遇到 '{', 说明是一个普通类型, 直接根据 token 解析它的类型.
            field_ty, cursor = parse_basic_ty(content, cursor)

        ## 根据题意, 在类型后面可能有 "[]" 或者 "[N]", 表示这个类型是一个数组,
        ## 需要判断一下并给出正确的类型.
        field_ty, cursor = maybe_array_ty(field_ty, content, cursor)

        ## 解析字段的名称.
        field_name, cursor = parse_name(content, cursor)

        ## 将字段的类型信息存到有序字典中.
        ty[field_name] = field_ty

        ## 消耗类型定义中的 ';'.
        cursor = eat_semicolon(content, cursor)

    ## 遇到了 '}', 说明当前组合类型的 Scope 结束了, 把 '}' 消耗掉即可.
    cursor = eat_rbrace(content, cursor)
    return ty, cursor

2. 序列化和反序列化

在构建出协议描述后, 序列化和反序列化就显得十分简单了. 只要遍历之前构造好的 OrderedDict, 按照相应的类型序列化/反序列化即可.

def serialize_helper(ty_tree, obj):
    binstr = ""
    ## 遍历上面得到的 OrderedDict.
    for k, ty in ty_tree.items():
        ## 判断是否为组合类型, 是的话递归地序列化嵌套字段
        ## 否则直接序列化即可.
        if is_composite_ty(ty):
            binstr += serialize_helper(ty, obj[k])
        else:
            binstr += ty.pack(obj[k])
    return binstr

def deserialize_helper(ty_tree, binstr):
    res = {}
    for k, ty in ty_tree.items():
        if is_composite_ty(ty):
            binstr, data = deserialize_helper(ty, binstr)
            res[k] = data
        else:
            binstr, data = ty.unpack(binstr)
            res[k] = data
    return (binstr, res)

完整代码

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import struct
from collections import OrderedDict
import binascii

ascii_table = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"
digit_table = "1234567890"

## Basic type.
class BasicType():
    def __init__(self, py_type, name, width, pack_fmt=None, pack_func=None, unpack_func=None):
        self.py_type = py_type
        self.name = name
        self.pack_fmt = pack_fmt
        self.pack_func = pack_func
        self.unpack_func = unpack_func
        self.width = width

    def is_scalar_ty(self):
        return self.width != -1

    def pack(self, v):
        if self.is_scalar_ty():
            return struct.pack(self.pack_fmt, v)
        return self.pack_func(v)

    def unpack(self, binstr):
        if self.is_scalar_ty():
            return unpack_and_consume_scalar(self.pack_fmt, binstr, self.width, self.py_type)
        return self.unpack_func(binstr)

## Array type.
class Array():
    def __init__(self, sized, inner_ty):
        ## -1 is for unsized arrays, since we have
        ## sized arrays whose size is 0. e.g., T[0]
        self.sized = sized
        self.inner_ty = inner_ty

    def is_sized(self):
        return self.sized != -1

    def pack(self, arr):
        binstr = ""
        if not self.is_sized():
            binstr = Uint16.pack(len(arr))
        for ele in arr:
            if is_composite_ty(ele):
                binstr += serialize_helper(self.inner_ty, ele)
            else:
                binstr += self.inner_ty.pack(ele)
        return binstr

    def unpack(self, binstr):
        array = []
        N = self.sized
        if N == -1:
            binstr, N = Uint16.unpack(binstr)
        for i in range(N):
            if is_composite_ty(self.inner_ty):
                binstr, data = deserialize_helper(self.inner_ty, binstr)
                array.append(data)
            else:
                binstr, data = self.inner_ty.unpack(binstr)
                array.append(data)
        return (binstr, tuple(array))

def unpack_and_consume_scalar(fmt, binstr, size, py_type):
    v = struct.unpack(fmt, binstr[:size])
    return (binstr[size:], py_type(v[0]))

def pack_str(s):
    binstr = Uint16.pack(len(s))
    for c in s:
        binstr += c
    return binstr

def unpack_str(binstr):
    binstr, size = Uint16.unpack(binstr)
    s = ""
    for i in range(size):
        s += binstr[i]
    return (binstr[size:], s)

## Basic types.
Undef  = BasicType(int,   'undef',  -1)
Int8   = BasicType(int,   'int8',   1,  pack_fmt="<b")
Uint8  = BasicType(int,   'uint8',  1,  pack_fmt="<B")
Int16  = BasicType(int,   'int16',  2,  pack_fmt="<h")
Uint16 = BasicType(int,   'uint16', 2,  pack_fmt="<H")
Int32  = BasicType(int,   'int32',  4,  pack_fmt="<i")
Uint32 = BasicType(int,   'uint32', 4,  pack_fmt="<I")
Float  = BasicType(float, 'float',  4,  pack_fmt="<f")
Double = BasicType(float, 'double', 8,  pack_fmt="<d")
Bool   = BasicType(bool,  'bool',   1,  pack_fmt="<?")
String = BasicType(str,   'string', -1, pack_func=pack_str, unpack_func=unpack_str)

def is_composite_ty(ty):
    return isinstance(ty, OrderedDict) or type(ty) == dict

def serialize_helper(ty_tree, obj):
    binstr = ""
    for k, ty in ty_tree.items():
        if is_composite_ty(ty):
            binstr += serialize_helper(ty, obj[k])
        else:
            binstr += ty.pack(obj[k])
    return binstr

def deserialize_helper(ty_tree, binstr):
    res = {}
    for k, ty in ty_tree.items():
        if is_composite_ty(ty):
            binstr, data = deserialize_helper(ty, binstr)
            res[k] = data
        else:
            binstr, data = ty.unpack(binstr)
            res[k] = data
    return (binstr, res)

def consume_blanks(content, cursor):
    if cursor >= len(content):
        return cursor
    while cursor < len(content) and (content[cursor] in ' \t\n\r'):
        cursor += 1
    return cursor

def consume_char(content, char, cursor):
    if cursor >= len(content) or content[cursor] != char:
        raise NameError("'%s' expected" % char)
    return cursor+1

def consume_token(content, cursor):
    if cursor >= len(content):
        raise NameError("token expected")
    begin = cursor
    while cursor < len(content) and (content[cursor] in ascii_table):
        cursor += 1
    return (content[begin: cursor], cursor)

def eat_char(content, char, cursor):
    cursor = consume_blanks(content, cursor)
    cursor = consume_char(content, char, cursor)
    cursor = consume_blanks(content, cursor)
    return cursor

def eat_lbrace(content, cursor):
    return eat_char(content, '{', cursor)

def eat_rbrace(content, cursor):
    return eat_char(content, '}', cursor)

def eat_lbrkt(content, cursor):
    return eat_char(content, '[', cursor)

def eat_rbrkt(content, cursor):
    return eat_char(content, ']', cursor)

def eat_semicolon(content, cursor):
    return eat_char(content, ';', cursor)

def eat_token(content, cursor):
    cursor = consume_blanks(content, cursor)
    token, cursor = consume_token(content, cursor)
    cursor = consume_blanks(content, cursor)
    return token, cursor

def parse_digits(content, cursor):
    digit = ''
    while content[cursor] in digit_table:
        digit += content[cursor]
        cursor += 1
    return int(digit), cursor

def maybe_array_ty(inner_ty, content, cursor):
    is_array = False
    sized = -1
    if content[cursor] == '[':
        is_array = True
        cursor = eat_lbrkt(content, cursor)
        if content[cursor] in digit_table:
            sized, cursor = parse_digits(content, cursor)
        cursor = eat_rbrkt(content, cursor)
    if is_array:
        return (Array(sized, inner_ty), cursor)
    return (inner_ty, cursor)

def parse_name(content, cursor):
    return eat_token(content, cursor)

def parse_basic_ty(content, cursor):
    token, cursor = eat_token(content, cursor)
    ty = None
    if token == 'string':
        ty = String
    elif token == 'int8':
        ty = Int8
    elif token == 'uint8':
        ty = Uint8
    elif token == 'int16':
        ty = Int16
    elif token == 'uint16':
        ty = Uint16
    elif token == 'int32':
        ty = Int32
    elif token == 'uint32':
        ty = Uint32
    elif token == 'bool':
        ty = Bool
    elif token == 'float':
        ty = Float
    elif token == 'double':
        ty = Double
    else:
        raise NameError("unknown type: %s" % token)

    return ty, cursor

def parse_composite_ty(content, cursor):
    cursor = eat_lbrace(content, cursor)
    ty = OrderedDict()

    while content[cursor] != '}':
        field_ty = None
        if content[cursor] == '{':
            ## Parse composite type recursively.
            field_ty, cursor = parse_composite_ty(content, cursor)
        else:
            field_ty, cursor = parse_basic_ty(content, cursor)

        field_ty, cursor = maybe_array_ty(field_ty, content, cursor)
        field_name, cursor = parse_name(content, cursor)
        ty[field_name] = field_ty
        cursor = eat_semicolon(content, cursor)

    cursor = eat_rbrace(content, cursor)
    return ty, cursor

def parse_desc(content):
    ty, cursor = parse_composite_ty(content, 0)
    if cursor != len(content):
        raise NameError("cannot consume the description file")
    return ty

class ProtoParser():
    def __init__(self):
        self.type_tree = OrderedDict()

    def buildDesc(self, filename):
        f = open(filename)
        content = f.read()
        self.type_tree = parse_desc(content)

    ## Serialize
    def dumps(self, obj):
        return binascii.hexlify(serialize_helper(self.type_tree, obj))

    ## Deserialize
    def loads(self, binstr):
        binstr = binascii.unhexlify(binstr)
        bstr, res = deserialize_helper(self.type_tree, binstr)
        if len(bstr) != 0:
            raise NameError("cannot consume all the binary string")
        return res
comments powered by Disqus