feat: support other protocol's server elect #209

This commit is contained in:
tjq 2021-02-08 21:07:00 +08:00
parent 0f1e17e862
commit eda39a6372
4 changed files with 31 additions and 21 deletions

View File

@ -3,7 +3,9 @@ package com.github.kfcfans.powerjob.server.service.ha;
import akka.actor.ActorSelection; import akka.actor.ActorSelection;
import akka.pattern.Patterns; import akka.pattern.Patterns;
import com.github.kfcfans.powerjob.common.PowerJobException; import com.github.kfcfans.powerjob.common.PowerJobException;
import com.github.kfcfans.powerjob.common.Protocol;
import com.github.kfcfans.powerjob.common.response.AskResponse; import com.github.kfcfans.powerjob.common.response.AskResponse;
import com.github.kfcfans.powerjob.server.transport.TransportService;
import com.github.kfcfans.powerjob.server.transport.starter.AkkaStarter; import com.github.kfcfans.powerjob.server.transport.starter.AkkaStarter;
import com.github.kfcfans.powerjob.server.handler.inner.requests.Ping; import com.github.kfcfans.powerjob.server.handler.inner.requests.Ping;
import com.github.kfcfans.powerjob.server.persistence.core.model.AppInfoDO; import com.github.kfcfans.powerjob.server.persistence.core.model.AppInfoDO;
@ -37,6 +39,8 @@ public class ServerSelectService {
@Resource @Resource
private LockService lockService; private LockService lockService;
@Resource @Resource
private TransportService transportService;
@Resource
private AppInfoRepository appInfoRepository; private AppInfoRepository appInfoRepository;
@Value("${oms.accurate.select.server.percentage}") @Value("${oms.accurate.select.server.percentage}")
@ -47,17 +51,17 @@ public class ServerSelectService {
private static final String SERVER_ELECT_LOCK = "server_elect_%d"; private static final String SERVER_ELECT_LOCK = "server_elect_%d";
public String getServer(Long appId, String currentServer) { public String getServer(Long appId, String currentServer, String protocol) {
if (!accurate()) { if (!accurate()) {
// 如果是本机就不需要查数据库那么复杂的操作了直接返回成功 // 如果是本机就不需要查数据库那么复杂的操作了直接返回成功
if (AkkaStarter.getActorSystemAddress().equals(currentServer)) { if (getThisServerAddress(protocol).equals(currentServer)) {
return currentServer; return currentServer;
} }
} }
return getServer0(appId); return getServer0(appId, protocol);
} }
private String getServer0(Long appId) { private String getServer0(Long appId, String protocol) {
Set<String> downServerCache = Sets.newHashSet(); Set<String> downServerCache = Sets.newHashSet();
@ -93,7 +97,7 @@ public class ServerSelectService {
} }
// 篡位本机作为Server // 篡位本机作为Server
appInfo.setCurrentServer(AkkaStarter.getActorSystemAddress()); appInfo.setCurrentServer(getThisServerAddress(protocol));
appInfo.setGmtModified(new Date()); appInfo.setGmtModified(new Date());
appInfoRepository.saveAndFlush(appInfo); appInfoRepository.saveAndFlush(appInfo);
@ -123,10 +127,6 @@ public class ServerSelectService {
return false; return false;
} }
if (AkkaStarter.getActorSystemAddress().equals(serverAddress)) {
return true;
}
Ping ping = new Ping(); Ping ping = new Ping();
ping.setCurrentTime(System.currentTimeMillis()); ping.setCurrentTime(System.currentTimeMillis());
@ -146,4 +146,9 @@ public class ServerSelectService {
private boolean accurate() { private boolean accurate() {
return ThreadLocalRandom.current().nextInt(100) < accurateSelectServerPercentage; return ThreadLocalRandom.current().nextInt(100) < accurateSelectServerPercentage;
} }
private String getThisServerAddress(String protocol) {
Protocol pt = Protocol.of(protocol);
return transportService.getTransporter(pt).getAddress();
}
} }

View File

@ -8,7 +8,6 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -33,20 +32,26 @@ public class TransportService {
} }
public void tell(Protocol protocol, String address, OmsSerializable object) { public void tell(Protocol protocol, String address, OmsSerializable object) {
Transporter transporter = protocol2Transporter.get(protocol); getTransporter(protocol).tell(address, object);
if (transporter == null) {
log.error("[TransportService] can't find transporter by protocol[{}], this is a bug!", protocol);
return;
}
transporter.tell(address, object);
} }
public AskResponse ask(Protocol protocol, String address, OmsSerializable object) throws Exception { public AskResponse ask(Protocol protocol, String address, OmsSerializable object) throws Exception {
return getTransporter(protocol).ask(address, object);
}
public Transporter getTransporter(Protocol protocol) {
Transporter transporter = protocol2Transporter.get(protocol); Transporter transporter = protocol2Transporter.get(protocol);
if (transporter == null) { if (transporter == null) {
log.error("[TransportService] can't find transporter by protocol[{}], this is a bug!", protocol); log.error("[TransportService] can't find transporter by protocol[{}], this is a bug!", protocol);
throw new IOException("can't find transporter by protocol: " + protocol); throw new UnknownProtocolException("can't find transporter by protocol: " + protocol);
}
return transporter;
}
public static class UnknownProtocolException extends RuntimeException {
public UnknownProtocolException(String message) {
super(message);
} }
return transporter.ask(address, object);
} }
} }

View File

@ -43,8 +43,8 @@ public class ServerController {
} }
@GetMapping("/acquire") @GetMapping("/acquire")
public ResultDTO<String> acquireServer(Long appId, String currentServer) { public ResultDTO<String> acquireServer(Long appId, String currentServer, String protocol) {
return ResultDTO.success(serverSelectService.getServer(appId, currentServer)); return ResultDTO.success(serverSelectService.getServer(appId, currentServer, protocol));
} }
@GetMapping("/hello") @GetMapping("/hello")

View File

@ -27,7 +27,7 @@ public class ServerDiscoveryService {
// 配置的可发起HTTP请求的ServerIP:Port // 配置的可发起HTTP请求的ServerIP:Port
private static final Map<String, String> IP2ADDRESS = Maps.newHashMap(); private static final Map<String, String> IP2ADDRESS = Maps.newHashMap();
// 服务发现地址 // 服务发现地址
private static final String DISCOVERY_URL = "http://%s/server/acquire?appId=%d&currentServer=%s"; private static final String DISCOVERY_URL = "http://%s/server/acquire?appId=%d&currentServer=%s&protocol=AKKA";
// 失败次数 // 失败次数
private static int FAILED_COUNT = 0; private static int FAILED_COUNT = 0;
// 最大失败次数 // 最大失败次数