finished processor api design

This commit is contained in:
tjq 2020-03-18 20:11:26 +08:00
parent 415da176ed
commit fafb708c7a
25 changed files with 509 additions and 52 deletions

View File

@ -13,5 +13,17 @@
<version>1.0.0-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<properties>
<slf4j.version>1.7.30</slf4j.version>
</properties>
<dependencies>
<!-- slf4j -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
</dependencies>
</project> </project>

View File

@ -0,0 +1,21 @@
package com.github.kfcfans.common;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* description
*
* @author tjq
* @since 2020/3/17
*/
@Getter
@AllArgsConstructor
public enum JobInstanceStatus {
RUNNING(1, "运行中"),
SUCCEED(2, "运行成功"),
FAILED(3, "运行失败");
private int value;
private String des;
}

View File

@ -1,4 +1,4 @@
package com.github.kfcfans.oms.worker.pojo.request; package com.github.kfcfans.common.request;
import lombok.Data; import lombok.Data;

View File

@ -0,0 +1,26 @@
package com.github.kfcfans.common.request;
import lombok.Data;
/**
* TaskTracker 将状态上报给服务器
*
* @author tjq
* @since 2020/3/17
*/
@Data
public class TaskTrackerReportInstanceStatusReq {
private String jobId;
private String instanceId;
private int instanceStatus;
private String result;
/* ********* 统计信息 ********* */
private long totalTaskNum;
private long runningTaskNum;
private long succeedTaskNum;
private long failedTaskNum;
}

View File

@ -0,0 +1,38 @@
package com.github.kfcfans.common.utils;
import lombok.extern.slf4j.Slf4j;
import java.util.function.Supplier;
/**
* 公共工具类
*
* @author tjq
* @since 2020/3/18
*/
@Slf4j
public class CommonUtils {
/**
* 重试执行仅适用于失败抛出异常的方法
* @param executor 需要执行的方法
* @param retryTimes 重试的次数
* @param intervalMS 失败后下一次执行的间隔时间
* @return 函数成功执行后的返回值
* @throws Exception 执行失败调用方自行处理
*/
public static <T> T executeWithRetry(Supplier<T> executor, int retryTimes, long intervalMS) throws Exception {
if (retryTimes <= 1 || intervalMS <= 0) {
return executor.get();
}
for (int i = 1; i < retryTimes; i++) {
try {
return executor.get();
}catch (Exception e) {
log.warn("[CommonUtils] executeWithRetry failed, system will retry after {}ms.", intervalMS, e);
Thread.sleep(intervalMS);
}
}
return executor.get();
}
}

View File

@ -16,11 +16,11 @@
<properties> <properties>
<spring.version>5.2.4.RELEASE</spring.version> <spring.version>5.2.4.RELEASE</spring.version>
<akka.version>2.6.4</akka.version> <akka.version>2.6.4</akka.version>
<slf4j.version>1.7.30</slf4j.version>
<oms.common.version>1.0.0-SNAPSHOT</oms.common.version> <oms.common.version>1.0.0-SNAPSHOT</oms.common.version>
<h2.db.version>1.4.200</h2.db.version> <h2.db.version>1.4.200</h2.db.version>
<hikaricp.version>3.4.2</hikaricp.version> <hikaricp.version>3.4.2</hikaricp.version>
<guava.version>28.2-jre</guava.version> <guava.version>28.2-jre</guava.version>
<fastjson.version>1.2.58</fastjson.version>
</properties> </properties>
<dependencies> <dependencies>
@ -40,13 +40,6 @@
<version>${akka.version}</version> <version>${akka.version}</version>
</dependency> </dependency>
<!-- slf4j -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.30</version>
</dependency>
<!-- oms-common --> <!-- oms-common -->
<dependency> <dependency>
<groupId>com.github.kfcfans</groupId> <groupId>com.github.kfcfans</groupId>
@ -74,6 +67,14 @@
<version>${guava.version}</version> <version>${guava.version}</version>
</dependency> </dependency>
<!-- fastJSON -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>${fastjson.version}</version>
</dependency>
<!-- 开发阶段输出日志 --> <!-- 开发阶段输出日志 -->
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>

View File

@ -1,7 +1,7 @@
package com.github.kfcfans.oms.worker.actors; package com.github.kfcfans.oms.worker.actors;
import akka.actor.AbstractActor; import akka.actor.AbstractActor;
import com.github.kfcfans.oms.worker.pojo.request.ServerScheduleJobReq; import com.github.kfcfans.common.request.ServerScheduleJobReq;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
/** /**

View File

@ -0,0 +1,16 @@
package com.github.kfcfans.oms.worker.common;
import com.github.kfcfans.oms.worker.sdk.TaskContext;
/**
* 存储一些不方便直接传递的东西
* #attention警惕内存泄漏问题最好在 ProcessorTracker destroy 执行 remove
*
* @author tjq
* @since 2020/3/18
*/
public class ThreadLocalStore {
public static final ThreadLocal<TaskContext> TASK_CONTEXT_THREAD_LOCAL = new ThreadLocal<>();
}

View File

@ -13,7 +13,7 @@ public class AkkaConstant {
*/ */
public static final String ACTOR_SYSTEM_NAME = "oms"; public static final String ACTOR_SYSTEM_NAME = "oms";
public static final String JOB_TRACKER_ACTOR_NAME = "job_tracker"; public static final String Task_TRACKER_ACTOR_NAME = "task_tracker";
public static final String WORKER_ACTOR_NAME = "worker"; public static final String WORKER_ACTOR_NAME = "worker";
} }

View File

@ -15,7 +15,7 @@ public enum TaskStatus {
/* ******************* TaskTracker 专用 ******************* */ /* ******************* TaskTracker 专用 ******************* */
WAITING_DISPATCH(1, "等待调度器调度"), WAITING_DISPATCH(1, "等待调度器调度"),
DISPATCH_SUCCESS(2, "调度成功但不保证worker收到"), DISPATCH_SUCCESS_WORKER_UNCHECK(2, "调度成功但不保证worker收到"),
WORKER_PROCESSING(3, "worker开始执行"), WORKER_PROCESSING(3, "worker开始执行"),
WORKER_PROCESS_SUCCESS(4, "worker执行成功"), WORKER_PROCESS_SUCCESS(4, "worker执行成功"),
WORKER_PROCESS_FAILED(5, "worker执行失败"), WORKER_PROCESS_FAILED(5, "worker执行失败"),
@ -32,7 +32,7 @@ public enum TaskStatus {
public static TaskStatus of(int v) { public static TaskStatus of(int v) {
switch (v) { switch (v) {
case 1: return WAITING_DISPATCH; case 1: return WAITING_DISPATCH;
case 2: return DISPATCH_SUCCESS; case 2: return DISPATCH_SUCCESS_WORKER_UNCHECK;
case 3: return WORKER_PROCESSING; case 3: return WORKER_PROCESSING;
case 4: return WORKER_PROCESS_SUCCESS; case 4: return WORKER_PROCESS_SUCCESS;
case 5: return WORKER_PROCESS_FAILED; case 5: return WORKER_PROCESS_FAILED;

View File

@ -22,11 +22,16 @@ public class SimpleTaskQuery {
private Integer status; private Integer status;
// 自定义的查询条件where 后面的语句 crated_time > 10086 and status = 3 // 自定义的查询条件where 后面的语句 crated_time > 10086 and status = 3
private String conditionSQL; private String queryCondition;
// 自定义的查询条件 GROUP BY status
private String otherCondition;
// 查询内容默认为 *
private String queryContent = " * ";
private Integer limit; private Integer limit;
public String getConditionSQL() { public String getQueryCondition() {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
if (!StringUtils.isEmpty(taskId)) { if (!StringUtils.isEmpty(taskId)) {
sb.append("task_id = '").append(taskId).append("'").append(LINK); sb.append("task_id = '").append(taskId).append("'").append(LINK);
@ -47,8 +52,8 @@ public class SimpleTaskQuery {
sb.append("status = ").append(status).append(LINK); sb.append("status = ").append(status).append(LINK);
} }
if (!StringUtils.isEmpty(conditionSQL)) { if (!StringUtils.isEmpty(queryCondition)) {
sb.append(conditionSQL).append(LINK); sb.append(queryCondition).append(LINK);
} }
String substring = sb.substring(0, sb.length() - LINK.length()); String substring = sb.substring(0, sb.length() - LINK.length());

View File

@ -2,6 +2,7 @@ package com.github.kfcfans.oms.worker.persistence;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* 任务持久化接口 * 任务持久化接口
@ -31,6 +32,8 @@ public interface TaskDAO {
List<TaskDO> simpleQuery(SimpleTaskQuery query); List<TaskDO> simpleQuery(SimpleTaskQuery query);
List<Map<String, Object>> simpleQueryPlus(SimpleTaskQuery query);
boolean simpleUpdate(SimpleTaskQuery condition, TaskDO updateField); boolean simpleUpdate(SimpleTaskQuery condition, TaskDO updateField);
} }

View File

@ -2,12 +2,13 @@ package com.github.kfcfans.oms.worker.persistence;
import com.github.kfcfans.oms.worker.common.constants.TaskStatus; import com.github.kfcfans.oms.worker.common.constants.TaskStatus;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import java.sql.*; import java.sql.*;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* 任务持久化实现层表名task_info * 任务持久化实现层表名task_info
@ -22,7 +23,7 @@ public class TaskDAOImpl implements TaskDAO {
public boolean initTable() { public boolean initTable() {
String delTableSQL = "drop table if exists task_info"; String delTableSQL = "drop table if exists task_info";
String createTableSQL = "create table task_info (task_id varchar(20), instance_id varchar(20), job_id varchar(20), task_name varchar(20), task_content blob, address varchar(20), status int(11), result text, failed_cnt int(11), created_time bigint(20), last_modified_time bigint(20), unique key pkey (instance_id, task_id))"; String createTableSQL = "create table task_info (task_id varchar(20), instance_id varchar(20), job_id varchar(20), task_name varchar(20), task_content blob, address varchar(20), status int(11), result text, failed_cnt int(11), created_time bigint(20), last_modified_time bigint(20), unique KEY pkey (instance_id, task_id))";
try (Connection conn = ConnectionFactory.getConnection(); Statement stat = conn.createStatement()) { try (Connection conn = ConnectionFactory.getConnection(); Statement stat = conn.createStatement()) {
stat.execute(delTableSQL); stat.execute(delTableSQL);
@ -99,7 +100,7 @@ public class TaskDAOImpl implements TaskDAO {
@Override @Override
public List<TaskDO> simpleQuery(SimpleTaskQuery query) { public List<TaskDO> simpleQuery(SimpleTaskQuery query) {
ResultSet rs = null; ResultSet rs = null;
String sql = "select * from task_info where " + query.getConditionSQL(); String sql = "select * from task_info where " + query.getQueryCondition();
List<TaskDO> result = Lists.newLinkedList(); List<TaskDO> result = Lists.newLinkedList();
try (Connection conn = ConnectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { try (Connection conn = ConnectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) {
rs = ps.executeQuery(); rs = ps.executeQuery();
@ -119,10 +120,43 @@ public class TaskDAOImpl implements TaskDAO {
return result; return result;
} }
@Override
public List<Map<String, Object>> simpleQueryPlus(SimpleTaskQuery query) {
ResultSet rs = null;
String sqlFormat = "select %s from task_info where %s";
String sql = String.format(sqlFormat, query.getQueryContent(), query.getQueryCondition());
List<Map<String, Object>> result = Lists.newLinkedList();
try (Connection conn = ConnectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) {
rs = ps.executeQuery();
// 原数据包含了列名
ResultSetMetaData metaData = rs.getMetaData();
while (rs.next()) {
Map<String, Object> row = Maps.newHashMap();
result.add(row);
for (int i = 0; i < metaData.getColumnCount(); i++) {
String colName = metaData.getColumnName(i + 1);
Object colValue = rs.getObject(colName);
row.put(colName, colValue);
}
}
}catch (Exception e) {
log.error("[TaskDAO] simpleQuery failed(sql = {}).", sql, e);
}finally {
if (rs != null) {
try {
rs.close();
}catch (Exception ignore) {
}
}
}
return result;
}
@Override @Override
public boolean simpleUpdate(SimpleTaskQuery condition, TaskDO updateField) { public boolean simpleUpdate(SimpleTaskQuery condition, TaskDO updateField) {
String sqlFormat = "update task_info set %s where %s"; String sqlFormat = "update task_info set %s where %s";
String updateSQL = String.format(sqlFormat, updateField.getUpdateSQL(), condition.getConditionSQL()); String updateSQL = String.format(sqlFormat, updateField.getUpdateSQL(), condition.getQueryCondition());
try (Connection conn = ConnectionFactory.getConnection(); PreparedStatement stat = conn.prepareStatement(updateSQL)) { try (Connection conn = ConnectionFactory.getConnection(); PreparedStatement stat = conn.prepareStatement(updateSQL)) {
return stat.execute(); return stat.execute();
}catch (Exception e) { }catch (Exception e) {
@ -193,6 +227,13 @@ public class TaskDAOImpl implements TaskDAO {
final List<TaskDO> res2 = taskDAO.simpleQuery(query); final List<TaskDO> res2 = taskDAO.simpleQuery(query);
System.out.println(res2); System.out.println(res2);
SimpleTaskQuery query3 = new SimpleTaskQuery();
query.setInstanceId("22");
query.setQueryContent("status, count(*) as num");
query.setOtherCondition("GROUP BY status");
List<Map<String, Object>> dbRES = taskDAO.simpleQueryPlus(query);
System.out.println(dbRES);
Thread.sleep(100000); Thread.sleep(100000);
} }
} }

View File

@ -3,9 +3,11 @@ package com.github.kfcfans.oms.worker.persistence;
import com.github.kfcfans.oms.worker.common.constants.TaskStatus; import com.github.kfcfans.oms.worker.common.constants.TaskStatus;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* 任务持久化服务 * 任务持久化服务
@ -67,4 +69,24 @@ public class TaskPersistenceService {
updateEntity.setStatus(status.getValue()); updateEntity.setStatus(status.getValue());
return taskDAO.simpleUpdate(condition, updateEntity); return taskDAO.simpleUpdate(condition, updateEntity);
} }
/**
* 获取 TaskTracker 管理的子 task 状态统计信息
* TaskStatus -> num
*/
public Map<TaskStatus, Long> getTaskStatusStatistics(String instanceId) {
SimpleTaskQuery query = new SimpleTaskQuery();
query.setInstanceId(instanceId);
query.setQueryContent("status, count(*) as num");
query.setOtherCondition("GROUP BY status");
List<Map<String, Object>> dbRES = taskDAO.simpleQueryPlus(query);
Map<TaskStatus, Long> result = Maps.newHashMap();
dbRES.forEach(row -> {
int status = Integer.parseInt(String.valueOf(row.get("status")));
long num = Long.parseLong(String.valueOf(row.get("num")));
result.put(TaskStatus.of(status), num);
});
return result;
}
} }

View File

@ -1,10 +1,53 @@
package com.github.kfcfans.oms.worker.pojo.request; package com.github.kfcfans.oms.worker.pojo.request;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.kfcfans.oms.worker.common.constants.TaskConstant;
import com.github.kfcfans.oms.worker.sdk.TaskContext;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.List;
/** /**
* WorkerMapTaskRequest * WorkerMapTaskRequest
* *
* @author tjq * @author tjq
* @since 2020/3/17 * @since 2020/3/17
*/ */
@Getter
@NoArgsConstructor
public class WorkerMapTaskRequest { public class WorkerMapTaskRequest {
private String instanceId;
private String jobId;
private String taskName;
private List<SubTask> subTasks;
@NoArgsConstructor
@AllArgsConstructor
private static class SubTask {
private String taskId;
private byte[] taskContent;
}
public WorkerMapTaskRequest(TaskContext taskContext, List<?> subTaskList, String taskName) {
this.instanceId = taskContext.getInstanceId();
this.jobId = taskContext.getJobId();
this.taskName = taskName;
this.subTasks = Lists.newLinkedList();
for (int i = 0; i < subTaskList.size(); i++) {
// 不同执行线程之间前缀taskId不同该ID可以保证分布式唯一
String subTaskId = taskContext.getTaskId() + "." + i;
// 写入类名方便反序列化
byte[] content = JSON.toJSONBytes(subTaskList.get(i), SerializerFeature.WriteClassName);
subTasks.add(new SubTask(subTaskId, content));
}
}
} }

View File

@ -0,0 +1,14 @@
package com.github.kfcfans.oms.worker.pojo.response;
import lombok.Data;
/**
* WorkerMapTaskRequest 的响应
*
* @author tjq
* @since 2020/3/18
*/
@Data
public class MapTaskResponse {
private boolean success;
}

View File

@ -0,0 +1,23 @@
package com.github.kfcfans.oms.worker.sdk;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
/**
* processor执行结果
*
* @author tjq
* @since 2020/3/18
*/
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class ProcessResult {
private boolean success = false;
private String msg;
}

View File

@ -0,0 +1,32 @@
package com.github.kfcfans.oms.worker.sdk;
import lombok.Data;
/**
* 任务上下文
* 概念统一所有的worker只处理TaskJob和JobInstance的概念只存在于Server和TaskTracker
* 单机任务 -> 整个Job变成一个Task
* 广播任务 -> 整个jOb变成一堆一样的Task
* MR 任务 -> 被map出来的任务都视为根Task的子Task
*
* @author tjq
* @since 2020/3/18
*/
@Data
public class TaskContext {
private String jobId;
private String instanceId;
private String taskId;
private String taskName;
private String jobParams;
private String instanceParams;
private int maxRetryTimes;
private int currentRetryTimes;
private Object subTask;
private String taskTrackerAddress;
}

View File

@ -0,0 +1,16 @@
package com.github.kfcfans.oms.worker.sdk.api;
import com.github.kfcfans.oms.worker.sdk.TaskContext;
import com.github.kfcfans.oms.worker.sdk.ProcessResult;
/**
* 基础的处理器适用于单机执行
*
* @author tjq
* @since 2020/3/18
*/
public interface BasicProcessor {
ProcessResult process(TaskContext context);
}

View File

@ -0,0 +1,21 @@
package com.github.kfcfans.oms.worker.sdk.api;
import com.github.kfcfans.oms.worker.sdk.ProcessResult;
/**
* 广播执行处理器适用于广播执行
*
* @author tjq
* @since 2020/3/18
*/
public interface BroadcastProcessor extends BasicProcessor {
/**
* 在所有节点广播执行前执行只会在一台机器执行一次
*/
ProcessResult preProcess();
/**
* 在所有节点广播执行完成后执行只会在一台机器执行一次
*/
ProcessResult postProcess();
}

View File

@ -0,0 +1,75 @@
package com.github.kfcfans.oms.worker.sdk.api;
import akka.actor.ActorSelection;
import akka.pattern.Patterns;
import com.github.kfcfans.oms.worker.OhMyWorker;
import com.github.kfcfans.oms.worker.common.ThreadLocalStore;
import com.github.kfcfans.oms.worker.common.constants.AkkaConstant;
import com.github.kfcfans.oms.worker.common.utils.AkkaUtils;
import com.github.kfcfans.oms.worker.pojo.request.WorkerMapTaskRequest;
import com.github.kfcfans.oms.worker.pojo.response.MapTaskResponse;
import com.github.kfcfans.oms.worker.sdk.TaskContext;
import com.github.kfcfans.oms.worker.sdk.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
/**
* MapReduce执行处理器适用于MapReduce任务
*
* @author tjq
* @since 2020/3/18
*/
@Slf4j
public abstract class MapReduceProcessor implements BasicProcessor {
private static final int RECOMMEND_BATCH_SIZE = 200;
private static final int REQUEST_TIMEOUT_MS = 5000;
/**
* 分发子任务
* @param taskList 子任务再次执行时可通过 TaskContext#getSubTask 获取
* @param taskName 子任务名称作用不大
* @return map结果
*/
public ProcessResult map(List<?> taskList, String taskName) {
if (CollectionUtils.isEmpty(taskList)) {
return new ProcessResult(false, "taskList can't be null");
}
if (taskList.size() > RECOMMEND_BATCH_SIZE) {
log.warn("[MapReduceProcessor] map task size is too large, network maybe overload... please try to split the tasks.");
}
TaskContext taskContext = ThreadLocalStore.TASK_CONTEXT_THREAD_LOCAL.get();
// 1. 构造请求
WorkerMapTaskRequest req = new WorkerMapTaskRequest(taskContext, taskList, taskName);
// 2. 可靠发送请求任务不允许丢失需要使用 ask 方法失败抛异常
boolean requestSucceed = false;
try {
String akkaRemotePath = AkkaUtils.getAkkaRemotePath(taskContext.getTaskTrackerAddress(), AkkaConstant.Task_TRACKER_ACTOR_NAME);
ActorSelection actorSelection = OhMyWorker.actorSystem.actorSelection(akkaRemotePath);
CompletionStage<Object> requestCS = Patterns.ask(actorSelection, req, Duration.ofMillis(REQUEST_TIMEOUT_MS));
MapTaskResponse respObj = (MapTaskResponse) requestCS.toCompletableFuture().get(REQUEST_TIMEOUT_MS, TimeUnit.MILLISECONDS);
requestSucceed = respObj.isSuccess();
}catch (Exception e) {
log.warn("[MapReduceProcessor] map failed.", e);
}
if (requestSucceed) {
return new ProcessResult(true, "MAP_SUCCESS");
}else {
return new ProcessResult(false, "MAP_FAILED");
}
}
public abstract ProcessResult reduce(TaskContext taskContext, Map<String, String> taskId2Result);
}

View File

@ -1,5 +1,8 @@
package com.github.kfcfans.oms.worker.tracker; package com.github.kfcfans.oms.worker.tracker;
import akka.actor.ActorRef;
import com.github.kfcfans.oms.worker.pojo.model.JobInstanceInfo;
/** /**
* 广播任务使用的 TaskTracker * 广播任务使用的 TaskTracker
* *
@ -8,6 +11,11 @@ package com.github.kfcfans.oms.worker.tracker;
*/ */
public class BroadcastTaskTracker extends TaskTracker { public class BroadcastTaskTracker extends TaskTracker {
public BroadcastTaskTracker(JobInstanceInfo jobInstanceInfo, ActorRef taskTrackerActorRef) {
super(jobInstanceInfo, taskTrackerActorRef);
}
@Override @Override
public void dispatch() { public void dispatch() {

View File

@ -1,5 +1,7 @@
package com.github.kfcfans.oms.worker.tracker; package com.github.kfcfans.oms.worker.tracker;
import akka.actor.ActorRef;
import com.github.kfcfans.oms.worker.pojo.model.JobInstanceInfo;
import com.github.kfcfans.oms.worker.pojo.request.WorkerMapTaskRequest; import com.github.kfcfans.oms.worker.pojo.request.WorkerMapTaskRequest;
@ -11,6 +13,11 @@ import com.github.kfcfans.oms.worker.pojo.request.WorkerMapTaskRequest;
*/ */
public class MapReduceTaskTracker extends StandaloneTaskTracker { public class MapReduceTaskTracker extends StandaloneTaskTracker {
public MapReduceTaskTracker(JobInstanceInfo jobInstanceInfo, ActorRef taskTrackerActorRef) {
super(jobInstanceInfo, taskTrackerActorRef);
}
public void newTask(WorkerMapTaskRequest mapRequest) { public void newTask(WorkerMapTaskRequest mapRequest) {
} }

View File

@ -1,5 +1,8 @@
package com.github.kfcfans.oms.worker.tracker; package com.github.kfcfans.oms.worker.tracker;
import akka.actor.ActorRef;
import com.github.kfcfans.oms.worker.pojo.model.JobInstanceInfo;
/** /**
* 单机任务使用的 TaskTracker * 单机任务使用的 TaskTracker
* *
@ -8,6 +11,11 @@ package com.github.kfcfans.oms.worker.tracker;
*/ */
public class StandaloneTaskTracker extends TaskTracker { public class StandaloneTaskTracker extends TaskTracker {
public StandaloneTaskTracker(JobInstanceInfo jobInstanceInfo, ActorRef taskTrackerActorRef) {
super(jobInstanceInfo, taskTrackerActorRef);
}
@Override @Override
public void dispatch() { public void dispatch() {

View File

@ -3,6 +3,8 @@ package com.github.kfcfans.oms.worker.tracker;
import akka.actor.ActorRef; import akka.actor.ActorRef;
import akka.actor.ActorSelection; import akka.actor.ActorSelection;
import com.github.kfcfans.common.ExecuteType; import com.github.kfcfans.common.ExecuteType;
import com.github.kfcfans.common.JobInstanceStatus;
import com.github.kfcfans.common.request.TaskTrackerReportInstanceStatusReq;
import com.github.kfcfans.oms.worker.OhMyWorker; import com.github.kfcfans.oms.worker.OhMyWorker;
import com.github.kfcfans.oms.worker.common.constants.AkkaConstant; import com.github.kfcfans.oms.worker.common.constants.AkkaConstant;
import com.github.kfcfans.oms.worker.common.constants.CommonSJ; import com.github.kfcfans.oms.worker.common.constants.CommonSJ;
@ -21,6 +23,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -43,14 +46,7 @@ public abstract class TaskTracker {
protected TaskPersistenceService taskPersistenceService; protected TaskPersistenceService taskPersistenceService;
protected ScheduledExecutorService scheduledPool; protected ScheduledExecutorService scheduledPool;
// 统计
protected AtomicBoolean finished = new AtomicBoolean(false); protected AtomicBoolean finished = new AtomicBoolean(false);
protected AtomicLong needDispatchTaskNum = new AtomicLong(0);
protected AtomicLong dispatchedTaskNum = new AtomicLong(0);
protected AtomicLong waitingToRunTaskNum = new AtomicLong(0);
protected AtomicLong runningTaskNum = new AtomicLong(0);
protected AtomicLong successTaskNum = new AtomicLong(0);
protected AtomicLong failedTaskNum = new AtomicLong(0);
public TaskTracker(JobInstanceInfo jobInstanceInfo, ActorRef taskTrackerActorRef) { public TaskTracker(JobInstanceInfo jobInstanceInfo, ActorRef taskTrackerActorRef) {
@ -62,6 +58,15 @@ public abstract class TaskTracker {
this.scheduledPool = Executors.newScheduledThreadPool(2, factory); this.scheduledPool = Executors.newScheduledThreadPool(2, factory);
allWorkerAddress = CommonSJ.commaSplitter.splitToList(jobInstanceInfo.getAllWorkerAddress()); allWorkerAddress = CommonSJ.commaSplitter.splitToList(jobInstanceInfo.getAllWorkerAddress());
// 持久化根任务
persistenceRootTask();
// 定时任务1任务派发
scheduledPool.scheduleWithFixedDelay(new DispatcherRunnable(), 0, 5, TimeUnit.SECONDS);
// 定时任务2状态检查
scheduledPool.scheduleWithFixedDelay(new StatusCheckRunnable(), 10, 10, TimeUnit.SECONDS);
} }
@ -70,16 +75,17 @@ public abstract class TaskTracker {
*/ */
public abstract void dispatch(); public abstract void dispatch();
public void updateTaskStatus(WorkerReportTaskStatusReq statusReportRequest) { public void updateTaskStatus(WorkerReportTaskStatusReq req) {
TaskStatus taskStatus = TaskStatus.of(statusReportRequest.getStatus()); TaskStatus taskStatus = TaskStatus.of(req.getStatus());
// 持久化
// 更新统计数据
switch (taskStatus) {
case RECEIVE_SUCCESS:
waitingToRunTaskNum.incrementAndGet();break;
case PROCESSING:
// 持久化失败则重试一次本地数据库操作几乎可以认为可靠......
boolean updateResult = taskPersistenceService.updateTaskStatus(req.getInstanceId(), req.getTaskId(), taskStatus);
if (!updateResult) {
try {
Thread.sleep(100);
taskPersistenceService.updateTaskStatus(req.getInstanceId(), req.getTaskId(), taskStatus);
}catch (Exception ignore) {
}
} }
} }
@ -90,7 +96,7 @@ public abstract class TaskTracker {
/** /**
* 持久化根任务只有完成持久化才能视为任务开始running先持久化再报告server * 持久化根任务只有完成持久化才能视为任务开始running先持久化再报告server
*/ */
private void persistenceTask() { private void persistenceRootTask() {
ExecuteType executeType = ExecuteType.valueOf(jobInstanceInfo.getExecuteType()); ExecuteType executeType = ExecuteType.valueOf(jobInstanceInfo.getExecuteType());
boolean persistenceResult; boolean persistenceResult;
@ -109,7 +115,6 @@ public abstract class TaskTracker {
rootTask.setCreatedTime(System.currentTimeMillis()); rootTask.setCreatedTime(System.currentTimeMillis());
persistenceResult = taskPersistenceService.save(rootTask); persistenceResult = taskPersistenceService.save(rootTask);
needDispatchTaskNum.incrementAndGet();
}else { }else {
List<TaskDO> taskList = Lists.newLinkedList(); List<TaskDO> taskList = Lists.newLinkedList();
List<String> addrList = CommonSJ.commaSplitter.splitToList(jobInstanceInfo.getAllWorkerAddress()); List<String> addrList = CommonSJ.commaSplitter.splitToList(jobInstanceInfo.getAllWorkerAddress());
@ -128,7 +133,6 @@ public abstract class TaskTracker {
taskList.add(task); taskList.add(task);
} }
persistenceResult = taskPersistenceService.batchSave(taskList); persistenceResult = taskPersistenceService.batchSave(taskList);
needDispatchTaskNum.addAndGet(taskList.size());
} }
if (!persistenceResult) { if (!persistenceResult) {
@ -136,12 +140,6 @@ public abstract class TaskTracker {
} }
} }
/**
* 启动任务分发器
*/
private void initDispatcher() {
}
public void destroy() { public void destroy() {
scheduledPool.shutdown(); scheduledPool.shutdown();
@ -171,12 +169,7 @@ public abstract class TaskTracker {
targetActor.tell(req, taskTrackerActorRef); targetActor.tell(req, taskTrackerActorRef);
// 更新数据库如果更新数据库失败可能导致重复执行先不处理 // 更新数据库如果更新数据库失败可能导致重复执行先不处理
taskPersistenceService.updateTaskStatus(task.getInstanceId(), task.getTaskId(), TaskStatus.DISPATCH_SUCCESS); taskPersistenceService.updateTaskStatus(task.getInstanceId(), task.getTaskId(), TaskStatus.DISPATCH_SUCCESS_WORKER_UNCHECK);
// 更新统计数据
needDispatchTaskNum.decrementAndGet();
dispatchedTaskNum.incrementAndGet();
}catch (Exception e) { }catch (Exception e) {
// 调度失败不修改数据库下次重新随机派发给 remote actor // 调度失败不修改数据库下次重新随机派发给 remote actor
log.warn("[TaskTracker] dispatch task({}) failed.", task); log.warn("[TaskTracker] dispatch task({}) failed.", task);
@ -193,6 +186,38 @@ public abstract class TaskTracker {
@Override @Override
public void run() { public void run() {
// 1. 查询统计信息
Map<TaskStatus, Long> status2Num = taskPersistenceService.getTaskStatusStatistics(jobInstanceInfo.getInstanceId());
long waitingDispatchNum = status2Num.get(TaskStatus.WAITING_DISPATCH);
long workerUnreceivedNum = status2Num.get(TaskStatus.DISPATCH_SUCCESS_WORKER_UNCHECK);
long receivedNum = status2Num.get(TaskStatus.RECEIVE_SUCCESS);
long succeedNum = status2Num.get(TaskStatus.WORKER_PROCESS_SUCCESS);
long failedNum = status2Num.get(TaskStatus.WORKER_PROCESS_FAILED);
long finishedNum = succeedNum + failedNum;
long unfinishedNum = waitingDispatchNum + workerUnreceivedNum + receivedNum;
log.debug("[TaskTracker] status check result({})", status2Num);
TaskTrackerReportInstanceStatusReq req = new TaskTrackerReportInstanceStatusReq();
// 2. 如果未完成任务数为0上报服务器
if (unfinishedNum == 0) {
finished.set(true);
if (failedNum == 0) {
req.setInstanceStatus(JobInstanceStatus.SUCCEED.getValue());
}else {
req.setInstanceStatus(JobInstanceStatus.FAILED.getValue());
}
// 特殊处理MapReduce任务(执行reduce)
// 特殊处理广播任务任务执行postProcess
}else {
}
} }
} }
} }