Files
window-axis-innovators-box/src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java
tzdwindows 7 e06c59c8d1 refactor(ai):重构分割模型包装类继承结构- 将 Anime2ModelWrapper、Anime2VividModelWrapper 和 AnimeModelWrapper 改为继承自 VividModelWrapper 基类
- 移除重复的 ResultFiles 内部类和相关工具方法实现
- Anime2Segmenter 和 AnimeSegmenter 继承自抽象基类 Segmenter
- Anime2SegmentationResult与 AnimeSegmentationResult 继承 SegmentationResult
- 重命名 LabelPalette 为 BiSeNetLabelPalette 并调整其引用
- 更新模型路径配置以匹配新的文件命名约定
- 删除冗余的 getLabels() 和 getPalette() 方法定义
- 简化 segmentAndSave 方法中的类型转换逻辑- 移除已被继承方法替代的手动资源管理代码
- 调整 import 语句以反映包结构调整- 清理不再需要的独立主测试函数入口点- 修改字段访问权限以符合继承设计模式
- 替换具体的返回类型为更通用的 SegmentationResult 接口- 整合公共功能至基类减少子类间重复代码
- 统一分割后处理流程提高模块复用性
- 引入泛型支持增强 Wrapper 类型安全性
- 更新注释文档保持与最新架构同步
- 优化异常处理策略统一关闭资源方式
- 规范文件命名规则便于未来维护扩展
- 提取共通逻辑到父类降低耦合度
- 完善类型检查避免运行时 ClassCastException 风险
2025-10-31 09:25:18 +08:00

155 lines
5.1 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package com.chuangzhou.vivid2D.ai;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public abstract class Segmenter implements AutoCloseable {
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
public final int[] indices;
public final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private String engine = "PyTorch";
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
protected final List<String> labels;
protected final Map<String, Integer> palette;
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = BiSeNetLabelPalette.defaultPalette();
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
return Segmenter.this.processInput(ctx, input);
}
@Override
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
return Segmenter.this.processOutput(ctx, list);
}
@Override
public Batchifier getBatchifier() {
return Segmenter.this.getBatchifier();
}
};
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, Segmenter.SegmentationData.class)
.optModelPath(modelDir)
.optEngine(engine)
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
/**
* 处理模型输入
* @param ctx translator 上下文
* @param input 图片
* @return 模型输入
*/
public abstract NDList processInput(TranslatorContext ctx, Image input);
/**
* 处理模型输出
* @param ctx translator 上下文
* @param list 模型输出
* @return 模型输出
*/
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
/**
* 获取批量处理方式
* @return 批量处理方式
*/
public Batchifier getBatchifier(){
return null;
}
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// predict 方法现在直接返回安全的 Java 对象
Segmenter.SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00);
mask.setRGB(x, y, argb);
}
}
return new SegmentationResult(mask, labelsMap, palette);
}
public void setEngine(String engine) {
this.engine = engine;
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
}