Retrain代码分析
<aside> 😀 这里写文章的前言: Retrain代码分析,包含线程锁,OSS搬运,图片标签处理,通知发送
</aside>
ReentrantLock是Java中的一个可重入锁实现,它是Lock接口的一个实现类。"可重入"意味着同一个线程可以多次获取同一把锁而不会死锁。主要特点包括:
- 可重入性: 允许同一个线程多次获取同一把锁。
- 公平性选择: 支持公平锁和非公平锁。公平锁按照申请锁的顺序获取锁,非公平锁则允许"插队"。
- 可中断性: 等待获取锁的线程可以被中断。
- 超时机制: 支持尝试获取锁的超时机制。
- 条件变量: 通过newCondition()方法可以获得Condition对象,实现线程间的协调。
- 锁投票: 通过方法如tryLock()可以尝试非阻塞地获取锁。
- 性能: 在大多数情况下,ReentrantLock的性能优于synchronized关键字。
使用ReentrantLock时,通常的模式是:
ReentrantLock lock = new ReentrantLock();
lock.lock();
try {
// 临界区代码
} finally {
lock.unlock();
}
📝 主旨内容
详细代码解释
1. sendRetrain 方法
@Override
public SendRetrainImageResult sendRetrain(SendRetrainParam retrainParam) {
String currProjectId = this.getProjectId(retrainParam.getHostName(), retrainParam.getProjectId());
if (currProjectId == null) {
throw new RuntimeException("找不到对应的projectId");
}
retrainParam.setProjectId(currProjectId);
ReentrantLock reentrantLock = LockConstant.projectIdLockMap.computeIfAbsent(currProjectId, key -> new ReentrantLock());
reentrantLock.lock();
try {
return this.moveLogPredictionImage(retrainParam);
} catch (Exception e) {
LOGGER.error("moveLogPredictionImage error", e);
throw e;
} finally {
reentrantLock.unlock();
}
}
这个方法是类的入口点,用于处理重新训练的请求。
- 首先,它通过
getProjectId
方法获取当前项目的ID。 - 如果找不到项目ID,抛出运行时异常。
- 使用
ReentrantLock
确保对同一个项目的操作是线程安全的。 - 在锁定的代码块中调用
moveLogPredictionImage
方法。 - 无论操作是否成功,都确保在
finally
块中释放锁。
2. moveLogPredictionImage 方法
这是主要的处理方法,包含了大量的逻辑。让我们分步骤解析:
2.1 初始化和验证
SendRetrainImageResult moveImageResult = new SendRetrainImageResult();
String projectId = retrainParam.getProjectId();
int retrainCount = this.getRetrainCountConfig();
List<RetrainCountInfo> retrainCountInfos = hLogPredictionMapper.findAvailableRetrainInfoByProjectId(projectId);
if (retrainCountInfos.isEmpty()) {
throw new RuntimeException("复训图片不足");
}
RetrainCountInfo retrainCountInfo = retrainCountInfos.get(0);
if (retrainCountInfo.getCount() < retrainCount) {
LOGGER.warn("projectId[{}] count[{}] lt {}", retrainCountInfo.getProjectId(), retrainCountInfo.getCount(), retrainCount);
throw new RuntimeException("复训图片不足");
}
- 创建结果对象。
- 获取项目ID和重新训练所需的图片数量。
- 检查是否有足够的图片可用于重新训练。
2.2 获取和处理图片信息
List<PredictionImage> list = hLogPredictionMapper.findAvailableImageInfoByProjectId(projectId);
Set<String> checkHosts = new HashSet<>();
for (PredictionImage predictionImage : list) {
String destinationPath = predictionImage.getDestinationPath();
if (destinationPath != null) {
predictionImage.setDestinationPath(destinationPath.replaceAll("\\\\\\\\\\\\\\\\", "/")
.replaceAll(" ", "%20"));
checkHosts.add(predictionImage.getDestinationPath().substring(0, DataUtils.indexOf(predictionImage.getDestinationPath(), '/', 3)));
}
}
- 获取可用的图片信息。
- 处理图片路径,替换反斜杠和空格。
- 提取主机信息以进行连接性检查。
2.3 检查主机连接性
List<String> badHosts = new ArrayList<>();
for (String host : checkHosts) {
boolean canConnect = HttpUtils.checkConnect(host);
if (!canConnect) {
badHosts.add(host);
}
}
if (!badHosts.isEmpty()) {
LOGGER.info("图片地址无法连接: {}", badHosts);
throw new RuntimeException("图片地址获取不到,请联系对应窗口开通防火墙");
}
- 检查每个主机的连接性。
- 如果有无法连接的主机,记录并抛出异常。
2.4 获取任务信息
Task task = taskService.getWithId(Integer.valueOf(projectId));
if (task == null) {
LOGGER.warn("Task not exist with projectId[{}]", projectId);
throw new RuntimeException("方案不存在");
}
- 获取与项目ID关联的任务信息。
- 如果任务不存在,抛出异常。
2.5 准备OSS路径和上传图片
String fabId = task.getSite();
String projectName = task.getName();
OssPathParamInfo ossPathParamInfo = new OssPathParamInfo(null, projectId, fabId, projectName);
String ossTrainingDataPath = ossPathService.getImagePath(ossPathParamInfo);
String functionType = task.getFunctionType();
OssFileInfo ossFileInfo = OssFileInfo.parseOssPath(ossTrainingDataPath);
String bucketName = ossFileInfo.getBucketName();
String basePath = ossFileInfo.getBasePath();
String userId = UserUtils.getUserId();
List<PredictionImage> uploaded = new ArrayList<>();
- 准备OSS(对象存储服务)相关的信息。
- 开始上传图片的过程。
2.6 上传图片到OSS
for (PredictionImage predictionImage : list) {
String url = predictionImage.getDestinationPath();
String imageName = predictionImage.getImageName();
if (!StringUtils.hasLength(url)) {
LOGGER.warn("image[{}] destinationPath not exist", imageName);
continue;
}
String pathFileName = basePath + imageName + "." + predictionImage.getImageExtName();
boolean success = ossService.upload(bucketName, pathFileName, url);
if (success) {
uploaded.add(predictionImage);
hLogPredictionMapper.updateRetrainPath(ossTrainingDataPath, projectId, predictionImage.getImageName());
LOGGER.info("bucket[{}] image[{}] upload success", bucketName, predictionImage.getImageName());
} else {
hLogPredictionMapper.updatePredictionResult(predictionImage.getGroupId(), predictionImage.getImageName(), null, "N", userId);
LOGGER.error("bucket[{}] image[{}] upload fail", bucketName, predictionImage.getImageName());
}
}
- 遍历图片列表,尝试上传每张图片到OSS。
- 记录上传成功和失败的情况。
2.7 生成和上传清单文件
String[] pathArr = basePath.split("/");
String fileName = pathArr[0] + "_" + pathArr[2];
String fileExtName = "manifest";
String manifestFileFullName = fileName + "." + fileExtName;
StringJoiner manifestContentSJ = new StringJoiner(LineSeparator.Windows, "", LineSeparator.Windows);
for (PredictionImage predictionImage : uploaded) {
manifestContentSJ.add(makeJsonInfo(predictionImage, ossTrainingDataPath, functionType));
}
String manifestContent = manifestContentSJ.toString();
- 创建清单文件名和内容。
- 根据上传的图片生成清单内容。
2.8 处理选定的清单内容
String selectedManifestContent = "";
String trainRecordIds = retrainParam.getTrainRecordIds();
String modelType = retrainParam.getModelType();
if (StringUtils.hasLength(trainRecordIds)) {
selectedManifestContent = trainFileContentService.getFileContentWithProjectIdAndTrainIds(projectId, trainRecordIds);
}
String allManifestContent = selectedManifestContent + manifestContent;
String manifestFilePath = OssFileInfo.parseFilePath(basePath);
trainFileContentService.uploadManifestFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName);
- 处理额外的选定清单内容(如果有)。
- 合并所有清单内容。
- 上传完整的清单文件。
2.9 特定模型类型的处理
if (TaskUtils.isBox(functionType)) {
trainFileContentService.uploadManifestTrainAndValFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName, retrainParam.getTrainPercent());
} else if (ModelTrainUtils.isONNX(modelType)) {
trainFileContentService.uploadManifestTrainAndValFile(allManifestContent, bucketName, manifestFilePath, manifestFileFullName, retrainParam.getTrainPercent());
}
- 根据任务类型或模型类型执行特定的清单文件处理。
2.10 保存训练记录
int totalImageCount = TrainFileUtils.parseForImageCount(allManifestContent);
HTrainRecord hTrainRecord = new HTrainRecord();
// ... (设置 hTrainRecord 的各个属性)
hTrainRecordService.save(hTrainRecord);
Integer id = hTrainRecord.getId();
- 解析总图片数量。
- 创建并保存训练记录。
2.11 处理图片标签
Map<String, List<String>> imageLabelsMap = TrainFileUtils.parseForSourceMarksMap(allManifestContent);
Map<String, List<String>> labelImagesMap = new HashMap<>();
// ... (处理标签和图片的映射关系)
List<HTrainRecordLabel> hTrainRecordLabels = new ArrayList<>();
// ... (创建并保存训练记录标签)
hTrainRecordLabelService.saveAll(hTrainRecordLabels);
- 解析图片和标签的关系。
- 创建并保存训练记录标签。
2.12 执行自动训练(无需)
List<String> classes = TrainFileUtils.parseForMarks(allManifestContent);
if (TaskUtils.isBox(functionType) || ModelTrainUtils.isONNX(hTrainRecord.getModelType())) {
UserCommandDetectorData userCommandData = new UserCommandDetectorData();
// ... (设置 userCommandData 的属性)
try {
if (ModelTrainUtils.isONNX(hTrainRecord.getModelType())) {
// ... (设置ONNX特定参数)
}
String jobId = autoTrainService.doTrain(userCommandData);
hTrainRecordService.setJobIdWithId(jobId, id);
} catch (Exception e) {
LOGGER.error("projectId[{}] doTrain error", projectId, e);
}
} else if (TaskUtils.isClassification(functionType)) {
// ... (处理分类任务的特定逻辑)
}
- 根据任务类型和模型类型执行自动训练。
- 处理可能的异常情况。
2.13 发送通知
this.doSendTornado(projectId, imageCount, totalImageCount, ossTrainingDataPath, manifestFileFullName);
return moveImageResult;
- 发送训练完成的通知。
- 返回处理结果。
3. doSendTornado 方法
这个方法负责发送训练完成的通知:
- 获取任务信息。
- 获取通知URL和发送者信息。
- 准备通知内容,包括项目详情、训练数据等。
- 使用
tornadoService
发送通知。
总的来说,这个类实现了一个复杂的图像重新训练流程,包括数据验证、图片上传、清单文件生成、训练记录保存、自动训练执行和通知发送等多个步骤。它使用了多个服务和工具类来完成这些任务,并处理了各种可能的异常情况。
🤗 总结归纳
这段代码实现了一个复杂的图像重新训练流程。主要功能包括:
- 验证重新训练的请求和数据。
- 上传图片到对象存储服务(OSS)。
- 生成和上传训练用的清单文件。
- 保存训练记录和相关的标签信息。
- 根据任务类型和模型类型执行自动训练。
- 发送训练完成的通知。
代码使用了多线程安全机制(ReentrantLock),处理了各种可能的异常情况,并与多个外部服务和数据库进行交互,实现涉及了图像处理、数据存储、机器学习训练等多个领域的概念和技术。
helloTest
回复删除