package com.chuangzhou.vivid2D.ai; import ai.djl.MalformedModelException; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette; import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult; import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.nio.file.Path; import java.util.*; public abstract class Segmenter implements AutoCloseable { // 内部类,用于从Translator安全地传出数据 public static class SegmentationData { public final int[] indices; public final long[] shape; public SegmentationData(int[] indices, long[] shape) { this.indices = indices; this.shape = shape; } } private String engine = "PyTorch"; protected final ZooModel modelWrapper; protected final Predictor predictor; protected final List labels; protected final Map palette; public Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { this.labels = new ArrayList<>(labels); this.palette = BiSeNetLabelPalette.defaultPalette(); Translator translator = new Translator() { @Override public NDList processInput(TranslatorContext ctx, Image input) { return Segmenter.this.processInput(ctx, input); } @Override public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) { return Segmenter.this.processOutput(ctx, list); } @Override public Batchifier getBatchifier() { return Segmenter.this.getBatchifier(); } }; Criteria criteria = Criteria.builder() .setTypes(Image.class, Segmenter.SegmentationData.class) .optModelPath(modelDir) .optEngine(engine) .optTranslator(translator) .build(); this.modelWrapper = criteria.loadModel(); this.predictor = modelWrapper.newPredictor(); } /** * 处理模型输入 * @param ctx translator 上下文 * @param input 图片 * @return 模型输入 */ public abstract NDList processInput(TranslatorContext ctx, Image input); /** * 处理模型输出 * @param ctx translator 上下文 * @param list 模型输出 * @return 模型输出 */ public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list); /** * 获取批量处理方式 * @return 批量处理方式 */ public Batchifier getBatchifier(){ return null; } public SegmentationResult segment(File imgFile) throws TranslateException, IOException { Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); // predict 方法现在直接返回安全的 Java 对象 Segmenter.SegmentationData data = predictor.predict(img); long[] shp = data.shape; int[] indices = data.indices; int height, width; if (shp.length == 2) { height = (int) shp[0]; width = (int) shp[1]; } else { throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); } // 后续处理完全基于 Java 对象,不再有 Native resource 问题 BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); Map labelsMap = new HashMap<>(); for (int i = 0; i < labels.size(); i++) { labelsMap.put(i, labels.get(i)); } for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int idx = indices[y * width + x]; String label = labelsMap.getOrDefault(idx, "unknown"); int argb = palette.getOrDefault(label, 0xFF00FF00); mask.setRGB(x, y, argb); } } return new SegmentationResult(mask, labelsMap, palette); } public void setEngine(String engine) { this.engine = engine; } @Override public void close() { try { predictor.close(); } catch (Exception ignore) { } try { modelWrapper.close(); } catch (Exception ignore) { } } }