diff --git a/powerjob-worker/src/main/java/tech/powerjob/worker/persistence/TaskDAOImpl.java b/powerjob-worker/src/main/java/tech/powerjob/worker/persistence/TaskDAOImpl.java index 5ba746bb..f1793a2b 100644 --- a/powerjob-worker/src/main/java/tech/powerjob/worker/persistence/TaskDAOImpl.java +++ b/powerjob-worker/src/main/java/tech/powerjob/worker/persistence/TaskDAOImpl.java @@ -47,18 +47,24 @@ public class TaskDAOImpl implements TaskDAO { @Override public boolean batchSave(Collection tasks) throws SQLException { - String insertSQL = "insert into task_info(task_id, instance_id, sub_instance_id, task_name, task_content, address, status, result, failed_cnt, created_time, last_modified_time, last_report_time) values (?,?,?,?,?,?,?,?,?,?,?,?)"; - try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(insertSQL)) { - - for (TaskDO task : tasks) { - - fillInsertPreparedStatement(task, ps); - ps.addBatch(); + String insertSql = "insert into task_info(task_id, instance_id, sub_instance_id, task_name, task_content, address, status, result, failed_cnt, created_time, last_modified_time, last_report_time) values (?,?,?,?,?,?,?,?,?,?,?,?)"; + boolean originAutoCommitFlag ; + try (Connection conn = connectionFactory.getConnection()) { + originAutoCommitFlag = conn.getAutoCommit(); + conn.setAutoCommit(false); + try ( PreparedStatement ps = conn.prepareStatement(insertSql)) { + for (TaskDO task : tasks) { + fillInsertPreparedStatement(task, ps); + ps.addBatch(); + } + ps.executeBatch(); + return true; + } catch (Throwable e) { + conn.rollback(); + throw e; + } finally { + conn.setAutoCommit(originAutoCommitFlag); } - - ps.executeBatch(); - return true; - } } diff --git a/powerjob-worker/src/test/java/tech/powerjob/worker/test/PersistenceServiceTest.java b/powerjob-worker/src/test/java/tech/powerjob/worker/test/PersistenceServiceTest.java index 0fb33e4b..f922a0b5 100644 --- a/powerjob-worker/src/test/java/tech/powerjob/worker/test/PersistenceServiceTest.java +++ b/powerjob-worker/src/test/java/tech/powerjob/worker/test/PersistenceServiceTest.java @@ -11,6 +11,8 @@ import org.junit.jupiter.api.*; import java.util.List; import java.util.concurrent.ThreadLocalRandom; +import static tech.powerjob.worker.core.tracker.task.CommonTaskTracker.ROOT_TASK_ID; + /** * H2 数据库持久化测试 * @@ -63,6 +65,31 @@ public class PersistenceServiceTest { } + @Test + public void testBatchSave(){ + List taskList = Lists.newLinkedList(); + long instanceId = 10086L + ThreadLocalRandom.current().nextInt(2); + for (int i = 0; i < 100; i++) { + TaskDO task = new TaskDO(); + taskList.add(task); + task.setSubInstanceId(instanceId); + task.setInstanceId(instanceId); + task.setTaskId(ROOT_TASK_ID + "." + i); + task.setFailedCnt(0); + task.setStatus(TaskStatus.WORKER_RECEIVED.getValue()); + task.setTaskName("ROOT_TASK"); + task.setAddress(NetUtils.getLocalHost()); + task.setLastModifiedTime(System.currentTimeMillis()); + task.setCreatedTime(System.currentTimeMillis()); + task.setLastReportTime(System.currentTimeMillis()); + task.setResult(""); + } + TaskDO firstTask = taskList.get(0); + taskList.add(firstTask); + taskPersistenceService.batchSave(taskList); + + } + @Test public void testDeleteAllTasks() {