refactor(ai):重构分割模型包装类继承结构- 将 Anime2ModelWrapper、Anime2VividModelWrapper 和 AnimeModelWrapper 改为继承自 VividModelWrapper 基类
- 移除重复的 ResultFiles 内部类和相关工具方法实现 - Anime2Segmenter 和 AnimeSegmenter 继承自抽象基类 Segmenter - Anime2SegmentationResult与 AnimeSegmentationResult 继承 SegmentationResult - 重命名 LabelPalette 为 BiSeNetLabelPalette 并调整其引用 - 更新模型路径配置以匹配新的文件命名约定 - 删除冗余的 getLabels() 和 getPalette() 方法定义 - 简化 segmentAndSave 方法中的类型转换逻辑- 移除已被继承方法替代的手动资源管理代码 - 调整 import 语句以反映包结构调整- 清理不再需要的独立主测试函数入口点- 修改字段访问权限以符合继承设计模式 - 替换具体的返回类型为更通用的 SegmentationResult 接口- 整合公共功能至基类减少子类间重复代码 - 统一分割后处理流程提高模块复用性 - 引入泛型支持增强 Wrapper 类型安全性 - 更新注释文档保持与最新架构同步 - 优化异常处理策略统一关闭资源方式 - 规范文件命名规则便于未来维护扩展 - 提取共通逻辑到父类降低耦合度 - 完善类型检查避免运行时 ClassCastException 风险
This commit is contained in:
154
src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java
Normal file
154
src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java
Normal file
@@ -0,0 +1,154 @@
|
||||
package com.chuangzhou.vivid2D.ai;
|
||||
|
||||
import ai.djl.MalformedModelException;
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.repository.zoo.Criteria;
|
||||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.translate.Batchifier;
|
||||
import ai.djl.translate.TranslateException;
|
||||
import ai.djl.translate.Translator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
|
||||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Path;
|
||||
import java.util.*;
|
||||
|
||||
public abstract class Segmenter implements AutoCloseable {
|
||||
// 内部类,用于从Translator安全地传出数据
|
||||
public static class SegmentationData {
|
||||
public final int[] indices;
|
||||
public final long[] shape;
|
||||
|
||||
public SegmentationData(int[] indices, long[] shape) {
|
||||
this.indices = indices;
|
||||
this.shape = shape;
|
||||
}
|
||||
}
|
||||
|
||||
private String engine = "PyTorch";
|
||||
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
|
||||
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
|
||||
protected final List<String> labels;
|
||||
protected final Map<String, Integer> palette;
|
||||
|
||||
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||||
this.labels = new ArrayList<>(labels);
|
||||
this.palette = BiSeNetLabelPalette.defaultPalette();
|
||||
|
||||
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
return Segmenter.this.processInput(ctx, input);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||||
return Segmenter.this.processOutput(ctx, list);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Batchifier getBatchifier() {
|
||||
return Segmenter.this.getBatchifier();
|
||||
}
|
||||
};
|
||||
|
||||
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
|
||||
.setTypes(Image.class, Segmenter.SegmentationData.class)
|
||||
.optModelPath(modelDir)
|
||||
.optEngine(engine)
|
||||
.optTranslator(translator)
|
||||
.build();
|
||||
|
||||
this.modelWrapper = criteria.loadModel();
|
||||
this.predictor = modelWrapper.newPredictor();
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理模型输入
|
||||
* @param ctx translator 上下文
|
||||
* @param input 图片
|
||||
* @return 模型输入
|
||||
*/
|
||||
public abstract NDList processInput(TranslatorContext ctx, Image input);
|
||||
|
||||
/**
|
||||
* 处理模型输出
|
||||
* @param ctx translator 上下文
|
||||
* @param list 模型输出
|
||||
* @return 模型输出
|
||||
*/
|
||||
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
|
||||
|
||||
/**
|
||||
* 获取批量处理方式
|
||||
* @return 批量处理方式
|
||||
*/
|
||||
public Batchifier getBatchifier(){
|
||||
return null;
|
||||
}
|
||||
|
||||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||||
|
||||
// predict 方法现在直接返回安全的 Java 对象
|
||||
Segmenter.SegmentationData data = predictor.predict(img);
|
||||
|
||||
long[] shp = data.shape;
|
||||
int[] indices = data.indices;
|
||||
|
||||
int height, width;
|
||||
if (shp.length == 2) {
|
||||
height = (int) shp[0];
|
||||
width = (int) shp[1];
|
||||
} else {
|
||||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||||
}
|
||||
|
||||
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
||||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||||
Map<Integer, String> labelsMap = new HashMap<>();
|
||||
for (int i = 0; i < labels.size(); i++) {
|
||||
labelsMap.put(i, labels.get(i));
|
||||
}
|
||||
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int idx = indices[y * width + x];
|
||||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||||
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
||||
mask.setRGB(x, y, argb);
|
||||
}
|
||||
}
|
||||
|
||||
return new SegmentationResult(mask, labelsMap, palette);
|
||||
}
|
||||
|
||||
public void setEngine(String engine) {
|
||||
this.engine = engine;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
predictor.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
try {
|
||||
modelWrapper.close();
|
||||
} catch (Exception ignore) {
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user