- 移除重复的 ResultFiles 内部类和相关工具方法实现 - Anime2Segmenter 和 AnimeSegmenter 继承自抽象基类 Segmenter - Anime2SegmentationResult与 AnimeSegmentationResult 继承 SegmentationResult - 重命名 LabelPalette 为 BiSeNetLabelPalette 并调整其引用 - 更新模型路径配置以匹配新的文件命名约定 - 删除冗余的 getLabels() 和 getPalette() 方法定义 - 简化 segmentAndSave 方法中的类型转换逻辑- 移除已被继承方法替代的手动资源管理代码 - 调整 import 语句以反映包结构调整- 清理不再需要的独立主测试函数入口点- 修改字段访问权限以符合继承设计模式 - 替换具体的返回类型为更通用的 SegmentationResult 接口- 整合公共功能至基类减少子类间重复代码 - 统一分割后处理流程提高模块复用性 - 引入泛型支持增强 Wrapper 类型安全性 - 更新注释文档保持与最新架构同步 - 优化异常处理策略统一关闭资源方式 - 规范文件命名规则便于未来维护扩展 - 提取共通逻辑到父类降低耦合度 - 完善类型检查避免运行时 ClassCastException 风险
155 lines
5.1 KiB
Java
155 lines
5.1 KiB
Java
package com.chuangzhou.vivid2D.ai;
|
||
|
||
import ai.djl.MalformedModelException;
|
||
import ai.djl.inference.Predictor;
|
||
import ai.djl.modality.cv.Image;
|
||
import ai.djl.modality.cv.ImageFactory;
|
||
import ai.djl.ndarray.NDArray;
|
||
import ai.djl.ndarray.NDList;
|
||
import ai.djl.ndarray.NDManager;
|
||
import ai.djl.ndarray.types.DataType;
|
||
import ai.djl.repository.zoo.Criteria;
|
||
import ai.djl.repository.zoo.ModelNotFoundException;
|
||
import ai.djl.repository.zoo.ZooModel;
|
||
import ai.djl.translate.Batchifier;
|
||
import ai.djl.translate.TranslateException;
|
||
import ai.djl.translate.Translator;
|
||
import ai.djl.translate.TranslatorContext;
|
||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
|
||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
|
||
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
|
||
|
||
import javax.imageio.ImageIO;
|
||
import java.awt.image.BufferedImage;
|
||
import java.io.File;
|
||
import java.io.IOException;
|
||
import java.nio.file.Path;
|
||
import java.util.*;
|
||
|
||
public abstract class Segmenter implements AutoCloseable {
|
||
// 内部类,用于从Translator安全地传出数据
|
||
public static class SegmentationData {
|
||
public final int[] indices;
|
||
public final long[] shape;
|
||
|
||
public SegmentationData(int[] indices, long[] shape) {
|
||
this.indices = indices;
|
||
this.shape = shape;
|
||
}
|
||
}
|
||
|
||
private String engine = "PyTorch";
|
||
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
|
||
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
|
||
protected final List<String> labels;
|
||
protected final Map<String, Integer> palette;
|
||
|
||
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
||
this.labels = new ArrayList<>(labels);
|
||
this.palette = BiSeNetLabelPalette.defaultPalette();
|
||
|
||
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
|
||
@Override
|
||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||
return Segmenter.this.processInput(ctx, input);
|
||
}
|
||
|
||
@Override
|
||
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
||
return Segmenter.this.processOutput(ctx, list);
|
||
}
|
||
|
||
@Override
|
||
public Batchifier getBatchifier() {
|
||
return Segmenter.this.getBatchifier();
|
||
}
|
||
};
|
||
|
||
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
|
||
.setTypes(Image.class, Segmenter.SegmentationData.class)
|
||
.optModelPath(modelDir)
|
||
.optEngine(engine)
|
||
.optTranslator(translator)
|
||
.build();
|
||
|
||
this.modelWrapper = criteria.loadModel();
|
||
this.predictor = modelWrapper.newPredictor();
|
||
}
|
||
|
||
/**
|
||
* 处理模型输入
|
||
* @param ctx translator 上下文
|
||
* @param input 图片
|
||
* @return 模型输入
|
||
*/
|
||
public abstract NDList processInput(TranslatorContext ctx, Image input);
|
||
|
||
/**
|
||
* 处理模型输出
|
||
* @param ctx translator 上下文
|
||
* @param list 模型输出
|
||
* @return 模型输出
|
||
*/
|
||
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
|
||
|
||
/**
|
||
* 获取批量处理方式
|
||
* @return 批量处理方式
|
||
*/
|
||
public Batchifier getBatchifier(){
|
||
return null;
|
||
}
|
||
|
||
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
||
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
||
|
||
// predict 方法现在直接返回安全的 Java 对象
|
||
Segmenter.SegmentationData data = predictor.predict(img);
|
||
|
||
long[] shp = data.shape;
|
||
int[] indices = data.indices;
|
||
|
||
int height, width;
|
||
if (shp.length == 2) {
|
||
height = (int) shp[0];
|
||
width = (int) shp[1];
|
||
} else {
|
||
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
||
}
|
||
|
||
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
||
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
||
Map<Integer, String> labelsMap = new HashMap<>();
|
||
for (int i = 0; i < labels.size(); i++) {
|
||
labelsMap.put(i, labels.get(i));
|
||
}
|
||
|
||
for (int y = 0; y < height; y++) {
|
||
for (int x = 0; x < width; x++) {
|
||
int idx = indices[y * width + x];
|
||
String label = labelsMap.getOrDefault(idx, "unknown");
|
||
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
||
mask.setRGB(x, y, argb);
|
||
}
|
||
}
|
||
|
||
return new SegmentationResult(mask, labelsMap, palette);
|
||
}
|
||
|
||
public void setEngine(String engine) {
|
||
this.engine = engine;
|
||
}
|
||
|
||
@Override
|
||
public void close() {
|
||
try {
|
||
predictor.close();
|
||
} catch (Exception ignore) {
|
||
}
|
||
try {
|
||
modelWrapper.close();
|
||
} catch (Exception ignore) {
|
||
}
|
||
}
|
||
}
|