Zookeeper 负载均衡

负载均衡是一种手段,用来把对某种资源的访问分摊给不同的设备,从而减轻单点的压力。

架构图

clipboard.png

图中左侧为ZooKeeper集群,右侧上方为工作服务器,下面为客户端。每台工作服务器在启动时都会去zookeeper的servers节点下注册临时节点,每台客户端在启动时都会去servers节点下取得所有可用的工作服务器列表,并通过一定的负载均衡算法计算得出一台工作服务器,并与之建立网络连接。网络连接我们采用开源框架netty。

流程图

负载均衡客户端流程

clipboard.png

服务端主体流程

clipboard.png

类图

Server端核心类

clipboard.png
每个服务端对应一个Server接口,ServiceImpl是服务端的实现类。把服务端启动时的注册过程抽出为一个接口RegistProvider,并给予一个默认实现DefaultRegistProvider,它将用到一个上下文的类ZooKeeperRegistContext。我们的服务端是给予Netty的,它需要ServerHandler来处理与客户端之间的连接,当有客户端建立或失去连接时,我们都需要去修改当前服务器的负载信息,我们把修改负载信息的过程也抽出为一个接口BalanceUpdateProvider,并且给予了一个默认实现DefaultBalanceUpdateProvider。ServerRunner是调度类,负责调度我们的Server。

Client端核心类

clipboard.png
每个客户端都需要实现一个Client接口,ClientImpl是实现,Client需要ClientHandler来处理与服务器之前的通讯,同时它需要BalanceProvider为它提供负载均衡的算法。BalanceProvider是接口,它有2个实现类,一个是抽象的实现AbstractBalanceProvider,一个是默认的实现DefaultBalanceProvider。ServerData是服务端和客户端共用的一个类,服务端会把自己的基本信息,包括负载信息,打包成ServerData并写入到zookeeper中,客户端在计算负载的时候需要到zookeeper中拿到ServerData,并取得服务端的地址和负载信息。ClientRunner是客户端的调度类,负责启动客户端。

代码实现

先是Server端的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
public class ServerData implements Serializable,Comparable<ServerData> {

private static final long serialVersionUID = -8892569870391530906L;


private Integer balance;
private String host;
private Integer port;


public Integer getBalance() {
return balance;
}
public void setBalance(Integer balance) {
this.balance = balance;
}
public String getHost() {
return host;
}
public void setHost(String host) {
this.host = host;
}
public Integer getPort() {
return port;
}
public void setPort(Integer port) {
this.port = port;
}


@Override
public String toString() {
return "ServerData [balance=" + balance + ", host=" + host + ", port="
+ port + "]";
}
public int compareTo(ServerData o) {
return this.getBalance().compareTo(o.getBalance());
}

}
public interface Server {

public void bind();

}
public class ServerImpl implements Server {

private EventLoopGroup bossGroup = new NioEventLoopGroup();
private EventLoopGroup workGroup = new NioEventLoopGroup();
private ServerBootstrap bootStrap = new ServerBootstrap();
private ChannelFuture cf;
private String zkAddress;
private String serversPath;
private String currentServerPath;
private ServerData sd;

private volatile boolean binded = false;

private final ZkClient zc;
private final RegistProvider registProvider;

private static final Integer SESSION_TIME_OUT = 10000;
private static final Integer CONNECT_TIME_OUT = 10000;



public String getCurrentServerPath() {
return currentServerPath;
}

public String getZkAddress() {
return zkAddress;
}

public String getServersPath() {
return serversPath;
}

public ServerData getSd() {
return sd;
}

public void setSd(ServerData sd) {
this.sd = sd;
}

public ServerImpl(String zkAddress, String serversPath, ServerData sd){
this.zkAddress = zkAddress;
this.serversPath = serversPath;
this.zc = new ZkClient(this.zkAddress,SESSION_TIME_OUT,CONNECT_TIME_OUT,new SerializableSerializer());
this.registProvider = new DefaultRegistProvider();
this.sd = sd;
}

//初始化服务端
private void initRunning() throws Exception {
String mePath = serversPath.concat("/").concat(sd.getPort().toString());
//注册到zookeeper
registProvider.regist(new ZooKeeperRegistContext(mePath,zc,sd));
currentServerPath = mePath;
}

public void bind() {

if (binded){
return;
}

System.out.println(sd.getPort()+":binding...");

try {
initRunning();
} catch (Exception e) {
e.printStackTrace();
return;
}

bootStrap.group(bossGroup,workGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 1024)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new ServerHandler(new DefaultBalanceUpdateProvider(currentServerPath,zc)));
}
});

try {
cf = bootStrap.bind(sd.getPort()).sync();
binded = true;
System.out.println(sd.getPort()+":binded...");
cf.channel().closeFuture().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}finally{
bossGroup.shutdownGracefully();
workGroup.shutdownGracefully();
}

}

}
public interface RegistProvider {

public void regist(Object context) throws Exception;

public void unRegist(Object context) throws Exception;

}
public class DefaultRegistProvider implements RegistProvider {

// 在zookeeper中创建临时节点并写入信息
public void regist(Object context) throws Exception {

// Server在zookeeper中注册自己,需要在zookeeper的目标节点上创建临时节点并写入自己
// 将需要的以下3个信息包装成上下文传入
// 1:path
// 2:zkClient
// 3:serverData

ZooKeeperRegistContext registContext = (ZooKeeperRegistContext) context;
String path = registContext.getPath();
ZkClient zc = registContext.getZkClient();

try {
zc.createEphemeral(path, registContext.getData());
} catch (ZkNoNodeException e) {
String parentDir = path.substring(0, path.lastIndexOf('/'));
zc.createPersistent(parentDir, true);
regist(registContext);
}
}

public void unRegist(Object context) throws Exception {
return;
}

}
public class ZooKeeperRegistContext {

private String path;
private ZkClient zkClient;
private Object data;

public ZooKeeperRegistContext(String path, ZkClient zkClient, Object data) {
super();
this.path = path;
this.zkClient = zkClient;
this.data = data;
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public ZkClient getZkClient() {
return zkClient;
}
public void setZkClient(ZkClient zkClient) {
this.zkClient = zkClient;
}
public Object getData() {
return data;
}
public void setData(Object data) {
this.data = data;
}

}
/**
* 处理服务端与客户端之间的通信
*/
public class ServerHandler extends ChannelHandlerAdapter{

private final BalanceUpdateProvider balanceUpdater;
private static final Integer BALANCE_STEP = 1;


public ServerHandler(BalanceUpdateProvider balanceUpdater){
this.balanceUpdater = balanceUpdater;

}

public BalanceUpdateProvider getBalanceUpdater() {
return balanceUpdater;
}

// 建立连接时增加负载
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
System.out.println("one client connect...");
balanceUpdater.addBalance(BALANCE_STEP);
}

// 断开连接时减少负载
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
balanceUpdater.reduceBalance(BALANCE_STEP);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}


}
public interface BalanceUpdateProvider {

// 增加负载
public boolean addBalance(Integer step);

// 减少负载
public boolean reduceBalance(Integer step);

}
public class DefaultBalanceUpdateProvider implements BalanceUpdateProvider {

private String serverPath;
private ZkClient zc;

public DefaultBalanceUpdateProvider(String serverPath, ZkClient zkClient) {
this.serverPath = serverPath;
this.zc = zkClient;
}

public boolean addBalance(Integer step) {
Stat stat = new Stat();
ServerData sd;

// 增加负载:读取服务器的信息ServerData,增加负载,并写回zookeeper
while (true) {
try {
sd = zc.readData(this.serverPath, stat);
sd.setBalance(sd.getBalance() + step);
// 带上版本,因为可能有其他客户端连接到服务器修改了负载
zc.writeData(this.serverPath, sd, stat.getVersion());
return true;
} catch (ZkBadVersionException e) {
// ignore
} catch (Exception e) {
return false;
}
}

}

public boolean reduceBalance(Integer step) {
Stat stat = new Stat();
ServerData sd;

while (true) {
try {
sd = zc.readData(this.serverPath, stat);
final Integer currBalance = sd.getBalance();
sd.setBalance(currBalance>step?currBalance-step:0);
zc.writeData(this.serverPath, sd, stat.getVersion());
return true;
} catch (ZkBadVersionException e) {
// ignore
} catch (Exception e) {
return false;
}
}
}

}
/**
* 用于测试,负责启动Work Server
*/
public class ServerRunner {

private static final int SERVER_QTY = 2;
private static final String ZOOKEEPER_SERVER = "192.168.1.105:2181";
private static final String SERVERS_PATH = "/servers";

public static void main(String[] args) {

List<Thread> threadList = new ArrayList<Thread>();

for(int i=0; i<SERVER_QTY; i++){

final Integer count = i;
Thread thread = new Thread(new Runnable() {

public void run() {
ServerData sd = new ServerData();
sd.setBalance(0);
sd.setHost("127.0.0.1");
sd.setPort(6000+count);
Server server = new ServerImpl(ZOOKEEPER_SERVER,SERVERS_PATH,sd);
server.bind();
}
});
threadList.add(thread);
thread.start();
}

for (int i=0; i<threadList.size(); i++){
try {
threadList.get(i).join();
} catch (InterruptedException ignore) {
//
}

}


}

}

再看Client端的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
public interface Client {

// 连接服务器
public void connect() throws Exception;
// 断开服务器
public void disConnect() throws Exception;

}
public class ClientImpl implements Client {

private final BalanceProvider<ServerData> provider;
private EventLoopGroup group = null;
private Channel channel = null;

private final Logger log = LoggerFactory.getLogger(getClass());

public ClientImpl(BalanceProvider<ServerData> provider) {
this.provider = provider;
}

public BalanceProvider<ServerData> getProvider() {
return provider;
}

public void connect(){

try{

ServerData serverData = provider.getBalanceItem(); // 获取负载最小的服务器信息,并与之建立连接

System.out.println("connecting to "+serverData.getHost()+":"+serverData.getPort()+", it's balance:"+serverData.getBalance());

group = new NioEventLoopGroup();
Bootstrap b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new ClientHandler());
}
});
ChannelFuture f = b.connect(serverData.getHost(),serverData.getPort()).syncUninterruptibly();
channel = f.channel();

System.out.println("started success!");

}catch(Exception e){

System.out.println("连接异常:"+e.getMessage());

}

}

public void disConnect(){

try{

if (channel!=null)
channel.close().syncUninterruptibly();

group.shutdownGracefully();
group = null;

log.debug("disconnected!");

}catch(Exception e){

log.error(e.getMessage());

}
}

}
public class ClientHandler extends ChannelHandlerAdapter {

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
// Close the connection when an exception is raised.
cause.printStackTrace();
ctx.close();
}
}
public interface BalanceProvider<T> {

public T getBalanceItem();

}
public abstract class AbstractBalanceProvider<T> implements BalanceProvider<T> {

protected abstract T balanceAlgorithm(List<T> items);
protected abstract List<T> getBalanceItems();

public T getBalanceItem(){
return balanceAlgorithm(getBalanceItems());
}

}
public class DefaultBalanceProvider extends AbstractBalanceProvider<ServerData> {

private final String zkServer; // zookeeper服务器地址
private final String serversPath; // servers节点路径
private final ZkClient zc;

private static final Integer SESSION_TIME_OUT = 10000;
private static final Integer CONNECT_TIME_OUT = 10000;

public DefaultBalanceProvider(String zkServer, String serversPath) {
this.serversPath = serversPath;
this.zkServer = zkServer;

this.zc = new ZkClient(this.zkServer, SESSION_TIME_OUT, CONNECT_TIME_OUT,
new SerializableSerializer());

}

@Override
protected ServerData balanceAlgorithm(List<ServerData> items) {
if (items.size()>0){
Collections.sort(items); // 根据负载由小到大排序
return items.get(0); // 返回负载最小的那个
}else{
return null;
}
}

/**
* 从zookeeper中拿到所有工作服务器的基本信息
*/
@Override
protected List<ServerData> getBalanceItems() {

List<ServerData> sdList = new ArrayList<ServerData>();
List<String> children = zc.getChildren(this.serversPath);
for(int i=0; i<children.size();i++){
ServerData sd = zc.readData(serversPath+"/"+children.get(i));
sdList.add(sd);
}
return sdList;

}

}
public class ClientRunner {

private static final int CLIENT_QTY = 3;
private static final String ZOOKEEPER_SERVER = "192.168.1.105:2181";
private static final String SERVERS_PATH = "/servers";

public static void main(String[] args) {

List<Thread> threadList = new ArrayList<Thread>(CLIENT_QTY);
final List<Client> clientList = new ArrayList<Client>();
final BalanceProvider<ServerData> balanceProvider = new DefaultBalanceProvider(ZOOKEEPER_SERVER, SERVERS_PATH);

try{

for(int i=0; i<CLIENT_QTY; i++){

Thread thread = new Thread(new Runnable() {

public void run() {
Client client = new ClientImpl(balanceProvider);
clientList.add(client);
try {
client.connect();
} catch (Exception e) {
e.printStackTrace();
}
}
});
threadList.add(thread);
thread.start();
//延时
Thread.sleep(2000);

}

System.out.println("敲回车键退出!\n");
new BufferedReader(new InputStreamReader(System.in)).readLine();


}catch(Exception e){

e.printStackTrace();

}finally{
//关闭客户端
for (int i=0; i<clientList.size(); i++){
try {
clientList.get(i);
clientList.get(i).disConnect();
} catch (Exception ignore) {
//ignore
}
}
//关闭线程
for (int i=0; i<threadList.size(); i++){
threadList.get(i).interrupt();
try{
threadList.get(i).join();
}catch (InterruptedException e){
//ignore
}
}
}
}
}

我们先启动服务端ServerRunner

1
2
3
4
6000:binding...
6000:binded...
6001:binding...
6001:binded...

再来启动客户端ClientRunner

1
2
3
4
5
6
7
connecting to 127.0.0.1:6000, it's balance 0
started success!
connecting to 127.0.0.1:6001, it's balance 0
started success!
connecting to 127.0.0.1:6000, it's balance 1
started success!
敲回车退出!