Files
window-axis-innovators-box/src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java

155 lines
5.1 KiB
Java
Raw Normal View History

package com.chuangzhou.vivid2D.ai;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.*;
public abstract class Segmenter implements AutoCloseable {
// 内部类用于从Translator安全地传出数据
public static class SegmentationData {
public final int[] indices;
public final long[] shape;
public SegmentationData(int[] indices, long[] shape) {
this.indices = indices;
this.shape = shape;
}
}
private String engine = "PyTorch";
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
protected final List<String> labels;
protected final Map<String, Integer> palette;
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
this.labels = new ArrayList<>(labels);
this.palette = BiSeNetLabelPalette.defaultPalette();
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
return Segmenter.this.processInput(ctx, input);
}
@Override
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
return Segmenter.this.processOutput(ctx, list);
}
@Override
public Batchifier getBatchifier() {
return Segmenter.this.getBatchifier();
}
};
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
.setTypes(Image.class, Segmenter.SegmentationData.class)
.optModelPath(modelDir)
.optEngine(engine)
.optTranslator(translator)
.build();
this.modelWrapper = criteria.loadModel();
this.predictor = modelWrapper.newPredictor();
}
/**
* 处理模型输入
* @param ctx translator 上下文
* @param input 图片
* @return 模型输入
*/
public abstract NDList processInput(TranslatorContext ctx, Image input);
/**
* 处理模型输出
* @param ctx translator 上下文
* @param list 模型输出
* @return 模型输出
*/
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
/**
* 获取批量处理方式
* @return 批量处理方式
*/
public Batchifier getBatchifier(){
return null;
}
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
// predict 方法现在直接返回安全的 Java 对象
Segmenter.SegmentationData data = predictor.predict(img);
long[] shp = data.shape;
int[] indices = data.indices;
int height, width;
if (shp.length == 2) {
height = (int) shp[0];
width = (int) shp[1];
} else {
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
}
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
Map<Integer, String> labelsMap = new HashMap<>();
for (int i = 0; i < labels.size(); i++) {
labelsMap.put(i, labels.get(i));
}
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int idx = indices[y * width + x];
String label = labelsMap.getOrDefault(idx, "unknown");
int argb = palette.getOrDefault(label, 0xFF00FF00);
mask.setRGB(x, y, argb);
}
}
return new SegmentationResult(mask, labelsMap, palette);
}
public void setEngine(String engine) {
this.engine = engine;
}
@Override
public void close() {
try {
predictor.close();
} catch (Exception ignore) {
}
try {
modelWrapper.close();
} catch (Exception ignore) {
}
}
}