155 lines
5.1 KiB
Java
155 lines
5.1 KiB
Java
|
|
package com.chuangzhou.vivid2D.ai;
|
|||
|
|
|
|||
|
|
import ai.djl.MalformedModelException;
|
|||
|
|
import ai.djl.inference.Predictor;
|
|||
|
|
import ai.djl.modality.cv.Image;
|
|||
|
|
import ai.djl.modality.cv.ImageFactory;
|
|||
|
|
import ai.djl.ndarray.NDArray;
|
|||
|
|
import ai.djl.ndarray.NDList;
|
|||
|
|
import ai.djl.ndarray.NDManager;
|
|||
|
|
import ai.djl.ndarray.types.DataType;
|
|||
|
|
import ai.djl.repository.zoo.Criteria;
|
|||
|
|
import ai.djl.repository.zoo.ModelNotFoundException;
|
|||
|
|
import ai.djl.repository.zoo.ZooModel;
|
|||
|
|
import ai.djl.translate.Batchifier;
|
|||
|
|
import ai.djl.translate.TranslateException;
|
|||
|
|
import ai.djl.translate.Translator;
|
|||
|
|
import ai.djl.translate.TranslatorContext;
|
|||
|
|
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette;
|
|||
|
|
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult;
|
|||
|
|
import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter;
|
|||
|
|
|
|||
|
|
import javax.imageio.ImageIO;
|
|||
|
|
import java.awt.image.BufferedImage;
|
|||
|
|
import java.io.File;
|
|||
|
|
import java.io.IOException;
|
|||
|
|
import java.nio.file.Path;
|
|||
|
|
import java.util.*;
|
|||
|
|
|
|||
|
|
public abstract class Segmenter implements AutoCloseable {
|
|||
|
|
// 内部类,用于从Translator安全地传出数据
|
|||
|
|
public static class SegmentationData {
|
|||
|
|
public final int[] indices;
|
|||
|
|
public final long[] shape;
|
|||
|
|
|
|||
|
|
public SegmentationData(int[] indices, long[] shape) {
|
|||
|
|
this.indices = indices;
|
|||
|
|
this.shape = shape;
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
private String engine = "PyTorch";
|
|||
|
|
protected final ZooModel<Image, Segmenter.SegmentationData> modelWrapper;
|
|||
|
|
protected final Predictor<Image, Segmenter.SegmentationData> predictor;
|
|||
|
|
protected final List<String> labels;
|
|||
|
|
protected final Map<String, Integer> palette;
|
|||
|
|
|
|||
|
|
public Segmenter(Path modelDir, List<String> labels) throws IOException, MalformedModelException, ModelNotFoundException {
|
|||
|
|
this.labels = new ArrayList<>(labels);
|
|||
|
|
this.palette = BiSeNetLabelPalette.defaultPalette();
|
|||
|
|
|
|||
|
|
Translator<Image, Segmenter.SegmentationData> translator = new Translator<Image, Segmenter.SegmentationData>() {
|
|||
|
|
@Override
|
|||
|
|
public NDList processInput(TranslatorContext ctx, Image input) {
|
|||
|
|
return Segmenter.this.processInput(ctx, input);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@Override
|
|||
|
|
public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) {
|
|||
|
|
return Segmenter.this.processOutput(ctx, list);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@Override
|
|||
|
|
public Batchifier getBatchifier() {
|
|||
|
|
return Segmenter.this.getBatchifier();
|
|||
|
|
}
|
|||
|
|
};
|
|||
|
|
|
|||
|
|
Criteria<Image, Segmenter.SegmentationData> criteria = Criteria.builder()
|
|||
|
|
.setTypes(Image.class, Segmenter.SegmentationData.class)
|
|||
|
|
.optModelPath(modelDir)
|
|||
|
|
.optEngine(engine)
|
|||
|
|
.optTranslator(translator)
|
|||
|
|
.build();
|
|||
|
|
|
|||
|
|
this.modelWrapper = criteria.loadModel();
|
|||
|
|
this.predictor = modelWrapper.newPredictor();
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* 处理模型输入
|
|||
|
|
* @param ctx translator 上下文
|
|||
|
|
* @param input 图片
|
|||
|
|
* @return 模型输入
|
|||
|
|
*/
|
|||
|
|
public abstract NDList processInput(TranslatorContext ctx, Image input);
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* 处理模型输出
|
|||
|
|
* @param ctx translator 上下文
|
|||
|
|
* @param list 模型输出
|
|||
|
|
* @return 模型输出
|
|||
|
|
*/
|
|||
|
|
public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list);
|
|||
|
|
|
|||
|
|
/**
|
|||
|
|
* 获取批量处理方式
|
|||
|
|
* @return 批量处理方式
|
|||
|
|
*/
|
|||
|
|
public Batchifier getBatchifier(){
|
|||
|
|
return null;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
public SegmentationResult segment(File imgFile) throws TranslateException, IOException {
|
|||
|
|
Image img = ImageFactory.getInstance().fromFile(imgFile.toPath());
|
|||
|
|
|
|||
|
|
// predict 方法现在直接返回安全的 Java 对象
|
|||
|
|
Segmenter.SegmentationData data = predictor.predict(img);
|
|||
|
|
|
|||
|
|
long[] shp = data.shape;
|
|||
|
|
int[] indices = data.indices;
|
|||
|
|
|
|||
|
|
int height, width;
|
|||
|
|
if (shp.length == 2) {
|
|||
|
|
height = (int) shp[0];
|
|||
|
|
width = (int) shp[1];
|
|||
|
|
} else {
|
|||
|
|
throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 后续处理完全基于 Java 对象,不再有 Native resource 问题
|
|||
|
|
BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB);
|
|||
|
|
Map<Integer, String> labelsMap = new HashMap<>();
|
|||
|
|
for (int i = 0; i < labels.size(); i++) {
|
|||
|
|
labelsMap.put(i, labels.get(i));
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for (int y = 0; y < height; y++) {
|
|||
|
|
for (int x = 0; x < width; x++) {
|
|||
|
|
int idx = indices[y * width + x];
|
|||
|
|
String label = labelsMap.getOrDefault(idx, "unknown");
|
|||
|
|
int argb = palette.getOrDefault(label, 0xFF00FF00);
|
|||
|
|
mask.setRGB(x, y, argb);
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return new SegmentationResult(mask, labelsMap, palette);
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
public void setEngine(String engine) {
|
|||
|
|
this.engine = engine;
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@Override
|
|||
|
|
public void close() {
|
|||
|
|
try {
|
|||
|
|
predictor.close();
|
|||
|
|
} catch (Exception ignore) {
|
|||
|
|
}
|
|||
|
|
try {
|
|||
|
|
modelWrapper.close();
|
|||
|
|
} catch (Exception ignore) {
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|