FengJian's Blog

iOS Golang node.js Developer


  • Home

  • About

  • Archives

手机端运行卷积神经网络实现文档检测功能(二) -- 从 VGG 到 MobileNetV2 知识梳理

Posted on 2018-06-02

前言

  • 这是 上一篇博客 的后续和补充,这次对边缘检测算法的升级优化,起源于一个意外事件,前一个版本是使用 TensorFlow 1.0 部署的, 并且是用 TF-Slim API 编写的代码,最近想使用 TensorFlow 1.7 重新部署一遍,本来以为是一件比较容易的事情,结果实操的时候才发现全是坑,首先遇到的就是废弃 API 的问题,TensorFlow 1.0 里面的某些 API 在 TensorFlow 1.7 里面已经是彻底废弃掉不能使用了,这个问题还好办,修改一下代码就行。后面遇到的一个问题就让我彻底傻眼了,用新的代码加载了旧的模型文件,想 Fine Tuning 一下,结果模型不收敛了,从零开始重新训练也是无法收敛,查了挺长时间也没定位到原因,所以,干脆重写一遍代码。
  • 反正都要重写代码了,那也就可以把最近一年学到的新东西融合进来,就当做是效果验证了。引入这些新的技术后,原始模型其实变化挺大的,而且用到的这些技术,又会牵扯出很多比较通用的基础知识,所以从这个角度来说,这篇文章要记录的重点并不是升级优化(升级后的模型,准确性和前一个版本相比并没有明显的区别,但是模型的体积从 4.4M 减小到了 1.6M ,网络的训练过程也比之前容易了许多),而是对 多个基础知识点的梳理和总结 。
  • 涉及到的知识点比较多,有工程层面的,也有理论算法层面的,和工程相关的内容会尽量用代码片段来展示,遇到理论知识,只会简单的介绍一下,划出重点,不会做数学层面的推导,同时,会在最后的『参考资料』章节中列出更多的参考内容。
  • 趁这个机会也把代码重新整理了一遍,放在了 github 上,https://github.com/fengjian0106/hed-tutorial-for-document-scanning

TensorFlow Code Style For CNN Net

之前的那个版本,选用 TF-Slim API 编写代码,就是因为这套 API 是比较优雅的,比如想调用一次最基本的卷积层运算,如果直接使用 tf.nn.conv2d 的话,代码会是下面这个样子:

1
2
3
4
5
6
7
input = ...
with tf.name_scope('conv1_1') as scope:
kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32, stddev=1e-1), name='weights')
conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32), trainable=True, name='biases')
bias = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(bias, name=scope)

如果用 TF-Slim API 编码的话,则会变成下面这种风格:

1
2
input = ...
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')

因为在各种卷积神经网络结构中,通常都会大量的使用卷积运算,构建很多卷积层,并且使用不同的配置参数,所以很明显,TF-Slim 风格的 API 可以很优雅的简化代码。

但是,在看过图像处理领域的一些论文和各种版本的参考代码之后,发现 TF-Slim 还是有一些局限性的。常规的卷积层操作,用 TF-Slim 是可以简化代码,但是神经网络这个领域发展的速度太快了,经常都会有新的论文发表出来,也就经常会遇到一些新的 layer 结构,TF-Slim 并不是总能很方便的表达出这些 layer,因此需要一种更低层一些、但是更灵活,同时还保持优雅的解决办法。

顺着这个思路,后来发现其实 tf.layers 这个 Module 就可以很好的满足前面提到的这些需求。

另外,这次遇到的在 TensorFlow 1.7 上旧模型不收敛的情况,虽然没有准确定位到原因、没找到解决办法,但是分析了一圈后,其实还是怀疑是因为使用 TF-Slim 而引出的问题,虽然 TF-Slim 简化了卷积层相关的代码,但是完整的代码中还是要使用 TensorFlow 中的其他 API 的,TF-Slim 封装出来的抽象度比较高,除了卷积操作的 API,它还封装了其他的一些 API,但是它的抽象设计和 TensorFlow 是有一种分裂感的,混合在一起编程时会觉得有点奇怪,我这次遇到的问题,也可能就是某些 API 使用的不正确而引起的(TF1.0时运行正常,TF1.7时运行不正常)。而 tf.layers 就不会有这种感觉,tf.layers 的抽象度比 TF-Slim 更低一些,它更像是 TensorFlow 的底层 API 的一个延展,并没有引入新的抽象度,这套 API 用起来就更舒服一些。

比如,升级前的 HED 网络,换用 tf.layers 后,代码是下面这个样子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def vgg_style_hed(inputs, batch_size, is_training):
filter_initializer = tf.contrib.layers.xavier_initializer()

if const.use_kernel_regularizer:
weights_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0005)
else:
weights_regularizer = None

def _vgg_conv2d(inputs, filters, kernel_size):
use_bias = True
if const.use_batch_norm:
use_bias = False

outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## call relu after batch normalization
use_bias=use_bias,
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)
if const.use_batch_norm:
outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)
return outputs

def _max_pool2d(inputs):
outputs = tf.layers.max_pooling2d(inputs,
[2, 2],
strides=(2, 2),
padding='same')
return outputs

def _dsn_1x1_conv2d(inputs, filters):
use_bias = True
if const.use_batch_norm:
use_bias = False

kernel_size = [1, 1]
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## no activation
use_bias=use_bias,
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

if const.use_batch_norm:
outputs = tf.layers.batch_normalization(outputs, training=is_training)
## no activation

return outputs

def _output_1x1_conv2d(inputs, filters):
kernel_size = [1, 1]
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## no activation
use_bias=True, ## use bias
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

## no batch normalization
## no activation

return outputs

def _dsn_deconv2d_with_upsample_factor(inputs, filters, upsample_factor):
kernel_size = [2 * upsample_factor, 2 * upsample_factor]
outputs = tf.layers.conv2d_transpose(inputs,
filters,
kernel_size,
strides=(upsample_factor, upsample_factor),
padding='same',
activation=None, ## no activation
use_bias=True, ## use bias
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

## no batch normalization

return outputs


# ref https://github.com/s9xie/hed/blob/master/examples/hed/train_val.prototxt
with tf.variable_scope('hed', 'hed', [inputs]):
end_points = {}
net = inputs

with tf.variable_scope('conv1'):
net = _vgg_conv2d(net, 12, [3, 3])
net = _vgg_conv2d(net, 12, [3, 3])
dsn1 = net
net = _max_pool2d(net)

with tf.variable_scope('conv2'):
net = _vgg_conv2d(net, 24, [3, 3])
net = _vgg_conv2d(net, 24, [3, 3])
dsn2 = net
net = _max_pool2d(net)

with tf.variable_scope('conv3'):
net = _vgg_conv2d(net, 48, [3, 3])
net = _vgg_conv2d(net, 48, [3, 3])
net = _vgg_conv2d(net, 48, [3, 3])
dsn3 = net
net = _max_pool2d(net)

with tf.variable_scope('conv4'):
net = _vgg_conv2d(net, 96, [3, 3])
net = _vgg_conv2d(net, 96, [3, 3])
net = _vgg_conv2d(net, 96, [3, 3])
dsn4 = net
net = _max_pool2d(net)

with tf.variable_scope('conv5'):
net = _vgg_conv2d(net, 192, [3, 3])
net = _vgg_conv2d(net, 192, [3, 3])
net = _vgg_conv2d(net, 192, [3, 3])
dsn5 = net
# net = _max_pool2d(net) # no need this pool layer


## dsn layers
with tf.variable_scope('dsn1'):
dsn1 = _dsn_1x1_conv2d(dsn1, 1)
## no need deconv2d

with tf.variable_scope('dsn2'):
dsn2 = _dsn_1x1_conv2d(dsn2, 1)
dsn2 = _dsn_deconv2d_with_upsample_factor(dsn2, 1, upsample_factor = 2)

with tf.variable_scope('dsn3'):
dsn3 = _dsn_1x1_conv2d(dsn3, 1)
dsn3 = _dsn_deconv2d_with_upsample_factor(dsn3, 1, upsample_factor = 4)

with tf.variable_scope('dsn4'):
dsn4 = _dsn_1x1_conv2d(dsn4, 1)
dsn4 = _dsn_deconv2d_with_upsample_factor(dsn4, 1, upsample_factor = 8)

with tf.variable_scope('dsn5'):
dsn5 = _dsn_1x1_conv2d(dsn5, 1)
dsn5 = _dsn_deconv2d_with_upsample_factor(dsn5, 1, upsample_factor = 16)


##dsn fuse
with tf.variable_scope('dsn_fuse'):
dsn_fuse = tf.concat([dsn1, dsn2, dsn3, dsn4, dsn5], 3)
dsn_fuse = _output_1x1_conv2d(dsn_fuse, 1)

return dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5

上面这份代码里面的一些细节,会在后面的章节里详细介绍,并且会逐步的演化成 MobileNetV2 style 的 HED 网络。这里首先看一下代码的整体结构,相当于是套用了下面这种形式的模板:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def xx_net(inputs, batch_size, is_training):
filter_initializer = tf.contrib.layers.xavier_initializer()

def layer_for_type1(inputs, ...):
...
return outputs

def layer_for_type2(inputs, ...):
...
return outputs

...

def layer_for_typeN(inputs, ...):
...
return outputs


with tf.variable_scope('xx_net', 'xx_net', [inputs]):
end_points = {}
net = inputs

net = layer_for_type1(net, ...)
net = layer_for_type1(net, ...)
net = layer_for_type2(net, ...)
...
net = layer_for_typeN(net, ...)

return net, end_points

这种风格的代码,前面一部分就是定义实现不同功能的各种 layer,后面部分就是用各种 layer 来组装 net 的主体结构。layer 由嵌套函数定义,方便进行各种自定义的配置或组装,net 主体部分,跟 TF-Slim 的风格其实也是类似的,layer 之间的层级关系简单明了,更容易和论文中的配置表格或结构示意图对应起来。我在实现其他网络结构的时候,都是套用的这种代码结构,基本上都能满足灵活性和简洁性的需求。

矩阵初始化

矩阵的初始化方法有很多种,在 TensorFlow 里,常规初始化方法的效果对比可以看这篇文章 Weight Initialization,能使用 tf.truncated_normal 或 tf.truncated_normal_initializer 进行初始化,说明已经对这个问题有所掌握了,随着学习的深入,更推荐使用另外一种初始化方法 Xavier initialization ,使用起来也比较简单:

1
2
W = tf.get_variable('W', shape=[784, 256], 
initializer=tf.contrib.layers.xavier_initializer())

关于 Xavier initialization 的更多内容,请参考本文末尾部分列出的资料。

Batch Normalization

Batch Normalization – Lesson 这篇教程对 Batch Normalization 解释的比较清楚,通俗点描述,普通的 Normalization 是对神经网络的输入数据做归一化处理,把输入数据和输出数据的取值都缩放到一个范围内,通常都是 0.0 ~ 1.0 这个区间,而 Batch Normalization 则是把整体的神经网络结构看成是由很多不同的 layer 组成的,对每个 layer 的输入数据再做一次规范化的操作,因为只能在训练的过程中才能获取到每个 layer 上的 input data,而训练过程又是基于 batch 的,所以叫做 Batch Normalization。Batch Normalization 的具体数学公式,这里不详细描述了,有兴趣的读者请参考末尾部分列出的资料,下面仅从工程层面提出一些建议和要注意的细节点。

Batch Normalization 的优势很明显,尽量使用

Batch Normalization 的优势挺多的,比如可以加快模型收敛的速度、可以使用较高的 learning rates、可以降低权重矩阵初始化的难度、可以提高网络的训练效果等等,总而言之,就是要尽量的使用 Batch Normalization 技术。近几年新发表的很多论文中,也是经常看到 Batch Normalization 的身影。

TensorFlow 提供了相关的 API,在 layer 中添加 Batch Normalization 也就是一行代码的事,不过因为 Batch Normalization 里面有一部分参数也是需要参与反向传播过程进行训练的,所以构造优化器的时候,还要额外添加一些代码把 Batch Normalization 的权重参数也包含进去,类似下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
...
def _vgg_conv2d(inputs, filters, kernel_size):
use_bias = True
if const.use_batch_norm:
use_bias = False

outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## call relu after batch normalization
use_bias=use_bias,
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)
if const.use_batch_norm:
outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)
return outputs

...

with tf.variable_scope("adam_vars"):
if const.use_batch_norm == True:
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(cost)
else:
train_step = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(cost)
...

不需要使用 bias

从前面的代码片段可以看到,用了 Batch Normalization 后,就不再需要添加 bias 偏移向量了,Can not use both bias and batch normalization in convolution layers 这里有解释原因。

在什么位置添加 Batch Normalization

前面有一个典型的代码片段:

1
2
3
4
5
6
7
8
9
10
11
12
13
def _vgg_conv2d(inputs, filters, kernel_size):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## call relu after batch normalization
use_bias= False,
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)
return outputs

这里容易遇到一个陷进,我之前就掉进去过。在看其他代码和资料的时候,也经常看到 convolution + batch_normalization + relu 这种顺序的代码调用,如果理解的不透彻,很有可能会错误的认为在每一个 convolution layer 的后面都应该添加一个 tf.layers.batch_normalization 调用,但是实际上,如果当前 layer 已经是网络结构中最后的 layer 或者已经属于 output layer 了,其实是不应该再使用 Batch Normalization 的。按照定义,是在 layer 的 input 部分添加 Batch Normalization,而代码里看上去像是在 layer 的 output 上调用了一次 Batch Normalization,这只是为了在代码里让 layer 更容易连接起来,而且,如果是第一层 layer,它的输入就是已经归一化处理过的 input label 数据,这也是不需要 Batch Normalization 的,到了最后一层 layer 的时候,理论上来说是需要 Batch Normalization 的,只不过对应到代码上,最后这层 layer 的 Batch Normalization 是添加在倒数第二层 layer 的输出结果上的。所以,在前面 HED 的代码里,_dsn_deconv2d_with_upsample_factor 和 _output_1x1_conv2d 这两种 layer 的封装函数里都是没有 Batch Normalization 的。

另外,之前展示的代码都是把 batch_normalization 放在了 relu 激活函数的前面,网上的很多代码也是这样写的,其实把 batch_normalization 放在非线性函数的后面也是可以的,而且整体的准确率可能还会有一点点提升,BN – before or after ReLU? 这里有一个简单的数据对比,可以参考。总之,batch_normalization 和激活函数的先后顺序,是可以灵活选择的。

是否还需要使用 Regularizer

这也是一个容易混淆的地方,其实 Batch Normalization 和 Regularizer 是完全不一样的东西,Batch Normalization 针对的是 layer 的输入数据,而 Regularizer 针对的是 layer 里面的权重矩阵,前者是从数据层面来改善模型的效果,而后者则是通过改善模型自身来提升模型的效果,这两种技术是不冲突的,可以同时使用。

从卷积运算到各种卷积层

卷积运算

关于卷积的基本概念,A technical report on convolution arithmetic in the context of deep learning 这里有很直观的动画演示,比如下面这种就是最常见的卷积运算:


same padding no strides transposed

其他的学习资料里,通常也是基于一个普通的二维矩阵来描述卷积的运算规则,上图这个例子,就是在一个 shape 为 (height, width) 的矩阵上,使用一个 (3, 3) 的卷积核,然后得到一个 shape 同样为 (height, width) 的矩阵。

但是在神经网络领域里面,卷积层 的运算规则其实是比上面这种单纯的 卷积运算 稍微更复杂一些的。在神经网络里面,通常会使用一个 shape 为 (batch_size, height, width, channels) 的 Tensor 来表示图像,比如一个 RGBA 的图像,channels 就是 4,经过某种卷积层的运算后,得到一个新的 Tensor,这个新的 Tensor 的 channels 通常又会变成另外一个数值,可见,这个 channel 也是有一定的映射规则的,标准的卷积运算和 channel 结合起来,才构成了神经网络里面的卷积层。

在介绍具体的卷积层之前,先使用下面这种简单的示意图来表示一个卷积运算:


convolution operation

顺着示意图中箭头的方向,左侧是输入矩阵,中间是卷积核,右侧是输出矩阵。

标准卷积层

TensorFlow 框架里的标准卷积层的定义如下:

1
2
3
4
5
6
7
8
9
10
tf.nn.conv2d(  
input,
filter,
strides,
padding,
use_cudnn_on_gpu=True,
data_format='NHWC',
dilations=[1, 1, 1, 1],
name=None
)

因为这里主要是为了讨论 channel 的映射规则,所以假设采用 ‘SAME’ padding,并且 strides 设置为 1,这样的话,输入的 Tensor 和 输出的 Tensor 中,height 和 width 都是相同的值,输入的 Tensor 的 shape 是 (batch_size, height, width, in_channels),如果期望的输出 Tensor 的 shape 是 (batch_size, height, width, out_channels),则作为 filter 的 Tensor 的 shape 应该设置成 (filter_height, filter_width, in_channels, out_channels),其中的 filter_height 和 filter_width 就对应卷积核的 size,这个函数内部的完整计算过程,可以用下面这个示意图来表示:


conv2d

图中的 in_channels 等于 2,out_channels 等于 5,总共有 in_channels*out_channels = 10 个卷积核(同时还有 5 次矩阵加法操作),仔细看一下这个示意图就会意识到,每一个输出的矩阵都是由两个输入矩阵共同计算出来的,也就是说不同的输入 channel 会一起影响到每一个输出 channel,通道之间是有关联的。

One By One 卷积层

这种网络结构和前面介绍的标准卷积层其实是一样的,只不过 filter 的 shape 是 (1, 1, in_channels, out_channels),也就是说每一个卷积核都只是一个标量值,而非矩阵。表面上看这种结构有点违反『套路』,因为卷积的目的就是要利用周围像素的 加权和 来替代原始位置上的单个像素,或者说卷积每次关注的是一个区域的像素,而非只关注单个像素。

那 1x1 convolution 的目的是什么呢?前面已经提到了,神经网络里面的卷积层,既有卷积运算,也有 channel 之间的运算,所以 1x1 convolution 的重点就在于让不同的 channel 再结合一遍。类似的,也可以用一个简单的示意图表示这种网络结构:


one_by_one_conv2d

1x1 convolution 的效果,相当于对输入矩阵做了一个简单的标量乘法,它的参数量和计算量都比标准的卷积层少了很多。前面 HED 代码里的 _output_1x1_conv2d 就是一个 1x1 convolution,在后面的讨论中也会遇到多个例子。

Depthwise 卷积层

标准卷积层运算,不同的输入 channel 会共同参与计算每一个输出 channel,还有另外一种名为 depthwise convolution 的卷积层运算,channel 之间是完全独立的,TensorFlow 里面的定义如下:

1
2
3
4
5
6
7
8
9
tf.nn.depthwise_conv2d(
input,
filter,
strides,
padding,
rate=None,
name=None,
data_format=None
)

类似的,假设采用 ‘SAME’ padding,并且 strides 设置为 1,最后的三个参数使用默认值,这样的话,输入 Tensor 和 输出 Tensor 的 height 和 width 就会是相同的值,输入的 Tensor 的 shape 是 (batch_size, height, width, in_channels),filter Tensor 的 shape 是 (filter_height, filter_width, in_channels, channel_multiplier),则得到的输出 Tensor 的 shape 是 (batch_size, height, width, in_channels * channel_multiplier),这个函数内部的完整计算过程,可以用下面这个示意图来表示:


depthwise_conv2d

可以看到,输出 Tensor 的 channels 不能是任意值,只能是 in_channels 的整数倍,这也就是参数 channel_multiplier 的含义。

Separable 卷积层

depthwise convolution 中,channel 之间是完全不会产生互相影响的,这可能也意味着这种方式的模型的复杂度是不够的,所以在实际使用的过程中,separable convolution 是一个更合适的选择,对应的 TensorFlow API 如下:

1
2
3
4
5
6
7
8
9
10
tf.nn.separable_conv2d(
input,
depthwise_filter,
pointwise_filter,
strides,
padding,
rate=None,
name=None,
data_format=None
)

同样的,采用 ‘SAME’ padding,并且 strides 设置为 1,最后的三个参数使用默认值,这样的话,输入 Tensor 和 输出 Tensor 的 height 和 width 就会是相同的值。这个 API 的内部首先执行了一次 depthwise convolution,然后执行了一次 1x1 convolution(pointwise convolution),所以 depthwise_filter 的 shape 应该设置为 (filter_height, filter_width, in_channels, channel_multiplier),pointwise_filter 的 shape 应该设置为 (1, 1, channel_multiplier * in_channels, out_channels),示意图如下:


separable_conv2d

在使用相同的 in_channels 和 out_channels 参数时,tf.nn.separable_conv2d 的运算量会比 tf.nn.conv2d 更小。

Dilated Convolutions / Atrous Convolution / 扩张卷积 / 空洞卷积

前面看到的几种不同的卷积层函数里,可能会有一个参数 rate,如果设置了 rate 并且 rate > 1,则内部执行了另外一种名为 Dilated Convolutions 的卷积运算操作,这种卷积运算的动画示意图如下:


dilation

在做边缘检测任务的时候,并没有用到 Dilated Convolutions,但是这种卷积操作也是很常用的,比如在 DeepLab 网络结构的各个版本中,它都是一个很重要的组件,考虑到这篇文章里已经汇总了多种不同的常用卷积操作,出于完整性的考虑,所以也简单提及一下 Atrous Convolution,有兴趣的同学可以进一步深入了解。

转置卷积/反卷积的初始化

HED 网络中是会用到转置卷积层的,简单回忆一下,transposed convolution 的动画示意图如下:


no_padding_no_strides_transposed

前一篇文章里提到过,当时是使用了双线性放大矩阵(bilinear upsampling kernel)来对反卷积的 kernel 进行的初始化,因为 FCN 要求采用这种初始化方案(HED 的论文中并没有明确的要求使用双线性初始化)。这次重写代码的时候,转置卷积层也统一替换成了 Xavier initialization,仍然能够得到很好的训练效果,同时也严格参照了 HED 的参考代码对转置卷积层的 kernel size 进行设置,具体的参数都在前面的函数 _dsn_deconv2d_with_upsample_factor 里面。

如何初始化 transposed convolution 的卷积核,这个问题其实纠结了很长时间,而且在前一个版本的 HED 的代码中,也尝试过用 tf.truncated_normal 初始化 transposed convolution 的 kernel,当时的确是没有训练出想要的效果,所以有点迷信『双线性初始化』,后来在做 UNet 网络的时候,因为已经接触到 Xavier initialization 方案了,所以也尝试了用 Xavier 对反卷积的 kernel 进行初始化,得到的效果很好,所以才开始慢慢的不再强求于『双线性初始化』。

Google 了很多文章,仍然没有找到关于『双线性初始化』的权威解释,只是找到过一些零星的线索,比如有些模型里,会把 deconvolution 的 kernel 的 learning rate 设置为 0,同时采用双线性插值矩阵对该 kernel 进行初始化,相当于就是通过双线性插值算法对输入矩阵进行上采样(放大)。目前我个人的准则就是,除非论文中有明确的强调要采用某种特殊的初始化方法,否则还是首先使用常规的 Tensor 初始化方案。这篇文章的读者朋友们,如果对这个问题有更清晰的答案,也请指教一下,谢谢~

顺便再举个例子,Deconvolution and Checkerboard Artifacts 这里就是用 resize-convolution 替代了常规的 deconvolution。

从 VGG 到 ResNet、Inception、Xception

前面着重介绍了几种不同的卷积层运算方式,目的就是为了引出这篇文章 An Intuitive Guide to Deep Network Architectures。VGG 作为一个经典的分类网络模型,它的结构其实是很简单的,就是标准卷积层串联在一起,如果想进一步提高 VGG 网络的准确率,一个比较直观的想法就是串联更多的标准卷积层(让网络变得更深)、在每一层里增加更多的卷积核,想法看上去是对的,但是实际的效果很不好,因为这种方式增加了大量的参数,训练起来自然就更难,而且网络的深度加深后,还会引起一个 梯度消失 的问题,所以简单粗暴并不总是有效的,需要想其他的办法。前面给出链接的这篇文章里介绍的三个重要网络结构,ResNet、Inception 和 Xception,就是为了解决这些问题而发展起来的,这三种网络模型使用的 层结构,已经成为了卷积神经网络领域里面的基础技术手段。

关于 ResNet、Inception、Xception 的详细内容,刚才提到的这篇文章就是一个很好的总结,网上也有一份整理过的中文翻译 无需数学背景,读懂 ResNet、Inception 和 Xception 三大变革性架构,在文末的参考资料里面还会列出几篇很棒的文章或代码。

如果是我自己对这三种网络结构做一个简单的总结,我觉得主要是下面几点:

  • ResNet(残差网络) 使得训练更深的网络成为一种可能,既然很深的映射关系 Y = F(X) 不容易训练,那就改成训练 Y = F(X) + X,梯度消失问题就不再是一个障碍。
  • Inception 架构通过增加每一层网络的宽度(使用不同 size 的卷积核,按照卷积核的大小进行分组)来提高网络的准确性,同时为了控制整体的运算量,借助 1x1 convolution 先对每一层的输入 Tensor 进行一个降维操作,减少 input channel 的数量,然后再进入每一个分组,用不同大小的卷积核进行计算。
  • 残差架构可以和分组策略结合,比如 Inception-ResNet 网络结构。
  • Inception 里面分组的概念使用到极致,就是让每一个通道成为一个独立的分组,在每个 channel 上先分别进行标准的卷积运算,然后再利用 1x1 convolution 得到最终的输出 channel,就其实就是 separable convolution。

从 MobileNet V1 到 MobileNet V2

ResNet、Inception、Xception 追求的目标,就是在达到更高的准确率的前提下,尽量在模型大小、模型运算速度、模型训练速度这几个指标之间找一个平衡点,如果在准确性上允许一定的损失,但是追求更小的模型和更快的速度,这就直接催生了 MobileNet 或类似的以手机端或嵌入式端为运行环境的网络结构的出现。

MobileNet V1 和 MobileNet V2 都是基于 Depthwise Separable Convolution 构建的卷积层(类似 Xception,但是并不是和 Xception 使用的 Separable Convolution 完全一致),这是它满足体积小、速度快的一个关键因素,另外就是精心设计和试验调优出来的层结构,下面就对照论文给出两个版本的代码实现。

MobileNet V1

MobileNet V1 的整体结构其实并没有特别复杂的地方,和 VGG 类似,层和层之间就是普通的串联型的结构,有区别的地方主要在于 layer 的内部,如下图所示:


mobilenet_v1_layer_block

这个图中没有用箭头表示数据的传递方向,但是只要对卷积神经网络有初步的经验,就能看出来数据是从上往下传递的,左图是标准的卷积层操作,类似于前面 HED 网络中 _vgg_conv2d 函数的结构(回想一下前面说过的 Batch Normalization 和 relu 先后顺序的话题,虽然 Batch Normalization 可以放到激活函数的后面,但是很多论文里面都还是习惯性的放在激活函数的前面,所以这里的代码也会严格的遵照论文中的方式),右侧的图相当于 separable convolution,但是在中间是有两次 Batch Normalization 的。

论文中用一张如下的表格来描述了整体结构:


mobilenet_v1_body_architecture

下面是一份简单的代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def mobilenet_v1(inputs, alpha, is_training):
if alpha not in [0.25, 0.50, 0.75, 1.0]:
raise ValueError('alpha must be one of'
'`0.25`, `0.50`, `0.75` or `1.0` only.')

filter_initializer = tf.contrib.layers.xavier_initializer()

def _conv2d(inputs, filters, kernel_size, stride, scope=''):
with tf.variable_scope(scope):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
strides=(stride, stride),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)
return outputs


def _mobilenet_v1_conv2d(inputs,
pointwise_conv_filters,
depthwise_conv_kernel_size,
stride, # stride is just for depthwise convolution
scope=''):
with tf.variable_scope(scope):
with tf.variable_scope('depthwise_conv'):
'''
tf.layers Module 里面有一个 tf.layers.separable_conv2d 函数,
但是它的内部调用流程是 depthwise convolution --> pointwise convolution --> activation func,
而 MobileNet V1 风格的卷积层的内部调用流程应该是
depthwise conv --> batch norm --> relu --> pointwise conv --> batch norm --> relu,
所以需要用其他的手段组装出想要的调用流程,
一种办法是使用 tf.nn.depthwise_conv2d,但是这个 API 比较底层,代码写起来很笨重。
后来找到了另外一种可行的办法,借助 tf.contrib.layers.separable_conv2d 函数,
tf.contrib.layers.separable_conv2d 的第二个参数 num_outputs 如果设置为 None,
则只会调用内部的 depthwise conv2d 部分,而不执行 pointwise conv2d 部分。
这样就可以组装出 MobileNet V1 需要的 layer 结构了。


TensorFlow 提供了四种 API,都命名为 separable_conv2d,但是又存在各种细微的差别,
有兴趣的读者可以自行阅读相关文档
tf.contrib.layers.separable_conv2d [Aliases tf.contrib.layers.separable_convolution2d]
VS
tf.keras.backend.separable_conv2d
VS
tf.layers.separable_conv2d
VS
tf.nn.separable_conv2d
'''
outputs = tf.contrib.layers.separable_conv2d(
inputs,
None, # ref https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py
depthwise_conv_kernel_size,
depth_multiplier=1, # 按照论文的描述,这里设置成1
stride=(stride, stride),
padding='SAME',
activation_fn=None,
weights_initializer=filter_initializer,
biases_initializer=None)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)

with tf.variable_scope('pointwise_conv'):
# 论文中 alpha 参数的含义,就是在每一层的 pointwise conv 的位置按比例缩小输出 channels 的数量
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
outputs = tf.layers.conv2d(outputs,
pointwise_conv_filters,
(1, 1),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)

return outputs

def _avg_pool2d(inputs, scope=''):
inputs_shape = inputs.get_shape().as_list()
assert len(inputs_shape) == 4

pool_height = inputs_shape[1]
pool_width = inputs_shape[2]

with tf.variable_scope(scope):
outputs = tf.layers.average_pooling2d(inputs,
[pool_height, pool_width],
strides=(1, 1),
padding='valid')

return outputs

'''
执行分类任务的网络结构,通常还可以作为实现其他任务的网络结构的 base architecture,
为了方便代码复用,这里只需要实现出卷积层构成的主体部分,
外部调用者根据各自的需求使用这里返回的 output 和 end_points。
比如对于分类任务,按照如下方式使用这个函数

image_height = 224
image_width = 224
image_channels = 3

x = tf.placeholder(tf.float32, [None, image_height, image_width, image_channels])
is_training = tf.placeholder(tf.bool, name='is_training')

output, net = mobilenet_v1(x, 1.0, is_training)
print('output shape is: %r' % (output.get_shape().as_list()))

output = tf.layers.flatten(output)
output = tf.layers.dense(output,
units=1024, # 1024 class
activation=None,
use_bias=True,
kernel_initializer=tf.contrib.layers.xavier_initializer())
print('output shape is: %r' % (output.get_shape().as_list()))
'''
with tf.variable_scope('mobilenet', 'mobilenet', [inputs]):
end_points = {}
net = inputs

net = _conv2d(net, 32, [3, 3], stride=2, scope='block0')
end_points['block0'] = net
net = _mobilenet_v1_conv2d(net, 64, [3, 3], stride=1, scope='block1')
end_points['block1'] = net

net = _mobilenet_v1_conv2d(net, 128, [3, 3], stride=2, scope='block2')
end_points['block2'] = net
net = _mobilenet_v1_conv2d(net, 128, [3, 3], stride=1, scope='block3')
end_points['block3'] = net

net = _mobilenet_v1_conv2d(net, 256, [3, 3], stride=2, scope='block4')
end_points['block4'] = net
net = _mobilenet_v1_conv2d(net, 256, [3, 3], stride=1, scope='block5')
end_points['block5'] = net

net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=2, scope='block6')
end_points['block6'] = net
net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=1, scope='block7')
end_points['block7'] = net
net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=1, scope='block8')
end_points['block8'] = net
net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=1, scope='block9')
end_points['block9'] = net
net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=1, scope='block10')
end_points['block10'] = net
net = _mobilenet_v1_conv2d(net, 512, [3, 3], stride=1, scope='block11')
end_points['block11'] = net

net = _mobilenet_v1_conv2d(net, 1024, [3, 3], stride=2, scope='block12')
end_points['block12'] = net
net = _mobilenet_v1_conv2d(net, 1024, [3, 3], stride=1, scope='block13')
end_points['block13'] = net

output = _avg_pool2d(net, scope='output')

return output, end_points

MobileNet V2

MobileNet V2 的改动就比较大了,首先引入了两种新的 layer 结构,如下图所示:


mobilenet_v2_layer_block

很明显的一个差异点,就是左边这种层结构引入了残差网络的手段,另外,这两种层结构中,在 depthwise convolution 之前又添加了一个 1x1 convolution 操作,在之前举得几个例子中,1x1 convolution 都是用来降维的,而在 MobileNet V2 里,这个位于 depthwise convolution 之前的 1x1 convolution 其实用来提升维度的,对应论文中 expansion factor 参数的含义,在 depthwise convolution 之后仍然还有一次 1x1 convolution 调用,但是这个 1x1 convolution 并不会跟随一个激活函数,只是一次线性变换,所以这里也不叫做 pointwise convolution,而是对应论文中的 1x1 projection convolution。

网络的整体结构由下面的表格描述:


mobilenet_v2_body_architecture

代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def mobilenet_v2_func_blocks(is_training):
filter_initializer = tf.contrib.layers.xavier_initializer()
activation_func = tf.nn.relu6

def conv2d(inputs, filters, kernel_size, stride, scope=''):
with tf.variable_scope(scope):
with tf.variable_scope('conv2d'):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
strides=(stride, stride),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = tf.nn.relu(outputs)
return outputs

def _1x1_conv2d(inputs, filters, stride):
kernel_size = [1, 1]
with tf.variable_scope('1x1_conv2d'):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
strides=(stride, stride),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
# no activation_func
return outputs

def expansion_conv2d(inputs, expansion, stride):
input_shape = inputs.get_shape().as_list()
assert len(input_shape) == 4
filters = input_shape[3] * expansion

kernel_size = [1, 1]
with tf.variable_scope('expansion_1x1_conv2d'):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
strides=(stride, stride),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = activation_func(outputs)
return outputs

def projection_conv2d(inputs, filters, stride):
kernel_size = [1, 1]
with tf.variable_scope('projection_1x1_conv2d'):
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
strides=(stride, stride),
padding='same',
activation=None,
use_bias=False,
kernel_initializer=filter_initializer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
# no activation_func
return outputs

def depthwise_conv2d(inputs,
depthwise_conv_kernel_size,
stride):
with tf.variable_scope('depthwise_conv2d'):
outputs = tf.contrib.layers.separable_conv2d(
inputs,
None, # https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py
depthwise_conv_kernel_size,
depth_multiplier=1,
stride=(stride, stride),
padding='SAME',
activation_fn=None,
weights_initializer=filter_initializer,
biases_initializer=None)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
outputs = activation_func(outputs)

return outputs

def avg_pool2d(inputs, scope=''):
inputs_shape = inputs.get_shape().as_list()
assert len(inputs_shape) == 4

pool_height = inputs_shape[1]
pool_width = inputs_shape[2]

with tf.variable_scope(scope):
outputs = tf.layers.average_pooling2d(inputs,
[pool_height, pool_width],
strides=(1, 1),
padding='valid')

return outputs

def inverted_residual_block(inputs,
filters,
stride,
expansion=6,
scope=''):
assert stride == 1 or stride == 2

depthwise_conv_kernel_size = [3, 3]
pointwise_conv_filters = filters

with tf.variable_scope(scope):
net = inputs
net = expansion_conv2d(net, expansion, stride=1)
net = depthwise_conv2d(net, depthwise_conv_kernel_size, stride=stride)
net = projection_conv2d(net, pointwise_conv_filters, stride=1)

if stride == 1:
# 如果 net.get_shape().as_list()[3] != inputs.get_shape().as_list()[3]
# 借助一个 1x1 的卷积让他们的 channels 相等,然后才能相加
if net.get_shape().as_list()[3] != inputs.get_shape().as_list()[3]:
inputs = _1x1_conv2d(inputs, net.get_shape().as_list()[3], stride=1)

net = net + inputs
return net
else:
# stride == 2
return net

func_blocks = {}
func_blocks['conv2d'] = conv2d
func_blocks['inverted_residual_block'] = inverted_residual_block
func_blocks['avg_pool2d'] = avg_pool2d
func_blocks['filter_initializer'] = filter_initializer
func_blocks['activation_func'] = activation_func

return func_blocks


def mobilenet_v2(inputs, is_training):
func_blocks = mobilenet_v2_func_blocks(is_training)
_conv2d = func_blocks['conv2d']
_inverted_residual_block = func_blocks['inverted_residual_block']
_avg_pool2d = func_blocks['avg_pool2d']

with tf.variable_scope('mobilenet_v2', 'mobilenet_v2', [inputs]):
end_points = {}
net = inputs

net = _conv2d(net, 32, [3, 3], stride=2, scope='block0_0') # size/2
end_points['block0'] = net

net = _inverted_residual_block(net, 16, stride=1, expansion=1, scope='block1_0')
end_points['block1'] = net

net = _inverted_residual_block(net, 24, stride=2, scope='block2_0') # size/4
net = _inverted_residual_block(net, 24, stride=1, scope='block2_1')
end_points['block2'] = net

net = _inverted_residual_block(net, 32, stride=2, scope='block3_0') # size/8
net = _inverted_residual_block(net, 32, stride=1, scope='block3_1')
net = _inverted_residual_block(net, 32, stride=1, scope='block3_2')
end_points['block3'] = net

net = _inverted_residual_block(net, 64, stride=2, scope='block4_0') # size/16
net = _inverted_residual_block(net, 64, stride=1, scope='block4_1')
net = _inverted_residual_block(net, 64, stride=1, scope='block4_2')
net = _inverted_residual_block(net, 64, stride=1, scope='block4_3')
end_points['block4'] = net

net = _inverted_residual_block(net, 96, stride=1, scope='block5_0')
net = _inverted_residual_block(net, 96, stride=1, scope='block5_1')
net = _inverted_residual_block(net, 96, stride=1, scope='block5_2')
end_points['block5'] = net

net = _inverted_residual_block(net, 160, stride=2, scope='block6_0') # size/32
net = _inverted_residual_block(net, 160, stride=1, scope='block6_1')
net = _inverted_residual_block(net, 160, stride=1, scope='block6_2')
end_points['block6'] = net

net = _inverted_residual_block(net, 320, stride=1, scope='block7_0')
end_points['block7'] = net

net = _conv2d(net, 1280, [1, 1], stride=1, scope='block8_0')
end_points['block8'] = net

output = _avg_pool2d(net, scope='output')

return output, end_points

MobileNet V2 Style HED

原始的 HED 使用 VGG 作为基础网络结构来得到 feature maps,参照这种思路,可以把基础网络部分替换为 MobileNet V2,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def mobilenet_v2_style_hed(inputs, batch_size, is_training):
if const.use_kernel_regularizer:
weights_regularizer = tf.contrib.layers.l2_regularizer(scale=0.0001)
else:
weights_regularizer = None

####################################################
func_blocks = mobilenet_v2_func_blocks(is_training)
# print('============ func_blocks are: %r' % func_blocks)
_conv2d = func_blocks['conv2d']
_inverted_residual_block = func_blocks['inverted_residual_block']
_avg_pool2d = func_blocks['avg_pool2d']
filter_initializer = func_blocks['filter_initializer']
activation_func = func_blocks['activation_func']
####################################################

def _dsn_1x1_conv2d(inputs, filters):
kernel_size = [1, 1]
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## no activation
use_bias=False,
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

outputs = tf.layers.batch_normalization(outputs, training=is_training)
## no activation

return outputs

def _output_1x1_conv2d(inputs, filters):
kernel_size = [1, 1]
outputs = tf.layers.conv2d(inputs,
filters,
kernel_size,
padding='same',
activation=None, ## no activation
use_bias=True, ## use bias
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

## no batch normalization
## no activation

return outputs


def _dsn_deconv2d_with_upsample_factor(inputs, filters, upsample_factor):
## https://github.com/s9xie/hed/blob/master/examples/hed/train_val.prototxt
## 从这个原版代码里看,是这样计算 kernel_size 的
kernel_size = [2 * upsample_factor, 2 * upsample_factor]
outputs = tf.layers.conv2d_transpose(inputs,
filters,
kernel_size,
strides=(upsample_factor, upsample_factor),
padding='same',
activation=None, ## no activation
use_bias=True, ## use bias
kernel_initializer=filter_initializer,
kernel_regularizer=weights_regularizer)

## 概念上来说,deconv2d 已经是最后的输出 layer 了,只不过最后还有一步 1x1 的 conv2d 把 5 个 deconv2d 的输出再融合到一起
## 所以不需要再使用 batch normalization 了

return outputs


with tf.variable_scope('hed', 'hed', [inputs]):
end_points = {}
net = inputs


## mobilenet v2 as base net
with tf.variable_scope('mobilenet_v2'):
# 标准的 mobilenet v2 里面并没有这两层,
# 这里是为了得到和 input image 相同 size 的 feature map 而增加的层
net = _conv2d(net, 3, [3, 3], stride=1, scope='block0_0')
net = _conv2d(net, 6, [3, 3], stride=1, scope='block0_1')

dsn1 = net
net = _conv2d(net, 12, [3, 3], stride=2, scope='block0_2') # size/2

net = _inverted_residual_block(net, 6, stride=1, expansion=1, scope='block1_0')

dsn2 = net
net = _inverted_residual_block(net, 12, stride=2, scope='block2_0') # size/4
net = _inverted_residual_block(net, 12, stride=1, scope='block2_1')

dsn3 = net
net = _inverted_residual_block(net, 24, stride=2, scope='block3_0') # size/8
net = _inverted_residual_block(net, 24, stride=1, scope='block3_1')
net = _inverted_residual_block(net, 24, stride=1, scope='block3_2')

dsn4 = net
net = _inverted_residual_block(net, 48, stride=2, scope='block4_0') # size/16
net = _inverted_residual_block(net, 48, stride=1, scope='block4_1')
net = _inverted_residual_block(net, 48, stride=1, scope='block4_2')
net = _inverted_residual_block(net, 48, stride=1, scope='block4_3')

net = _inverted_residual_block(net, 64, stride=1, scope='block5_0')
net = _inverted_residual_block(net, 64, stride=1, scope='block5_1')
net = _inverted_residual_block(net, 64, stride=1, scope='block5_2')

dsn5 = net


## dsn layers
with tf.variable_scope('dsn1'):
dsn1 = _dsn_1x1_conv2d(dsn1, 1)
## no need deconv2d

with tf.variable_scope('dsn2'):
dsn2 = _dsn_1x1_conv2d(dsn2, 1)
dsn2 = _dsn_deconv2d_with_upsample_factor(dsn2, 1, upsample_factor = 2)

with tf.variable_scope('dsn3'):
dsn3 = _dsn_1x1_conv2d(dsn3, 1)
dsn3 = _dsn_deconv2d_with_upsample_factor(dsn3, 1, upsample_factor = 4)

with tf.variable_scope('dsn4'):
dsn4 = _dsn_1x1_conv2d(dsn4, 1)
dsn4 = _dsn_deconv2d_with_upsample_factor(dsn4, 1, upsample_factor = 8)

with tf.variable_scope('dsn5'):
dsn5 = _dsn_1x1_conv2d(dsn5, 1)
dsn5 = _dsn_deconv2d_with_upsample_factor(dsn5, 1, upsample_factor = 16)


# dsn fuse
with tf.variable_scope('dsn_fuse'):
dsn_fuse = tf.concat([dsn1, dsn2, dsn3, dsn4, dsn5], 3)
dsn_fuse = _output_1x1_conv2d(dsn_fuse, 1)

return dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5

这个 MobileNet V2 风格的 HED 网络,整体结构和 VGG 风格的 HED 并没有区别,只是把 VGG 里面用到的卷积层操作替换成了 MobileNet V2 对应的卷积层,另外,因为 MobileNet V2 的第一个卷积层就设置了 stride=2,并不匹配 dsn1 层的 size,所以额外添加了两个 stride=1 的普通卷积层,把它们的输出作为 dsn1 层。

MobileNet V2 As Base Net

MobileNet 只是针对手机运行环境设计出来的执行 分类任务 的网络结构,但是,和同样执行分类任务的 ResNet、Inception、Xception 这一类网络结构类似,都可以作为执行其他任务的网络结构的 base net,提取输入 image 的 feature maps,我尝试过 mobilenet_v2_style_unet、mobilenet_v2_style_deeplab_v3plus、mobilenet_v2_style_ssd,都是可以看到效果的。

Android 性能瓶颈

作为一个参考值,在 iPhone 7 Plus 上运行这个 mobilenet_v2_style_hed 网络并且执行后续的找点算法,FPS 可以跑到12,基本满足实时性的需求。但是当尝试在 Android 上部署的时候,即便是在高价位高配置的机型上,FPS 也很低,卡顿现象很明显。

经过排查,找到了一些线索。在 iPhone 7 Plus 上,计算量的分布如下图所示:


mobilenet_v2_hed_node_summary_ios

红框中的三种操作占据了大部分的 CPU 时间,用这几个数值做一个粗略估算,1.0 / (32 + 30 + 10 + 6) = 12.8,这和检测到的 FPS 是比较吻合的,说明大量的计算时间都用在神经网络上了,OpenCV 实现的找点算法的耗时是很短的。

但是在 Android 上,情况则完全不一样了,如下图所示:


mobilenet_v2_hed_node_summary_android_32

用红框里的数值计算一下,FPS = 1.0 / (232 + 76 + 29 + 16) = 2.8,达不到实时的要求。从上图还可以看出,在 Android 上,Batch Normalization 消耗了大量的计算时间,而且和 Conv2D 消耗的 CPU 时间相比,不在一个数量级上了,这就和 iOS 平台上完全不是同一种分布规律了。进一步 debug 后发现,我们 Android 平台的 app,由于一些历史原因被限定住了只能使用 32bit 的 .so 动态库,换成 64bit 的 TensorFlow 动态库在独立的 demo app 里面重新测量,mobilenet_v2_style_hed 在 Android 上的运行情况就和 iOS 的接近了,虽然还是比 iOS 慢,但是 CPU 耗时的统计数据是同一种分布规律了。

所以,性能瓶颈就在于 Batch Normalization 在 32bit 的 ARM CPU 环境中执行效率不高,尝试过使用一些编译器优化选项重新编译 32bit 的 TensorFlow 库,但是并没有明显的改善。最后的解决方案是退而求其次,使用 vgg_style_hed,并且不使用 Batch Normalization,经过这样的调整后,Android 上的统计数据如下图:


vgg_hed_node_summary_android_32

关于 TensorFlow Lite

在使用 TensorFlow 1.7 部署模型的时候,TensorFlow Lite 还未支持 transposed convolution,所以没有使用 TF Lite (目前 github 上已经看到有 Lite 版本的 transpose_conv.cc 了)。TensorFlow Lite 目前发展的很快,以后在选择部署方案的时候,TensorFlow Lite 是优先于 TensorFlow Mobile 的。

参考资料

xavier init

How to do Xavier initialization on TensorFlow
聊一聊深度学习的weight initialization

Batch Normalization

Understanding the backward pass through Batch Normalization Layer
机器学习里的黑色艺术:normalization, standardization, regularization
How could I use Batch Normalization in TensorFlow?
add Batch Normalization immediately before non-linearity or after in Keras?

1x1 Convolution

What does 1x1 convolution mean in a neural network?
How are 1x1 convolutions the same as a fully connected layer?
One by One [ 1 x 1 ] Convolution - counter-intuitively useful

Upsampling && Transposed Convolution

Upsampling and Image Segmentation with Tensorflow and TF-Slim
Image Segmentation using deconvolution layer in Tensorflow

ResNet && Inception && Xception

Network In Network architecture: The beginning of Inception
ResNets, HighwayNets, and DenseNets, Oh My!
Inception modules: explained and implemented
TensorFlow implementation of the Xception Model by François Chollet

TensorFlow Lite

TensorFlow Lite 深度解析

手机端运行卷积神经网络的一次实践 -- 基于 TensorFlow 和 OpenCV 实现文档检测功能

Posted on 2017-05-08 | Edited on 2018-06-02

2018-06-02 update:

这篇博客有一个后续更新版本,请看 手机端运行卷积神经网络实现文档检测功能(二) – 从 VGG 到 MobileNetV2 知识梳理

另外,代码也已开源放在 github 上,https://github.com/fengjian0106/hed-tutorial-for-document-scanning

前言

  • 本文不是神经网络或机器学习的入门教学,而是通过一个真实的产品案例,展示了在手机客户端上运行一个神经网络的关键技术点
  • 在卷积神经网络适用的领域里,已经出现了一些很经典的图像分类网络,比如 VGG16/VGG19,Inception v1-v4 Net,ResNet 等,这些分类网络通常又都可以作为其他算法中的基础网络结构,尤其是 VGG 网络,被很多其他的算法借鉴,本文也会使用 VGG16 的基础网络结构,但是不会对 VGG 网络做详细的入门教学
  • 虽然本文不是神经网络技术的入门教程,但是仍然会给出一系列的相关入门教程和技术文档的链接,有助于进一步理解本文的内容
  • 具体使用到的神经网络算法,只是本文的一个组成部分,除此之外,本文还介绍了如何裁剪 TensorFlow 静态库以便于在手机端运行,如何准备训练样本图片,以及训练神经网络时的各种技巧等等

需求是什么

image to point

需求很容易描述清楚,如上图,就是在一张图里,把矩形形状的文档的四个顶点的坐标找出来。

传统的技术方案

Google 搜索 opencv scan document,是可以找到好几篇相关的教程的,这些教程里面的技术手段,也都大同小异,关键步骤就是调用 OpenCV 里面的两个函数,cv2.Canny() 和 cv2.findContours()。

demo method

看上去很容易就能实现出来,但是真实情况是,这些教程,仅仅是个 demo 演示而已,用来演示的图片,都是最理想的简单情况,真实的场景图片会比这个复杂的多,会有各种干扰因素,调用 canny 函数得到的边缘检测结果,也会比 demo 中的情况凌乱的多,比如会检测出很多各种长短的线段,或者是文档的边缘线被截断成了好几条短的线段,线段之间还存在距离不等的空隙。另外,findContours 函数也只能检测闭合的多边形的顶点,但是并不能确保这个多边形就是一个合理的矩形。因此在我们的第一版技术方案中,对这两个关键步骤,进行了大量的改进和调优,概括起来就是:

  • 改进 canny 算法的效果,增加额外的步骤,得到效果更好的边缘检测图
  • 针对 canny 步骤得到的边缘图,建立一套数学算法,从边缘图中寻找出一个合理的矩形区域

传统技术方案的难度和局限性

  • canny 算法的检测效果,依赖于几个阈值参数,这些阈值参数的选择,通常都是人为设置的经验值,在改进的过程中,引入额外的步骤后,通常又会引入一些新的阈值参数,同样,也是依赖于调试结果设置的经验值。整体来看,这些阈值参数的个数,不能特别的多,因为一旦太多了,就很难依赖经验值进行设置,另外,虽然有这些阈值参数,但是最终的参数只是一组或少数几组固定的组合,所以算法的鲁棒性又会打折扣,很容易遇到边缘检测效果不理想的场景
  • 在边缘图上建立的数学模型很复杂,代码实现难度大,而且也会遇到算法无能为力的场景

下面这张图表,能够很好的说明上面列出的这两个问题:

hed vs canny

这张图表的第一列是输入的 image,最后的三列(先不用看这张图表的第二列),是用三组不同阈值参数调用 canny 函数和额外的函数后得到的输出 image,可以看到,边缘检测的效果,并不总是很理想的,有些场景中,矩形的边,出现了很严重的断裂,有些边,甚至被完全擦除掉了,而另一些场景中,又会检测出很多干扰性质的长短边。可想而知,想用一个数学模型,适应这么不规则的边缘图,会是多么困难的一件事情。

思考如何改善

在第一版的技术方案中,负责的同学花费了大量的精力进行各种调优,终于取得了还不错的效果,但是,就像前面描述的那样,还是会遇到检测不出来的场景。在第一版技术方案中,遇到这种情况的时候,采用的做法是针对这些不能检测的场景,人工进行分析和调试,调整已有的一组阈值参数和算法,可能还需要加入一些其他的算法流程(可能还会引入新的一些阈值参数),然后再整合到原有的代码逻辑中。经过若干轮这样的调整后,我们发现,已经进入一个瓶颈,按照这种手段,很难进一步提高检测效果了。

既然传统的算法手段已经到极限了,那不如试试机器学习/神经网络。

无效的神经网络算法

end-to-end 直接拟合

首先想到的,就是仿照人脸对齐(face alignment)的思路,构建一个端到端(end-to-end)的网络,直接回归拟合,也就是让这个神经网络直接输出 4 个顶点的坐标,但是,经过尝试后发现,根本拟合不出来。后来仔细琢磨了一下,觉得不能直接拟合也是对的,因为:

  • 除了分类(classification)问题之外,所有的需求看上去都像是一个回归(regression)问题,如果回归是万能的,学术界为啥还要去搞其他各种各样的网络模型
  • face alignment 之所以可以用回归网络得到很好的拟合效果,是因为在输入 image 上先做了 bounding box 检测,缩小了人脸图像范围后,才做的 regression
  • 人脸上的关键特征点,具有特别明显的统计学特征,所以 regression 可以发挥作用
  • 在需要更高检测精度的场景中,其实也是用到了更复杂的网络模型来解决 face alignment 问题的

YOLO && FCN

后来还尝试过用 YOLO 网络做 Object Detection,用 FCN 网络做像素级的 Semantic Segmentation,但是结果都很不理想,比如:

  • 达不到文档检测功能想要的精确度
  • 网络结构复杂,运算量大,在手机上无法做到实时检测

有效的神经网络算法

前面尝试的几种神经网络算法,都不能得到想要的效果,后来换了一种思路,既然传统的技术手段里包含了两个关键的步骤,那能不能用神经网络来分别改善这两个步骤呢,经过分析发现,可以尝试用神经网络来替换 canny 算法,也就是用神经网络来对图像中的矩形区域进行边缘检测,只要这个边缘检测能够去除更多的干扰因素,那第二个步骤里面的算法也就可以变得更简单了。

神经网络的输入和输出

image to edge

按照这种思路,对于神经网络部分,现在的需求变成了上图所示的样子。

HED(Holistically-Nested Edge Detection) 网络

边缘检测这种需求,在图像处理领域里面,通常叫做 Edge Detection 或 Contour Detection,按照这个思路,找到了 Holistically-Nested Edge Detection 网络模型。

HED 网络模型是在 VGG16 网络结构的基础上设计出来的,所以有必要先看看 VGG16。

vgg detail

上图是 VGG16 的原理图,为了方便从 VGG16 过渡到 HED,我们先把 VGG16 变成下面这种示意图:

vgg to hed 1

在上面这个示意图里,用不同的颜色区分了 VGG16 的不同组成部分。

vgg to hed 2

从示意图上可以看到,绿色代表的卷积层和红色代表的池化层,可以很明显的划分出五组,上图用紫色线条框出来的就是其中的第三组。

vgg to hed 3

HED 网络要使用的就是 VGG16 网络里面的这五组,后面部分的 fully connected 层和 softmax 层,都是不需要的,另外,第五组的池化层(红色)也是不需要的。

vgg to hed 4

去掉不需要的部分后,就得到上图这样的网络结构,因为有池化层的作用,从第二组开始,每一组的输入 image 的长宽值,都是前一组的输入 image 的长宽值的一半。

vgg to hed 5

HED 网络是一种多尺度多融合(multi-scale and multi-level feature learning)的网络结构,所谓的多尺度,就是如上图所示,把 VGG16 的每一组的最后一个卷积层(绿色部分)的输出取出来,因为每一组得到的 image 的长宽尺寸是不一样的,所以这里还需要用转置卷积(transposed convolution)/反卷积(deconv)对每一组得到的 image 再做一遍运算,从效果上看,相当于把第二至五组得到的 image 的长宽尺寸分别扩大 2 至 16 倍,这样在每个尺度(VGG16 的每一组就是一个尺度)上得到的 image,都是相同的大小了。

vgg to hed 6

把每一个尺度上得到的相同大小的 image,再融合到一起,这样就得到了最终的输出 image,也就是具有边缘检测效果的 image。

基于 TensorFlow 编写的 HED 网络结构代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def hed_net(inputs, batch_size):
# ref https://github.com/s9xie/hed/blob/master/examples/hed/train_val.prototxt
with tf.variable_scope('hed', 'hed', [inputs]):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
# vgg16 conv && max_pool layers
net = slim.repeat(inputs, 2, slim.conv2d, 12, [3, 3], scope='conv1')
dsn1 = net
net = slim.max_pool2d(net, [2, 2], scope='pool1')

net = slim.repeat(net, 2, slim.conv2d, 24, [3, 3], scope='conv2')
dsn2 = net
net = slim.max_pool2d(net, [2, 2], scope='pool2')

net = slim.repeat(net, 3, slim.conv2d, 48, [3, 3], scope='conv3')
dsn3 = net
net = slim.max_pool2d(net, [2, 2], scope='pool3')

net = slim.repeat(net, 3, slim.conv2d, 96, [3, 3], scope='conv4')
dsn4 = net
net = slim.max_pool2d(net, [2, 2], scope='pool4')

net = slim.repeat(net, 3, slim.conv2d, 192, [3, 3], scope='conv5')
dsn5 = net
# net = slim.max_pool2d(net, [2, 2], scope='pool5') # no need this pool layer

# dsn layers
dsn1 = slim.conv2d(dsn1, 1, [1, 1], scope='dsn1')
# no need deconv for dsn1

dsn2 = slim.conv2d(dsn2, 1, [1, 1], scope='dsn2')
deconv_shape = tf.pack([batch_size, const.image_height, const.image_width, 1])
dsn2 = deconv_mobile_version(dsn2, 2, deconv_shape) # deconv_mobile_version can work on mobile

dsn3 = slim.conv2d(dsn3, 1, [1, 1], scope='dsn3')
deconv_shape = tf.pack([batch_size, const.image_height, const.image_width, 1])
dsn3 = deconv_mobile_version(dsn3, 4, deconv_shape)

dsn4 = slim.conv2d(dsn4, 1, [1, 1], scope='dsn4')
deconv_shape = tf.pack([batch_size, const.image_height, const.image_width, 1])
dsn4 = deconv_mobile_version(dsn4, 8, deconv_shape)

dsn5 = slim.conv2d(dsn5, 1, [1, 1], scope='dsn5')
deconv_shape = tf.pack([batch_size, const.image_height, const.image_width, 1])
dsn5 = deconv_mobile_version(dsn5, 16, deconv_shape)

# dsn fuse
dsn_fuse = tf.concat(3, [dsn1, dsn2, dsn3, dsn4, dsn5])
dsn_fuse = tf.reshape(dsn_fuse, [batch_size, const.image_height, const.image_width, 5]) #without this, will get error: ValueError: Number of in_channels must be known.

dsn_fuse = slim.conv2d(dsn_fuse, 1, [1, 1], scope='dsn_fuse')

return dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5

训练网络

cost 函数

论文给出的 HED 网络是一个通用的边缘检测网络,按照论文的描述,每一个尺度上得到的 image,都需要参与 cost 的计算,这部分的代码如下:

1
2
3
4
5
6
7
8
9
10
11
input_queue_for_train = tf.train.string_input_producer([FLAGS.csv_path])
image_tensor, annotation_tensor = input_image_pipeline(dataset_root_dir_string, input_queue_for_train, FLAGS.batch_size)

dsn_fuse, dsn1, dsn2, dsn3, dsn4, dsn5 = hed_net(image_tensor, FLAGS.batch_size)

cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor) + \
class_balanced_sigmoid_cross_entropy(dsn1, annotation_tensor) + \
class_balanced_sigmoid_cross_entropy(dsn2, annotation_tensor) + \
class_balanced_sigmoid_cross_entropy(dsn3, annotation_tensor) + \
class_balanced_sigmoid_cross_entropy(dsn4, annotation_tensor) + \
class_balanced_sigmoid_cross_entropy(dsn5, annotation_tensor)

按照这种方式训练出来的网络,检测到的边缘线是有一点粗的,为了得到更细的边缘线,通过多次试验找到了一种优化方案,代码如下:

1
2
3
4
5
6
input_queue_for_train = tf.train.string_input_producer([FLAGS.csv_path])
image_tensor, annotation_tensor = input_image_pipeline(dataset_root_dir_string, input_queue_for_train, FLAGS.batch_size)

dsn_fuse, _, _, _, _, _ = hed_net(image_tensor, FLAGS.batch_size)

cost = class_balanced_sigmoid_cross_entropy(dsn_fuse, annotation_tensor)

也就是不再让每个尺度上得到的 image 都参与 cost 的计算,只使用融合后得到的最终 image 来进行计算。

两种 cost 函数的效果对比如下图所示,右侧是优化过后的效果:

edge thickness

另外还有一点,按照 HED 论文里的要求,计算 cost 的时候,不能使用常见的方差 cost,而应该使用 cost-sensitive loss function,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy

:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
"""
y = tf.cast(label, tf.float32)

count_neg = tf.reduce_sum(1. - y) # the number of 0 in y
count_pos = tf.reduce_sum(y) # the number of 1 in y (less than count_neg)
beta = count_neg / (count_neg + count_pos)

pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)

return cost

转置卷积层的双线性初始化

在尝试 FCN 网络的时候,就被这个问题卡住过很长一段时间,按照 FCN 的要求,在使用转置卷积(transposed convolution)/反卷积(deconv)的时候,要把卷积核的值初始化成双线性放大矩阵(bilinear upsampling kernel),而不是常用的正态分布随机初始化,同时还要使用很小的学习率,这样才更容易让模型收敛。

HED 的论文中,并没有明确的要求也要采用这种方式初始化转置卷积层,但是,在训练过程中发现,采用这种方式进行初始化,模型才更容易收敛。

这部分的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_kernel_size(factor):
"""
Find the kernel size given the desired factor of upsampling.
"""
return 2 * factor - factor % 2


def upsample_filt(size):
"""
Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)


def bilinear_upsample_weights(factor, number_of_classes):
"""
Create weights matrix for transposed convolution with bilinear filter
initialization.
"""
filter_size = get_kernel_size(factor)

weights = np.zeros((filter_size,
filter_size,
number_of_classes,
number_of_classes), dtype=np.float32)

upsample_kernel = upsample_filt(filter_size)

for i in xrange(number_of_classes):
weights[:, :, i, i] = upsample_kernel

return weights

训练过程冷启动

HED 网络不像 VGG 网络那样很容易就进入收敛状态,也不太容易进入期望的理想状态,主要是两方面的原因:

  • 前面提到的转置卷积层的双线性初始化,就是一个重要因素,因为在 4 个尺度上,都需要反卷积,如果反卷积层不能收敛,那整个 HED 都不会进入期望的理想状态
  • 另外一个原因,是由 HED 的多尺度引起的,既然是多尺度了,那每个尺度上得到的 image 都应该对模型的最终输出 image 产生贡献,在训练的过程中发现,如果输入 image 的尺寸是 224*224,还是很容易就训练成功的,但是当把输入 image 的尺寸调整为 256*256 后,很容易出现一种状况,就是 5 个尺度上得到的 image,会有 1 ~ 2 个 image 是无效的(全部是黑色)

为了解决这里遇到的问题,采用的办法就是先使用少量样本图片(比如 2000 张)训练网络,在很短的训练时间(比如迭代 1000 次)内,如果 HED 网络不能表现出收敛的趋势,或者不能达到 5 个尺度的 image 全部有效的状态,那就直接放弃这轮的训练结果,重新开启下一轮训练,直到满意为止,然后才使用完整的训练样本集合继续训练网络。

训练数据集(大量合成数据 + 少量真实数据)

HED 论文里使用的训练数据集,是针对通用的边缘检测目的的,什么形状的边缘都有,比如下面这种:

hed training dataset 1

用这份数据训练出来的模型,在做文档扫描的时候,检测出来的边缘效果并不理想,而且这份训练数据集的样本数量也很小,只有一百多张图片(因为这种图片的人工标注成本太高了),这也会影响模型的质量。

现在的需求里,要检测的是具有一定透视和旋转变换效果的矩形区域,所以可以大胆的猜测,如果准备一批针对性更强的训练样本,应该是可以得到更好的边缘检测效果的。

借助第一版技术方案收集回来的真实场景图片,我们开发了一套简单的标注工具,人工标注了 1200 张图片(标注这 1200 张图片的时间成本也很高),但是这 1200 多张图片仍然有很多问题,比如对于神经网络来说,1200 个训练样本其实还是不够的,另外,这些图片覆盖的场景其实也比较少,有些图片的相似度比较高,这样的数据放到神经网络里训练,泛化的效果并不好。

所以,还采用技术手段,合成了80000多张训练样本图片。

hed training dataset 2

如上图所示,一张背景图和一张前景图,可以合成出一对训练样本数据。在合成图片的过程中,用到了下面这些技术和技巧:

  • 在前景图上添加旋转、平移、透视变换
  • 对背景图进行了随机的裁剪
  • 通过试验对比,生成合适宽度的边缘线
  • OpenCV 不支持透明图层之间的旋转和透视变换操作,只能使用最低精度的插值算法,为了改善这一点,后续改成了使用 iOS 模拟器,通过 CALayer 上的操作来合成图片
  • 在不断改进训练样本的过程中,还根据真实样本图片的统计情况和各种途径的反馈信息,刻意模拟了一些更复杂的样本场景,比如凌乱的背景环境、直线边缘干扰等等

经过不断的调整和优化,最终才训练出一个满意的模型,可以再次通过下面这张图表中的第二列看一下神经网络模型的边缘检测效果:

hed vs canny

在手机设备上运行 TensorFlow

在手机上使用 TensorFlow 库

TensorFlow 官方是支持 iOS 和 Android 的,而且有清晰的文档,照着做就行。但是因为 TensorFlow 是依赖于 protobuf 3 的,所以有可能会遇到一些其他的问题,比如下面这两种,就是我们在两个不同的 iOS APP 中遇到的问题和解决办法,可以作为一个参考:

  • A 产品使用的是 protobuf 2,同时由于各种历史原因,使用并且停留在了很旧的某个版本的 Base 库上,而 protobuf 3 的内部也使用了 Base 库,当 A 产品升级到 protobuf 3 后,protobuf 3 的 Base 库和 A 源码中的 Base 库产生了一些奇怪的冲突,最后的解决办法是手动修改了 A 源码中的 Base 库,避免编译时的冲突
  • B 产品也是使用的 protobuf 2,而且 B 产品使用到的多个第三方模块(没有源码,只有二进制文件)也是依赖于 protobuf 2,直接升级 B 产品使用的 protobuf 库就行不通了,最后采用的方法是修改 TensorFlow 和 TensorFlow 中使用的 protobuf 3 的源代码,把 protobuf 3 换了一个命名空间,这样两个不同版本的 protobuf 库就可以共存了

Android 上因为本身是可以使用动态库的,所以即便 app 必须使用 protobuf 2 也没有关系,不同的模块使用 dlopen 的方式加载各自需要的特定版本的库就可以了。

在手机上使用训练得到的模型文件

模型通常都是在 PC 端训练的,对于大部分使用者,都是用 Python 编写的代码,得到 ckpt 格式的模型文件。在使用模型文件的时候,一种做法就是用代码重新构建出完整的神经网络,然后加载这个 ckpt 格式的模型文件,如果是在 PC 上使用模型文件,用这个方法其实也是可以接受的,复制粘贴一下 Python 代码就可以重新构建整个神经网络。但是,在手机上只能使用 TensorFlow 提供的 C++ 接口,如果还是用同样的思路,就需要用 C++ API 重新构建一遍神经网络,这个工作量就有点大了,而且 C++ API 使用起来比 Python API 复杂的多,所以,在 PC 上训练完网络后,还需要把 ckpt 格式的模型文件转换成 pb 格式的模型文件,这个 pb 格式的模型文件,是用 protobuf 序列化得到的二进制文件,里面包含了神经网络的具体结构以及每个矩阵的数值,使用这个 pb 文件的时候,不需要再用代码构建完整的神经网络结构,只需要反序列化一下就可以了,这样的话,用 C++ API 编写的代码就会简单很多,其实这也是 TensorFlow 推荐的使用方法,在 PC 上使用模型的时候,也应该使用这种 pb 文件(训练过程中使用 ckpt 文件)。

HED 网络在手机上遇到的奇怪 crash

在手机上加载 pb 模型文件并且运行的时候,遇到过一个诡异的错误,内容如下:

1
2
3
4
Invalid argument: No OpKernel was registered to support Op 'Mul' with these attrs.  Registered devices: [CPU], Registered kernels:
device='CPU'; T in [DT_FLOAT]

[[Node: hed/mul_1 = Mul[T=DT_INT32](hed/strided_slice_2, hed/mul_1/y)]]

之所以诡异,是因为从字面上看,这个错误的含义是缺少乘法操作(Mul),但是我用其他的神经网络模型做过对比,乘法操作模块是可以正常工作的。

Google 搜索后发现很多人遇到过类似的情况,但是错误信息又并不相同,后来在 TensorFlow 的 github issues 里终于找到了线索,综合起来解释,是因为 TensorFlow 是基于操作(Operation)来模块化设计和编码的,每一个数学计算模块就是一个 Operation,由于各种原因,比如内存占用大小、GPU 独占操作等等,mobile 版的 TensorFlow,并没有包含所有的 Operation,mobile 版的 TensorFlow 支持的 Operation 只是 PC 完整版 TensorFlow 的一个子集,我遇到的这个错误,就是因为使用到的某个 Operation 并不支持 mobile 版。

按照这个线索,在 Python 代码中逐个排查,后来定位到了出问题的代码,修改前后的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def deconv(inputs, upsample_factor):
input_shape = tf.shape(inputs)

# Calculate the ouput size of the upsampled tensor
upsampled_shape = tf.pack([input_shape[0],
input_shape[1] * upsample_factor,
input_shape[2] * upsample_factor,
1])

upsample_filter_np = bilinear_upsample_weights(upsample_factor, 1)
upsample_filter_tensor = tf.constant(upsample_filter_np)

# Perform the upsampling
upsampled_inputs = tf.nn.conv2d_transpose(inputs, upsample_filter_tensor,
output_shape=upsampled_shape,
strides=[1, upsample_factor, upsample_factor, 1])

return upsampled_inputs

def deconv_mobile_version(inputs, upsample_factor, upsampled_shape):
upsample_filter_np = bilinear_upsample_weights(upsample_factor, 1)
upsample_filter_tensor = tf.constant(upsample_filter_np)

# Perform the upsampling
upsampled_inputs = tf.nn.conv2d_transpose(inputs, upsample_filter_tensor,
output_shape=upsampled_shape,
strides=[1, upsample_factor, upsample_factor, 1])

return upsampled_inputs

问题就是由 deconv 函数中的 tf.shape 和 tf.pack 这两个操作引起的,在 PC 版代码中,为了简洁,是基于这两个操作,自动计算出 upsampled_shape,修改过后,则是要求调用者用 hard coding 的方式设置对应的 upsampled_shape。

裁剪 TensorFlow

TensorFlow 是一个很庞大的框架,对于手机来说,它占用的体积是比较大的,所以需要尽量的缩减 TensorFlow 库占用的体积。

其实在解决前面遇到的那个 crash 问题的时候,已经指明了一种裁剪的思路,既然 mobile 版的 TensorFlow 本来就是 PC 版的一个子集,那就意味着可以根据具体的需求,让这个子集变得更小,这也就达到了裁剪的目的。具体来说,就是修改 TensorFlow 源码中的 tensorflow/tensorflow/contrib/makefile/tf_op_files.txt 文件,只保留使用到了的模块。针对 HED 网络,原有的 200 多个模块裁剪到只剩 46 个,裁剪过后的 tf_op_files.txt 文件如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
tensorflow/core/kernels/xent_op.cc
tensorflow/core/kernels/where_op.cc
tensorflow/core/kernels/unpack_op.cc
tensorflow/core/kernels/transpose_op.cc
tensorflow/core/kernels/transpose_functor_cpu.cc
tensorflow/core/kernels/tensor_array_ops.cc
tensorflow/core/kernels/tensor_array.cc
tensorflow/core/kernels/split_op.cc
tensorflow/core/kernels/split_v_op.cc
tensorflow/core/kernels/split_lib_cpu.cc
tensorflow/core/kernels/shape_ops.cc
tensorflow/core/kernels/session_ops.cc
tensorflow/core/kernels/sendrecv_ops.cc
tensorflow/core/kernels/reverse_op.cc
tensorflow/core/kernels/reshape_op.cc
tensorflow/core/kernels/relu_op.cc
tensorflow/core/kernels/pooling_ops_common.cc
tensorflow/core/kernels/pack_op.cc
tensorflow/core/kernels/ops_util.cc
tensorflow/core/kernels/no_op.cc
tensorflow/core/kernels/maxpooling_op.cc
tensorflow/core/kernels/matmul_op.cc
tensorflow/core/kernels/immutable_constant_op.cc
tensorflow/core/kernels/identity_op.cc
tensorflow/core/kernels/gather_op.cc
tensorflow/core/kernels/gather_functor.cc
tensorflow/core/kernels/fill_functor.cc
tensorflow/core/kernels/dense_update_ops.cc
tensorflow/core/kernels/deep_conv2d.cc
tensorflow/core/kernels/xsmm_conv2d.cc
tensorflow/core/kernels/conv_ops_using_gemm.cc
tensorflow/core/kernels/conv_ops_fused.cc
tensorflow/core/kernels/conv_ops.cc
tensorflow/core/kernels/conv_grad_filter_ops.cc
tensorflow/core/kernels/conv_grad_input_ops.cc
tensorflow/core/kernels/conv_grad_ops.cc
tensorflow/core/kernels/constant_op.cc
tensorflow/core/kernels/concat_op.cc
tensorflow/core/kernels/concat_lib_cpu.cc
tensorflow/core/kernels/bias_op.cc
tensorflow/core/ops/sendrecv_ops.cc
tensorflow/core/ops/no_op.cc
tensorflow/core/ops/nn_ops.cc
tensorflow/core/ops/nn_grad.cc
tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/array_grad.cc

需要强调的一点是,这种操作思路,是针对不同的神经网络结构有不同的裁剪方式,原则就是用到什么模块就保留什么模块。当然,因为有些模块之间还存在隐含的依赖关系,所以裁剪的时候也是要反复尝试多次才能成功的。

除此之外,还有下面这些通用手段也可以实现裁剪的目的:

  • 编译器级别的 strip 操作,在链接的时候会自动的把没有调用到的函数去除掉(集成开发环境里通常已经自动将这些参数设置成了最佳组合)
  • 借助一些高级技巧和工具,对二进制文件进行瘦身

借助所有这些裁剪手段,最终我们的 ipa 安装包的大小只增加了 3M。如果不做手动裁剪这一步,那 ipa 的增量,则是 30M 左右。

裁剪 HED 网络

按照 HED 论文给出的参考信息,得到的模型文件的大小是 56M,对于手机来说也是比较大的,而且模型越大也意味着计算量越大,所以需要考虑能否把 HED 网络也裁剪一下。

HED 网络是用 VGG16 作为基础网络结构,而 VGG 又是一个得到广泛验证的基础网络结构,因此修改 HED 的整体结构肯定不是一个明智的选择,至少不是首选的方案。

考虑到现在的需求,只是检测矩形区域的边缘,而并不是检测通用场景下的广义的边缘,可以认为前者的复杂度比后者更低,所以一种可行的思路,就是保留 HED 的整体结构,修改 VGG 每一组卷积层里面的卷积核的数量,让 HED 网络变的更『瘦』。

按照这种思路,经过多次调整和尝试,最终得到了一组合适的卷积核的数量参数,对应的模型文件只有 4.2M,在 iPhone 7P 上,处理每帧图片的时间消耗是 0.1 秒左右,满足实时性的要求。

神经网络的裁剪,目前在学术界也是一个很热门的领域,有好几种不同的理论来实现不同目的的裁剪,但是,也并不是说每一种网络结构都有裁剪的空间,通常来说,应该结合实际情况,使用合适的技术手段,选择一个合适大小的模型文件。

TensorFlow API 的选择

TensorFlow 的 API 是很灵活的,也比较底层,在学习过程中发现,每个人写出来的代码,风格差异很大,而且很多工程师又采用了各种各样的技巧来简化代码,但是这其实反而在无形中又增加了代码的阅读难度,也不利于代码的复用。

第三方社区和 TensorFlow 官方,都意识到了这个问题,所以更好的做法是,使用封装度更高但又保持灵活性的 API 来进行开发。本文中的代码,就是使用 TensorFlow-Slim 编写的。

OpenCV 算法

虽然用神经网络技术,已经得到了一个比 canny 算法更好的边缘检测效果,但是,神经网络也并不是万能的,干扰是仍然存在的,所以,第二个步骤中的数学模型算法,仍然是需要的,只不过因为第一个步骤中的边缘检测有了大幅度改善,所以第二个步骤中的算法,得到了适当的简化,而且算法整体的适应性也更强了。

这部分的算法如下图所示:

find rect 1

按照编号顺序,几个关键步骤做了下面这些事情:

  1. 用 HED 网络检测边缘,可以看到,这里得到的边缘线还是存在一些干扰的
  2. 在前一步得到的图像上,使用 HoughLinesP 函数检测线段(蓝色线段)
  3. 把前一步得到的线段延长成直线(绿色直线)
  4. 在第二步中检测到的线段,有一些是很接近的,或者有些短线段是可以连接成一条更长的线段的,所以可以采用一些策略把它们合并到一起,这个时候,就要借助第三步中得到的直线。定义一种策略判断两条直线是否相等,当遇到相等的两条直线时,把这两条直线各自对应的线段再合并或连接成一条线段。这一步完成后,后面的步骤就只需要蓝色的线段而不需要绿色的直线了
  5. 根据第四步得到的线段,计算它们之间的交叉点,临近的交叉点也可以合并,同时,把每一个交叉点和产生这个交叉点的线段也要关联在一起(每一个蓝色的点,都有一组红色的线段和它关联)
  6. 对于第五步得到的所有交叉点,每次取出其中的 4 个,判断这 4 个点组成的四边形是否是一个合理的矩形(有透视变换效果的矩形),除了常规的判断策略,比如角度、边长的比值之外,还有一个判断条件就是每条边是否可以和第五步中得到的对应的点的关联线段重合,如果不能重合,则这个四边形就不太可能是我们期望检测出来的矩形
  7. 经过第六步的过滤后,如果得到了多个四边形,可以再使用一个简单的过滤策略,比如排序找出周长或面积最大的矩形

对于上面这个例子,第一版技术方案中检测出来的边缘线如下图所示:

find rect 2

有兴趣的读者也可以考虑一下,在这种边缘图中,如何设计算法才能找出我们期望的那个矩形。

总结

算法角度

  • 神经网络的参数/超参数的调优,通常只能基于经验来设置,有 magic trick 的成分
  • 神经网络/机器学习是一门试验科学
  • 对于监督学习,数据的标注成本很高,这一步很容易出现瓶颈
  • 论文、参考代码和自己的代码,这三者之间不完全一致也是正常现象
  • 对于某些需求,可以在模型的准确度、大小和运行速度之间找一个平衡点

工程角度

  • end-to-end 网络无效的时候,可以用 pipeline 的思路考虑问题、拆分业务,针对性的使用神经网络技术
  • 至少要熟练掌握一种神经网络的开发框架,而且要追求代码的工程质量
  • 要掌握神经网络技术中的一些基本套路,举一反三
  • 要在学术界和工业界中间找平衡点,尽可能多的学习一些不同问题领域的神经网络模型,作为技术储备

参考文献

Hacker’s guide to Neural Networks
神经网络浅讲:从神经元到深度学习
分类与回归区别是什么?
神经网络架构演进史:全面回顾从LeNet5到ENet十余种架构

数据的游戏:冰与火
为什么“高大上”的算法工程师变成了数据民工?
Facebook人工智能负责人Yann LeCun谈深度学习的局限性

The best explanation of Convolutional Neural Networks on the Internet!
从入门到精通:卷积神经网络初学者指南
Transposed Convolution, Fractionally Strided Convolution or Deconvolution
A technical report on convolution arithmetic in the context of deep learning

Visualizing what ConvNets learn
Visualizing Features from a Convolutional Neural Network

Neural networks: which cost function to use?
difference between tensorflow tf.nn.softmax and tf.nn.softmax_cross_entropy_with_logits
Why You Should Use Cross-Entropy Error Instead Of Classification Error Or Mean Squared Error For Neural Network Classifier Training

Tensorflow 3 Ways
TensorFlow-Slim
TensorFlow-Slim image classification library

Holistically-Nested Edge Detection
深度卷积神经网络在目标检测中的进展
全卷积网络:从图像级理解到像素级理解
图像语义分割之FCN和CRF

Image Classification and Segmentation with Tensorflow and TF-Slim
Upsampling and Image Segmentation with Tensorflow and TF-Slim
Image Segmentation with Tensorflow using CNNs and Conditional Random Fields

How to Build a Kick-Ass Mobile Document Scanner in Just 5 Minutes
MAKE DOCUMENT SCANNER USING PYTHON AND OPENCV
Fast and Accurate Document Detection for Scanning

用 ReactiveCocoa 事半功倍的写代码(五)

Posted on 2016-07-25 | Edited on 2016-07-26

体会 Composition 的含义

有些读者可能会注意到一点,这个系列教程的英文标题是 The Power Of Composition In FRP,看上去并不像是中文标题的直接翻译,其实这也是纠结过后的一个妥协的选择,其实我个人更喜欢这个英文标题,因为 Composition 这个词,更能体现出 FRP 的一个精髓理念,如果要用一个中文词语来表示,我觉得『组装』这个词更准确一些。

先看下面这段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
- (void)fetchNecessaryDataForAccounts:(NSArray<FMAccount *> *)accounts {
NSParameterAssert(accounts != nil);
NSParameterAssert(accounts.count > 0);

@weakify(self);
[[[[[accounts.rac_sequence signal] //1
map:^id(FMAccount *account) {
//2
return [[[QHOldAccountMigration fetchInitialDataForAccount:account]
map:^id(FMAccount *account) {
//3
return RACTuplePack(account, nil);
}]
catch:^RACSignal *(NSError *error) {
//3
return [RACSignal return:RACTuplePack(account, error)];
}];
}]
collect] //4
flattenMap:^RACStream *(NSArray *arrayOfSignal) {
return [[RACSignal zip:arrayOfSignal] //4
map:^id(RACTuple *tuple) {
NSMutableArray *successAccounts = [[NSMutableArray alloc] init];
NSMutableArray *failAccounts = [[NSMutableArray alloc] init];

for (int i = 0; i < tuple.count; i++) {
RACTuple *t = [tuple objectAtIndex:i];
FMAccount *account = t.first;
NSError *error = t.second;

if (error) {
[failAccounts addObject:account];
} else {
[successAccounts addObject:account];
}
}

return RACTuplePack([successAccounts copy], [failAccounts copy]);
}];
}]
subscribeNext:^(RACTuple *tuple) {
@strongify(self);
NSArray *successAccounts = tuple.first;
NSArray *failAccounts = tuple.second;
//5

if (failAccounts.count == 0) {
[self jumpToOriginalLogic];
} else {
NSMutableString *title;
if (successAccounts.count == 0) {
title = [[NSMutableString alloc] initWithString:@"所有账号迁移失败,请重新登录"];
} else {
title = [[NSMutableString alloc] initWithString:@"邮箱账号"];
for (FMAccount *account in failAccounts) {
[title appendFormat:@"%@, ", account.profile.mailAddress];
}

title = [[title substringToIndex:title.length - 2] mutableCopy];
[title appendString:@"迁移失败,需要重新登录"];
}

UIAlertController *alertController = [UIAlertController alertControllerWithTitle:title message:nil preferredStyle:UIAlertControllerStyleAlert];
[alertController addAction:[UIAlertAction actionWithTitle:@"确定" style:UIAlertActionStyleDefault handler:^(UIAlertAction * _Nonnull action) {
@strongify(self);
for (FMAccount *account in failAccounts) {
FMMigrationFailAccount *failAccount = [FMMigrationFailAccount convertAccountToMigrationFailAccount:account];
[failAccount save];

[[FMManager shareInstance] deleteAccount:account.accountId];
}
[self jumpToOriginalLogic];
}]];

UIViewController *viewController = [UIApplication sharedApplication].keyWindow.rootViewController;
[viewController presentViewController:alertController animated:YES completion:^{
}];
}
} error:^(NSError *error) {

} completed:^{

}];
}

这个 pipeline 其实用的就是 collect + combineLatest 或者 zip 这种管道模型,只不过管道内部具体的业务不一样,这里的业务就是针对每一个 FMAccount 帐号,下载一些必要的初始数据,然后等每个下载都完成后,再执行后续的业务,主要就是下面几个点:

  1. 把包含有 FMAccount 的数组,先变换成 signal。
  2. [QHOldAccountMigration fetchInitialDataForAccount:account] 里面执行的是下载初始数据的具体业务逻辑,其实它的内部就是一个用多个 map 操作串联起来的 pipeline,对应下载初始数据过程中的多个步骤。
  3. 如果 fetchInitialDataForAccount 失败,则把 error 转换成 next 事件,用 tuple 的形式继续向 pipeline 的后续环节传递,等所有的下载都结束后,会统一对 error 进行处理。
  4. 套用 collect + zip 这种 pipeline 模型。
  5. 当所有的下载都结束后,才会运行到这里,successAccounts 和 failAccounts 分别对应成功下载初始数据的所有帐号和下载数据失败了的所有帐号,至于后面 if 分支里的代码,只是后续的一些业务逻辑功能,读者可以不用在意,我们的重点还是在于这个 pipeline 的外形。

在编写程序的时候,通常我们都会提到『复用』这个概念,最简单的场景就是函数复用,这里的 pipeline 也是一种复用,只不过 pipeline 不像普通函数那样通过抽象出输入参数和返回结果来实现复用,pipeline 的复用体现在管道的形状上,这里所谓的形状,就是把 FRP 中对 signal 的各种操作组装起来后 pipeline 的形状。多个 map 串联是一种形状,collect + zip 是一种形状,之前的教程中提到的那些案例,都可以理解为一种形状(甚至还可以看成是多个不同形状的 pipeline 的进一步组装),每一种形状的管道,有输入的数据,有输出的数据,同时,还存在各种各样的中间处理环节,每次复用 pipeline 的时候,输入数据、输出数据以及中间处理环节,都是可以根据具体的业务需求灵活的进行填充的。

回到前面这个例子,accounts 是 pipeline 的输入,successAccounts 和 failAccounts 是 pipeline 的输出,其他的操作都可以看成是中间处理环节。这个 pipeline 仅仅是完成了下载数据的功能,在真实的产品需求中,为了更好的照顾用户体验,还希望能够显示出下载进度信息,也就是说,对于 accounts 这个输入,还需要另外一种形式的输出信息,可以体现出下载进度情况。这里还有一个约束条件,[QHOldAccountMigration fetchInitialDataForAccount:account] 这个操作本身是无法表现出下载数据时的进度信息的,因为并不是下载一个文件(在编程惯例中,通常只在上传和下载文件的时候或类似的场景中,才会设计出能体现进度信息的 API),所以这里还需要想办法模拟出一种进度信息用来在 UI 上进行显示,主要代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
- (void)fetchNecessaryDataForAccounts:(NSArray<FMAccount *> *)accounts {
NSParameterAssert(accounts != nil);
NSParameterAssert(accounts.count > 0);

@weakify(self);
//1
RACSignal *fetchAllInitialData = [[[[accounts.rac_sequence signal]
map:^id(FMAccount *account) {
return [[[[[QHOldAccountMigration fetchInitialDataForAccount:account]
map:^id(FMAccount *account) {
return RACTuplePack(account, nil);
}]
catch:^RACSignal *(NSError *error) {
return [RACSignal return:RACTuplePack(account, error)];
}]
multicast:[RACReplaySubject subject]] //3
autoconnect];
}]
multicast:[RACReplaySubject subject]] //2
autoconnect];


//4
RACSignal *businessLogicSignal = [[fetchAllInitialData collect]
flattenMap:^RACStream *(NSArray *arrayOfSignal) {
return [[RACSignal zip:arrayOfSignal]
map:^id(RACTuple *tuple) {
NSMutableArray *successAccounts = [[NSMutableArray alloc] init];
NSMutableArray *failAccounts = [[NSMutableArray alloc] init];

for (int i = 0; i < tuple.count; i++) {
RACTuple *t = [tuple objectAtIndex:i];
FMAccount *account = t.first;
NSError *error = t.second;

if (error) {
[failAccounts addObject:account];
} else {
[successAccounts addObject:account];
}
}

return RACTuplePack([successAccounts copy], [failAccounts copy]);
}];
}];


[businessLogicSignal
subscribeNext:^(RACTuple *tuple) {
@strongify(self);
NSArray *successAccounts = tuple.first;
NSArray *failAccounts = tuple.second;
//5
if (failAccounts.count == 0) {
[self jumpToOriginalLogic];
} else {
NSMutableString *title;
if (successAccounts.count == 0) {
title = [[NSMutableString alloc] initWithString:@"所有账号迁移失败,请重新登录"];
} else {
title = [[NSMutableString alloc] initWithString:@"邮箱账号"];
for (FMAccount *account in failAccounts) {
[title appendFormat:@"%@, ", account.profile.mailAddress];
}

title = [[title substringToIndex:title.length - 2] mutableCopy];
[title appendString:@"迁移失败,需要重新登录"];
}

UIAlertController *alertController = [UIAlertController alertControllerWithTitle:title message:nil preferredStyle:UIAlertControllerStyleAlert];
[alertController addAction:[UIAlertAction actionWithTitle:@"确定" style:UIAlertActionStyleDefault handler:^(UIAlertAction * _Nonnull action) {
@strongify(self);
for (FMAccount *account in failAccounts) {
FMMigrationFailAccount *failAccount = [FMMigrationFailAccount convertAccountToMigrationFailAccount:account];
[failAccount save];

[[FMManager shareInstance] deleteAccount:account.accountId];
}
[self jumpToOriginalLogic];
}]];

UIViewController *viewController = [UIApplication sharedApplication].keyWindow.rootViewController;
[viewController presentViewController:alertController animated:YES completion:^{
}];
}
} error:^(NSError *error) {

} completed:^{

}];


//11
static const CGFloat tickCount = 60 / 0.5;
RACSignal *timer = [[RACSignal interval:0.5 onScheduler:[RACScheduler mainThreadScheduler]]
map:^id(id value) {
return nil;
}];

NSMutableArray *numbers = [[NSMutableArray alloc] init];
for (NSInteger i = 0; i < tickCount; i++) {
[numbers addObject:@(i)];
}

RACSignal *counter = [[[[[numbers.rac_sequence signal]
zipWith:timer]
map:^id(RACTuple *tuple) {
NSNumber *n = tuple.first;
return RACTuplePack(n, @(tickCount), nil);//12
}]
takeUntil:businessLogicSignal]
logCompleted];



//6
NSMutableArray *sequence = [[NSMutableArray alloc] init];
for (int i = 0; i < accounts.count; i++) {
[sequence addObject:@(i + 1)];
}


static NSInteger progressValue = 0;

[[[[[[[fetchAllInitialData
flatten]//6
map:^id(RACTuple *tuple) {
//7
return tuple.first;
}]
zipWith:[sequence.rac_sequence signal]]//8
combineLatestWith:[RACSignal return:@(accounts.count)]]//8
map:^id(RACTuple *tuple) {
//9
RACTuple *nestedTuple = tuple.first;
NSNumber *accountsCount = tuple.second;

FMAccount *account = nestedTuple.first;
NSNumber *order = nestedTuple.second;

//10
return RACTuplePack(order, accountsCount, account);
}]
merge:counter]//11
subscribeNext:^(RACTuple *tuple) {
NSNumber *order = tuple.first;
NSNumber *accountsCount = tuple.second;
FMAccount *account = tuple.third;

//13
if (account) {
NSLog(@"fetch initial data finished, order is: [%@, %@], account is: %@", order, accountsCount, account.profile.mailAddress);

NSInteger nextValue = order.integerValue * 100 / accountsCount.integerValue;
if (order.integerValue == accountsCount.integerValue) {
nextValue = 100;
progressValue = 100;
}

if (nextValue > progressValue) {
progressValue = nextValue;
}
} else {
NSLog(@"counter info, [%@, %@]", order, accountsCount);
progressValue = progressValue + 1.0;

//14
if (progressValue > 95) {
progressValue = 95.0;
}
}

//14
NSLog(@"======== progressValue is: %ld", (long)progressValue);

} error:^(NSError *error) {
} completed:^{
}];
}

下面看看这个 pipeline 是如何组装出来的:

  1. 需求越复杂,通常 pipeline 也就会越复杂,现在我们遇到了新的需求,但是之前那个 pipeline 做的业务仍然是需要保留的,这种时候,通常可以考虑先把 pipeline 的代码拆分一下,然后对拆出来的 signal 或者 pipeline 重新进行组装。首先就可以把 [QHOldAccountMigration fetchInitialDataForAccount:account] 动作拆分出来,注意一点,这里还没有调用 collect 操作。
  2. 因为后续多个业务逻辑都要用到前面第一步得到的 signal,根据业务的需求,对于每个 FMAccount 只需要下载一次数据,所以这里应该让 signal 变成广播的形式。
  3. 内层嵌套的 signal 才是真正的 fetchInitialDataForAccount 动作,同理,也需要变成广播(其实在刚开始设计 pipeline 的时候,可能还意识不到需要广播,这种时候,可以先组装业务流程,当遇到问题后,再考虑是否需要使用广播 signal)。如果暂时看不明白为什么 2 和 3 两处需要使用广播,没有关系,先接着往后看,把整个 pipeline 看明白后,再倒回来想想为什么需要广播。
  4. 这个中间环节也拆分出来,以备后用。
  5. 这里是对 successAccounts 和 failAccounts 的处理逻辑,和前一个版本的 pipeline 没有区别。
  6. 现在开始考虑如何显示进度信息,虽然每次 [QHOldAccountMigration fetchInitialDataForAccount:account] 调用是没有进度信息的,但是当有多次调用的时候,是可以计算出一种形式的进度信息的,比如总共有 5 个 FMAccount,当第一个 FMAccount 下载完数据(或者失败)的时候,整体进度就是 1/5,当第二个 FMAccount 下载完数据(或者失败)的时候,整体进度就是 2/5,依次类推。
  7. 回忆一下 fetchAllInitialData 里面的内容,因为现在是计算进度信息,并不关心具体的 error,所以这里的 map 操作只需要返回 tuple.first,也就是只需要继续传递 FMAccount。
  8. 这里连续调用 zip 和 combineLatest,如果觉得这里很难理解,没有关系,先分别回忆一下 zip 和 combineLatest 的效果,想象一下这里应该得到什么样的结果。
  9. 前面的 zip 操作会得到一个 tuple,然后这个 tuple 又和 [RACSignal return:@(accounts.count)] 进行一次 combineLatest,所以这里会得到一个嵌套的 tuple。
  10. 8 和 9 的操作,最终就是为了组装出这样的一个 tuple,然后继续在 pipeline 中传递。比如总共有 5 个 FMAccount,当第一个 FMAccount 下载完数据(或者失败)的时候,这个 tuple 的值是 (1, 5, 第一个 FMAccount 对象的指针),当第二个 FMAccount 下载完数据(或者失败)的时候,返回的 tuple 的值是 (2, 5, 第二个 FMAccount 对象的指针),依次类推,后续还会返回 3 个 tuple。
  11. 前面已经组装出进度信息了,但是对于 UI 来说,这种进度信息还是太粗糙了,为了让 UI 上的进度条能够更平滑的进行动画过渡,还应该插入一些更细粒度的进度信息。这里借助 RAC 的定时器来构造出一种和 10 里面的 tuple 具有相同格式的 tuple 数据。关于这部分定时器的 pipeline,和 发送验证码的倒计时按钮 里面的 pipeline 是相似的形状的,可以看看之前的介绍。
  12. 为了和 10 里面返回的 tuple 具有同样的格式,这里需要这样组装数据,按照顺序,这里返回的 tuple 依次将会是 (1, 120, nil)、(2, 120, nil)、(3, 120, nil),依次类推,直到 (120, 120, nil)。
  13. 终于到了 pipeline 的最终输出了,把 tuple 里面的数据先分别取出来,如果 account 不为 nil,则是通过 fetchAllInitialData 计算出来的进度信息,如果 account 为 nil,则对应通过定时器模拟出来的进度信息。假设最终的进度值会达到 100,这里还需要采用适当的手段将两种不同的进度值融合在一起,现在就是用最简单的办法进行的处理。
  14. 如果定时器返回的 tuple 已经达到 (120, 120, nil),而 fetchAllInitialData 还没有执行结束,这种情况下,不应该让进度值达到 100,必须得等所有的 fetchAllInitialData 都结束后进度值才能是 100,所以这里做一个约束,定时器模拟出的进度值,最大只能达到 95。

用 ReactiveCocoa 事半功倍的写代码(四)

Posted on 2016-05-03 | Edited on 2016-05-04

监听系统截屏操作的复杂管道

这是一个很复杂的 Pipeline,因为要做的业务比较繁琐,如下图:

monitor screenshot

需求大致可以描述为:

  1. 当 app 停留在读信页面的时候,要实时的监听用户是否有截屏操作。
  2. 在 1 的基础上,只有 app 前台运行的时候,才实时监听用户是否有截屏操作,如果是后台状态,则不监听。
  3. 如果用户有截图动作,则将截图内容显示在一个预览视图内(如上图中红框区域)。
  4. 如果用户点击了预览视图,则进入后续的业务流程,对截图进行涂鸦编辑等等。
  5. 如果点击了预览视图的外部区域,则隐藏预览视图。
  6. 如果 10 秒钟之内没有任何操作,也自动隐藏预览视图。

主要的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//13
- (void)viewDidAppear:(BOOL)animated {
[self initPipeline];
}

- (void)initPipeline {
//1
RACSignal *isNotActive = [[NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationWillResignActiveNotification object:nil]
map:^id(NSNotification *notification) {
return [NSNumber numberWithBool:NO];
}];

RACSignal *isActive = [[[[NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationDidBecomeActiveNotification object:nil]
map:^id(NSNotification *notification) {
return [NSNumber numberWithBool:YES];
}]
startWith:[NSNumber numberWithBool:YES]]
merge:isNotActive];

RACSignal *isNotInBackground = [[NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationDidEnterBackgroundNotification object:nil]
map:^id(NSNotification *notification) {
return [NSNumber numberWithBool:NO];
}];

RACSignal *isInForeground = [[[[NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationWillEnterForegroundNotification object:nil]
map:^id(NSNotification *notification) {
return [NSNumber numberWithBool:YES];
}]
startWith:[NSNumber numberWithBool:YES]]
merge:isNotInBackground];

//2
RACSignal *didTakeScreenshot = [NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationUserDidTakeScreenshotNotification object:nil];

@weakify(self);
RACSignal *imageSignal = [[[[[[[RACSignal if:[RACSignal merge:@[isInForeground, isActive]] then:didTakeScreenshot else:[RACSignal never]] //3
takeUntil:self.rac_willDeallocSignal]
filter:^BOOL(id value) {
//4
@strongify(self);
return [self filterScreenshotNotification];
}]
filter:^BOOL(id value) {
//5
@strongify(self);
return self.previewShotView == nil;
}]
map:^id(NSNotification *notification) {
//6
@strongify(self);
return [self takeCurrentScreenshotOfWebview];
}]
multicast:[RACReplaySubject subject]]
autoconnect];//7

//8
RACSignal *hotSignalForPreview = [[[imageSignal
map:^id(UIImage *image) {
@strongify(self);
return [self showScreenshotPreviewView:image];
}]
multicast:[RACReplaySubject subject]]
autoconnect];

//9
RACSignal *cancel = [[hotSignalForPreview
map:^id(FMScreenshotPreviewView *previewView) {
return [previewView.cancelSignal
map:^id(id value) {
return nil;
}];
}]
switchToLatest];

//10
RACSignal *editImage = [[hotSignalForPreview
map:^id(FMScreenshotPreviewView *previewView) {
return [previewView.editImage
map:^id(id value) {
return nil;
}];
}]
switchToLatest];

//11
RACSignal *otherActionForHidePreview = [[hotSignalForPreview
map:^id(id value) {
RACSignal *willResignActive = [[[NSNotificationCenter.defaultCenter rac_addObserverForName:UIApplicationWillResignActiveNotification object:nil]
take:1]
takeUntil:[RACSignal merge:@[cancel, editImage]]];

RACSignal *timeout = [[[RACSignal return:nil]
delay:10.0]
takeUntil:[RACSignal merge:@[cancel, editImage, willResignActive]]];


return [[RACSignal merge:@[timeout, willResignActive]]
take:1];
}]
switchToLatest];

//12
RACSignal *shouldHidePreviewView = [RACSignal merge:@[cancel, editImage, otherActionForHidePreview]];

//13
RACSignal *viewWillDisappear = [self rac_signalForSelector:@selector(viewWillDisappear:)];

//14
[[[shouldHidePreviewView
zipWith:hotSignalForPreview]
takeUntil:viewWillDisappear]//13
subscribeNext:^(RACTuple *tuple) {
@strongify(self);
[self hideScreenshotPreviewView:tuple];
} completed:^{
}];

//15
[[[imageSignal sample:editImage]
takeUntil:viewWillDisappear]
subscribeNext:^(UIImage *image) {
@strongify(self);
[self showDrawViewController:image];
} completed:^{
}];
}

代码有点长,而且里面的 signal 也比较多,主要是下面这些点:

  1. 把 app 的 avtive、background、foreground 状态用 signal 的形式表达出来,使用 merge 操作把互为相反状态的 signal 合并在了一起,注意,还使用了 startWith 操作提供初始值。
  2. 这个 signal 是真正的截屏操作,它仅仅是整个 Pipeline 中的一个小环节。
  3. 因为只有 app 前台运行的时候才需要监听截屏事件,所以这里用 if/else 操作做第一层过滤。
  4. 这里是第二层过滤,因为这个 ViewController 里面有很多功能,可能会出现一些页面层叠的情况,比如显示了一个 UIActionSheet 或自定义的菜单选项等等,这个时候,也是不需要监听截屏事件的。
  5. self.previewShotView 就是显示预览图的 view,当已经有一个预览图正在显示的时候,也不需要监听截屏事件。
  6. 终于过滤完了,按照产品的需求,并不是从系统相册里把用户刚才的截图找出来,而是在 app 中自行截图一遍(只截取有效区域,不截取导航栏和工具栏区域),takeCurrentScreenshotOfWebview 方法返回的就是截图得到的 UIImage。
  7. Pipeline 的后续部分,不止一处会用到前面得到的 UIImage,所以需要 hot signal。
  8. 显示预览 view,同时在 Pipeline 中传递这个 view,这个也是 hot signal。
  9. 点击预览 view 外部区域的时候,会发送 cancelSignal signal,因为形成了 signal 的嵌套,所以要通过 switchToLatest 取出来。
  10. 类似的,点击预览 view 的时候,会发送 editImage signal,也是通过 switchToLatest 取出来。
  11. 当已经显示了一个预览 view 的时候,如果超过10秒没有任何操作,或者 app 进入了不活跃状态,也是需要隐藏预览 view 的,这里组装出对应的 signal。注意这里是如何通过 takeUntil 控制 willResignActive 和 timeout 的生命期的。
  12. 用 merge 操作组装出最终用来隐藏预览 view 的 signal。
  13. 把这个 ViewController 的 viewWillDisappear 转换成 signal 的形式。根据需求,只有这个 ViewController 可见的时候,才监听截图事件,所以,在 viewDidAppear 的时候构造 Pipeline,在 viewWillDisappear 的时候释放 Pipeline。
  14. 这里是隐藏预览 view 的具体逻辑。
  15. 当点击了预览 view 的时候,通过 showDrawViewController 方法进入后续的业务逻辑,这里使用了 sample 操作。

用 ReactiveCocoa 事半功倍的写代码(三)

Posted on 2016-04-28 | Edited on 2016-05-04

collect + combineLatest 或者 zip

RAC 里面的 collect 是一个比较容易理解的操作,它的强大之处,在于和其他的操作进行组合之后,可以完成很复杂的业务逻辑。在看真实业务代码之前,先通过下面的代码初步了解一下这种 Pipeline 的行为模式。collect 相当于 Rx 中的 ToArray 操作

版本 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
- (void)testCollectSignalsAndCombineLatestOrZip {
//1
RACSignal *numbers = @[@(0), @(1), @(2)].rac_sequence.signal;

RACSignal *letters1 = @[@"A", @"B", @"C"].rac_sequence.signal;
RACSignal *letters2 = @[@"X", @"Y", @"Z"].rac_sequence.signal;
RACSignal *letters3 = @[@"M", @"N"].rac_sequence.signal;
NSArray *arrayOfSignal = @[letters1, letters2, letters3]; //2


[[[numbers
map:^id(NSNumber *n) {
//3
return arrayOfSignal[n.integerValue];
}]
collect] //4
subscribeNext:^(NSArray *array) {
DDLogVerbose(@"%@, %@", [array class], array);
} completed:^{
DDLogVerbose(@"completed");
}];
}

这个代码纯粹只是为了演示 collect 的行为模式:

  1. 构造一个 NSNumber 的数组,包含数字 0、1、2,并且转换成 signal。
  2. 用同样的方法,构造 3 个字符串的数组,并转换成 signal,再把得到的 3 个 signal 放到数组 arrayOfSignal 中。
  3. 这里形成了一个 signal 的嵌套,但是和以前的处理方式不一样,并不会直接在后续环节中使用 flatten 操作,而是先使用 collect。
  4. collect 操作会把 Pipeline 中所有的 next 发送的数据收集到一个 NSArray 中,然后一次性通过 next 发送给后续的环节。

这段代码的执行结果如下:

1
2
3
4
5
6
2016-04-28 17:45:38:034 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] __NSArrayM, (
"<RACDynamicSignal: 0x7ffee1c9dc10> name: ",
"<RACDynamicSignal: 0x7ffee1c9dda0> name: ",
"<RACDynamicSignal: 0x7ffee1c9df20> name: "
)
2016-04-28 17:45:38:034 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] completed

可以看到,array 里面包含的是 3 个 signal。另外,因为 signal 已经形成嵌套了,所以迟早是要 flatten 的,那么如何 flatten 呢?

版本 2

因为 array 里面有 3 个 signal,所以可以构造一种 Pipeline,把这 3 个 signal 合并成一个 signal,然后对合并后的 signal 再做 flatten 操作。合并的时候,可以有不同的策略,先看下面这段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
- (void)testCollectSignalsAndCombineLatestOrZip {
RACSignal *numbers = @[@(0), @(1), @(2)].rac_sequence.signal;

RACSignal *letters1 = @[@"A", @"B", @"C"].rac_sequence.signal;
RACSignal *letters2 = @[@"X", @"Y", @"Z"].rac_sequence.signal;
RACSignal *letters3 = @[@"M", @"M"].rac_sequence.signal;
NSArray *arrayOfSignal = @[letters1, letters2, letters3];


[[[[numbers
map:^id(NSNumber *n) {
return arrayOfSignal[n.integerValue];
}]
collect]
flattenMap:^RACStream *(NSArray *arrayOfSignal) {
//1
return [RACSignal combineLatest:arrayOfSignal
reduce:^(NSString *first, NSString *second, NSString *third) {
return [NSString stringWithFormat:@"%@-%@-%@", first, second, third];
}];
}]
subscribeNext:^(NSString *x) {
DDLogVerbose(@"%@, %@", [x class], x);
} completed:^{
DDLogVerbose(@"completed");
}];
}

这段代码在接收到 collect 发送的 array 之后,对这个数组里面的 signal 进行了一个 combineLatest 操作,这个时候,原本的 3 个 signal 被 reduce 成了一个 signal,这个 signal 继续被 flatten 一次,然后最终被 Pipeline 的订阅者接收到。

这段代码的执行结果如下(也可能和下面的结果完全不一样,这是正常的,combineLatest 操作就是这样):

1
2
3
4
2016-04-28 18:48:14:453 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] NSTaggedPointerString, A-Z-N
2016-04-28 18:48:14:453 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] NSTaggedPointerString, B-Z-N
2016-04-28 18:48:14:454 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] NSTaggedPointerString, C-Z-N
2016-04-28 18:48:14:455 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] completed

版本 3

除了 combineLatest,zip 操作也可以把多个 signal reduce 成一个,但是 zip 的策略是不一样的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
- (void)testCollectSignalsAndCombineLatestOrZip {
RACSignal *numbers = @[@(0), @(1), @(2)].rac_sequence.signal;

RACSignal *letters1 = @[@"A", @"B", @"C"].rac_sequence.signal;
RACSignal *letters2 = @[@"X", @"Y", @"Z"].rac_sequence.signal;
RACSignal *letters3 = @[@"M", @"M"].rac_sequence.signal;
NSArray *arrayOfSignal = @[letters1, letters2, letters3];


[[[[numbers
map:^id(NSNumber *n) {
return arrayOfSignal[n.integerValue];//!! this is Signal, but just use map NOT flatMap
}]
collect]
flattenMap:^RACStream *(NSArray *arrayOfSignal) {
//1
return [RACSignal zip:arrayOfSignal
reduce:^(NSString *first, NSString *second, NSString *third) {
return [NSString stringWithFormat:@"%@-%@-%@", first, second, third];

}];
}]
subscribeNext:^(NSString *x) {
DDLogVerbose(@"%@, %@", [x class], x);
} completed:^{
DDLogVerbose(@"completed");
}];
}

这段代码的执行结果是下面这个样子,不像前面的 combineLatest,zip 操作的结果,只能出现下面这种唯一的情况:

1
2
3
2016-04-28 18:55:01:208 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] NSTaggedPointerString, A-X-M
2016-04-28 18:55:01:209 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] NSTaggedPointerString, B-Y-N
2016-04-28 18:55:01:209 [com.ReactiveCocoa.RACScheduler.backgroundScheduler] completed

保存联系人的头像

前面的代码很抽象,在业务中,能用上这种 Pipeline 吗?当然是可以的,比如下面这段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
- (RACSignal *)savaAvatar:(UIImage *)image withContact:(FMContact *)contact {
NSParameterAssert(image != nil);
NSParameterAssert(contact.contactItems.count > 0);

//1
RACSignal *addrs = [[contact.contactItems.rac_sequence
map:^(FMContactItem *contactItem) {
return contactItem.email;
}]
signal];

return [[[[addrs
map:^id(NSString *emailAddr) {
return [[[[FMAvatarManager shareInstance] rac_setAvatar:emailAddr image:image] //2
map:^id(id value) {
//4
return RACTuplePack(value, nil);
}]
catch:^RACSignal *(NSError *error) {
//3
return [RACSignal return:RACTuplePack(nil, error)];
}];
}]
collect] //5
flattenMap:^RACStream *(NSArray<RACSignal *> *arrayOfSignal) {
return [[RACSignal zip:arrayOfSignal] //6
map:^id(RACTuple *tuple) { //7
//8
return [tuple allObjects];
}];
}]
map:^id(NSArray<RACTuple *> *value) {
//9
return value;
}];
}

这段代码稍微有点复杂,做的事情是让 FMContact 里面的所有 email 地址,和一个 image 关联在一起,并且保存在服务器端,关键是下面这几个点:

  1. 把 contact.contactItems 里面所有的 email 转换成 signal 的形式发送出来。
  2. 每次 map 的时候,得到一个 email 地址,调用 [[FMAvatarManager shareInstance] rac_setAvatar:emailAddr image:image] 让 email 地址和 image 关联在一起,这个接口也是返回一个 signal,当成功的时候,next 里面发送一个 value (业务中并不关心这个 value 的具体值,只关心是否成功),如果失败,则会发送一个 error。如果不对 error 做特殊处理,当遇到一次 error 的时候,会使整个 Pipeline error,有些业务需要这种处理 error 的默认方式 (n 个小任务中,任何一个出现 error,整个 Pipeline 都要 error),但是我们这里的业务,并不想要这种效果,如果一个 email 上的操作失败了,不希望整个 Pipeline 因为这个 error 而结束,而是要其余的 email 地址继续执行各自的小任务,等所有的 email 都处理完毕后,再由 Pipeline 的订阅者一起处理所有的 error,这个时候,就需要用到 catch 操作了。这里有一个槽点,rac_setAvatar 每次都需要传入 image 和 email 地址,然后调用服务器接口进行保存操作,这种方式的接口,不够优雅,对于每一个 email 地址,都要重新发送一遍 image,也有点浪费流量,这是一个历史原因造成的问题。更好的方案是,先把 image 上传到服务器端,然后得到这个 image 对应的一个唯一值,比如 id,然后在这里,只需要让这个 image 的 id 和 email 能够关联起来就行了。不过这并不影响这里 Pipeline 的设计,不管是 image 还是 id,Pipeline 的形状是没有区别的。
  3. 在 catch 里面,用新的 signal 替换原有的 signal。因为需要把 error 暂存下来,放到最后再做处理,所以,用 RACTuple 把 error 包装起来并且发送出去。
  4. 虽然目前的业务,并不关心 [[FMAvatarManager shareInstance] rac_setAvatar:emailAddr image:image] 发送的 next 数据,但是,把 next 发送的数据和 error 一起用 RACTuple 包装起来,也是一个合理的设计(万一以后需要用到这个值了呢),当接收到 next 的时候,error 就是 nil,当发生 error 的时候,相当于 next 就是 nil,所以在这里,返回的是 RACTuplePack(value, nil),而在前面 3 中,返回的是 RACTuplePack(nil, error)。
  5. 使用 collect 操作。注意,前面 map 操作返回的是一个 signal,signal 的 next 发送的是一个 RACTuple,而 collect 发送的 next 是 NSArray<RACSignal *>。
  6. 前面的 map 已经形成了 signal 的嵌套,而且还通过 collect 把嵌套的 signal 放到了数组里面,所以这里需要先把数组里的 signal 合并成一个,然后再 flatten 出来。zip 操作符合我们的需求。
  7. 这里不像前面的代码演示那样使用 + (instancetype)zip:(id<NSFastEnumeration>)streams reduce:(id (^)())reduceBlock 接口,而是先使用 + (instancetype)zip:(id<NSFastEnumeration>)streams,然后 map,因为前一种 zip,输入参数 streams(数组) 中包含的元素的数量是已知的,所以可以直接在 reduce(变参数方法) 中把所有的参数都罗列出来,我们这里的 Pipeline,arrayOfSignal 里面的元素个数是不固定的,所以只能用原始的 zip 接口,然后在 map 中再进一步处理 zip 发送的 RACTuple。
  8. 在这个 map 里面得到的 RACTuple 是 zip 操作返回的,这个 tuple 里面包含的每一个数据,是前面 4 里面返回的 RACTuple,这里的 RACTuple 里面又包含了 RACTuple,千万不要搞晕了。如果没搞清楚这里的数据到底是怎么来的,可以再倒回去看看前面的步骤。为了方便后续的处理,可以把外层 RACTuple 里面的数据放到一个 NSArray 里面,然后再返回给下一个环节。[tuple allObjects] 就是做的这个动作(其实 RACTuple 的内部,就是用 NSArray 存储的数据)。
  9. 直接把 value 返回,让 Pipeline 的订阅者得到最终的结果。这里没有做任何额外的动作,仅仅是为了说明现在得到的数据是一个 NSArray<RACTuple *>。可以在这里加一些日志,方便调试。不执行这一次 map 操作也是可以的。

表单页面

再看另外一个真实业务,如下图:

edit contact

这是一个编辑联系人的页面,整体是用 UITableView 实现的,可以动态的增加、删减字段,其中有一个需求,只有当至少有一个字段有数据的时候,右上角的『保存』按钮才可以使用。如果这个页面,不需要动态的增加、删减字段,那这个需求是很容易实现的,如果不使用 UITableView,就算要动态的增加、删减字段,这个需求实现起来也还好,不会很困难。但是现在的问题在于,要在 UITableView 的基础上实现,这就有点复杂了,UITableViewCell 是在复用的,所以不能直接依赖 UITableViewCell 里面的 UITextField 来判断『保存』按钮是否可用,必须严格的使用 MVC 的思路,先把 UI 上所有的操作(增加、删减字段,编辑字段内容)都映射到 model 上,通过 model 再来计算『保存』按钮是否可用。UITableView 的代码,是传统代码和 RAC 混合编写的,RAC 做的事情并不多,主要是把 UITextField 的内容用 signal 发送出来,因为并不复杂(但是也挺繁琐的,产品还提了很多很细节的体验要求),所以这里不详细讨论,主要还是看一下基于 model 构造的 Pipeline:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
- (void)initPipline {
@weakify(self);
//1
RACSignal *emailsIsNil = [[RACObserve(self.contact, contactItems) //2
flattenMap:^id(NSMutableArray *items) {
if (items.count == 0) { //3
return [RACSignal return:[NSNumber numberWithBool:YES]];
}

//4
return [[[items.rac_sequence.signal
map:^id(FMContactItem *item) {
return [[RACObserve(item, email) //5
distinctUntilChanged] //6
map:^id(NSString *email) {
//7
return [NSNumber numberWithBool:(email.length == 0)];
}];
}]
collect] //8
flattenMap:^id(NSArray *arrayOfBoolSignal) {
return [[[RACSignal combineLatest:arrayOfBoolSignal] //9
map:^id(RACTuple *tuple) {
//10
BOOL b = YES;
for (NSUInteger i = 0; i < tuple.count; i++) {
NSNumber *n = [tuple objectAtIndex:i];
b = b && n.boolValue;
}
return [NSNumber numberWithBool:b];
}]
distinctUntilChanged]; //11
}];
}]
distinctUntilChanged]; //11

//12
RACSignal *phonesIsNil = [[RACObserve(self.contact, telephone)
map:^id(NSMutableArray *phones) {
if (phones.count == 0) {
return [NSNumber numberWithBool:YES];
}

for (NSString *phone in phones) {
if (phone.length > 0) {
return [NSNumber numberWithBool:NO];
}
}

return [NSNumber numberWithBool:YES];
}]
distinctUntilChanged];

RACSignal *addressIsNil = [[RACObserve(self.contact, familyAddress)
map:^id(NSMutableArray *addrs) {
if (addrs.count == 0) {
return [NSNumber numberWithBool:YES];
}

for (NSString *addr in addrs) {
if (addr.length > 0) {
return [NSNumber numberWithBool:NO];
}
}

return [NSNumber numberWithBool:YES];
}]
distinctUntilChanged];

RACSignal *customInfosIsNil = [[RACObserve(self.contact, customInformations)
flattenMap:^id(NSMutableArray *infos) {
if (infos.count == 0) {
return [RACSignal return:[NSNumber numberWithBool:YES]];
}

return [[[infos.rac_sequence.signal
map:^id(FMCustomInformation *info) {
RACSignal *nameSignal = [[RACObserve(info, name)
distinctUntilChanged]
map:^id(NSString *name) {
return [NSNumber numberWithBool:(name.length == 0)];
}];

RACSignal *infoSignal = [[RACObserve(info, information)
distinctUntilChanged]
map:^id(NSString *i) {
return [NSNumber numberWithBool:(i.length == 0)];
}];

return [RACSignal combineLatest:@[nameSignal, infoSignal]
reduce:(id)^id(NSNumber *name, NSNumber *info){
return [NSNumber numberWithBool:(name.boolValue && info.boolValue)];
}];

}]
collect]
flattenMap:^id(NSArray *arrayOfBoolSignal) {
return [[[RACSignal combineLatest:arrayOfBoolSignal]
map:^id(RACTuple *tuple) {
BOOL b = YES;
for (NSUInteger i = 0; i < tuple.count; i++) {
NSNumber *n = [tuple objectAtIndex:i];
b = b && n.boolValue;
}
return [NSNumber numberWithBool:b];
}]
distinctUntilChanged];
}];
}]
distinctUntilChanged];

RACSignal *nickIsNil = [[RACObserve(self.contact, nick)
map:^id(NSString *nick) {
@strongify(self);
if (self.contact.nick == nil || [self.contact.nick isEqualToString:@""] == YES) {
return [NSNumber numberWithBool:YES];
}
return [NSNumber numberWithBool:NO];
}]
distinctUntilChanged];

RACSignal *markIsNil = [RACObserve(self.contact, mark)
map:^id(NSString *mark) {
@strongify(self);
if (self.contact.mark == nil || [self.contact.mark isEqualToString:@""] == YES) {
return [NSNumber numberWithBool:YES];
}
return [NSNumber numberWithBool:NO];
}];

RACSignal *birthdayIsNil = [RACObserve(self.contact, birthday)
map:^id(NSString *birthday) {
@strongify(self);
if (self.contact.birthday == nil || [self.contact.birthday isEqualToString:@""] == YES) {
return [NSNumber numberWithBool:YES];
}
return [NSNumber numberWithBool:NO];
}];

//13
NSArray *allSignal = @[nickIsNil, emailsIsNil, markIsNil, phonesIsNil, addressIsNil, birthdayIsNil, customInfosIsNil];
self.contactHasNoPros = [[[[RACSignal combineLatest:allSignal] //13
map:^id(RACTuple *tuple) {
//14
BOOL b = YES;
for (NSUInteger i = 0; i < tuple.count; i++) {
NSNumber *n = [tuple objectAtIndex:i];
b = b && n.boolValue;
}
return [NSNumber numberWithBool:b];
}]
distinctUntilChanged]
deliverOnMainThread];
}

这部分代码有点长,不过不用恐惧,中间有很大一部分代码都是做的类似事情,只需要看其中的一个就行,以 email 字段为例子:

edit contact

  1. 联系人的字段,被划分为了好几个部分,比如 email 数组、电话号码数组、备注信息字段等等,每一部分的处理逻辑都是类似的,主要看一下 email 相关的部分。
  2. 当添加或删除 email 的时候,UITableView 部分的代码,已经在 FMContact.contactItems 数组上做了对应的动作,这里通过 RACObserve 对这个 model 进行 KVO,就可以获取到 FMContactItem 的数组。
  3. 如果用户删除了所有的 email 地址(FMContactItem 数组的元素个数为 0),emailsIsNil 就应该为 YES,说明当前输入的 email 是没有值的。
  4. 如果 FMContactItem 数组的元素个数不为 0,则把这个数组里面的 FMContactItem 转换成 signal 的形式发送出去。
  5. UI 模块会实时的更新 FMContactItem.email 字段,所以这里也是使用 RACObserve 监听 email 字段的值。
  6. distinctUntilChanged 操作相当于一种过滤,只有当这一次 next 发送的数据和前一次 next 发送的数据不一样的时候,才会把这次 next 发送的数据继续往后续环节传递。
  7. 拿到一个 email 地址的时候,只要这个 email 的长度大于 0,就认为这个字段是有值的(并没有进行 email 有效性检查,即便输入的 email 不合理,『保存』按钮仍然可用,只有点击『保存』按钮的时候,才会检查 email 是否合理有效,产品需求是设计成这样的)。
  8. 使用 collect。注意前面 5 所在的 map 操作,返回的是 signal,所以这里形成了 signal 的嵌套,然后 collect 又会把这些 signal 全部放到一个数组里面。
  9. 拿到 signal 的数组后,要把这些 signal 合并成一个,combineLatest 满足这里的需求。
  10. 这里实现具体的产品需求,比如现在有 n 个 email 的输入框,当所有的输入框都没有输入内容的时候,才认为 email 是没有值的,只要有任何一个 email 输入框有内容,都认为 email 是有值的。
  11. 这几个地方使用 distinctUntilChanged,都是为了避免不必要的 signal 数据传递。
  12. 这里好几个 signal,都是类似的思考思路和实现方式。
  13. 把不同的 *IsNil signal 放到一个数组里,用 combineLatest 把它们合并成一个。
  14. 和 10 类似,实现产品约定好的需求,当所有输入框都没有内容的时候,这个联系人就是没有值的(通过 self.contactHasNoPros 这个 signal 来传递这个 Bool 值)。

上面这段代码,最终实现出了一个 signal,就是 contactHasNoPros,这个 signal 的订阅者,根据 next 发送的 Bool 值,设置 button 的状态就可以了,代码片段如下:

1
2
3
4
5
6
7
@weakify(self);
[[self.contactEditView.contactHasNoPros
not] //1
subscribeNext:^(NSNumber *x) {
@strongify(self);
self.navigationItem.rightBarButtonItem.enabled = x.boolValue;
}];

因为 contactHasNoPros 发送 YES 的时候,表达的含义是联系人所有的字段都没有值,没有值的时候,『保存』按钮应该是不可用状态,所以这里用 not 操作先做一个 Bool 值的取反,然后再设置 button 的 enabled 状态。

用 ReactiveCocoa 事半功倍的写代码(二)

Posted on 2016-04-26 | Edited on 2016-05-04

利用 map 组装顺序执行的业务

这其实应该是最常见的使用场景,有一类业务,是可以抽象成一组按顺序执行的串行任务的,比如下面这段代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*
ColdSignal<NSString, NSError>
Completion: decode success;
Error: FMBarCodeServiceErrorDomain || NSURLErrorDomain || RACSignalErrorDomain with RACSignalErrorTimedOut //4
Scheduler: specified;
Multicast: NO;
*/
- (RACSignal *)decodeBarWithURLString: (NSString *)urlString {
NSParameterAssert(urlString != nil);

@weakify(self);
return [[[self getUIImageWithURLString:urlString] //1
flattenMap:^(UIImage *image) {
@strongify(self);
return [self decodeBarWithUIImage:image]; //2
}]
timeout:1.5 onScheduler:[RACScheduler schedulerWithPriority:RACSchedulerPriorityDefault]]; //3

}

这段代码做的事情并不复杂,就是传入一个图片的 url 地址,然后下载对应的图片,然后尝试对这张图片进行二维码解码:

  1. getUIImageWithURLString 里面完成的小任务,就是下载 UIImage。当下载失败的时候,会发出一个 NSURLErrorDomain 的 NSError。
  2. 这里的小任务,就是对前一步得到的 UIImage 进行二维码解码。当解码失败的时候,会发出一个 FMBarCodeServiceErrorDomain 的 NSError(自己的业务代码中定义的 error domain)。
  3. 这里的业务需求,是当用户长按一张图片的时候,弹出一个选项菜单,让用户可以选择合适的操作,比如『保存图片』,『转发图片』等等,同时,如果这张图片中能够识别出二维码,在弹出的选项菜单中,还要包含一项『识别图中二维码』。二维码解析是需要消耗一定的时间的,下载图片也是需要时间的,有些情况下,即便图片本身的确是一个二维码,但是二维码可能很复杂,解析的时间就会比较长,为了保证最佳的用户体验,这里需要做一个超时逻辑,如果 1.5 秒内都还没有解析出一个有效的二维码,则放弃当前的解析动作。timeout 操作就是针对这种场景的,当到达设定的超时时间时,如果还没有发送 Next 事件,则会在 Pipeline 中发送一个 RACSignalErrorDomain 的 NSError,error code 是 RACSignalErrorTimedOut。
  4. 这个 Pipeline 是由好几个小任务组合出来的,每一个环节都有可能发送 error,所以对于这个 Pipeline 的订阅者,捕获到的 NSError 会是好几个不同 Domain 的其中之一。

这个 Pipeline 的订阅者的代码会是下面这种样子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
-(void)jsCallImageClick:(NSString *)imageUrl imageClickName:(NSString *)imgClickName {
NSMutableArray *components = [NSMutableArray arrayWithArray:[imageUrl componentsSeparatedByString:@"&qmSrc:"]];
NSMutableArray *temp = [NSMutableArray arrayWithArray:[(NSString*)[components firstObject] componentsSeparatedByString:[NSString stringWithFormat:@"&%@:",imgClickName]]];
[self filterJsArray:temp];
NSString *imageUrlString = [NSString stringWithFormat:@"%@",(NSString *)[temp firstObject]];

RACSignal *barCodeStringSignal = [self.barCodeService decodeBarWithURLString:imageUrlString];

@weakify(self);
[[barCodeStringSignal
deliverOn:[RACScheduler mainThreadScheduler]] //1
subscribeNext:^(NSString *barCodeString) {
@strongify(self);
[self showImageSaveSheetWithImageUrl:imageUrl withImageClickName:imgClickName withBarCode:barCodeString];
} error:^(NSError *error) {

@strongify(self);
[self showImageSaveSheetWithImageUrl:imageUrl withImageClickName:imgClickName withBarCode:nil];
} completed:^{
}];
}

因为 decodeBarWithURLString 的内部在使用 timeout 的时候,已经通过 RACScheduler 切换到了后台线程,所以在订阅者(UI)这里还要切换回 [RACScheduler mainThreadScheduler]。

捕获并且替换 error

下面也是一个真实业务场景中的代码片段,有适当的删减,需求大致可以描述为:FMContact.contactItems 数组里包含的是一个联系人的所有的 email 地址(至少有一个),在用 FMContactCreateAvatarCell 显示这个联系人的头像的时候,要通过其中的一个 email 地址,构造出一个 url 地址,然后下载对应的头像,最后把头像 image 设置到 UIButton 上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
//1
/*
ColdSignal<UIImage?, NoError>
Completion: download image finished;
Error:
Scheduler: specified;
Multicast: NO;
*/
- (RACSignal *)getAvatarWithContact: (FMContact *)contact {
RACSignal *addrs = [[contact.contactItems.rac_sequence
map:^(FMContactItem *contactItem) {
return contactItem.email;
}]
signal];//4

return [[[[addrs take:1] //5
map:^id(NSString *emailAddr) {
return [[[FMAvatarManager shareInstance] rac_asyncGetAvatar:emailAddr]
retry:3]; //6
}]
flatten]
catch:^RACSignal *(NSError *error) {
//7
return [RACSignal return:nil]; //8
}];
}

- (void)initPipelineWithCell:(FMContactCreateAvatarCell *)cell {
@weakify(cell);
[[[[self getAvatarWithContact:self.contact] //1
deliverOnMainThread]
takeUntil:cell.rac_prepareForReuseSignal]
subscribeNext:^(UIImage *image) {
@strongify(cell);
if (image) { //2
[cell.avatarButton setImage:image forState:UIControlStateNormal];
}
} error:^(NSError *error) {
//3
} completed:^{
}];
}

这个业务需求看上去也没有太大的难度,大家肯定都可以用传统的代码写出来,但是如果用 FRP,则可以用声明式(declarative)的代码把逻辑写的更清晰:

  1. getAvatarWithContact 定义了一个 Signal,通过输入参数 FMContact,获取一个对应的头像,如果头像下载成功,则通过 next 把 image 发送给 Pipeline 的订阅者,如果下载图片失败,并不会发送 error,而是在 next 里面发送一个 nil。
  2. 这个 Pipeline 只会有一次 next 事件,按照 Signal 的定义,可能为 nil,所以需要检查。
  3. 这个 Pipeline 是不会产生 error 的,所以这里不需要做任何事情。但是真正的下载图片的操作,也就是 [[FMAvatarManager shareInstance] rac_asyncGetAvatar:emailAddr] 这一句代码产生的 signal,是有 error 事件的,有意思的地方就是如何对这里可能出现的 error 进行处理,请接着往下看。
  4. 把 FMContact.contactItems 数组里面的 email 地址,用 signal 的形式发送出来。
  5. FMContact 至少有一个 email 地址,因为只需要显示一个头像,所以直接用最简单的办法,通过 take 操作取出其中的第一个 email 地址。
  6. 从模块设计的角度来看,应该遵循一个基本原则,如果一个小任务可能出现失败的情况,就应该通过 error 把错误信息发送出去。[[FMAvatarManager shareInstance] rac_asyncGetAvatar:emailAddr] 是在下载头像图片,肯定是存在下载失败的可能性,所以这个小任务应该遵循这个基本原则。但是,为了更好的用户体验,可以在 Pipeline 中增加一个环节,添加一个策略,就是遇到下载失败的时候,自动重新下载一遍,总共尝试 3 次,这个需求可以用 retry 操作方便的实现出来。
  7. 如果运气真的不好,3 次下载都失败了,那 Pipeline 里还是会发送 error 的,但是 getAvatarWithContact 这个 signal 的设计要求是不要 error,这个时候就该用到 catch 操作了。catch 做的事情,就是当 Pipeline 里出现 error 的时候,把这个 error 『吃掉』,然后用另外的一个 signal 来替换原来的 signal,让整个 Pipeline 可以继续发送 next 数据。
  8. [RACSignal return:nil] 就是用来替换的 signal,这个 signal 会在 next 里面发送一次 nil,然后立刻就 complete。(如果业务需求变化,这里也可以通过 [RACSignal return:defaultAvatarImage] 发送一个默认的头像图片,Pipeline 是很方便的,可以灵活的组装)。

用 ReactiveCocoa 事半功倍的写代码(一)

Posted on 2016-04-17 | Edited on 2016-05-04

前言

FRP 是一门学习曲线比较陡峭的技术,回想自己以前的学习过程,也是反反复复好几次,而且总是挫败感很强。不过还好坚持了下来,现在也算是用着比较顺手了。

关于 FRP, 最容易被吐槽的地方就是没有好的学习资料和文档。一开始我也是这种感觉,后来在反复尝试的过程中,发现其实真的不是文档的问题。先说我的结论 —- 不要指望脱离代码能够把 FRP 的原理讲清楚,这是 FRP 和其他编程技术的一个明显差异,这就类似于很难用一段文字把一个数学公式描述清楚一样。而且,即便是开始看用 FRP 编写的各种代码了,还是会觉得太抽象了,仍然需要大量的时间体会代码,或者说,『悟』出其中的一些基本门道。

关于入门学习,没有捷径,最好的办法就是通过代码来学习,下面是我觉得比较好的一些入门学习资料

  • The introduction to Reactive Programming you’ve been missing
  • ReactiveCocoa Documentation 我本人主要是做 iOS 开发,目前使用的是 RAC 这个库,所以它的官方文档也是一个学习途径。另外,本文中的代码也是使用 RAC 进行编写
  • ReactiveCocoa Tutorial – The Definitive Introduction: Part 1/2
  • ReactiveCocoa Tutorial – The Definitive Introduction: Part 2/2
  • Interactive diagrams of Rx Observables 这个是一组动态效果图,用可视化的效果演示了一些 FRP 里常用操作(当然,其实还是很抽象的)

之所以说 FRP 的学习曲线很陡峭,不仅仅是指它的入门学习比较耗时费脑,当入了门或者稍微找到一些感觉之后,紧接着就会面对第二个问题:FRP 里面提供的都是一些比较抽象的函数操作,怎样才能用这些基本函数来解决各种各样的业务问题?尤其是那些很抽象的操作,怎样才能用起来?

这个系列的文章,主要就是针对后面这第二个问题,做的一些 demo 演示。

可以把 FRP 看成是一种更高级的 Pipeline 编程范式,Pipeline 的一个精髓,就是可以灵活的组合,虽然 FRP 里常用的操作也就那么几十个,但是一旦像搭积木那样对它们进行了组装之后,FRP 的强大之处一下子就展现了出来。

FRP 通常是以库或框架的形式提供给使用者,目前已经有很多常见编程语言的具体实现。在这个系列文章中,将使用 RAC 2 (ReactiveCocoa 的 Objective-C 版本) 进行编写。但是 FRP 本质上是一种编程范式,从 Pipeline 的角度来看,它的侧重点在于如何组装出不同形状的 Pipeline,而不太在乎 Pipeline 的具体构成材料(编程语言),从框架的角度来看,虽然有不同语言版本的实现,但是每个版本里,提供的诸如 map、flattenMap、reduce 等基础操作,在概念上和行为模式上,又都是一样的。所以,FRP 也是一门 “Learn once, write anywhere” 的技术。

FRP 有几个明显的好处,比如可以减少中间状态变量的使用,可以编写紧凑的代码,可以用同步风格编写异步运行的代码,在本系列文章中,也会尽量体现出这些特点。

处理键盘的弹出和隐藏

这个业务其实是非常简单的,就是在某个 UIViewController 里面,当检测到键盘弹出的时候,为了避免键盘遮挡住某个 UIView,需要根据键盘的高度重新对 view 进行 layout,用 RAC 写出来的代码是下面这个样子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
//1
- (void)initPipeline {
@weakify(self);
RACSignal *keyboardWillShowNotification =
[[NSNotificationCenter.defaultCenter rac_addObserverForName:UIKeyboardWillShowNotification object:nil]
map:^id(NSNotification *notification) {
//2
NSDictionary* userInfo = [notification userInfo];
NSValue* aValue = [userInfo objectForKey:UIKeyboardFrameEndUserInfoKey];
return aValue;
}];

[[[[[NSNotificationCenter.defaultCenter rac_addObserverForName:UIKeyboardWillHideNotification object:nil]
map:^id(NSNotification *notification) {
//3
return [NSValue valueWithCGRect:CGRectZero];
}]
merge:keyboardWillShowNotification] //4
takeUntil:self.rac_willDeallocSignal] //6
subscribeNext:^(NSValue *value) {
NSLog(@"Keyboard size is: %@", value);
//5
@strongify(self);
self.messageEditViewContainerViewBottomConstraint.constant = 5.0 + [value CGRectValue].size.height;

[self.view updateConstraints];
[UIView animateWithDuration:0.6 animations:^{
@strongify(self);
[self.view layoutIfNeeded];
}];
} completed:^{
//6
NSLog(@"%s, Keyboard Notification signal completed", __PRETTY_FUNCTION__);
}];
}

用数字标注的地方,是比较关键的点:

  1. 很多时候,Pipeline 都是只需要构建一次的,如果是针对 UIViewController,通常都是在 viewDidLoad 方法里调用 [self initPipeline],如果是针对 UIView,则很有可能是在 awakeFromNib 方法里进行调用,这里遵循的一个策略是,在模块『活』起来之后,应该尽快的构造所有的 Pipeline,如果是 model 或 service 类型的模块,则很可能是在 init 完成后,就调用 initPipeline,但是对于 UI 性质的模块,因为有 iOS 平台相关的 view 加载策略,而且 Pipeline 通常又是和 UI 交互有关,所以通常是需要在 view 生命期相关的方法中才构造 Pipeline。
  2. 通过 map 操作,把 UIKeyboardWillShowNotification 转换成一个 CGRect(包装在 NSValue 里面)。map 操作是 FRP 里面最核心的一个基本操作,也是最体现函数式编程(FP)哲学的一个操作,所谓的这个哲学,用通俗的话来描述,就是『把复杂的业务拆分成一个一个的小任务,每一个小任务,都需要一个输入值,并且会给出一个输出值(当然也会反馈错误信息),而且每个小任务都只专心的做一件事情』。如果第一个小任务的输出值,是第二个小任务的输入值,那么,就可以用 map 操作把这两个小任务串联在一起。在接收到 UIKeyboardWillShowNotification 消息通知的时候,这个小任务的输入值就是 NSNotification,输出值是键盘尺寸对应的 CGRect,小任务本身做的事情,就是从 NSNotification 里面取出包装着这个 CGRect 的 NSValue。
  3. 当接收到 UIKeyboardWillHideNotification 消息通知的时候,这个小任务要做的事情,和 2 里面的小任务是类似的,只不过这一次,NSNotification 并没有包含键盘的尺寸,那我们自己用 CGRectZero 构造一个就行了。
  4. 终于到了这段代码的重点了,merge 操作在这里的使用效果,相当于把 2 和 3 里面的两个小任务的输出值作为自己的输入值,按照时间先后顺序排列起来,然后作为自己这个小任务的输出值,返回给 Pipeline 中的下一个环节。这样描述还是很抽象,看不懂,是吧?没关系,早就说过用语言很难描述了。把代码运行起来,通过 NSLog(@"Keyboard size is: %@", value) 这句代码的输出信息体会一下 merge 的实际效果。
  5. 这里才是真正的实现业务想要的效果,根据前一个小任务的输出值(键盘尺寸 CGRect)来计算 layout 的尺寸。
  6. takeUntil 是一个难点,如果没有这一句代码调用,运行代码后会发现,前面 5 里面的业务还是正常执行了,但是当 self 被 dealloc 后(比如 pop UIViewController 后),NSLog(@”Keyboard size is: %@”, value) 这句代码还是会被执行到(因为已经处理过 retain cycle,所以此时 self 是 nil),这是因为当 self 被 dealloc 后,这个 Pipeline 并没有被释放,Pipeline 里面还是有数据在继续流动。这个话题牵扯到 RAC 框架中的内存管理策略,很重要,后面的内容中还会讲到这个话题。这里暂时只需要知道可以借助 takeUntil:self.rac_willDeallocSignal 这样的一行代码方便的解决问题就行了。

Singal上的 next、complete、error

在学习的过程中,发现有一个问题很容易被忽略掉,那就是 Signal 的 next、complete、error 这 3 种数据,会在什么时候被发送出来,针对这个问题做过一个总结,放在了 这篇文档 中,主要目的是使用一种简单易懂的格式把 Signal 的关键信息描述出来,这里简单摘录一下。

基础格式

1
2
3
4
5
HotSignal<T, E>   // or ColdSignal<T, E>
Completion: ...
Error: ...
Scheduler: ...
Multicast: ...

关键字解释

  • HotSignal And ColdSignal:
    • HotSignal: Signal 已经处于活动状态(activated);
    • ColdSignal: Signal 需要订阅(subscribed)才会活动(activate);
  • T: Signal sendNext 的类型, 可以下面几种情况:
    • T: 表示只会发送 1 次 next 事件, 内容是类型 T 的实例;
    • T?: 表示只会发送 1 次 next 事件, 内容是类型 T 的实例或者 nil;
    • [T]: 表示会发送 0 到 n 次 next 事件, 内容是类型 T 的实例;
    • [T?]: 表示会发送 0 到 n 次 next 事件, 内容是类型 T 的实例或者 nil;
    • None: 表示不会发送 next 事件;
  • E: Signal sendError 的类型, 通常是 NSError 或 NoError; NoError 表示 Signal 不会 sendError;
  • Completion: 描述什么情况 sendCompleted;
    • 如果 next 事件的发送次数是 无穷多次,相当于使用者永远也接收不到 Completed 事件,所以这一行可以不写;
  • Error: 描述什么情况 sendError;
    如果 Signal 不会 sendError, 这一行可以不写;
  • Scheduler: Signal 所在的线程,通常是 main specified current, 默认是 current
    • main 模块内部的pipeline有切换不同的scheduler,所以模块内部有责任确保最终的signal始终是在main schedular上的
    • specified 模块内部自定义了一个任务队列,模块会确保最终返回的signal都在这个特定的schedular中(或者是使用全局默认的后台schedular)
    • current 模块内部pipeline没有做任何scheduler的切换,且不指定特定的schedular,所以最终返回的signal和外部调用者的线程保持一致
  • Multicast: 是否广播,通常是 YES NO, 默认是 NO

所有可能出现的有意义的非嵌套 Signal

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
*    HotSignal<T, NoError>
* HotSignal<T?, NoError>
* HotSignal<[T], NoError>
* HotSignal<[T?], NoError>
* HotSignal<None, NoError>
* HotSignal<T, NSError>
* HotSignal<T?, NSError>
* HotSignal<[T], NSError>
* HotSignal<[T?], NSError>
* HotSignal<None, NSError>

* ColdSignal<T, NoError>
* ColdSignal<T?, NoError>
* ColdSignal<[T], NoError>
* ColdSignal<[T?], NoError>
* ColdSignal<None, NoError>
* ColdSignal<T, NSError>
* ColdSignal<T?, NSError>
* ColdSignal<[T], NSError>
* ColdSignal<[T?], NSError>
* ColdSignal<None, NSError>

发送验证码的倒计时按钮

Retry Button

如上图,这里的需求是,点击右上角的按钮后,该按钮不可以使用,同时在按钮上显示一个倒计时时间,当达到倒计时时间后,按钮恢复可用状态。这个需求并不难,相信大家都可以写出来,但是,每个人写出来的代码,风格肯定千差万别,而且,免不了会需要一些状态变量来记录一些信息,比如定时器对象和倒计时的时间等等。如果换用 RAC,则可以在一段连续的代码中,满足所有的需求,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
//1
/*
ColdSignal<RACTuple<NSString, NSNumber<BOOL> >), NoError>
Completion: 1分钟倒计时结束;
Error: none;
Scheduler: main;
Multicast: NO;
*/
- (RACSignal *)retryButtonTitleAndEnable {
static const NSInteger n = 60;

RACSignal *timer = [[[RACSignal interval:1 onScheduler:[RACScheduler mainThreadScheduler]] //7
map:^id(id value) {
return nil; //8
}]
startWith:nil]; //9

//10
NSMutableArray *numbers = [[NSMutableArray alloc] init];
for (NSInteger i = n; i >= 0; i--) {
[numbers addObject:[NSNumber numberWithInteger:i]];
}

return [[[[[numbers.rac_sequence.signal zipWith:timer] //11
map:^id(RACTuple *tuple) {
//12
NSNumber *number = tuple.first;
NSInteger count = number.integerValue;

if (count == 0) {
return RACTuplePack(@"重试", [NSNumber numberWithBool:YES]);
} else {
NSString *title = [NSString stringWithFormat:@"重试(%lds)", (long)count];
return RACTuplePack(title, [NSNumber numberWithBool:NO]);
}
}]
takeUntil:[self rac_willDeallocSignal]] //13
setNameWithFormat:@"%s, retryButtonTitleAndEnable signal", __PRETTY_FUNCTION__]
logCompleted]; //14
}

- (void)initPipeline {
@weakify(self);
[[[[[[self.retryButtton rac_signalForControlEvents:UIControlEventTouchUpInside]
map:^id(id value) {
//2
@strongify(self);
return [self retryButtonTitleAndEnable];
}]
startWith:[self retryButtonTitleAndEnable]] //3
switchToLatest] //4
takeUntil:[self rac_willDeallocSignal]] //5
subscribeNext:^(RACTuple *tuple) {
//6
@strongify(self);
NSString *title = tuple.first;
[self.retryButtton setTitle:title forState:UIControlStateNormal];
self.retryButtton.enabled = ((NSNumber *)tuple.second).boolValue;
} completed:^{
//5
NSLog(@"%s, pipeline completed", __PRETTY_FUNCTION__);
}];

//这里省略了点击 retryButtton 后具体要做的业务逻辑,同时也省略了验证按钮和验证码输入框的处理逻辑
}

对关键代码的描述如下:

  1. 设计一个 RACSignal,这个 Signal 每次发送的 Next 数据里面包含的就是按钮上要显示的文本信息和按钮的可用状态。从模块的角度来看,这个 Signal 的内部细节(倒计时逻辑),外部使用者是不需要知道的,所以后面我们会先看外层 Pipeline 的实现代码,然后再倒回来看这个 Signal 的内部逻辑。
  2. 每当 retryButtton 被点击的时候,要重新启动一个定时器,所以在这个 map 操作里面,调用 [self retryButtonTitleAndEnable] 得到一个 Signal,将这个 Signal 作为这个小任务的输出值。注意,因为这里 map 操作返回的是一个 Signal,形成了一个 Pipeline 的嵌套,所以可以预见到,在外层 Pipeline 的后续操作中,肯定是需要把这个内嵌的 Pipeline flatten 出来的。
  3. 在业务需求中,点击这个 retryButtton 后,要请求服务器发送一个验证码(省略了这部分的代码,如果要用 RAC 实现的话,是比较容易的),同时,当每次进入这个 UI 页面的时候,不需要用户主动点击这个 retryButtton 按钮,首先就要自动的请求服务器发送一个验证码,这种情况下,也要求 retryButtton 开始进入倒计时的模式,所以,用 startWith 操作,在外层 Pipeline 中先插入第一个 Next 数据,因为是同样的倒计时逻辑,所以这里也是调用 [self retryButtonTitleAndEnable] 得到内嵌的 Pipeline。
  4. 前面已经提过了,既然形成了 Pipeline 的嵌套,那肯定是要把这种嵌套解出来的,这里使用 switchToLatest 更合适。要注意区分一下和 flattenMap 的差异。
  5. Pipeline 的生命期控制,前面的例子中已经讲过这种技巧了,但是,这是写上这句,只是一个双保险。复杂的地方在于外层 Pipeline 有 switchToLatest 操作,这个 switchToLatest 后的 Signal 什么时候才会 Completed,请继续看至后面 13 中的解释。
  6. 这里是更新 retryButtton 的 title 和状态。
  7. 现在开始回到内层 Pipeline 的逻辑中去。用 Pipeline 的方式实现一个定时器,借助 RAC 提供的 interval 操作就行。每隔一秒都会在主线程上发送一个 Next。
  8. 7 里面的定时器上的 Next 数据,是当前的系统时间值,我们的需求里面并不需要这个时间值,所以这里直接 map 成 nil。
  9. RACSignal interval 要隔一秒后才会发出第一次,需要用 startWith 立刻发送一个,代表倒计时的初始值。
  10. 把倒计时要用到的数字放到一个数组里面,然后通过 numbers.rac_sequence.signal 语句转换成一个 Signal。
  11. 把前面 10 中得到的 Signal 和 9 中得到的 timer Signal,用 zipWith 组装起来。注意一点,这个通过 zipWith 组装出来的 Signal,会在 numbers.rac_sequence.signal Completed 的时候 Completed (这句话有点绕,需要结合 zipWith 的定义仔细体会一下)。
  12. 根据倒计时的数值,计算按钮上需要显示的 title 信息和按钮的状态。
  13. 前面 11 里面的 zipWith 操作,可以确保倒计时结束时,会触发 Completed,但是万一在倒计时的过程中,用户离开了当前页面,这个时候就需要通过 takeUntil 来触发 Completed。之所以在这里这么注重 Completed,是因为前面的 5 里面的 switchToLatest 操作,会 sends completed when both the receiver and the last sent signal complete。
  14. 通过 setNameWithFormat 和 logCompleted 打印一些 log 信息,方便调试,注意观察一下 Signal 的 Completed。

内存管理,自动释放 Pipeline

从前面的 code 中可以看到,好几个地方都在强调要触发 Completed,这完全就是为了正确的进行内存管理,避免内存泄露,避免手动的调用 disposal。takeUntil:self.rac_willDeallocSignal 是一种常用的手段。

还有一种典型的场景,也可以通过 takeUntil 操作来触发 Completed,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
- (UICollectionViewCell *)collectionView:(UICollectionView *)collectionView cellForItemAtIndexPath:(NSIndexPath *)indexPath {
WWKPhoto *photo = self.photos[indexPath.row];
XMCollectionImageViewCell *cell = [self.imageCollectionView dequeueReusableCellWithReuseIdentifier:NSStringFromClass([XMCollectionImageViewCell class]) forIndexPath:indexPath];
cell.imageView.image = photo.thumbnail;


@weakify(self);
[[[cell.longPressSignal map:^id(XMCollectionImageViewCell *viewCell) {
@strongify(self);
return [self.imageCollectionView indexPathForCell:viewCell];
}]
takeUntil:[cell rac_prepareForReuseSignal]]
subscribeNext:^(NSIndexPath *longPressIndexPath) {
@strongify(self);
UIAlertController *alert= [UIAlertController
alertControllerWithTitle:@"确定删除此图片"
message:nil
preferredStyle:UIAlertControllerStyleAlert];

UIAlertAction* ok = [UIAlertAction actionWithTitle:@"确定" style:UIAlertActionStyleDefault
handler:^(UIAlertAction * action){
@strongify(self);
[[self mutableArrayValueForKey:@keypath(self, photos)] removeObjectAtIndex:longPressIndexPath.row];
[self.imageCollectionView deleteItemsAtIndexPaths:@[longPressIndexPath]];
}];
UIAlertAction* cancel = [UIAlertAction actionWithTitle:@"取消" style:UIAlertActionStyleDefault
handler:^(UIAlertAction * action) {
[alert dismissViewControllerAnimated:YES completion:nil];
}];

[alert addAction:ok];
[alert addAction:cancel];

[self.containerViewController presentViewController:alert animated:YES completion:nil];
} completed:^{
}];

return cell;
}

这段代码也很简单,唯一需要特别注意的就是 takeUntil:[cell rac_prepareForReuseSignal] 这一句,因为 UICollectionViewCell 本身是有一套复用机制的,每个 cell 上的 Pipeline 的生命期和 cell 本身的生命期并不一致,所以不能依赖于 cell.rac_willDeallocSignal,而应该使用 [cell rac_prepareForReuseSignal] 这个更准确的 Signal。

讨论到这里,还可以得到一个结论,在设计 Signal 的时候,要尽量的让这个 Signal 能够发送 Completed 事件,这样才能够充分的利用 Pipeline 的自动释放功能,保持代码的简洁。RAC 框架里,有一些很常用的 Signal,其实它们的内部实现也是用类似 takeUntil 的操作做了这种处理,比如下面这些 Signal:

1
2
3
4
5
6
7
8
9
@interface UIControl (RACSignalSupport)
- (RACSignal *)rac_signalForControlEvents:(UIControlEvents)controlEvents;
@end

@interface UIGestureRecognizer (RACSignalSupport)
- (RACSignal *)rac_gestureSignal;
@end

RACObserve 宏定义

下面这个 Signal,则是没有 Completed 事件的,要求它的使用者来决定什么时候释放对应的 Pipeline:

1
2
3
@interface NSNotificationCenter (RACSupport)
- (RACSignal *)rac_addObserverForName:(NSString *)notificationName object:(id)object;
@end

通过 Swift 学习 CSP 并发模型

Posted on 2016-04-10 | Edited on 2016-05-04

前言

这篇文章的主要内容,是从 Go Concurrency Patterns 翻译过来的。

原文是介绍 Golang 里面的 CSP 并发模型(Communicating Sequential Processes),这里则是使用一个基于 Swift3.0 的库 Venice 编写的代码。

这篇文章的主要目的,并不是鼓励大家立刻就用 Swift 进行后端开发(至少不是目前这个阶段),但是,对于想尝试全栈开发的 iOS 工程师来说,则可以通过这篇文章入门学习 CSP 这种并发编程模型。

2016/04/08 update: 为了成功运行本文中的代码,需要安装 https://swift.org/builds/development/xcode/swift-DEVELOPMENT-SNAPSHOT-2016-03-24-a/swift-DEVELOPMENT-SNAPSHOT-2016-03-24-a-osx.pkg 这个版本的 swift。

2016/04/13 update: Venice 里面的 Channel 不再支持基于自定义运算符的读写操作,只能使用 func api。

为什么要讨论并发 (concurrency)

观察一下我们周围,能发现什么?

我们的世界里发生的事情,总是一步一步按顺序执行的吗?

或者说,发生在我们身边的所有的事件,是一个很复杂的组合体,里面充满了更独立、更小型的事件单元,这些单元之间,则是有各种各样的交互和组织关系。

其实就像后者描述的这样,顺序处理 (Sequential processing) 并不是完美的建模思路。

什么是并发?

并发是独立的计算任务的组合。

并发是一种软件的设计模式,用并发的思维模式,可以编写出更清晰的代码。

并发 (concurrency) 不是并行 (parallelism)

并发不是并行,但是可以在并行的基础上形成并发。

如果只有一个单核处理器(单线程模式),则谈不上并行,但是仍然可以写出并发的代码。

另一方面,如果一段代码已经按照并发的思路进行了设计,那它也是可以很容易的在多核处理器(多线程模式)中并行执行。

关于这个话题,更详细的讨论可以参看 Concurrency is not Parallelism

什么是好的代码架构

  • 要容易理解
  • 要容易使用
  • 要容易描述出设计意图
  • 不需要人人都是专家 (不应该总是出现大量threads,semaphores,locks,barriers等等高深的话题)

CSP 的历史

CSP 并不是新技术,Communicating Sequential Processes 是 Tony Hoare 在 1978 年就提出来的概念,甚至在更早的 1975 年,Edsger Dijkstra 的 Guarded Command Language 里面,也能看到 CSP 的影子。

还有其他的一些语言,也有类似的并发模型

  • Occam (May, 1983)
  • Erlang (Armstrong, 1986)
  • Newsqueak (Pike, 1988)
  • Concurrent ML (Reppy, 1993)
  • Alef (Winterbottom, 1995)
  • Limbo (Dorward, Pike, Winterbottom, 1996).

Venice / Golang 和 Erlang 的差异

Venice / Golang 通过 channels 来实现 CSP。

Erlang 是最接近于原始的 CSP 定义的,通过 name 进行通信,而非 channel。

它们的模型其实是一致的,只不过具体的表现形式有差异。

粗略来看相当于:writing to a file by name (process, Erlang) vs. writing to a file descriptor (channel, Venice / Golang).

CSP 的基本使用

这篇文章最主要的目的是讨论并发模式,为了避免陷入编程语言本身的各种细节,我们只会使用到 Swift 很少的语法特性。

从下面这个简单的 boring 函数开始

1
2
3
4
5
6
7
8
9
10
11
12
import Foundation

private func boring(msg: String) {
for i in 0...10 {
print("\(msg) \(i)")
usleep(100)
}
}

public func run01() {
boring("this is a boring func")
}

很容易想象到,这段代码的执行结果会是下买这个样子

1
2
3
4
5
6
7
8
9
10
11
this is a boring func 0
this is a boring func 1
this is a boring func 2
this is a boring func 3
this is a boring func 4
this is a boring func 5
this is a boring func 6
this is a boring func 7
this is a boring func 8
this is a boring func 9
this is a boring func 10

稍微改动一下

增加一点随机的延时,让 message 出现的时机不可预测 (延迟时间仍然控制在1秒之内)。并且让 boring 函数一直循环运行。

1
2
3
4
5
6
7
8
9
10
11
12
import Foundation

private func boring(msg: String) {
for i in 0..<Int.max {
print("\(msg) \(i)")
usleep(1000 * (arc4random_uniform(1000) + 1))
}
}

public func run02() {
boring("this is a less boring func")
}

进入正题

Venice 的 co 函数,传入的参数是一个函数,在 co 的内部会执行这个传入的函数,但是并不会等待这个函数执行结束,对于 co 的调用者来说,co 函数本身会立刻返回。co 函数其实是开启了一个新的协程 (轻量级线程) 来真正的执行传入的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import Foundation
import Venice

private func boring(msg: String) {
for i in 0..<Int.max {
print("\(msg) \(i)")
nap(for: Int(arc4random_uniform(1000) + 1).milliseconds)//sleep
}
}

public func run03() {
co {
boring("co a less boring func")
}

/**
//if do not want run03() finish, run the loop below
for i in 0..<Int.max {
yield
}
*/
print("run03() will return")
}

上面这段代码的运行结果如下

1
2
co a less boring func 0
run03() will return

可以看到,boring函数里面的循环只执行了一次,这是因为 co 函数是立刻返回的,紧接着,run03() 执行完 print 后也立刻返回,然后 run03() 的调用者 main 函数也就执行结束了 (进程结束),之前 co 启动的协程自然也就无法继续执行了。

如果想让 co 里面的协程一直运行下去,可以在 co 调用返回后,执行代码中的那段 for loop。

要注意的一点是,for loop 里面调用的 yield,是 Venice 引入的一种操作,意思是让出 CPU 给其他的协程。Golang 是不需要手动进行这种调用的,runtime 会自动的进行调度。

在 Venice 里面,如果是在 channel 上进行读写操作,读写的同时已经相当于调用过 yield 了,所以也不需要使用者再次显式的调用 yield。在后面的例子的,就会看到这种不需要手动调用 yield 的场景。

继续改动代码

调整代码成下面这个样子,在 co 调用后,让 run04() 所在的协程 sleep 一小段时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import Foundation
import Venice

private func boring(msg: String) {
for i in 0..<Int.max {
print("\(msg) \(i)")
nap(for: Int(arc4random_uniform(1000) + 1).milliseconds)
}
}

public func run04() {
co(boring("co a less boring func"))
print("I'm listening")
nap(for: 2.second)
print("You're boring; I'm leaving.")
}

这段代码的执行结果是下面这个样子的

1
2
3
4
5
6
7
co a less boring func 0
I'm listening
co a less boring func 1
co a less boring func 2
co a less boring func 3
co a less boring func 4
You're boring; I'm leaving.

nap()是Venice提供的sleep函数,它的内部,相当于调用了yield。

当main函数结束的时候,boring函数所在的协程也会结束。

协程 (coroutine)

协程是一段独立运行的代码集合,通过 co 函数来启动。

协程的系统开销是很小的 (比 thread 小很多),可以同时存在大量的协程 (具体到 Venice 底层使用的 libmill,可以同时运行 2000万个 协程,并且每秒可以进行 5000万次 协程上下文切换)。

协程不是线程。

一个程序里面,可以只运行一个线程,但是在这个线程里面,可以包含千万个协程。

可以把协程看成是轻量级的线程。

通讯 (communication)

在 run04() 里面,是不能看到在协程中运行的 boring 函数的运行结果的。

boring 函数仅仅是把 msg 打印到了终端上。

想在协程之间真正的传递数据,需要用到通讯 (communication)。

Channel

在 Venice 里面,两个协程之间,通过 Channel 进行通讯。

Channel 的基本操作就是下面这3个:

1
2
3
4
5
6
7
8
//声明、初始化
let channel = Channel<String>()

//在channel上发送数据
channel.send("ping")

//在channel上接收数据
let message = channel.receive()!

使用 Channel

用 channel 连接 boring 函数和 run05 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import Foundation
import Venice

private func boring(msg msg: String, channel: SendingChannel<String> ) {
for i in 0..<Int.max {
channel.send("\(msg) \(i)")
nap(for: Int(arc4random_uniform(1000) + 1).milliseconds)
}
}

public func run05() {
let channel = Channel<String>()

co {
boring(msg: "co a less boring func", channel: channel.sendingChannel)
}

for _ in 0..<5 {
print("You say: \(channel.receivingChannel.receive()!)")
}

print("You're boring; I'm leaving.")
channel.close()
}

运行结果如下

1
2
3
4
5
6
You say: co a less boring func 0
You say: co a less boring func 1
You say: co a less boring func 2
You say: co a less boring func 3
You say: co a less boring func 4
You're boring; I'm leaving.

同步 (Synchronization)

在 channel 上的读、写操作,是同步的、阻塞的。

run05() 执行到 channel.receivingChannel.receive()! 的时候,只有当 channel 里面有数据被写入的时候,这个读操作才会返回 (读到数据的时候才返回),否则 run05() 就会一直在这里等待,不会继续往下执行。

同样的,在 boring 函数里面,执行 channel.send(“(msg) (i)”) 这个写操作的时候,只有当 channel 里面为空的时候,数据才能被写到 channel 里面,channel.send(“(msg) (i)”) 才会返回,否则,send 操作也会阻塞在这里。

在通讯过程中,发送者和接收者,必须都分别完成他们的写和读动作,否则双方就会一直互相等待下去 (死锁)。

channel 在协程之间完成通讯的同时,也达到了同步的目的。

带缓冲的 channel

可以创建具有 buffer 的 channel。

这种 channel,当 buffer 还没有写满的时候,是没有前面描述的那种同步特性的。

buffering 有点类似 Erlang 语言里面的 mailboxes。

没有特殊理由的时候,不应该使用 buffered channel。

这篇文章后续的讨论,都不会使用 buffer。

Golang 哲学

Don’t communicate by sharing memory, share memory by communicating.

模式 (Patterns)

Generator 模式:通过函数返回一个 channel 给调用者

Channel 是一等公民,和 class、struct、closure 同等重要。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
channel.send("\(msg) \(i)")
nap(for: Int(arc4random_uniform(1000) + 1).milliseconds)
}
}

return channel.receivingChannel
}

public func run06() {
let receivingChannel = boring("co a less boring func")

for _ in 0..<5 {
print("You say: \(receivingChannel.receive()!)")
}

print("You're boring; I'm leaving.")
}

这段代码和前面的代码的运行结果,没有什么差别

1
2
3
4
5
6
You say: co a less boring func 0
You say: co a less boring func 1
You say: co a less boring func 2
You say: co a less boring func 3
You say: co a less boring func 4
You're boring; I'm leaving.

但是代码本身确有明显的变化,boring 函数返回一个 channel 给调用者,同时,在 boring 函数内部,通过 co 启动一个新的协程做具体的业务,并且通过刚才创建的 channel 把结果发送出去。

利用 channel 作为 service 的接口

boring 函数对外提供了一个 service,这个 service 运行在独立的协程里面,并且通过channel 把数据传递给 service 的使用者。

可以同时使用多个 service。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
channel.send("\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)")
nap(for: sleepTime)
}
}

return channel.receivingChannel
}


public func run07() {
let joe = boring("Joe")
let ann = boring("Ann")

for _ in 0..<5 {
print("\(joe.receive()!)")
print("\(ann.receive()!)")
}

print("You're both boring; I'm leaving.")
}

运行结果如下

1
2
3
4
5
6
7
8
9
10
11
Joe 0  (will sleep 996 ms)
Ann 0 (will sleep 681 ms)
Joe 1 (will sleep 173 ms)
Ann 1 (will sleep 147 ms)
Joe 2 (will sleep 750 ms)
Ann 2 (will sleep 374 ms)
Joe 3 (will sleep 318 ms)
Ann 3 (will sleep 705 ms)
Joe 4 (will sleep 126 ms)
Ann 4 (will sleep 828 ms)
You're both boring; I'm leaving.

多路复用 (Multiplexing)

前面 run07() 里面的代码,始终都是先从 joe 里面读取数据,然后再从 ann 里面读取。如果 ann 里面的数据早于 joe 里面的数据就发送了,由于 channel 的同步特性,ann channel 其实会阻塞在它的 send 操作上,直到 run07 从 joe 里面读取完数据后,ann 所在的协程才能继续运行。

为了改善这种情况,可以使用 fan-in 模式。不管是 joe 还是 ann,只要有数据准备好并且执行了 send 操作,都可以立刻读取到。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
channel.send("\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)")
nap(for: sleepTime)
}
}

return channel.receivingChannel
}

private func fanIn(input1 input1: ReceivingChannel<String>, input2: ReceivingChannel<String>) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
while true {
channel.send(input1.receive()!)
}
}

co {
while true {
channel.send(input2.receive()!)
}
}

return channel.receivingChannel
}

public func run08() {
let joe = boring("Joe")
let ann = boring("Ann")

let c = fanIn(input1: joe, input2: ann)

for _ in 0..<10 {
print("\(c.receive()!)")
}

print("You're both boring; I'm leaving.")
}

运行结果如下

1
2
3
4
5
6
7
8
9
10
11
Joe 0  (will sleep 75 ms)
Ann 0 (will sleep 473 ms)
Joe 1 (will sleep 57 ms)
Joe 2 (will sleep 219 ms)
Joe 3 (will sleep 20 ms)
Joe 4 (will sleep 723 ms)
Ann 1 (will sleep 712 ms)
Joe 5 (will sleep 377 ms)
Ann 2 (will sleep 431 ms)
Joe 6 (will sleep 228 ms)
You're both boring; I'm leaving.

回复消息 (Restoring sequencing)

前面 run08 里面的 fan-in 模式,boring 函数只负责 send 消息,并不需要消息的接收者做一个答复。如果需要,可以像下面这样修改代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import Foundation
import Venice

private struct Message {
let str: String
let wait: Channel<Bool>
}

private let waitForIt = Channel<Bool>() // Shared between all messages

private func boring(msg: String) -> ReceivingChannel<Message> {
let channel = Channel<Message>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds

let message = Message(str: "\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)", wait: waitForIt)

channel.send(message)
nap(for: sleepTime)

waitForIt.receive()!
}
}

return channel.receivingChannel
}

private func fanIn(input1 input1: ReceivingChannel<Message>, input2: ReceivingChannel<Message>) -> ReceivingChannel<Message> {
let channel = Channel<Message>()

co {
while true {
channel.send(input1.receive()!)
}
}

co {
while true {
channel.send(input2.receive()!)
}
}

return channel.receivingChannel
}

public func run09() {
let joe = boring("Joe")
let ann = boring("Ann")

let c = fanIn(input1: joe, input2: ann)

for _ in 0..<5 {
let message1 = c.receive()!
print("\(message1.str)")
message1.wait.send(true)

let message2 = c.receive()!
print("\(message2.str)")
message2.wait.send(true)
}

print("You're both boring; I'm leaving.")
}

运行结果会是下面这个样子,并没有明显的区别

1
2
3
4
5
6
7
8
9
10
11
Joe 0  (will sleep 551 ms)
Ann 0 (will sleep 53 ms)
Ann 1 (will sleep 543 ms)
Joe 1 (will sleep 412 ms)
Ann 2 (will sleep 847 ms)
Joe 2 (will sleep 46 ms)
Joe 3 (will sleep 274 ms)
Joe 4 (will sleep 69 ms)
Joe 5 (will sleep 202 ms)
Ann 3 (will sleep 962 ms)
You're both boring; I'm leaving.

Select

前面介绍的多路复用技术,是通过启动多个协程实现的,每个 channel 对应一个协程。

另一种更常用的办法,是使用 select 操作,在一个协程里面同时读写多个 channel。

可以用 select 操作重新实现一遍 fan-in 模式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
channel.send("\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)")
nap(for: sleepTime)
}

}

return channel.receivingChannel
}

private func fanIn(input1 input1: ReceivingChannel<String>, input2: ReceivingChannel<String>) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
while true {
select { when in
when.receive(from: input1) { value in
//print("received \(value)")
channel.send(value)
}

when.receive(from: input2) { value in
channel.send(value)
}

when.otherwise {
//print("default case")
}
}
}
}

return channel.receivingChannel
}

public func run10() {
let joe = boring("Joe")
let ann = boring("Ann")

let c = fanIn(input1: joe, input2: ann)

for _ in 0..<10 {
print("\(c.receive()!)")
}

print("You're both boring; I'm leaving.")
}

运行结果和之前的 fan-in 没有区别

1
2
3
4
5
6
7
8
9
10
11
Ann 0  (will sleep 816 ms)
Joe 0 (will sleep 252 ms)
Joe 1 (will sleep 756 ms)
Ann 1 (will sleep 879 ms)
Joe 2 (will sleep 157 ms)
Joe 3 (will sleep 578 ms)
Ann 2 (will sleep 700 ms)
Joe 4 (will sleep 499 ms)
Joe 5 (will sleep 352 ms)
Ann 3 (will sleep 642 ms)
You're both boring; I'm leaving.

这里用的 select 操作,和 Linux / Unix 里面的 select、poll、epoll,都是类似的,只不过前者监听的是 channel,后者监听的是 fd

在 Select 的基础上实现超时机制 (Timeout)

定时器是基于 channel 实现出来的,当达到定时时间的时候,定时器 channel 上会发送一个消息。

定时器可以放在 select 操作的里面

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
nap(for: sleepTime)

channel.send("\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)")
}
}

return channel.receivingChannel
}


public func run11() {
let joe = boring("Joe")

var done = false
while !done {
select { when in
when.receive(from: joe) { value in
print("\(value)")
}

when.timeout(800.millisecond.fromNow()) {
print("You are too slow.")
done = true
}
}
}

print("You're boring; I'm leaving.")
}

运行结果是下面这个样子

1
2
3
4
5
6
Joe 0  (will sleep 48 ms)
Joe 1 (will sleep 706 ms)
Joe 2 (will sleep 747 ms)
Joe 3 (will sleep 304 ms)
You are too slow.
You're boring; I'm leaving.

Select 操作的整体超时

前面的 run11,是在每次进入 select 的时候,设置了一个超时 channel。

也可以在 while 循环的外面,设置一个整体的超时 channel,像下面这样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import Foundation
import Venice

private func boring(msg: String) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
for i in 0..<Int.max {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
nap(for: sleepTime)

channel.send("\(msg) \(i) (will sleep \(Int(sleepTime * 1000)) ms)")
}
}

return channel.receivingChannel
}

public func run12() {
let joe = boring("Joe")
let timeout = Timer(timingOut: 5.second.fromNow()).channel

var done = false
while !done {
select { when in
when.receive(from: joe) { value in
print("\(value)")
}

when.receive(from: timeout) { _ in
print("You are too slow.")
done = true
}
}
}

print("You're boring; I'm leaving.")
}

运行结果如下

1
2
3
4
5
6
7
8
9
10
11
12
Joe 0  (will sleep 586 ms)
Joe 1 (will sleep 226 ms)
Joe 2 (will sleep 297 ms)
Joe 3 (will sleep 850 ms)
Joe 4 (will sleep 442 ms)
Joe 5 (will sleep 525 ms)
Joe 6 (will sleep 730 ms)
Joe 7 (will sleep 227 ms)
Joe 8 (will sleep 630 ms)
Joe 9 (will sleep 411 ms)
You are too slow.
You're boring; I'm leaving.

取消 (quit channel)

boring 函数的调用者,可以主动的让 boring 内部的协程停止工作,也是通过 channel 来实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import Foundation
import Venice

private func boring(msg msg: String, quit: ReceivingChannel<Bool>) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
forSelect { when, done in
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
nap(for: sleepTime)

when.send("\(msg), and will sleep \(Int(sleepTime * 1000)) ms", to: channel) {
//print("sent value")
}
when.receive(from: quit) { _ in
done()
}
}

channel.close()
}

return channel.receivingChannel
}

public func run13() {
let quit = Channel<Bool>()
let joe = boring(msg: "Joe", quit: quit.receivingChannel)

for _ in 0..<Int64(arc4random_uniform(10) + 1) {
print("\(joe.receive()!)")
}

quit.send(true)

print("You're boring; I'm leaving.")
}

运行结果仍然是类似的

1
2
3
4
5
6
7
8
9
Joe, and will sleep 154 ms
Joe, and will sleep 390 ms
Joe, and will sleep 133 ms
Joe, and will sleep 520 ms
Joe, and will sleep 752 ms
Joe, and will sleep 482 ms
Joe, and will sleep 47 ms
Joe, and will sleep 359 ms
You're boring; I'm leaving.

在 quit channel 上接收消息

接着上面的例子,当 run13 向 quit channel 发送 true 的时候,run13 怎样才能知道 boring 函数成功的结束了自己的运行呢?让 boring 告诉它的调用者就行,同样,还是通过 quit channel。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import Foundation
import Venice

private func cleanup() {
print("Here, do clean up")
}

private func boring(msg msg: String, quit: Channel<String>) -> ReceivingChannel<String> {
let channel = Channel<String>()

co {
forSelect { when, done in
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
nap(for: sleepTime)

when.send("\(msg), and will sleep \(Int(sleepTime * 1000)) ms", to: channel) {
//print("sent value")
}
when.receive(from: quit) { _ in
cleanup()
quit.send("See you!")
done()
}
}

channel.close()
}

return channel.receivingChannel
}

public func run14() {
let quit = Channel<String>()
let joe = boring(msg: "Joe", quit: quit)

for _ in 0..<Int64(arc4random_uniform(10) + 1) {
print("\(joe.receive()!)")
}

quit.send("Bye")
print("Joe says: \(quit.receive()!)")

print("You're boring; I'm leaving.")
}

现在运行结果会变成下面这个样子

1
2
3
4
5
6
7
8
9
10
Joe, and will sleep 220 ms
Joe, and will sleep 736 ms
Joe, and will sleep 308 ms
Joe, and will sleep 858 ms
Joe, and will sleep 527 ms
Joe, and will sleep 163 ms
Joe, and will sleep 844 ms
Here, do clean up
Joe says: See you!
You're boring; I'm leaving.

Daisy-chain

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import Foundation
import Venice

private func f(left left: Channel<Int>, right: Channel<Int>) {
left.send(right.receive()! + 1)
}

public func run15() {
let leftMost = Channel<Int>()

var right = leftMost
var left = leftMost

for _ in 0..<10000 {
right = Channel<Int>()
co {
f(left: left, right: right)
}
left = right
}

co {
right.send(1)
}

print("Joe says: \(leftMost.receive()!)")

print("You're boring; I'm leaving.")
}

运行结果如下

1
2
Joe says: 10001
You're boring; I'm leaving.

系统软件 (Systems Software)

让我们具体看一下 CSP 这种并发模型,是如何用在系统软件的开发中的。

例子:Google Search

问: Google search 需要做什么事情?

答: 输入一个搜索关键字 (query),得到一组搜索结果 (和一些广告)。

问: 怎样获取这样的一组搜索结果?

答: 把搜索关键字分别发送给 Web search service,Image search service,YouTube search service,Maps search service,News search service 等等,然后把它们返回的结果再组合到一起。

那么,怎样做呢?

模拟各种 search service

模拟 3 个 search service,每次执行 search 的时候,随机延时一小段时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import Foundation
import Venice

public typealias GoogleSearchResult = String

internal func fakeSearch(kind: String) -> (String) -> GoogleSearchResult {
func search(query: String) -> GoogleSearchResult {
let sleepTime = Int(arc4random_uniform(1000) + 1).milliseconds
//print("-->\(kind) search use time: \(Int(sleepTime * 1000)) ms")
nap(for: sleepTime)

return GoogleSearchResult("\(kind) result for \(query), use time: \(Int(sleepTime * 1000)) ms")
}

return search
}

let web = fakeSearch("web")
let image = fakeSearch("image")
let video = fakeSearch("video")



//some util
internal func time(desc: String, function: ()->()) {
let start : UInt64 = mach_absolute_time()
function()
let duration : UInt64 = mach_absolute_time() - start

var info : mach_timebase_info = mach_timebase_info(numer: 0, denom: 0)
mach_timebase_info(&info)

let total = (duration * UInt64(info.numer) / UInt64(info.denom)) / NSEC_PER_MSEC
print("\(desc)\(total) ms.")
}


protocol GoogleSearchResultDebugAble {
func log()
}

extension GoogleSearchResult: GoogleSearchResultDebugAble {
func log() {
print(" \(self)")
}
}

internal extension Array where Element: GoogleSearchResultDebugAble {
internal func log() {
print("google search result is:")
for searchResult in self {
searchResult.log()
}
}
}

Google Search 1.0

google 函数有一个输入参数,返回一个数组。

google 内部按照顺序依次调用 web、image、video search service,然后把它们的结果组装在一个数组内。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import Foundation
import Venice

private func google(query: String) -> Array<GoogleSearchResult> {
var results = Array<GoogleSearchResult>()

results.append(web(query))
results.append(image(query))
results.append(video(query))

return results
}

public func run17() {
var result: Array<GoogleSearchResult>?

time("google search v1.0, use time: ") { () -> () in
result = google("CSP")
}

result?.log()
}

运行结果是下面这个样子

1
2
3
4
5
google search v1.0, use time: 1237 ms.
google search result is:
web result for CSP, use time: 743 ms
image result for CSP, use time: 240 ms
video result for CSP, use time: 243 ms

Google Search 2.0

并发调用 web、image、video search service,然后等待它们的返回结果。

不使用锁机制,不使用条件状态变量,不使用 callback。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import Foundation
import Venice

private func google(query: String) -> Array<GoogleSearchResult> {
let channel = Channel<GoogleSearchResult>()

co(channel.send(web(query)))
co(channel.send(image(query)))
co(channel.send(video(query)))

var results = Array<GoogleSearchResult>()
for _ in 0..<3 {
results.append(channel.receive()!)
}
return results
}

public func run18() {
var result: Array<GoogleSearchResult>?

time("google search v2.0, use time: ") { () -> () in
result = google("CSP")
}

result?.log()
}

运行结果如下

1
2
3
4
5
google search v2.0, use time: 871 ms.
google search result is:
image result for CSP, use time: 40 ms
video result for CSP, use time: 307 ms
web result for CSP, use time: 864 ms

很明显,并发执行的效果比顺序执行的效果好很多。

Google Search 2.1

还可以加上超时机制,如果某个 search service 执行的时间太长,就不等待它的返回结果。

不使用锁机制,不使用条件状态变量,不使用 callback。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import Foundation
import Venice

private func google(query: String) -> Array<GoogleSearchResult> {
let channel = Channel<GoogleSearchResult>()

co(channel.send(web(query)))
co(channel.send(image(query)))
co(channel.send(video(query)))

var results = Array<GoogleSearchResult>()

let timeout = Timer(timingOut: 800.milliseconds.fromNow()).channel

var done = false
for _ in 0..<3 {
if done == true {
break
}

select { when in
when.receive(from: channel) { value in
results.append(value)
}

when.receive(from: timeout) { _ in
print("timeout.")
done = true
}
}
}
return results
}

public func run19() {
var result: Array<GoogleSearchResult>?

time("google search v2.1, use time: ") { () -> () in
result = google("CSP")
}

result?.log()
}

如果看到下面这种形式的运行结果,则说明是触发了超时的条件

1
2
3
4
5
timeout.
google search v2.1, use time: 810 ms.
google search result is:
web result for CSP, use time: 341 ms
video result for CSP, use time: 537 ms

避免超时

问:怎样才能避免丢弃响应速度更慢的服务器返回的搜索结果?

答:使用 Replicate 策略。同时向多个同类型的 search service 发送请求,使用第一个返回来的查询结果。

1
2
3
4
5
6
7
8
9
private func first(query query: String, replicas: ((String) -> GoogleSearchResult)...) -> GoogleSearchResult {
let channel = Channel<GoogleSearchResult>()

for search in replicas {
co(channel.send(search(query)))
}

return channel.receive()!
}

Google Search 3.0

仍然不使用锁机制,不使用条件状态变量,不使用 callback。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import Foundation
import Venice


let web1 = fakeSearch("web1")
let web2 = fakeSearch("web2")
let image1 = fakeSearch("image1")
let image2 = fakeSearch("image2")
let video1 = fakeSearch("video1")
let video2 = fakeSearch("video2")


private func first(query query: String, replicas: ((String) -> GoogleSearchResult)...) -> GoogleSearchResult {
let channel = Channel<GoogleSearchResult>()

for search in replicas {
co(channel.send(search(query)))
}

return channel.receive()!
}


private func google(query: String) -> Array<GoogleSearchResult> {
let channel = Channel<GoogleSearchResult>()

co {
channel.send(first(query: query, replicas: web1, web2))
}

co {
channel.send(first(query: query, replicas: image1, image2))
}

co {
channel.send(first(query: query, replicas: video1, video2))
}

var results = Array<GoogleSearchResult>()

let timeout = Timer(timingOut: 1000.milliseconds.fromNow()).channel

var done = false
for _ in 0..<3 {
if done == true {
break
}

select { when in
when.receive(from: channel) { value in
//print("receive \(value)")
results.append(value)
}

when.receive(from: timeout) { _ in
print("timeout.")
done = true
}
}
}
return results
}


public func run20() {
var result: Array<GoogleSearchResult>?

time("google search v3.0, use time: ") { () -> () in
result = google("CSP")
}

result?.log()
}

最终的运行结果如下

1
2
3
4
5
google search v3.0, use time: 506 ms.
google search result is:
web1 result for CSP, use time: 433 ms
image1 result for CSP, use time: 434 ms
video2 result for CSP, use time: 499 ms

不要过度使用

coroutine 和 channel 是一种很好的设计思想,可以解决某些类型的问题。

但是,有时我们仍然会面对一些需要用传统思路来解决的小问题,也就是基于锁机制 (共享内存)。

这两种不同的技术思路,并不冲突,它们是可以共存的。

正确的工具做正确的事情。

后记

这篇文章里面的 demo code 位于 https://github.com/fengjian0106/CSP-tutorial.git

如何实现Minuum和Fleksy输入法中的智能纠错功能

Posted on 2015-02-11 | Edited on 2016-04-10

输入法的产品一直在持续技术演进,最近的一项工作,是实现了一个类似 Minuum 和 Fleksy 这两款输入法中的模糊输入功能的单词智能纠错引擎,前后尝试过4种不同的算法思路,最终才找到适合手机的解决方案,特此记录一下。

  1. 基于贝叶斯推断的,主要线索是 http://norvig.com/spell-correct.html
  2. 基于Levenshtein自动机的,主要线索是 http://blog.notdot.net/2010/07/Damn-Cool-Algorithms-Levenshtein-Automata
  3. 一种基于预处理词库的改进算法,主要线索是 http://blog.faroo.com/2012/06/07/improved-edit-distance-based-spelling-correction/
  4. 使用机器学习中的kNN算法,主要线索是 http://minuum.com/model-your-users-algorithms-behind-the-minuum-keyboard/ 和 http://www.zhihu.com/question/27567987

前三个方案,都是用传统的算法思路,基于编辑距离来实现模糊匹配,但是在手机上无法满足输入法的性能需求,尤其是查询速度这一点,而且也无法做到和Minuum或Fleksy类似的纠错效果。最终的第4个方案,则是彻底更换了思路,直接用机器学习中的 kNN 算法,把字符串映射到更抽象的几何空间中,也就是所谓的特征向量,进行纯粹的数学计算。学习和研究的过程中,是直接用Python做的代码原型验证,放到github上了,有兴趣的朋友可以看看 https://github.com/fengjian0106/Minuum-Fleksy-Fuzzy-Matching

手机超声波通信

Posted on 2014-12-09 | Edited on 2016-04-10

这其实是两年前的一个产品,使用场景类似于基于苹果推出的iBeacon实现室内定位,利用一个小巧的超声波硬件设备,周期性的广播信标信号,手机直接使用麦克风接收这个音频信号并且解码,得到信标信号中的有效数据,最后再根据这个数据进行室内定位的算法逻辑处理。声波通信部分,是技术基础,开发难度比较大,在当时的情况下,对于产品来说,是一个技术壁垒。现在已经过去两年时间了,而且其实出于商业层面的原因,团队也早已放弃这个产品转战其他方向了,所以我还是准备把其中的一些技术细节记录下来。

基于使用场景而提出的一些技术参数指标

  1. 只能使用18kHz~21kHz这个范围内的音频信号(人耳要尽量听不见这个声音,手机要能够接收到这个声音,所以只能限制在这个频率范围内)
  2. 通信距离最远要达到10米
  3. 还有一些是硬件设备的性能要求,和通信协议无关,不是这次的讨论重点,可以忽略

关键技术点

  1. 为了减小数据处理过程中的延迟现象,采取的是实时的进行音频数据采集和数据处理,而不是像某些类似SDK中使用的技术那样,先录音得到一小段音频文件,然后再进行数据处理。具体到iOS平台上,就是使用audio unit框架来搭建PCM音频数据的采集管道,在管道的最后一个节点上,对得到的PCM数据再进行进一步的处理。
  2. 仍然是为了降低延迟,使用手机的DSP硬件来进行快速傅里叶变换,具体到iOS上,就是用了Accelerate.framework框架中的相关函数。
  3. 为了提高数据传输和解码的成功率,在2FSK的基础上,做了一些调整(magic trick)。

前两点没有太多可说的,对应的开发文档中有很详尽的描述,只不过稍微偏底层一些,只要静下心来老老实实的啃啃文档,还是可以搞定的。第3点中用到的技巧,可能不太常见,我会详细解释一下。

基于标准的2FSK,假如约定用18kHz的音频信号表示二进制的0,用19kHz的音频信号表示二进制的1,同时约定每一个bit持续的发送时间为50ms,假设要发送一个8bit的二进制数据0b11001010(忽略同步和校验部分的bit),对于发送端来说,代码逻辑其实比较简单,只需要让特定频率的引号信号发送特定的时间就行了。但是对于接收端来说,代码就很困难了,虽然用的是2FSK,但是并没有专用的硬件来完成调制解调过程,所以要完全用代码来模拟整个过程,这个里面就涉及到了傅里叶变换、滤波等大量的数字信号处理里面的内容,这些处理完后,才会真正的进入到通信协议栈里面处理二进制的0和1。

如果按照标准的2FSK方式,接收端的代码必须用定时器记录0或1(18kHz或19kHz)持续的时间,然后用这个时间值和50ms做比较,才能判断出这一部分音频片段对应了多少个连续的0或1。而且这仅仅是理论上对解码算法的描述,实际情况中,发送端维持的每个bit位的持续时间是50ms,进入空气中后,会和其他的各种各样的音频信号混杂在一起,然后才进入接收端进行变换和滤波等操作,这个时候,是很难保证每个bit位仍然能够维持在50ms的(即便有50ms,代码仍然会很难编写),正式因为这些原因,成功解码数据的概率并不高。为了改善这种情况,对2FSK做了一些调整,这里借鉴了数字电路里面的一些概念和技术。在数电的串行接口电路中,使用高低电平来表示二进制的1和0,根据传输比特率的约定,每个电平会持续特定的时间,这类似于我们的音频系统中约定的每个bit持续发送50ms,这通常称为电平检测(根据电平值持续的时间进行检测),还有另外一种称为边缘检测的技术,它不依赖于每个电平值持续的时间,而是依赖于电平值的变化事件,比如电平从高变为低(从1变为0)。这里就是使用了边缘检测这种方式来处理音频信号,接收端需要关注的,是音频信号频率值的变化,而不是每个频率值持续的时间。为了实现这种方式,还需要对之前的约定做一些调整,调整为18kHz和19kHz的音频信号都可以表示二进制的0,20kHz和21kHz的音频信号都可以表示二进制的1,如果是为了表示两个连续的0,那么就应该是18kHz的音频信号持续50ms,然后变成19kHz的音频信号持续50ms(或者先发送19kHz的,再发送18kHz的),对于连续的1,也采用类似的策略。举个例子,对于二进制数据0b11101000,转换成频率值后,可能就会是这样的一组值 [20kHz,21kHz,20kHz,18kHz,21kHz,18kHz,19kHz,18kHz],因为每一个bit对应的频率值都会发生变化,那么接收端就可以忽略每个bit持续的时间,只需要检测出每一次频率值发生变化就行了,每一次变化后得到的数值,就可以对应到当前的bit位的二进制值。用了这种调制解调的思路后,接收端的代码,写起来就很容易了:]

12
FengJian

FengJian

14 posts
22 tags
RSS
GitHub
© 2018 FengJian
Powered by Hexo v3.7.1
|
Theme — NexT.Pisces v6.3.0