基于函数注解语法实现参数类型检查功能

​ 最近在学习FastAPI的使用,发现其极力推广使用类型注释,其中的pydantic 提供的 BaseModel 更能够进行一些参数验证。如果我们只写了一个函数,如何根据函数类型注释来实现检测功能呢?

函数注解(Function Annotations)

函数注解语法 可以让你在定义函数的时候对参数和返回值添加注解:

1
2
def foobar(a: int, b: "it's b", c: str = 5) -> tuple:
return a, b, c
  • a: int 这种是注解参数
  • c: str = 5 是注解有默认值的参数
  • -> tuple 是注解返回值。

注解的内容既可以是个类型也可以是个字符串,甚至表达式:

1
2
def foobar(a: 1+1) -> 2 * 2:
return a

那么如何获取我们定义的函数注解呢?至少有两种办法:

  • __annotations__:

    1
    2
    In [18]: foobar.__annotations__
    Out[18]: {'a': int, 'b': "it's b", 'c': str, 'return': tuple}
  • inspect.signature:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    In [22]: import inspect

    In [23]: sig = inspect.signature(foobar)

    # 通过签名获取函数参数
    In [24]: sig.parameters
    Out[24]:
    mappingproxy({'a': <Parameter "a: int">,
    'b': <Parameter "b: "it's b"">,
    'c': <Parameter "c: str = 5">})

    # 获取函数参数注解
    In [25]: for k, v in sig.parameters.items():
    ...: print('{k}: {a!r}'.format(k=k, a=v.annotation))
    ...:
    a: <class 'int'>
    b: "it's b"
    c: <class 'str'>

    # 返回值注解
    In [26]: sig.return_annotation
    Out[26]: tuple

既然可以得到函数中定义的注解,那么我们就可以用它进行参数类型检查了。

类型检查

Python 解释器并不会基于函数注解来自动进行类型检查,需要我们自己去实现类型检查功能:

1
2
3
4
5
6
In [27]: foobar.__annotations__
Out[27]: {'a': int, 'b': "it's b", 'c': str, 'return': tuple}

# 即使和类型注释的类型不一致 也是可以赋值的
In [28]: foobar(a='a', b=2, c=3)
Out[28]: ('a', 2, 3)

既然通过 inspect.signature 我们可以获取函数定义的参数的顺序以及函数注解, 那么我们就可以通过定义一个装饰器来检查传入函数的参数的类型是否跟函数注解相符, 这里实现的装饰器函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import collections
import functools
import inspect


def check(func):
msg = ('Expected type {expected!r} for argument {argument}, '
'but got type {got!r} with value {value!r}')
# 获取函数定义的参数
sig = inspect.signature(func)
parameters = sig.parameters # 参数有序字典
arg_keys = tuple(parameters.keys()) # 参数名称

@functools.wraps(func)
def wrapper(*args, **kwargs):
CheckItem = collections.namedtuple('CheckItem', ('anno', 'arg_name', 'value'))
check_list = []

# collect args *args 传入的参数以及对应的函数参数注解
for i, value in enumerate(args):
arg_name = arg_keys[i]
anno = parameters[arg_name].annotation
check_list.append(CheckItem(anno, arg_name, value))

# collect kwargs **kwargs 传入的参数以及对应的函数参数注解
for arg_name, value in kwargs.items():
anno = parameters[arg_name].annotation
check_list.append(CheckItem(anno, arg_name, value))

# check type
for item in check_list:
if not isinstance(item.value, item.anno):
error = msg.format(expected=item.anno, argument=item.arg_name,
got=type(item.value), value=item.value)
raise TypeError(error)

return func(*args, **kwargs)

return wrapper

下面来测试一下我们的装饰器

顺序传参测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
In [30]: @check
...: def foobar(a: int, b: str, c: float = 3.2) -> tuple:
...: return a, b, c
...:

In [31]: foobar(1, 'b')
Out[31]: (1, 'b', 3.2)

In [32]: foobar(1, 'b', 3.5)
Out[32]: (1, 'b', 3.5)

# 参数类型和类型注释不一致的时候 会报错
In [33]: foobar('a', 'b')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-33-35d6f7d34e34> in <module>
----> 1 foobar('a', 'b')

<ipython-input-29-fda063f559cc> in wrapper(*args, **kwargs)
33 error = msg.format(expected=item.anno, argument=item.arg_name,
34 got=type(item.value), value=item.value)
---> 35 raise TypeError(error)
36
37 return func(*args, **kwargs)

TypeError: Expected type <class 'int'> for argument a, but got type <class 'str'> with value 'a'

关键字传参:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
In [34]: foobar(b='b', a=2)
Out[34]: (2, 'b', 3.2)

In [35]: foobar(b='b', a=2, c=3.5)
Out[35]: (2, 'b', 3.5)

# 参数类型和类型注释不一致的时候 会报错
In [36]: foobar(a='foo', b='bar')
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-36-8764071d3767> in <module>
----> 1 foobar(a='foo', b='bar')

<ipython-input-29-fda063f559cc> in wrapper(*args, **kwargs)
33 error = msg.format(expected=item.anno, argument=item.arg_name,
34 got=type(item.value), value=item.value)
---> 35 raise TypeError(error)
36
37 return func(*args, **kwargs)

TypeError: Expected type <class 'int'> for argument a, but got type <class 'str'> with value 'foo'

借助于Function Annotations一个简单的参数类型检查的装饰器就这样实现了。

知识就是财富
如果您觉得文章对您有帮助, 欢迎请我喝杯水!