Retrain代码分析

 <aside> 😀 这里写文章的前言: Retrain代码分析,包含线程锁,OSS搬运,图片标签处理,通知发送

</aside>

ReentrantLock是Java中的一个可重入锁实现,它是Lock接口的一个实现类。"可重入"意味着同一个线程可以多次获取同一把锁而不会死锁。主要特点包括:

  1. 可重入性: 允许同一个线程多次获取同一把锁。
  2. 公平性选择: 支持公平锁和非公平锁。公平锁按照申请锁的顺序获取锁,非公平锁则允许"插队"。
  3. 可中断性: 等待获取锁的线程可以被中断。
  4. 超时机制: 支持尝试获取锁的超时机制。
  5. 条件变量: 通过newCondition()方法可以获得Condition对象,实现线程间的协调。
  6. 锁投票: 通过方法如tryLock()可以尝试非阻塞地获取锁。
  7. 性能: 在大多数情况下,ReentrantLock的性能优于synchronized关键字。

使用ReentrantLock时,通常的模式是:

ReentrantLock lock = new ReentrantLock();
lock.lock();
try {
    // 临界区代码
} finally {
    lock.unlock();
}

📝 主旨内容

detailed-code-explanation.md

详细代码解释

1. sendRetrain 方法

@Override
public SendRetrainImageResult sendRetrain(SendRetrainParam retrainParam) {
    String currProjectId = this.getProjectId(retrainParam.getHostName(), retrainParam.getProjectId());
    if (currProjectId == null) {
        throw new RuntimeException("找不到对应的projectId");
    }
    retrainParam.setProjectId(currProjectId);
    ReentrantLock reentrantLock = LockConstant.projectIdLockMap.computeIfAbsent(currProjectId, key -> new ReentrantLock());
    reentrantLock.lock();
    try {
        return this.moveLogPredictionImage(retrainParam);
    } catch (Exception e) {
        LOGGER.error("moveLogPredictionImage error", e);
        throw e;
    } finally {
        reentrantLock.unlock();
    }
}

这个方法是类的入口点,用于处理重新训练的请求。

  1. 首先,它通过 getProjectId 方法获取当前项目的ID。
  2. 如果找不到项目ID,抛出运行时异常。
  3. 使用 ReentrantLock 确保对同一个项目的操作是线程安全的。
  4. 在锁定的代码块中调用 moveLogPredictionImage 方法。
  5. 无论操作是否成功,都确保在 finally 块中释放锁。

2. moveLogPredictionImage 方法

这是主要的处理方法,包含了大量的逻辑。让我们分步骤解析:

2.1 初始化和验证

SendRetrainImageResult moveImageResult = new SendRetrainImageResult();
String projectId = retrainParam.getProjectId();
int retrainCount = this.getRetrainCountConfig();
List<RetrainCountInfo> retrainCountInfos = hLogPredictionMapper.findAvailableRetrainInfoByProjectId(projectId);
if (retrainCountInfos.isEmpty()) {
    throw new RuntimeException("复训图片不足");
}
RetrainCountInfo retrainCountInfo = retrainCountInfos.get(0);
if (retrainCountInfo.getCount() < retrainCount) {
    LOGGER.warn("projectId[{}] count[{}] lt {}", retrainCountInfo.getProjectId(), retrainCountInfo.getCount(), retrainCount);
    throw new RuntimeException("复训图片不足");
}

  • 创建结果对象。
  • 获取项目ID和重新训练所需的图片数量。
  • 检查是否有足够的图片可用于重新训练。

2.2 获取和处理图片信息

List<PredictionImage> list = hLogPredictionMapper.findAvailableImageInfoByProjectId(projectId);
Set<String> checkHosts = new HashSet<>();
for (PredictionImage predictionImage : list) {
    String destinationPath = predictionImage.getDestinationPath();
    if (destinationPath != null) {
        predictionImage.setDestinationPath(destinationPath.replaceAll("\\\\\\\\\\\\\\\\", "/")
                .replaceAll(" ", "%20"));
        checkHosts.add(predictionImage.getDestinationPath().substring(0, DataUtils.indexOf(predictionImage.getDestinationPath(), '/', 3)));
    }
}

  • 获取可用的图片信息。
  • 处理图片路径,替换反斜杠和空格。
  • 提取主机信息以进行连接性检查。

2.3 检查主机连接性

List<String> badHosts = new ArrayList<>();
for (String host : checkHosts) {
    boolean canConnect = HttpUtils.checkConnect(host);
    if (!canConnect) {
        badHosts.add(host);
    }
}
if (!badHosts.isEmpty()) {
    LOGGER.info("图片地址无法连接: {}", badHosts);
    throw new RuntimeException("图片地址获取不到,请联系对应窗口开通防火墙");
}

  • 检查每个主机的连接性。
  • 如果有无法连接的主机,记录并抛出异常。

2.4 获取任务信息

Task task = taskService.getWithId(Integer.valueOf(projectId));
if (task == null) {
    LOGGER.warn("Task not exist with projectId[{}]", projectId);
    throw new RuntimeException("方案不存在");
}

  • 获取与项目ID关联的任务信息。
  • 如果任务不存在,抛出异常。

2.5 准备OSS路径和上传图片

String fabId = task.getSite();
String projectName = task.getName();
OssPathParamInfo ossPathParamInfo = new OssPathParamInfo(null, projectId, fabId, projectName);
String ossTrainingDataPath = ossPathService.getImagePath(ossPathParamInfo);
String functionType = task.getFunctionType();
OssFileInfo ossFileInfo = OssFileInfo.parseOssPath(ossTrainingDataPath);
String bucketName = ossFileInfo.getBucketName();
String basePath = ossFileInfo.getBasePath();
String userId = UserUtils.getUserId();
List<PredictionImage> uploaded = new ArrayList<>();

  • 准备OSS(对象存储服务)相关的信息。
  • 开始上传图片的过程。

2.6 上传图片到OSS

for (PredictionImage predictionImage : list) {
    String url = predictionImage.getDestinationPath();
    String imageName = predictionImage.getImageName();
    if (!StringUtils.hasLength(url)) {
        LOGGER.warn("image[{}] destinationPath not exist", imageName);
        continue;
    }
    String pathFileName = basePath + imageName + "." + predictionImage.getImageExtName();
    boolean success = ossService.upload(bucketName, pathFileName, url);
    if (success) {
        uploaded.add(predictionImage);
        hLogPredictionMapper.updateRetrainPath(ossTrainingDataPath, projectId, predictionImage.getImageName());
        LOGGER.info("bucket[{}] image[{}] upload success", bucketName, predictionImage.getImageName());
    } else {
        hLogPredictionMapper.updatePredictionResult(predictionImage.getGroupId(), predictionImage.getImageName(), null, "N", userId);
        LOGGER.error("bucket[{}] image[{}] upload fail", bucketName, predictionImage.getImageName());
    }
}

  • 遍历图片列表,尝试上传每张图片到OSS。
  • 记录上传成功和失败的情况。

2.7 生成和上传清单文件

String[] pathArr = basePath.split("/");
String fileName = pathArr[0] + "_" + pathArr[2];
String fileExtName = "manifest";
String manifestFileFullName = fileName + "." + fileExtName;
StringJoiner manifestContentSJ = new StringJoiner(LineSeparator.Windows, "", LineSeparator.Windows);
for (PredictionImage predictionImage : uploaded) {
    manifestContentSJ.add(makeJsonInfo(predictionImage, ossTrainingDataPath, functionType));
}
String manifestContent = manifestContentSJ.toString();

  • 创建清单文件名和内容。
  • 根据上传的图片生成清单内容。

2.8 处理选定的清单内容

String selectedManifestContent = "";
String trainRecordIds = retrainParam.getTrainRecordIds();
String modelType = retrainParam.getModelType();
if (StringUtils.hasLength(trainRecordIds)) {
    selectedManifestContent = trainFileContentService.getFileContentWithProjectIdAndTrainIds(projectId, trainRecordIds);
}
String allManifestContent = selectedManifestContent + manifestContent;
String manifestFilePath = OssFileInfo.parseFilePath(basePath);
trainFileContentService.uploadManifestFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName);

  • 处理额外的选定清单内容(如果有)。
  • 合并所有清单内容。
  • 上传完整的清单文件。

2.9 特定模型类型的处理

if (TaskUtils.isBox(functionType)) {
    trainFileContentService.uploadManifestTrainAndValFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName, retrainParam.getTrainPercent());
} else if (ModelTrainUtils.isONNX(modelType)) {
    trainFileContentService.uploadManifestTrainAndValFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName, retrainParam.getTrainPercent());
}

  • 根据任务类型或模型类型执行特定的清单文件处理。

2.10 保存训练记录

int totalImageCount = TrainFileUtils.parseForImageCount(allManifestContent);
HTrainRecord hTrainRecord = new HTrainRecord();
// ... (设置 hTrainRecord 的各个属性)
hTrainRecordService.save(hTrainRecord);
Integer id = hTrainRecord.getId();

  • 解析总图片数量。
  • 创建并保存训练记录。

2.11 处理图片标签

Map<String, List<String>> imageLabelsMap = TrainFileUtils.parseForSourceMarksMap(allManifestContent);
Map<String, List<String>> labelImagesMap = new HashMap<>();
// ... (处理标签和图片的映射关系)
List<HTrainRecordLabel> hTrainRecordLabels = new ArrayList<>();
// ... (创建并保存训练记录标签)
hTrainRecordLabelService.saveAll(hTrainRecordLabels);

  • 解析图片和标签的关系。
  • 创建并保存训练记录标签。

2.12 执行自动训练(无需)

List<String> classes = TrainFileUtils.parseForMarks(allManifestContent);
if (TaskUtils.isBox(functionType) || ModelTrainUtils.isONNX(hTrainRecord.getModelType())) {
    UserCommandDetectorData userCommandData = new UserCommandDetectorData();
    // ... (设置 userCommandData 的属性)
    try {
        if (ModelTrainUtils.isONNX(hTrainRecord.getModelType())) {
            // ... (设置ONNX特定参数)
        }
        String jobId = autoTrainService.doTrain(userCommandData);
        hTrainRecordService.setJobIdWithId(jobId, id);
    } catch (Exception e) {
        LOGGER.error("projectId[{}] doTrain error", projectId, e);
    }
} else if (TaskUtils.isClassification(functionType)) {
    // ... (处理分类任务的特定逻辑)
}

  • 根据任务类型和模型类型执行自动训练。
  • 处理可能的异常情况。

2.13 发送通知

this.doSendTornado(projectId, imageCount, totalImageCount, ossTrainingDataPath, manifestFileFullName);
return moveImageResult;

  • 发送训练完成的通知。
  • 返回处理结果。

3. doSendTornado 方法

这个方法负责发送训练完成的通知:

  1. 获取任务信息。
  2. 获取通知URL和发送者信息。
  3. 准备通知内容,包括项目详情、训练数据等。
  4. 使用 tornadoService 发送通知。

总的来说,这个类实现了一个复杂的图像重新训练流程,包括数据验证、图片上传、清单文件生成、训练记录保存、自动训练执行和通知发送等多个步骤。它使用了多个服务和工具类来完成这些任务,并处理了各种可能的异常情况。

🤗 总结归纳

这段代码实现了一个复杂的图像重新训练流程。主要功能包括:

  1. 验证重新训练的请求和数据。
  2. 上传图片到对象存储服务(OSS)。
  3. 生成和上传训练用的清单文件。
  4. 保存训练记录和相关的标签信息。
  5. 根据任务类型和模型类型执行自动训练。
  6. 发送训练完成的通知。

代码使用了多线程安全机制(ReentrantLock),处理了各种可能的异常情况,并与多个外部服务和数据库进行交互,实现涉及了图像处理、数据存储、机器学习训练等多个领域的概念和技术。

评论

发表评论

此博客中的热门博文

test