diff --git a/build.gradle b/build.gradle index 664d6e7..c03c720 100644 --- a/build.gradle +++ b/build.gradle @@ -51,6 +51,16 @@ dependencies { implementation files('libs/dog api 1.3.jar') implementation files('libs/DesktopWallpaperSdk-1.0-SNAPSHOT.jar') + // === DJL API === + implementation platform('ai.djl:bom:0.35.0') + implementation 'ai.djl:api' + implementation 'ai.djl:model-zoo' + implementation 'ai.djl.pytorch:pytorch-model-zoo:0.35.0' + implementation 'ai.djl.pytorch:pytorch-engine' + implementation 'ai.djl:basicdataset' + implementation 'ai.djl.onnxruntime:onnxruntime-engine' + runtimeOnly 'ai.djl.pytorch:pytorch-native-cpu:2.7.1' + runtimeOnly 'ai.djl.onnxruntime:onnxruntime-native-cpu:1.3.0' // === 核心工具库 === implementation 'com.google.code.gson:gson:2.10.1' // 统一版本 implementation 'org.apache.logging.log4j:log4j-api:2.20.0' diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeLabelPalette.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeLabelPalette.java new file mode 100644 index 0000000..efc3265 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeLabelPalette.java @@ -0,0 +1,62 @@ +package com.chuangzhou.vivid2D.ai.anime_face_segmentation; + +import java.util.*; + +/** + * Anime-Face-Segmentation UNet 模型的标签和颜色调色板。 + * 基于 Anime-Face-Segmentation 项目的 util.py 中的颜色定义。 + * 标签索引必须与模型输出索引一致(0-6)。 + */ +public class AnimeLabelPalette { + + /** + * Anime-Face-Segmentation UNet 模型的标准标签(7个类别,索引 0-6) + */ + public static List defaultLabels() { + return Arrays.asList( + "background", // 0 - 青色 (0,255,255) + "hair", // 1 - 蓝色 (255,0,0) + "eye", // 2 - 红色 (0,0,255) + "mouth", // 3 - 白色 (255,255,255) + "face", // 4 - 绿色 (0,255,0) + "skin", // 5 - 黄色 (255,255,0) + "clothes" // 6 - 紫色 (255,0,255) + ); + } + + /** + * 返回对应的调色板:类别名 -> ARGB 颜色值。 + * 颜色值基于 util.py 中的 PALETTE 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。 + */ + public static Map defaultPalette() { + Map map = new HashMap<>(); + // 索引 0: background -> (0,255,255) 青色 + map.put("background", 0xFF00FFFF); + // 索引 1: hair -> (255,0,0) 蓝色 + map.put("hair", 0xFFFF0000); + // 索引 2: eye -> (0,0,255) 红色 + map.put("eye", 0xFF0000FF); + // 索引 3: mouth -> (255,255,255) 白色 + map.put("mouth", 0xFFFFFFFF); + // 索引 4: face -> (0,255,0) 绿色 + map.put("face", 0xFF00FF00); + // 索引 5: skin -> (255,255,0) 黄色 + map.put("skin", 0xFFFFFF00); + // 索引 6: clothes -> (255,0,255) 紫色 + map.put("clothes", 0xFFFF00FF); + + return map; + } + + /** + * 获取类别索引到名称的映射 + */ + public static Map getIndexToLabelMap() { + List labels = defaultLabels(); + Map map = new HashMap<>(); + for (int i = 0; i < labels.size(); i++) { + map.put(i, labels.get(i)); + } + return map; + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java new file mode 100644 index 0000000..5d6658e --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java @@ -0,0 +1,426 @@ +package com.chuangzhou.vivid2D.ai.anime_face_segmentation; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +/** + * AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper + */ +public class AnimeModelWrapper implements AutoCloseable { + + private final AnimeSegmenter segmenter; + private final List labels; // index -> name + private final Map palette; // name -> ARGB + + private AnimeModelWrapper(AnimeSegmenter segmenter, List labels, Map palette) { + this.segmenter = segmenter; + this.labels = labels; + this.palette = palette; + } + + /** + * 加载模型 + */ + public static AnimeModelWrapper load(Path modelDir) throws Exception { + List labels = loadLabelsFromSynset(modelDir).orElseGet(AnimeLabelPalette::defaultLabels); + AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels); + Map palette = AnimeLabelPalette.defaultPalette(); + return new AnimeModelWrapper(segmenter, labels, palette); + } + + public List getLabels() { + return Collections.unmodifiableList(labels); + } + + public Map getPalette() { + return Collections.unmodifiableMap(palette); + } + + /** + * 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理:RGB 转换 + 等比 letterbox 缩放到模型输入尺寸) + */ + public AnimeSegmentationResult segment(File inputImage) throws Exception { + File pre = null; + try { + pre = preprocessAndSave(inputImage); + // 将预处理后的临时文件丢给底层 segmenter + return segmenter.segment(pre); + } finally { + if (pre != null && pre.exists()) { + try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {} + } + } + } + + /** + * 分割并保存结果 + */ + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + AnimeSegmentationResult res = segment(inputImage); + BufferedImage original = ImageIO.read(inputImage); + BufferedImage maskImage = res.getMaskImage(); + + int maskW = maskImage.getWidth(); + int maskH = maskImage.getHeight(); + + // 解析 targets + Set realTargets = parseTargetsSet(targets); + Map saved = new LinkedHashMap<>(); + + for (String target : realTargets) { + if (!palette.containsKey(target)) { + // 尝试忽略大小写匹配 + String finalTarget = target; + Optional matched = palette.keySet().stream() + .filter(k -> k.equalsIgnoreCase(finalTarget)) + .findFirst(); + if (matched.isPresent()) target = matched.get(); + else { + System.err.println("Warning: unknown label '" + target + "' - skip."); + continue; + } + } + + int targetColor = palette.get(target); + + // 1) 生成透明背景的二值掩码(只保留 target 像素) + BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < maskH; y++) { + for (int x = 0; x < maskW; x++) { + int c = maskImage.getRGB(x, y); + if (c == targetColor) { + partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明 + } else { + partMask.setRGB(x, y, 0x00000000); + } + } + } + + // 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明) + BufferedImage maskResized = partMask; + if (original.getWidth() != maskW || original.getHeight() != maskH) { + maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = maskResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + // 半透明颜色(alpha = 0x88) + int rgbOnly = (targetColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = maskResized.getRGB(x, y); + if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + // 保存 + String safe = safeFileName(target); + File maskOut = outDir.resolve(safe + "_mask.png").toFile(); + File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); + + ImageIO.write(maskResized, "png", maskOut); + ImageIO.write(overlay, "png", overlayOut); + + saved.put(target, new ResultFiles(maskOut, overlayOut)); + } + + return saved; + } + + private static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } + + private static Set parseTargetsSet(Set in) { + if (in == null || in.isEmpty()) return Collections.emptySet(); + // 若包含单个 "all" + if (in.size() == 1) { + String only = in.iterator().next(); + if ("all".equalsIgnoreCase(only.trim())) { + // 返回所有标签 + return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels()); + } + } + // 直接返回 trim 后的集合 + Set out = new LinkedHashSet<>(); + for (String s : in) { + if (s != null) out.add(s.trim()); + } + return out; + } + + /** + * 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理) + */ + public ResultFiles extractEyes(File inputImage, Path outDir) throws Exception { + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + File pre = null; + BufferedImage eyes; + try { + pre = preprocessAndSave(inputImage); + eyes = segmenter.extractEyes(pre); + } finally { + if (pre != null && pre.exists()) { + try { Files.deleteIfExists(pre.toPath()); } catch (Exception ignore) {} + } + } + + File eyesMask = outDir.resolve("eyes_mask.png").toFile(); + ImageIO.write(eyes, "png", eyesMask); + + // 创建眼睛的 overlay(原有逻辑,保持不变) + BufferedImage original = ImageIO.read(inputImage); + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + + // 缩放眼睛掩码到原图尺寸 + BufferedImage eyesResized = eyes; + if (original.getWidth() != eyes.getWidth() || original.getHeight() != eyes.getHeight()) { + eyesResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = eyesResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(eyes, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + int eyeColor = palette.getOrDefault("eye", 0xFF00FF); // 若没有 eye,给个显眼默认色 + int rgbOnly = (eyeColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = eyesResized.getRGB(x, y); + if ((mc & 0x00FFFFFF) == (eyeColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + File eyesOverlay = outDir.resolve("eyes_overlay.png").toFile(); + ImageIO.write(overlay, "png", eyesOverlay); + + return new ResultFiles(eyesMask, eyesOverlay); + } + + /** + * 关闭底层资源 + */ + @Override + public void close() { + try { + segmenter.close(); + } catch (Exception ignore) {} + } + + /** + * 存放结果文件路径 + */ + public static class ResultFiles { + private final File maskFile; + private final File overlayFile; + + public ResultFiles(File maskFile, File overlayFile) { + this.maskFile = maskFile; + this.overlayFile = overlayFile; + } + + public File getMaskFile() { + return maskFile; + } + + public File getOverlayFile() { + return overlayFile; + } + } + + /* ================= helper: 从 modelDir 读取 synset.txt ================= */ + private static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + // ========== 新增:预处理并保存到临时文件 ========== + private File preprocessAndSave(File inputImage) throws IOException { + BufferedImage img = ImageIO.read(inputImage); + if (img == null) throw new IOException("无法读取图片: " + inputImage); + + // 转成标准 RGB(去掉 alpha / 保证三通道) + BufferedImage rgb = new BufferedImage(img.getWidth(), img.getHeight(), BufferedImage.TYPE_INT_RGB); + Graphics2D g = rgb.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(img, 0, 0, null); + g.dispose(); + + // 获取模型输入尺寸(尝试反射读取,找不到则使用默认 512x512) + int[] size = getModelInputSize(); + int targetW = size[0], targetH = size[1]; + + // 等比缩放并居中填充(letterbox),背景用白色 + double scale = Math.min((double) targetW / rgb.getWidth(), (double) targetH / rgb.getHeight()); + int newW = Math.max(1, (int) Math.round(rgb.getWidth() * scale)); + int newH = Math.max(1, (int) Math.round(rgb.getHeight() * scale)); + + BufferedImage resized = new BufferedImage(targetW, targetH, BufferedImage.TYPE_INT_RGB); + Graphics2D g2 = resized.createGraphics(); + g2.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g2.setColor(Color.WHITE); + g2.fillRect(0, 0, targetW, targetH); + int x = (targetW - newW) / 2; + int y = (targetH - newH) / 2; + g2.drawImage(rgb, x, y, newW, newH, null); + g2.dispose(); + + // 保存为临时 PNG 文件(确保无压缩失真) + File tmp = Files.createTempFile("anime_pre_", ".png").toFile(); + ImageIO.write(resized, "png", tmp); + return tmp; + } + + // ========== 新增:尝试通过反射从 segmenter 上读取模型输入尺寸 ========== + private int[] getModelInputSize() { + // 默认值 + int defaultSize = 512; + int w = defaultSize, h = defaultSize; + + try { + Class cls = segmenter.getClass(); + + // 尝试方法 getInputWidth/getInputHeight + try { + Method mw = cls.getMethod("getInputWidth"); + Method mh = cls.getMethod("getInputHeight"); + Object ow = mw.invoke(segmenter); + Object oh = mh.invoke(segmenter); + if (ow instanceof Number && oh instanceof Number) { + int iw = ((Number) ow).intValue(); + int ih = ((Number) oh).intValue(); + if (iw > 0 && ih > 0) { + return new int[]{iw, ih}; + } + } + } catch (NoSuchMethodException ignored) {} + + // 尝试方法 getInputSize 返回 int[] 或 Dimension + try { + Method ms = cls.getMethod("getInputSize"); + Object os = ms.invoke(segmenter); + if (os instanceof int[] && ((int[]) os).length >= 2) { + int iw = ((int[]) os)[0]; + int ih = ((int[]) os)[1]; + if (iw > 0 && ih > 0) return new int[]{iw, ih}; + } else if (os != null) { + // 处理 java.awt.Dimension + try { + Method gw = os.getClass().getMethod("getWidth"); + Method gh = os.getClass().getMethod("getHeight"); + Object ow2 = gw.invoke(os); + Object oh2 = gh.invoke(os); + if (ow2 instanceof Number && oh2 instanceof Number) { + int iw = ((Number) ow2).intValue(); + int ih = ((Number) oh2).intValue(); + if (iw > 0 && ih > 0) return new int[]{iw, ih}; + } + } catch (Exception ignored2) {} + } + } catch (NoSuchMethodException ignored) {} + + // 尝试字段 inputWidth/inputHeight + try { + try { + java.lang.reflect.Field fw = cls.getDeclaredField("inputWidth"); + java.lang.reflect.Field fh = cls.getDeclaredField("inputHeight"); + fw.setAccessible(true); fh.setAccessible(true); + Object ow = fw.get(segmenter); + Object oh = fh.get(segmenter); + if (ow instanceof Number && oh instanceof Number) { + int iw = ((Number) ow).intValue(); + int ih = ((Number) oh).intValue(); + if (iw > 0 && ih > 0) return new int[]{iw, ih}; + } + } catch (NoSuchFieldException ignoredField) {} + } catch (Exception ignored) {} + + } catch (Exception ignored) { + // 任何反射异常都回退到默认值 + } + + return new int[]{w, h}; + } + + /* ================= convenience 主方法(快速测试) ================= */ + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.out.println("用法: AnimeModelWrapper "); + System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face"); + System.out.println("标签: " + AnimeLabelPalette.defaultLabels()); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + Path out = Path.of(args[2]); + String targetsArg = args[3]; + + Set targets; + if ("all".equalsIgnoreCase(targetsArg.trim())) { + targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels()); + } else { + String[] parts = targetsArg.split(","); + targets = new LinkedHashSet<>(); + for (String p : parts) { + if (!p.trim().isEmpty()) targets.add(p.trim()); + } + } + + try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) { + Map m = wrapper.segmentAndSave(input, targets, out); + m.forEach((k, v) -> { + System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); + }); + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java new file mode 100644 index 0000000..c354cf8 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java @@ -0,0 +1,61 @@ +package com.chuangzhou.vivid2D.ai.anime_face_segmentation; + +import java.awt.image.BufferedImage; +import java.util.Map; + +/** + * 动漫分割结果容器 + */ +public class AnimeSegmentationResult { + // 分割掩码图(每个像素的颜色为对应类别颜色) + private final BufferedImage maskImage; + + // 分割概率图(每个像素的类别概率分布) + private final float[][][] probabilityMap; + + // 类别索引 -> 类别名称 + private final Map labels; + + // 类别名称 -> ARGB 颜色 + private final Map palette; + + public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap, + Map labels, Map palette) { + this.maskImage = maskImage; + this.probabilityMap = probabilityMap; + this.labels = labels; + this.palette = palette; + } + + public BufferedImage getMaskImage() { + return maskImage; + } + + public float[][][] getProbabilityMap() { + return probabilityMap; + } + + public Map getLabels() { + return labels; + } + + public Map getPalette() { + return palette; + } + + /** + * 获取指定类别的概率图 + */ + public float[][] getClassProbability(int classIndex) { + if (probabilityMap == null) return null; + int height = probabilityMap.length; + int width = probabilityMap[0].length; + float[][] result = new float[height][width]; + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + result[y][x] = probabilityMap[y][x][classIndex]; + } + } + return result; + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java new file mode 100644 index 0000000..98d0c6c --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java @@ -0,0 +1,230 @@ +package com.chuangzhou.vivid2D.ai.anime_face_segmentation; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +/** + * AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器 + */ +public class AnimeSegmenter implements AutoCloseable { + + // 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。 + private static final int MODEL_INPUT_W = 512; + private static final int MODEL_INPUT_H = 512; + + // 内部类,用于从Translator安全地传出数据 + public static class SegmentationData { + final int[] indices; // 类别索引 [H * W] + final float[][][] probMap; // 概率图 [H][W][C] + final long[] shape; // 形状 [H, W] + + public SegmentationData(int[] indices, float[][][] probMap, long[] shape) { + this.indices = indices; + this.probMap = probMap; + this.shape = shape; + } + } + + private final ZooModel modelWrapper; + private final Predictor predictor; + private final List labels; + private final Map palette; + + public AnimeSegmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { + this.labels = new ArrayList<>(labels); + this.palette = AnimeLabelPalette.defaultPalette(); + + Translator translator = new Translator() { + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + + // 如果图片已经是模型输入大小则不再 resize(避免重复缩放导致失真) + Image toUse = input; + if (!(input.getWidth() == MODEL_INPUT_W && input.getHeight() == MODEL_INPUT_H)) { + toUse = input.resize(MODEL_INPUT_W, MODEL_INPUT_H, true); + } + + // 转换为NDArray并预处理 + NDArray array = toUse.toNDArray(manager); + // DJL 返回 HWC 格式数组,转换为 CHW,并标准化到 [0,1] + array = array.transpose(2, 0, 1) // HWC -> CHW + .toType(DataType.FLOAT32, false) + .div(255f) // 归一化到[0,1] + .expandDims(0); // 添加batch维度 [1,3,H,W] + + return new NDList(array); + } + + @Override + public SegmentationData processOutput(TranslatorContext ctx, NDList list) { + if (list == null || list.isEmpty()) { + throw new IllegalStateException("Model did not return any output."); + } + + NDArray output = list.get(0); // 期望形状 [1,C,H,W] 或 [1,C,W,H](以训练时一致为准) + + // 确保维度:把 output 视作 [1, C, H, W] + Shape outShape = output.getShape(); + if (outShape.dimension() != 4) { + throw new IllegalStateException("Unexpected output shape: " + outShape); + } + + // 1. 获取类别索引(argmax) -> [H, W] + NDArray squeezed = output.squeeze(0); // [C,H,W] + NDArray classMap = squeezed.argMax(0).toType(DataType.INT32, false); // argMax over channel维度 + + // 2. 获取概率图(softmax 输出或模型已经输出概率),转换为 [H,W,C] + NDArray probabilities = squeezed.transpose(1, 2, 0) // [H,W,C] + .toType(DataType.FLOAT32, false); + + // 3. 转换为Java数组 + long[] shape = classMap.getShape().getShape(); // [H, W] + int[] indices = classMap.toIntArray(); + long[] probShape = probabilities.getShape().getShape(); // [H, W, C] + int height = (int) probShape[0]; + int width = (int) probShape[1]; + int classes = (int) probShape[2]; + float[] flatProbMap = probabilities.toFloatArray(); + float[][][] probMap = new float[height][width][classes]; + for (int i = 0; i < height; i++) { + for (int j = 0; j < width; j++) { + for (int k = 0; k < classes; k++) { + int index = i * width * classes + j * classes + k; + probMap[i][j][k] = flatProbMap[index]; + } + } + } + + return new SegmentationData(indices, probMap, shape); + } + + @Override + public Batchifier getBatchifier() { + return null; + } + }; + + Criteria criteria = Criteria.builder() + .setTypes(Image.class, SegmentationData.class) + .optModelPath(modelDir) + .optEngine("PyTorch") + .optTranslator(translator) + .build(); + + this.modelWrapper = criteria.loadModel(); + this.predictor = modelWrapper.newPredictor(); + } + + public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException { + Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); + + // 预测并获取分割数据 + SegmentationData data = predictor.predict(img); + + long[] shp = data.shape; + int[] indices = data.indices; + float[][][] probMap = data.probMap; + + int height = (int) shp[0]; + int width = (int) shp[1]; + + // 创建掩码图像 + BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); + Map labelsMap = AnimeLabelPalette.getIndexToLabelMap(); + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + int idx = indices[y * width + x]; + String label = labelsMap.getOrDefault(idx, "unknown"); + int argb = palette.getOrDefault(label, 0xFF00FF00); // 默认绿色 + mask.setRGB(x, y, argb); + } + } + + return new AnimeSegmentationResult(mask, probMap, labelsMap, palette); + } + + /** + * 专门针对眼睛的分割方法 + */ + public BufferedImage extractEyes(File imgFile) throws TranslateException, IOException { + AnimeSegmentationResult result = segment(imgFile); + BufferedImage mask = result.getMaskImage(); + BufferedImage eyeMask = new BufferedImage(mask.getWidth(), mask.getHeight(), BufferedImage.TYPE_INT_ARGB); + + int eyeColor = palette.get("eye"); + + for (int y = 0; y < mask.getHeight(); y++) { + for (int x = 0; x < mask.getWidth(); x++) { + int rgb = mask.getRGB(x, y); + if (rgb == eyeColor) { + eyeMask.setRGB(x, y, eyeColor); + } else { + eyeMask.setRGB(x, y, 0x00000000); // 透明 + } + } + } + + return eyeMask; + } + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignore) { + } + try { + modelWrapper.close(); + } catch (Exception ignore) { + } + } + + // 测试主函数 + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.out.println("用法: java AnimeSegmenter "); + System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png"); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + File out = new File(args[2]); + + List labels = AnimeLabelPalette.defaultLabels(); + + try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) { + AnimeSegmentationResult res = segmenter.segment(input); + ImageIO.write(res.getMaskImage(), "png", out); + System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath()); + + // 额外保存眼睛分割结果 + BufferedImage eyes = segmenter.extractEyes(input); + File eyesOut = new File(out.getParent(), "eyes_" + out.getName()); + ImageIO.write(eyes, "png", eyesOut); + System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath()); + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2LabelPalette.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2LabelPalette.java new file mode 100644 index 0000000..da9ed07 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2LabelPalette.java @@ -0,0 +1,46 @@ +package com.chuangzhou.vivid2D.ai.anime_segmentation; + +import java.util.*; + +/** + * 动漫分割模型的标签和颜色调色板。 + * 这是一个二分类模型:背景和前景(动漫人物) + */ +public class Anime2LabelPalette { + + /** + * 动漫分割模型的标准标签(2个类别) + */ + public static List defaultLabels() { + return Arrays.asList( + "background", // 0 + "foreground" // 1 + ); + } + + /** + * 返回动漫分割模型的调色板 + */ + public static Map defaultPalette() { + Map map = new HashMap<>(); + // 索引 0: background - 黑色 + map.put("background", 0xFF000000); + // 索引 1: foreground - 白色 + map.put("foreground", 0xFFFFFFFF); + + return map; + } + + /** + * 专门为动漫分割模型设计的调色板(可视化更友好) + */ + public static Map animeSegmentationPalette() { + Map map = new HashMap<>(); + // 背景 - 透明 + map.put("background", 0x00000000); + // 前景 - 红色(用于可视化) + map.put("foreground", 0xFFFF0000); + + return map; + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java new file mode 100644 index 0000000..de57b2e --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java @@ -0,0 +1,244 @@ +package com.chuangzhou.vivid2D.ai.anime_segmentation; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +/** + * Anime2ModelWrapper - 对动漫分割模型的封装 + * + * 用法示例: + * Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir")); + * Map out = wrapper.segmentAndSave( + * new File("input.jpg"), + * Set.of("foreground"), // 动漫分割主要关注前景 + * Paths.get("outDir") + * ); + * wrapper.close(); + */ +public class Anime2ModelWrapper implements AutoCloseable { + + private final Anime2Segmenter segmenter; + private final List labels; // index -> name + private final Map palette; // name -> ARGB + + private Anime2ModelWrapper(Anime2Segmenter segmenter, List labels, Map palette) { + this.segmenter = segmenter; + this.labels = labels; + this.palette = palette; + } + + /** + * 创建 Anime2Segmenter 实例 + */ + public static Anime2ModelWrapper load(Path modelDir) throws Exception { + List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); + Anime2Segmenter s = new Anime2Segmenter(modelDir, labels); + Map palette = Anime2LabelPalette.animeSegmentationPalette(); + return new Anime2ModelWrapper(s, labels, palette); + } + + public List getLabels() { + return Collections.unmodifiableList(labels); + } + + public Map getPalette() { + return Collections.unmodifiableMap(palette); + } + + /** + * 直接返回分割结果 + */ + public Anime2SegmentationResult segment(File inputImage) throws Exception { + return segmenter.segment(inputImage); + } + + /** + * 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir + */ + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + Anime2SegmentationResult res = segment(inputImage); + BufferedImage original = ImageIO.read(inputImage); + BufferedImage maskImage = res.getMaskImage(); + + int maskW = maskImage.getWidth(); + int maskH = maskImage.getHeight(); + + // 解析 targets + Set realTargets = parseTargetsSet(targets); + Map saved = new LinkedHashMap<>(); + + for (String target : realTargets) { + if (!palette.containsKey(target)) { + System.err.println("Warning: unknown label '" + target + "' - skip."); + continue; + } + + int targetColor = palette.get(target); + + // 1) 生成透明背景的二值掩码(只保留 target 像素) + BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < maskH; y++) { + for (int x = 0; x < maskW; x++) { + int c = maskImage.getRGB(x, y); + if (c == targetColor) { + partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明 + } else { + partMask.setRGB(x, y, 0x00000000); + } + } + } + + // 2) 将 mask 缩放到与原图一致 + BufferedImage maskResized = partMask; + if (original.getWidth() != maskW || original.getHeight() != maskH) { + maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = maskResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + // 3) 生成叠加图 + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + + int rgbOnly = (targetColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = maskResized.getRGB(x, y); + if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + // 保存 + String safe = safeFileName(target); + File maskOut = outDir.resolve(safe + "_mask.png").toFile(); + File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); + + ImageIO.write(maskResized, "png", maskOut); + ImageIO.write(overlay, "png", overlayOut); + + saved.put(target, new ResultFiles(maskOut, overlayOut)); + } + + return saved; + } + + private static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } + + private static Set parseTargetsSet(Set in) { + if (in == null || in.isEmpty()) return Collections.emptySet(); + if (in.size() == 1) { + String only = in.iterator().next(); + if ("all".equalsIgnoreCase(only.trim())) { + return Set.of("foreground"); // 动漫分割主要关注前景 + } + } + Set out = new LinkedHashSet<>(); + for (String s : in) { + if (s != null) out.add(s.trim()); + } + return out; + } + + /** + * 关闭底层资源 + */ + @Override + public void close() { + try { + segmenter.close(); + } catch (Exception ignore) {} + } + + /** + * 存放结果文件路径 + */ + public static class ResultFiles { + private final File maskFile; + private final File overlayFile; + + public ResultFiles(File maskFile, File overlayFile) { + this.maskFile = maskFile; + this.overlayFile = overlayFile; + } + + public File getMaskFile() { + return maskFile; + } + + public File getOverlayFile() { + return overlayFile; + } + } + + /* ================= helper: 从 modelDir 读取 synset.txt ================= */ + private static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + /* ================= convenience 主方法(快速测试) ================= */ + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.out.println("用法: Anime2ModelWrapper "); + System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground"); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + Path out = Path.of(args[2]); + String targetsArg = args[3]; + + List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); + Set targets; + if ("all".equalsIgnoreCase(targetsArg.trim())) { + targets = new LinkedHashSet<>(labels); + } else { + String[] parts = targetsArg.split(","); + targets = new LinkedHashSet<>(); + for (String p : parts) { + if (!p.trim().isEmpty()) targets.add(p.trim()); + } + } + + try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) { + Map m = wrapper.segmentAndSave(input, targets, out); + m.forEach((k, v) -> { + System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); + }); + } + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java new file mode 100644 index 0000000..d638f29 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java @@ -0,0 +1,36 @@ +package com.chuangzhou.vivid2D.ai.anime_segmentation; + +import java.awt.image.BufferedImage; +import java.util.Map; + +/** + * 动漫分割结果容器 + */ +public class Anime2SegmentationResult { + // 分割掩码图(每个像素的颜色为对应类别颜色) + private final BufferedImage maskImage; + + // 类别索引 -> 类别名称 + private final Map labels; + + // 类别名称 -> ARGB 颜色 + private final Map palette; + + public Anime2SegmentationResult(BufferedImage maskImage, Map labels, Map palette) { + this.maskImage = maskImage; + this.labels = labels; + this.palette = palette; + } + + public BufferedImage getMaskImage() { + return maskImage; + } + + public Map getLabels() { + return labels; + } + + public Map getPalette() { + return palette; + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java new file mode 100644 index 0000000..2482b92 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java @@ -0,0 +1,175 @@ +package com.chuangzhou.vivid2D.ai.anime_segmentation; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +/** + * Anime2Segmenter: 专门用于动漫分割模型 + * 处理 anime-segmentation 模型的二值分割输出 + */ +public class Anime2Segmenter implements AutoCloseable { + + public static class SegmentationData { + final int[] indices; + final long[] shape; + + public SegmentationData(int[] indices, long[] shape) { + this.indices = indices; + this.shape = shape; + } + } + + private final ZooModel modelWrapper; + private final Predictor predictor; + private final List labels; + private final Map palette; + + public Anime2Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { + this.labels = new ArrayList<>(labels); + this.palette = Anime2LabelPalette.animeSegmentationPalette(); + + Translator translator = new Translator() { + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + + // 调整输入图像尺寸到模型期望的大小 (1024x1024) + Image resized = input.resize(1024, 1024, true); + NDArray array = resized.toNDArray(manager); + + // 转换为 CHW 格式并归一化 + array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); + array = array.div(255f); + array = array.expandDims(0); // 添加batch维度 + + return new NDList(array); + } + + @Override + public SegmentationData processOutput(TranslatorContext ctx, NDList list) { + if (list == null || list.isEmpty()) { + throw new IllegalStateException("Model did not return any output."); + } + + NDArray out = list.get(0); + + // 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图 + // 应用sigmoid并二值化 + NDArray probabilities = out.div(out.neg().exp().add(1)); + NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false); + + // 移除batch和channel维度 + if (binaryMask.getShape().dimension() == 4) { + binaryMask = binaryMask.squeeze(0).squeeze(0); + } + + // 转换为Java数组 + long[] finalShape = binaryMask.getShape().getShape(); + int[] indices = binaryMask.toIntArray(); + + return new SegmentationData(indices, finalShape); + } + + @Override + public Batchifier getBatchifier() { + return null; + } + }; + + Criteria criteria = Criteria.builder() + .setTypes(Image.class, SegmentationData.class) + .optModelPath(modelDir) + .optEngine("PyTorch") + .optTranslator(translator) + .build(); + + this.modelWrapper = criteria.loadModel(); + this.predictor = modelWrapper.newPredictor(); + } + + public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException { + Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); + + SegmentationData data = predictor.predict(img); + + long[] shp = data.shape; + int[] indices = data.indices; + + int height, width; + if (shp.length == 2) { + height = (int) shp[0]; + width = (int) shp[1]; + } else { + throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); + } + + // 创建分割掩码 + BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); + Map labelsMap = new HashMap<>(); + for (int i = 0; i < labels.size(); i++) { + labelsMap.put(i, labels.get(i)); + } + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + int idx = indices[y * width + x]; + String label = labelsMap.getOrDefault(idx, "unknown"); + int argb = palette.getOrDefault(label, 0xFFFF0000); + mask.setRGB(x, y, argb); + } + } + + return new Anime2SegmentationResult(mask, labelsMap, palette); + } + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignore) { + } + try { + modelWrapper.close(); + } catch (Exception ignore) { + } + } + + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.out.println("用法: java Anime2Segmenter "); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + File out = new File(args[2]); + + List labels = Anime2LabelPalette.defaultLabels(); + + try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) { + Anime2SegmentationResult res = s.segment(input); + ImageIO.write(res.getMaskImage(), "png", out); + System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath()); + } + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java new file mode 100644 index 0000000..3643dc8 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java @@ -0,0 +1,262 @@ +package com.chuangzhou.vivid2D.ai.anime_segmentation; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +/** + * Anime2VividModelWrapper - 对之前 Anime2Segmenter 的封装,提供更便捷的API + * + * 用法示例: + * Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("/path/to/modelDir")); + * Map out = wrapper.segmentAndSave( + * new File("input.jpg"), + * Set.of("foreground"), // 动漫分割主要关注前景 + * Paths.get("outDir") + * ); + * // out contains 每个目标标签对应的 mask+overlay 文件路径 + * wrapper.close(); + */ +public class Anime2VividModelWrapper implements AutoCloseable { + + private final Anime2Segmenter segmenter; + private final List labels; // index -> name + private final Map palette; // name -> ARGB + + private Anime2VividModelWrapper(Anime2Segmenter segmenter, List labels, Map palette) { + this.segmenter = segmenter; + this.labels = labels; + this.palette = palette; + } + + /** + * 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 Anime2LabelPalette.defaultLabels() + * 并创建 Anime2Segmenter 实例。 + */ + public static Anime2VividModelWrapper load(Path modelDir) throws Exception { + List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); + Anime2Segmenter s = new Anime2Segmenter(modelDir, labels); + Map palette = Anime2LabelPalette.animeSegmentationPalette(); + return new Anime2VividModelWrapper(s, labels, palette); + } + + public List getLabels() { + return Collections.unmodifiableList(labels); + } + + public Map getPalette() { + return Collections.unmodifiableMap(palette); + } + + /** + * 直接返回分割结果(Anime2SegmentationResult) + */ + public Anime2SegmentationResult segment(File inputImage) throws Exception { + return segmenter.segment(inputImage); + } + + /** + * 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。 + * 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。 + *

+ * 返回值:Map,ResultFiles 包含 maskFile、overlayFile(两个 PNG) + */ + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + Anime2SegmentationResult res = segment(inputImage); + BufferedImage original = ImageIO.read(inputImage); + BufferedImage maskImage = res.getMaskImage(); + + int maskW = maskImage.getWidth(); + int maskH = maskImage.getHeight(); + + // 解析 targets + Set realTargets = parseTargetsSet(targets); + Map saved = new LinkedHashMap<>(); + + for (String target : realTargets) { + if (!palette.containsKey(target)) { + // 尝试忽略大小写匹配 + String finalTarget = target; + Optional matched = palette.keySet().stream() + .filter(k -> k.equalsIgnoreCase(finalTarget)) + .findFirst(); + if (matched.isPresent()) target = matched.get(); + else { + System.err.println("Warning: unknown label '" + target + "' - skip."); + continue; + } + } + + int targetColor = palette.get(target); + + // 1) 生成透明背景的二值掩码(只保留 target 像素) + BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < maskH; y++) { + for (int x = 0; x < maskW; x++) { + int c = maskImage.getRGB(x, y); + if (c == targetColor) { + partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明 + } else { + partMask.setRGB(x, y, 0x00000000); + } + } + } + + // 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明) + BufferedImage maskResized = partMask; + if (original.getWidth() != maskW || original.getHeight() != maskH) { + maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = maskResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + // 半透明颜色(alpha = 0x88) + int rgbOnly = (targetColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = maskResized.getRGB(x, y); + if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + // 保存 + String safe = safeFileName(target); + File maskOut = outDir.resolve(safe + "_mask.png").toFile(); + File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); + + ImageIO.write(maskResized, "png", maskOut); + ImageIO.write(overlay, "png", overlayOut); + + saved.put(target, new ResultFiles(maskOut, overlayOut)); + } + + return saved; + } + + private static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } + + private static Set parseTargetsSet(Set in) { + if (in == null || in.isEmpty()) return Collections.emptySet(); + // 若包含单个 "all" + if (in.size() == 1) { + String only = in.iterator().next(); + if ("all".equalsIgnoreCase(only.trim())) { + // 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette) + // For convenience, return a set containing "all" and let caller logic handle it earlier. + return Set.of("all"); + } + } + // 直接返回 trim 后的小写不变集合(保持用户传入的名字) + Set out = new LinkedHashSet<>(); + for (String s : in) { + if (s != null) out.add(s.trim()); + } + return out; + } + + /** + * 关闭底层资源 + */ + @Override + public void close() { + try { + segmenter.close(); + } catch (Exception ignore) {} + } + + /** + * 存放结果文件路径 + */ + public static class ResultFiles { + private final File maskFile; + private final File overlayFile; + + public ResultFiles(File maskFile, File overlayFile) { + this.maskFile = maskFile; + this.overlayFile = overlayFile; + } + + public File getMaskFile() { + return maskFile; + } + + public File getOverlayFile() { + return overlayFile; + } + } + + /* ================= helper: 从 modelDir 读取 synset.txt ================= */ + + private static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + /* ================= convenience 主方法(快速测试) ================= */ + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.out.println("用法: Anime2VividModelWrapper "); + System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground"); + System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all"); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + Path out = Path.of(args[2]); + String targetsArg = args[3]; + + List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); + Set targets; + if ("all".equalsIgnoreCase(targetsArg.trim())) { + targets = new LinkedHashSet<>(labels); + } else { + String[] parts = targetsArg.split(","); + targets = new LinkedHashSet<>(); + for (String p : parts) { + if (!p.trim().isEmpty()) targets.add(p.trim()); + } + } + + try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) { + Map m = wrapper.segmentAndSave(input, targets, out); + m.forEach((k, v) -> { + System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); + }); + } + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java new file mode 100644 index 0000000..12f33b8 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java @@ -0,0 +1,89 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import java.util.*; + +/** + * BiSeNet 人脸解析模型的标准标签和颜色调色板。 + * 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件。 + * 标签索引必须与模型输出索引一致(0-18)。 + */ +public class LabelPalette { + + /** + * BiSeNet 人脸解析模型的标准标签(19个类别,索引 0-18) + */ + public static List defaultLabels() { + return Arrays.asList( + "background", // 0 + "skin", // 1 + "nose", // 2 + "eye_left", // 3 + "eye_right", // 4 + "eyebrow_left", // 5 + "eyebrow_right",// 6 + "ear_left", // 7 + "ear_right", // 8 + "mouth", // 9 + "lip_upper", // 10 + "lip_lower", // 11 + "hair", // 12 + "hat", // 13 + "earring", // 14 + "necklace", // 15 + "clothes", // 16 + "facial_hair",// 17 + "neck" // 18 + ); + } + + /** + * 返回一个对应的调色板:类别名 -> ARGB 颜色值。 + * 颜色值基于 test.py 中 part_colors 数组的 RGB 值转换为 ARGB 格式 (0xFFRRGGBB)。 + */ + public static Map defaultPalette() { + Map map = new HashMap<>(); + // 索引 0: background + map.put("background", 0xFF000000); // 黑色 + + // 索引 1-18: 对应 part_colors 数组的前 18 个颜色 + // 注意:这里假设 part_colors[i-1] 对应 索引 i 的标签。 + // 索引 1: skin -> [255, 0, 0] + map.put("skin", 0xFFFF0000); + // 索引 2: nose -> [255, 85, 0] + map.put("nose", 0xFFFF5500); + // 索引 3: eye_left -> [255, 170, 0] + map.put("eye_left", 0xFFFFAA00); + // 索引 4: eye_right -> [255, 0, 85] + map.put("eye_right", 0xFFFF0055); + // 索引 5: eyebrow_left -> [255, 0, 170] + map.put("eyebrow_left",0xFFFF00AA); + // 索引 6: eyebrow_right -> [0, 255, 0] + map.put("eyebrow_right",0xFF00FF00); + // 索引 7: ear_left -> [85, 255, 0] + map.put("ear_left", 0xFF55FF00); + // 索引 8: ear_right -> [170, 255, 0] + map.put("ear_right", 0xFFAAFF00); + // 索引 9: mouth -> [0, 255, 85] + map.put("mouth", 0xFF00FF55); + // 索引 10: lip_upper -> [0, 255, 170] + map.put("lip_upper", 0xFF00FFAA); + // 索引 11: lip_lower -> [0, 0, 255] + map.put("lip_lower", 0xFF0000FF); + // 索引 12: hair -> [85, 0, 255] + map.put("hair", 0xFF5500FF); + // 索引 13: hat -> [170, 0, 255] + map.put("hat", 0xFFAA00FF); + // 索引 14: earring -> [0, 85, 255] + map.put("earring", 0xFF0055FF); + // 索引 15: necklace -> [0, 170, 255] + map.put("necklace", 0xFF00AAFF); + // 索引 16: clothes -> [255, 255, 0] + map.put("clothes", 0xFFFFFF00); + // 索引 17: facial_hair -> [255, 85, 85] + map.put("facial_hair", 0xFFFF5555); + // 索引 18: neck -> [255, 170, 170] + map.put("neck", 0xFFFFAAAA); + + return map; + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java new file mode 100644 index 0000000..2dda41b --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java @@ -0,0 +1,36 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import java.awt.image.BufferedImage; +import java.util.Map; + +/** + * 分割结果容器 + */ +public class SegmentationResult { + // 分割掩码图(每个像素的颜色为对应类别颜色) + private final BufferedImage maskImage; + + // 类别索引 -> 类别名称 + private final Map labels; + + // 类别名称 -> ARGB 颜色 + private final Map palette; + + public SegmentationResult(BufferedImage maskImage, Map labels, Map palette) { + this.maskImage = maskImage; + this.labels = labels; + this.palette = palette; + } + + public BufferedImage getMaskImage() { + return maskImage; + } + + public Map getLabels() { + return labels; + } + + public Map getPalette() { + return palette; + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java new file mode 100644 index 0000000..b4f0717 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java @@ -0,0 +1,193 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +/** + * Segmenter: 加载模型并对图片做语义分割 + * + * 说明: + * - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray, + * 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。 + * - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()), + * 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。 + */ +public class Segmenter implements AutoCloseable { + + // 内部类,用于从Translator安全地传出数据 + public static class SegmentationData { + final int[] indices; + final long[] shape; + + public SegmentationData(int[] indices, long[] shape) { + this.indices = indices; + this.shape = shape; + } + } + + private final ZooModel modelWrapper; + private final Predictor predictor; + private final List labels; + private final Map palette; + + public Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { + this.labels = new ArrayList<>(labels); + this.palette = LabelPalette.defaultPalette(); + + // Translator 的输出类型现在是 SegmentationData + Translator translator = new Translator() { + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + NDArray array = input.toNDArray(manager); + array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); + array = array.div(255f); + array = array.expandDims(0); + return new NDList(array); + } + + @Override + public SegmentationData processOutput(TranslatorContext ctx, NDList list) { + if (list == null || list.isEmpty()) { + throw new IllegalStateException("Model did not return any output."); + } + + NDArray out = list.get(0); + NDArray classMap; + + // 1. 解析模型输出,得到类别图谱 (classMap) + long[] shape = out.getShape().getShape(); + if (shape.length == 4 && shape[1] > 1) { + classMap = out.argMax(1); + } else if (shape.length == 3) { + classMap = (shape[0] == 1) ? out : out.argMax(0); + } else if (shape.length == 2) { + classMap = out; + } else { + throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape)); + } + + if (classMap.getShape().dimension() == 3) { + classMap = classMap.squeeze(0); + } + + // 2. *** 关键步骤 *** + // 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型 + + // 首先,确保数据类型是 INT32 + NDArray int32ClassMap = classMap.toType(DataType.INT32, false); + + // 然后,获取形状和 int[] 数组 + long[] finalShape = int32ClassMap.getShape().getShape(); + int[] indices = int32ClassMap.toIntArray(); + + // 3. 将 Java 对象封装并返回 + return new SegmentationData(indices, finalShape); + } + + @Override + public Batchifier getBatchifier() { + return null; // 或者根据需要使用 Batchifier.STACK + } + }; + + // Criteria 的类型也需要更新 + Criteria criteria = Criteria.builder() + .setTypes(Image.class, SegmentationData.class) + .optModelPath(modelDir) + .optEngine("PyTorch") + .optTranslator(translator) + .build(); + + this.modelWrapper = criteria.loadModel(); + this.predictor = modelWrapper.newPredictor(); + } + + public SegmentationResult segment(File imgFile) throws TranslateException, IOException { + Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); + + // predict 方法现在直接返回安全的 Java 对象 + SegmentationData data = predictor.predict(img); + + long[] shp = data.shape; + int[] indices = data.indices; + + int height, width; + if (shp.length == 2) { + height = (int) shp[0]; + width = (int) shp[1]; + } else { + throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); + } + + // 后续处理完全基于 Java 对象,不再有 Native resource 问题 + BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); + Map labelsMap = new HashMap<>(); + for (int i = 0; i < labels.size(); i++) { + labelsMap.put(i, labels.get(i)); + } + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + int idx = indices[y * width + x]; + String label = labelsMap.getOrDefault(idx, "unknown"); + int argb = palette.getOrDefault(label, 0xFF00FF00); + mask.setRGB(x, y, argb); + } + } + + return new SegmentationResult(mask, labelsMap, palette); + } + + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignore) { + } + try { + modelWrapper.close(); + } catch (Exception ignore) { + } + } + + // 小测试主函数(示例) + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.out.println("用法: java Segmenter "); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + File out = new File(args[2]); + + List labels = LabelPalette.defaultLabels(); + + try (Segmenter s = new Segmenter(modelDir, labels)) { + SegmentationResult res = s.segment(input); + ImageIO.write(res.getMaskImage(), "png", out); + System.out.println("分割掩码已保存到: " + out.getAbsolutePath()); + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java new file mode 100644 index 0000000..0a1ad11 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java @@ -0,0 +1,184 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.awt.*; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +/** + * SegmenterExample + * + * 使用说明(命令行): + * java -cp com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \ + * + * + * 示例: + * java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face" + * java ... SegmenterExample /models/face_bisent /images/in.jpg /out all + */ +public class SegmenterExample { + + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.err.println("用法: SegmenterExample "); + System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face"); + return; + } + + Path modelDir = Path.of(args[0]); + File inputImage = new File(args[1]); + Path outDir = Path.of(args[2]); + String targetsArg = args[3]; + + if (!Files.exists(modelDir)) { + System.err.println("modelDir 不存在: " + modelDir); + return; + } + if (!inputImage.exists()) { + System.err.println("输入图片不存在: " + inputImage.getAbsolutePath()); + return; + } + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + // 读取 synset.txt(如果有),否则使用默认 LabelPalette + List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); + + // 打开 Segmenter + try (Segmenter segmenter = new Segmenter(modelDir, labels)) { + SegmentationResult res = segmenter.segment(inputImage); + + // 原始图片 + BufferedImage original = ImageIO.read(inputImage); + + // palette: labelName -> ARGB int + Map palette = res.getPalette(); + Map labelsMap = res.getLabels(); // index -> name + + // 解析目标 labels 列表 + Set targets = parseTargets(targetsArg, labels); + + System.out.println("Will export targets: " + targets); + + // maskImage: 每像素是类别颜色(ARGB) + BufferedImage maskImage = res.getMaskImage(); + int w = maskImage.getWidth(); + int h = maskImage.getHeight(); + + // 为快速查 color -> labelName + Map colorToLabel = new HashMap<>(); + for (Map.Entry e : palette.entrySet()) { + colorToLabel.put(e.getValue(), e.getKey()); + } + + // 对每个 target 生成单独的 mask 和 overlay + for (String target : targets) { + if (!palette.containsKey(target)) { + System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。"); + continue; + } + int targetColor = palette.get(target); + + // 1) 生成透明背景的二值掩码(只保留 target 像素) + BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + int c = maskImage.getRGB(x, y); + if (c == targetColor) { + // 保留为不透明(使用原始颜色) + partMask.setRGB(x, y, targetColor); + } else { + // 透明 + partMask.setRGB(x, y, 0x00000000); + } + } + } + + // 2) 生成 overlay:在原图上叠加半透明的 targetColor 区域 + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + // 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设) + BufferedImage maskResized = maskImage; + if (original.getWidth() != w || original.getHeight() != h) { + // 简单缩放 mask 到原图尺寸 + maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = maskResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + // 在 overlay 上先画原图 + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + + // 创建半透明颜色(将 targetColor 的 alpha 设为 0x88) + int rgbOnly = (targetColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + // 创建一个图像,把 mask 中对应像素设置为 translucent,否则透明 + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = maskResized.getRGB(x, y); + if (mc == targetColor) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + // 将 colorOverlay 画到 overlay 上 + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + // 保存文件 + File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile(); + File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile(); + + ImageIO.write(partMask, "png", maskOut); + ImageIO.write(overlay, "png", overlayOut); + + System.out.println("Saved mask: " + maskOut.getAbsolutePath()); + System.out.println("Saved overlay: " + overlayOut.getAbsolutePath()); + } + } + } + + private static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + private static Set parseTargets(String arg, List availableLabels) { + String a = arg.trim(); + if (a.equalsIgnoreCase("all")) { + return new LinkedHashSet<>(availableLabels); + } + String[] parts = a.split(","); + Set out = new LinkedHashSet<>(); + for (String p : parts) { + String t = p.trim(); + if (!t.isEmpty()) out.add(t); + } + return out; + } + + private static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java new file mode 100644 index 0000000..641e0df --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java @@ -0,0 +1,261 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +/** + * VividModelWrapper - 对之前 Segmenter / SegmenterExample 的封装 + * + * 用法示例: + * VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("/path/to/modelDir")); + * Map out = wrapper.segmentAndSave( + * new File("input.jpg"), + * Set.of("eye","face"), // 或 Set.of(all labels...);若想全部传 "all" 可以用 helper parseTargets + * Paths.get("outDir") + * ); + * // out contains 每个目标标签对应的 mask+overlay 文件路径 + * wrapper.close(); + */ +public class VividModelWrapper implements AutoCloseable { + + private final Segmenter segmenter; + private final List labels; // index -> name + private final Map palette; // name -> ARGB + + private VividModelWrapper(Segmenter segmenter, List labels, Map palette) { + this.segmenter = segmenter; + this.labels = labels; + this.palette = palette; + } + + /** + * 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 LabelPalette.defaultLabels() + * 并创建 Segmenter 实例。 + */ + public static VividModelWrapper load(Path modelDir) throws Exception { + List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); + Segmenter s = new Segmenter(modelDir, labels); + Map palette = LabelPalette.defaultPalette(); + return new VividModelWrapper(s, labels, palette); + } + + public List getLabels() { + return Collections.unmodifiableList(labels); + } + + public Map getPalette() { + return Collections.unmodifiableMap(palette); + } + + /** + * 直接返回分割结果(SegmentationResult) + */ + public SegmentationResult segment(File inputImage) throws Exception { + return segmenter.segment(inputImage); + } + + /** + * 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。 + * 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。 + *

+ * 返回值:Map,ResultFiles 包含 maskFile、overlayFile(两个 PNG) + */ + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + if (!Files.exists(outDir)) { + Files.createDirectories(outDir); + } + + SegmentationResult res = segment(inputImage); + BufferedImage original = ImageIO.read(inputImage); + BufferedImage maskImage = res.getMaskImage(); + + int maskW = maskImage.getWidth(); + int maskH = maskImage.getHeight(); + + // 解析 targets + Set realTargets = parseTargetsSet(targets); + Map saved = new LinkedHashMap<>(); + + for (String target : realTargets) { + if (!palette.containsKey(target)) { + // 尝试忽略大小写匹配 + String finalTarget = target; + Optional matched = palette.keySet().stream() + .filter(k -> k.equalsIgnoreCase(finalTarget)) + .findFirst(); + if (matched.isPresent()) target = matched.get(); + else { + System.err.println("Warning: unknown label '" + target + "' - skip."); + continue; + } + } + + int targetColor = palette.get(target); + + // 1) 生成透明背景的二值掩码(只保留 target 像素) + BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < maskH; y++) { + for (int x = 0; x < maskW; x++) { + int c = maskImage.getRGB(x, y); + if (c == targetColor) { + partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明 + } else { + partMask.setRGB(x, y, 0x00000000); + } + } + } + + // 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明) + BufferedImage maskResized = partMask; + if (original.getWidth() != maskW || original.getHeight() != maskH) { + maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g = maskResized.createGraphics(); + g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); + g.dispose(); + } + + BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); + Graphics2D g2 = overlay.createGraphics(); + g2.drawImage(original, 0, 0, null); + // 半透明颜色(alpha = 0x88) + int rgbOnly = (targetColor & 0x00FFFFFF); + int translucent = (0x88 << 24) | rgbOnly; + BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); + for (int y = 0; y < colorOverlay.getHeight(); y++) { + for (int x = 0; x < colorOverlay.getWidth(); x++) { + int mc = maskResized.getRGB(x, y); + if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { + colorOverlay.setRGB(x, y, translucent); + } else { + colorOverlay.setRGB(x, y, 0x00000000); + } + } + } + g2.drawImage(colorOverlay, 0, 0, null); + g2.dispose(); + + // 保存 + String safe = safeFileName(target); + File maskOut = outDir.resolve(safe + "_mask.png").toFile(); + File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); + + ImageIO.write(maskResized, "png", maskOut); + ImageIO.write(overlay, "png", overlayOut); + + saved.put(target, new ResultFiles(maskOut, overlayOut)); + } + + return saved; + } + + private static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } + + private static Set parseTargetsSet(Set in) { + if (in == null || in.isEmpty()) return Collections.emptySet(); + // 若包含单个 "all" + if (in.size() == 1) { + String only = in.iterator().next(); + if ("all".equalsIgnoreCase(only.trim())) { + // 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette) + // For convenience, return a set containing "all" and let caller logic handle it earlier. + return Set.of("all"); + } + } + // 直接返回 trim 后的小写不变集合(保持用户传入的名字) + Set out = new LinkedHashSet<>(); + for (String s : in) { + if (s != null) out.add(s.trim()); + } + return out; + } + + /** + * 关闭底层资源 + */ + @Override + public void close() { + try { + segmenter.close(); + } catch (Exception ignore) {} + } + + /** + * 存放结果文件路径 + */ + public static class ResultFiles { + private final File maskFile; + private final File overlayFile; + + public ResultFiles(File maskFile, File overlayFile) { + this.maskFile = maskFile; + this.overlayFile = overlayFile; + } + + public File getMaskFile() { + return maskFile; + } + + public File getOverlayFile() { + return overlayFile; + } + } + + /* ================= helper: 从 modelDir 读取 synset.txt ================= */ + + private static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + /* ================= convenience 主方法(快速测试) ================= */ + public static void main(String[] args) throws Exception { + if (args.length < 4) { + System.out.println("用法: VividModelWrapper "); + System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face"); + return; + } + Path modelDir = Path.of(args[0]); + File input = new File(args[1]); + Path out = Path.of(args[2]); + String targetsArg = args[3]; + + List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); + Set targets; + if ("all".equalsIgnoreCase(targetsArg.trim())) { + targets = new LinkedHashSet<>(labels); + } else { + String[] parts = targetsArg.split(","); + targets = new LinkedHashSet<>(); + for (String p : parts) { + if (!p.trim().isEmpty()) targets.add(p.trim()); + } + } + + try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) { + Map m = wrapper.segmentAndSave(input, targets, out); + m.forEach((k, v) -> { + System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); + }); + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java b/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java new file mode 100644 index 0000000..ee60a24 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java @@ -0,0 +1,38 @@ +package com.chuangzhou.vivid2D.test; + +import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper; + +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.util.Set; + +/** + * 用来分析人物的脸部信息头发、眼睛、嘴巴、脸部、皮肤、衣服 + */ +public class AI2Test { + public static void main(String[] args) throws Exception { + System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8)); + System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); + + // 使用 AnimeModelWrapper 而不是 VividModelWrapper + AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt")); + + // 使用 Anime-Face-Segmentation 的 7 个标签 + Set animeLabels = Set.of( + "background", + "hair", // 头发 + "eye", // 眼睛 + "mouth", // 嘴巴 + "face", // 脸部 + "skin", // 皮肤 + "clothes" // 衣服 + ); + + wrapper.segmentAndSave( + Paths.get("C:\\Users\\Administrator\\Desktop\\b_215609167a3a20ac2075487bd532bbff.jpg").toFile(), + animeLabels, + Paths.get("C:\\models\\out") + ); + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java b/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java new file mode 100644 index 0000000..0a3956e --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java @@ -0,0 +1,28 @@ +package com.chuangzhou.vivid2D.test; + +import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper; +import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper; + +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.util.Set; + +/** + * 这个ai模型负责分离人物与背景 + */ +public class AI3Test { + public static void main(String[] args) throws Exception { + System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8)); + System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); + Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt")); + + Set faceLabels = Set.of("foreground"); + + wrapper.segmentAndSave( + Paths.get("C:\\Users\\Administrator\\Desktop\\b_7a8349adece17d1e4bebd20cb2387cf6.jpg").toFile(), + faceLabels, + Paths.get("C:\\models\\out") + ); + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AITest.java b/src/main/java/com/chuangzhou/vivid2D/test/AITest.java new file mode 100644 index 0000000..c79036d --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/test/AITest.java @@ -0,0 +1,33 @@ +package com.chuangzhou.vivid2D.test; + +import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper; + +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.util.Set; + +/** + * 测试人脸解析模型 + */ +public class AITest { + public static void main(String[] args) throws Exception { + System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8)); + System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); + VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt")); + + // 使用 BiSeNet 人脸解析模型的 18 个非背景标签 + Set faceLabels = Set.of( + "skin", "nose", "eye_left", "eye_right", "eyebrow_left", + "eyebrow_right", "ear_left", "ear_right", "mouth", "lip_upper", + "lip_lower", "hair", "hat", "earring", "necklace", "clothes", + "facial_hair", "neck" + ); + + wrapper.segmentAndSave( + Paths.get("C:\\Users\\Administrator\\Desktop\\b_f4881214f0d18b6cf848b6736f554821.png").toFile(), + faceLabels, + Paths.get("C:\\models\\out") + ); + } +}