From e06c59c8d1d8d596f88b225181323eab31a217b4 Mon Sep 17 00:00:00 2001 From: tzdwindows 7 <3076584115@qq.com> Date: Fri, 31 Oct 2025 09:25:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor(ai):=E9=87=8D=E6=9E=84=E5=88=86?= =?UTF-8?q?=E5=89=B2=E6=A8=A1=E5=9E=8B=E5=8C=85=E8=A3=85=E7=B1=BB=E7=BB=A7?= =?UTF-8?q?=E6=89=BF=E7=BB=93=E6=9E=84-=20=E5=B0=86=20Anime2ModelWrapper?= =?UTF-8?q?=E3=80=81Anime2VividModelWrapper=20=E5=92=8C=20AnimeModelWrappe?= =?UTF-8?q?r=20=20=E6=94=B9=E4=B8=BA=E7=BB=A7=E6=89=BF=E8=87=AA=20VividMod?= =?UTF-8?q?elWrapper=20=E5=9F=BA=E7=B1=BB=20-=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E7=9A=84=20ResultFiles=20=E5=86=85=E9=83=A8?= =?UTF-8?q?=E7=B1=BB=E5=92=8C=E7=9B=B8=E5=85=B3=E5=B7=A5=E5=85=B7=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E5=AE=9E=E7=8E=B0=20-=20Anime2Segmenter=20=E5=92=8C?= =?UTF-8?q?=20AnimeSegmenter=20=E7=BB=A7=E6=89=BF=E8=87=AA=E6=8A=BD?= =?UTF-8?q?=E8=B1=A1=E5=9F=BA=E7=B1=BB=20Segmenter=20-=20Anime2Segmentatio?= =?UTF-8?q?nResult=E4=B8=8E=20AnimeSegmentationResult=20=E7=BB=A7=E6=89=BF?= =?UTF-8?q?=20SegmentationResult=20-=20=E9=87=8D=E5=91=BD=E5=90=8D=20Label?= =?UTF-8?q?Palette=20=E4=B8=BA=20BiSeNetLabelPalette=20=E5=B9=B6=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E5=85=B6=E5=BC=95=E7=94=A8=20-=20=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=B7=AF=E5=BE=84=E9=85=8D=E7=BD=AE=E4=BB=A5?= =?UTF-8?q?=E5=8C=B9=E9=85=8D=E6=96=B0=E7=9A=84=E6=96=87=E4=BB=B6=E5=91=BD?= =?UTF-8?q?=E5=90=8D=E7=BA=A6=E5=AE=9A=20-=20=E5=88=A0=E9=99=A4=E5=86=97?= =?UTF-8?q?=E4=BD=99=E7=9A=84=20getLabels()=20=E5=92=8C=20getPalette()=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E5=AE=9A=E4=B9=89=20-=20=E7=AE=80=E5=8C=96?= =?UTF-8?q?=20segmentAndSave=20=E6=96=B9=E6=B3=95=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E8=BD=AC=E6=8D=A2=E9=80=BB=E8=BE=91-=20?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=B7=B2=E8=A2=AB=E7=BB=A7=E6=89=BF=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E6=9B=BF=E4=BB=A3=E7=9A=84=E6=89=8B=E5=8A=A8=E8=B5=84?= =?UTF-8?q?=E6=BA=90=E7=AE=A1=E7=90=86=E4=BB=A3=E7=A0=81=20-=20=E8=B0=83?= =?UTF-8?q?=E6=95=B4=20import=20=E8=AF=AD=E5=8F=A5=E4=BB=A5=E5=8F=8D?= =?UTF-8?q?=E6=98=A0=E5=8C=85=E7=BB=93=E6=9E=84=E8=B0=83=E6=95=B4-=20?= =?UTF-8?q?=E6=B8=85=E7=90=86=E4=B8=8D=E5=86=8D=E9=9C=80=E8=A6=81=E7=9A=84?= =?UTF-8?q?=E7=8B=AC=E7=AB=8B=E4=B8=BB=E6=B5=8B=E8=AF=95=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E5=85=A5=E5=8F=A3=E7=82=B9-=20=E4=BF=AE=E6=94=B9=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=E8=AE=BF=E9=97=AE=E6=9D=83=E9=99=90=E4=BB=A5=E7=AC=A6?= =?UTF-8?q?=E5=90=88=E7=BB=A7=E6=89=BF=E8=AE=BE=E8=AE=A1=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=20-=20=E6=9B=BF=E6=8D=A2=E5=85=B7=E4=BD=93=E7=9A=84=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E7=B1=BB=E5=9E=8B=E4=B8=BA=E6=9B=B4=E9=80=9A=E7=94=A8?= =?UTF-8?q?=E7=9A=84=20SegmentationResult=20=E6=8E=A5=E5=8F=A3-=20?= =?UTF-8?q?=E6=95=B4=E5=90=88=E5=85=AC=E5=85=B1=E5=8A=9F=E8=83=BD=E8=87=B3?= =?UTF-8?q?=E5=9F=BA=E7=B1=BB=E5=87=8F=E5=B0=91=E5=AD=90=E7=B1=BB=E9=97=B4?= =?UTF-8?q?=E9=87=8D=E5=A4=8D=E4=BB=A3=E7=A0=81=20-=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E5=88=86=E5=89=B2=E5=90=8E=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B?= =?UTF-8?q?=E6=8F=90=E9=AB=98=E6=A8=A1=E5=9D=97=E5=A4=8D=E7=94=A8=E6=80=A7?= =?UTF-8?q?=20-=20=E5=BC=95=E5=85=A5=E6=B3=9B=E5=9E=8B=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=20Wrapper=20=E7=B1=BB=E5=9E=8B=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=80=A7=20-=20=E6=9B=B4=E6=96=B0=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E4=BF=9D=E6=8C=81=E4=B8=8E=E6=9C=80=E6=96=B0?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E5=90=8C=E6=AD=A5=20-=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=E5=A4=84=E7=90=86=E7=AD=96=E7=95=A5=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E5=85=B3=E9=97=AD=E8=B5=84=E6=BA=90=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=20-=20=E8=A7=84=E8=8C=83=E6=96=87=E4=BB=B6=E5=91=BD=E5=90=8D?= =?UTF-8?q?=E8=A7=84=E5=88=99=E4=BE=BF=E4=BA=8E=E6=9C=AA=E6=9D=A5=E7=BB=B4?= =?UTF-8?q?=E6=8A=A4=E6=89=A9=E5=B1=95=20-=20=E6=8F=90=E5=8F=96=E5=85=B1?= =?UTF-8?q?=E9=80=9A=E9=80=BB=E8=BE=91=E5=88=B0=E7=88=B6=E7=B1=BB=E9=99=8D?= =?UTF-8?q?=E4=BD=8E=E8=80=A6=E5=90=88=E5=BA=A6=20-=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=9F=A5=E9=81=BF=E5=85=8D=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=97=B6=20ClassCastException=20=E9=A3=8E=E9=99=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../vivid2D/ai/ModelManagement.java | 221 ++++++++++++++++ .../SegmentationResult.java | 5 +- .../com/chuangzhou/vivid2D/ai/Segmenter.java | 154 +++++++++++ .../vivid2D/ai/VividModelWrapper.java | 113 ++++++++ .../AnimeModelWrapper.java | 112 +------- .../AnimeSegmentationResult.java | 5 +- .../AnimeSegmenter.java | 42 +-- .../Anime2ModelWrapper.java | 244 ------------------ .../Anime2SegmentationResult.java | 29 +-- .../anime_segmentation/Anime2Segmenter.java | 153 +++-------- .../Anime2VividModelWrapper.java | 126 +-------- ...lPalette.java => BiSeNetLabelPalette.java} | 2 +- .../BiSeNetSegmentationResult.java | 15 ++ .../ai/face_parsing/BiSeNetSegmenter.java | 101 ++++++++ ...per.java => BiSeNetVividModelWrapper.java} | 142 +--------- .../vivid2D/ai/face_parsing/Segmenter.java | 193 -------------- .../ai/face_parsing/SegmenterExample.java | 184 ------------- .../com/chuangzhou/vivid2D/test/AI2Test.java | 2 +- .../com/chuangzhou/vivid2D/test/AI3Test.java | 3 +- .../com/chuangzhou/vivid2D/test/AITest.java | 4 +- 20 files changed, 700 insertions(+), 1150 deletions(-) create mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java rename src/main/java/com/chuangzhou/vivid2D/ai/{face_parsing => }/SegmentationResult.java (91%) create mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java create mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java delete mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java rename src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/{LabelPalette.java => BiSeNetLabelPalette.java} (98%) create mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmentationResult.java create mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmenter.java rename src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/{VividModelWrapper.java => BiSeNetVividModelWrapper.java} (52%) delete mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java delete mode 100644 src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java b/src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java new file mode 100644 index 0000000..7e90b3b --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/ModelManagement.java @@ -0,0 +1,221 @@ +package com.chuangzhou.vivid2D.ai; + +import com.chuangzhou.vivid2D.ai.anime_face_segmentation.AnimeModelWrapper; +import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper; +import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +/** + * 模型管理器 - 负责模型的注册、分类和检索 + */ +public class ModelManagement { + private final Map> models = new ConcurrentHashMap<>(); + private final Map> modelsByCategory = new ConcurrentHashMap<>(); + private final List modelDisplayNames = new ArrayList<>(); + private final Map displayNameToRegistrationName = new ConcurrentHashMap<>(); + + private ModelManagement() { + initializeDefaultCategories(); + registerDefaultModels(); + } + + /** + * 初始化默认分类 + */ + private void initializeDefaultCategories() { + modelsByCategory.put("Image Segmentation", new ArrayList<>()); + modelsByCategory.put("Image Processing", new ArrayList<>()); + modelsByCategory.put("Image Generation", new ArrayList<>()); + modelsByCategory.put("Image Inpainting", new ArrayList<>()); + modelsByCategory.put("Image Completion", new ArrayList<>()); + modelsByCategory.put("Face Analysis", new ArrayList<>()); + } + + /** + * 注册默认模型 + */ + private void registerDefaultModels() { + registerModel("segmentation:anime_face", "Anime Face Segmentation", + AnimeModelWrapper.class, "Image Segmentation"); + registerModel("segmentation:anime", "Anime Image Segmentation", + Anime2VividModelWrapper.class, "Image Segmentation"); + registerModel("segmentation:face_parsing", "Face Parsing", + BiSeNetVividModelWrapper.class, "Image Segmentation"); + } + + /** + * 注册模型 + * @param modelRegistrationName 注册名称,格式必须为 "category:model_name" + * @param modelDisplayName 模型显示名称 + * @param modelClass 模型类 + * @param category 模型类别 + */ + public void registerModel(String modelRegistrationName, String modelDisplayName, + Class modelClass, String category) { + if (!isValidRegistrationName(modelRegistrationName)) { + throw new IllegalArgumentException( + "Invalid registration name format. Expected 'category:model_name', got: " + modelRegistrationName); + } + if (models.containsKey(modelRegistrationName)) { + throw new IllegalArgumentException( + "Model registration name already exists: " + modelRegistrationName); + } + if (displayNameToRegistrationName.containsKey(modelDisplayName)) { + throw new IllegalArgumentException( + "Model display name already exists: " + modelDisplayName); + } + if (!modelsByCategory.containsKey(category)) { + modelsByCategory.put(category, new ArrayList<>()); + } + models.put(modelRegistrationName, modelClass); + displayNameToRegistrationName.put(modelDisplayName, modelRegistrationName); + modelDisplayNames.add(modelDisplayName); + modelsByCategory.get(category).add(modelRegistrationName); + } + + /** + * 验证注册名称格式 + */ + private boolean isValidRegistrationName(String name) { + return name != null && name.matches("^[a-zA-Z0-9_]+:[a-zA-Z0-9_]+$"); + } + + /** + * 通过显示名称获取模型类 + */ + public Class getModel(String modelDisplayName) { + String registrationName = displayNameToRegistrationName.get(modelDisplayName); + return registrationName != null ? models.get(registrationName) : null; + } + + /** + * 通过索引获取模型类 + */ + public Class getModel(int modelIndex) { + if (modelIndex >= 0 && modelIndex < modelDisplayNames.size()) { + String displayName = modelDisplayNames.get(modelIndex); + return getModel(displayName); + } + return null; + } + + /** + * 通过注册名称获取模型类 + */ + public Class getModelByRegistrationName(String registrationName) { + return models.get(registrationName); + } + + /** + * 通过类名获取模型类 + */ + public Class getModelByClassName(String className) { + for (Class modelClass : models.values()) { + if (modelClass.getName().equals(className)) { + return modelClass; + } + } + return null; + } + + /** + * 获取所有模型的显示名称 + */ + public List getAllModelDisplayNames() { + return Collections.unmodifiableList(modelDisplayNames); + } + + /** + * 获取所有模型的注册名称 + */ + public Set getAllModelRegistrationNames() { + return Collections.unmodifiableSet(models.keySet()); + } + + /** + * 按类别获取模型注册名称 + */ + public List getModelsByCategory(String category) { + return Collections.unmodifiableList( + modelsByCategory.getOrDefault(category, new ArrayList<>()) + ); + } + + /** + * 获取所有可用的类别 + */ + public Set getAllCategories() { + return Collections.unmodifiableSet(modelsByCategory.keySet()); + } + + /** + * 获取模型数量 + */ + public int getModelCount() { + return modelDisplayNames.size(); + } + + /** + * 获取模型显示名称对应的注册名称 + */ + public String getRegistrationName(String modelDisplayName) { + return displayNameToRegistrationName.get(modelDisplayName); + } + + /** + * 获取模型注册名称对应的显示名称 + */ + public String getDisplayName(String registrationName) { + for (Map.Entry entry : displayNameToRegistrationName.entrySet()) { + if (entry.getValue().equals(registrationName)) { + return entry.getKey(); + } + } + return null; + } + + /** + * 检查模型是否存在 + */ + public boolean containsModel(String modelDisplayName) { + return displayNameToRegistrationName.containsKey(modelDisplayName); + } + + /** + * 检查注册名称是否存在 + */ + public boolean containsRegistrationName(String registrationName) { + return models.containsKey(registrationName); + } + + /** + * 移除模型 + */ + public boolean removeModel(String modelDisplayName) { + String registrationName = displayNameToRegistrationName.get(modelDisplayName); + if (registrationName != null) { + // 从所有存储中移除 + models.remove(registrationName); + displayNameToRegistrationName.remove(modelDisplayName); + modelDisplayNames.remove(modelDisplayName); + + // 从类别中移除 + for (List categoryModels : modelsByCategory.values()) { + categoryModels.remove(registrationName); + } + + return true; + } + return false; + } + + private static final class InstanceHolder { + private static final ModelManagement instance = new ModelManagement(); + } + + public static ModelManagement getInstance() { + return InstanceHolder.instance; + } +} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/SegmentationResult.java similarity index 91% rename from src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java rename to src/main/java/com/chuangzhou/vivid2D/ai/SegmentationResult.java index 2dda41b..3644106 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmentationResult.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/SegmentationResult.java @@ -1,11 +1,8 @@ -package com.chuangzhou.vivid2D.ai.face_parsing; +package com.chuangzhou.vivid2D.ai; import java.awt.image.BufferedImage; import java.util.Map; -/** - * 分割结果容器 - */ public class SegmentationResult { // 分割掩码图(每个像素的颜色为对应类别颜色) private final BufferedImage maskImage; diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java new file mode 100644 index 0000000..db5c1d7 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/Segmenter.java @@ -0,0 +1,154 @@ +package com.chuangzhou.vivid2D.ai; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetLabelPalette; +import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmentationResult; +import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetSegmenter; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +public abstract class Segmenter implements AutoCloseable { + // 内部类,用于从Translator安全地传出数据 + public static class SegmentationData { + public final int[] indices; + public final long[] shape; + + public SegmentationData(int[] indices, long[] shape) { + this.indices = indices; + this.shape = shape; + } + } + + private String engine = "PyTorch"; + protected final ZooModel modelWrapper; + protected final Predictor predictor; + protected final List labels; + protected final Map palette; + + public Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { + this.labels = new ArrayList<>(labels); + this.palette = BiSeNetLabelPalette.defaultPalette(); + + Translator translator = new Translator() { + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + return Segmenter.this.processInput(ctx, input); + } + + @Override + public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) { + return Segmenter.this.processOutput(ctx, list); + } + + @Override + public Batchifier getBatchifier() { + return Segmenter.this.getBatchifier(); + } + }; + + Criteria criteria = Criteria.builder() + .setTypes(Image.class, Segmenter.SegmentationData.class) + .optModelPath(modelDir) + .optEngine(engine) + .optTranslator(translator) + .build(); + + this.modelWrapper = criteria.loadModel(); + this.predictor = modelWrapper.newPredictor(); + } + + /** + * 处理模型输入 + * @param ctx translator 上下文 + * @param input 图片 + * @return 模型输入 + */ + public abstract NDList processInput(TranslatorContext ctx, Image input); + + /** + * 处理模型输出 + * @param ctx translator 上下文 + * @param list 模型输出 + * @return 模型输出 + */ + public abstract Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list); + + /** + * 获取批量处理方式 + * @return 批量处理方式 + */ + public Batchifier getBatchifier(){ + return null; + } + + public SegmentationResult segment(File imgFile) throws TranslateException, IOException { + Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); + + // predict 方法现在直接返回安全的 Java 对象 + Segmenter.SegmentationData data = predictor.predict(img); + + long[] shp = data.shape; + int[] indices = data.indices; + + int height, width; + if (shp.length == 2) { + height = (int) shp[0]; + width = (int) shp[1]; + } else { + throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); + } + + // 后续处理完全基于 Java 对象,不再有 Native resource 问题 + BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); + Map labelsMap = new HashMap<>(); + for (int i = 0; i < labels.size(); i++) { + labelsMap.put(i, labels.get(i)); + } + + for (int y = 0; y < height; y++) { + for (int x = 0; x < width; x++) { + int idx = indices[y * width + x]; + String label = labelsMap.getOrDefault(idx, "unknown"); + int argb = palette.getOrDefault(label, 0xFF00FF00); + mask.setRGB(x, y, argb); + } + } + + return new SegmentationResult(mask, labelsMap, palette); + } + + public void setEngine(String engine) { + this.engine = engine; + } + + @Override + public void close() { + try { + predictor.close(); + } catch (Exception ignore) { + } + try { + modelWrapper.close(); + } catch (Exception ignore) { + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java new file mode 100644 index 0000000..140b540 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/VividModelWrapper.java @@ -0,0 +1,113 @@ +package com.chuangzhou.vivid2D.ai; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; + +public abstract class VividModelWrapper implements AutoCloseable{ + protected final s segmenter; + protected final List labels; // index -> name + protected final Map palette; // name -> ARGB + + protected VividModelWrapper(s segmenter, List labels, Map palette) { + this.segmenter = segmenter; + this.labels = labels; + this.palette = palette; + } + + public List getLabels() { + return Collections.unmodifiableList(labels); + } + + public Map getPalette() { + return Collections.unmodifiableMap(palette); + } + + /** + * 直接返回分割结果(SegmentationResult) + */ + public SegmentationResult segment(File inputImage) throws Exception { + return segmenter.segment(inputImage); + } + + /** + * 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir。 + * 如果 targets 包含单个元素 "all"(忽略大小写),则保存所有标签。 + *

+ * 返回值:Map,ResultFiles 包含 maskFile、overlayFile(两个 PNG) + */ + public abstract Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception; + + protected static String safeFileName(String s) { + return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); + } + + protected static Set parseTargetsSet(Set in) { + if (in == null || in.isEmpty()) return Collections.emptySet(); + // 若包含单个 "all" + if (in.size() == 1) { + String only = in.iterator().next(); + if ("all".equalsIgnoreCase(only.trim())) { + return Set.of("all"); + } + } + // 直接返回 trim 后的小写不变集合(保持用户传入的名字) + Set out = new LinkedHashSet<>(); + for (String s : in) { + if (s != null) out.add(s.trim()); + } + return out; + } + + /** + * 关闭底层资源 + */ + @Override + public void close() { + try { + segmenter.close(); + } catch (Exception ignore) {} + } + + /* ================= helper: 从 modelDir 读取 synset.txt ================= */ + + protected static Optional> loadLabelsFromSynset(Path modelDir) { + Path syn = modelDir.resolve("synset.txt"); + if (Files.exists(syn)) { + try { + List lines = Files.readAllLines(syn); + List cleaned = new ArrayList<>(); + for (String l : lines) { + String s = l.trim(); + if (!s.isEmpty()) cleaned.add(s); + } + if (!cleaned.isEmpty()) return Optional.of(cleaned); + } catch (IOException ignore) {} + } + return Optional.empty(); + } + + /** + * 存放结果文件路径 + */ + public static class ResultFiles { + private final File maskFile; + private final File overlayFile; + + public ResultFiles(File maskFile, File overlayFile) { + this.maskFile = maskFile; + this.overlayFile = overlayFile; + } + + public File getMaskFile() { + return maskFile; + } + + public File getOverlayFile() { + return overlayFile; + } + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java index 5d6658e..f8a4a3f 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeModelWrapper.java @@ -1,5 +1,7 @@ package com.chuangzhou.vivid2D.ai.anime_face_segmentation; +import com.chuangzhou.vivid2D.ai.VividModelWrapper; + import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; @@ -14,16 +16,10 @@ import java.util.List; /** * AnimeModelWrapper - 专门为 Anime-Face-Segmentation 模型封装的 Wrapper */ -public class AnimeModelWrapper implements AutoCloseable { - - private final AnimeSegmenter segmenter; - private final List labels; // index -> name - private final Map palette; // name -> ARGB +public class AnimeModelWrapper extends VividModelWrapper { private AnimeModelWrapper(AnimeSegmenter segmenter, List labels, Map palette) { - this.segmenter = segmenter; - this.labels = labels; - this.palette = palette; + super(segmenter, labels, palette); } /** @@ -36,14 +32,6 @@ public class AnimeModelWrapper implements AutoCloseable { return new AnimeModelWrapper(segmenter, labels, palette); } - public List getLabels() { - return Collections.unmodifiableList(labels); - } - - public Map getPalette() { - return Collections.unmodifiableMap(palette); - } - /** * 直接返回分割结果(在丢给底层 segmenter 前会做通用预处理:RGB 转换 + 等比 letterbox 缩放到模型输入尺寸) */ @@ -152,28 +140,6 @@ public class AnimeModelWrapper implements AutoCloseable { return saved; } - private static String safeFileName(String s) { - return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); - } - - private static Set parseTargetsSet(Set in) { - if (in == null || in.isEmpty()) return Collections.emptySet(); - // 若包含单个 "all" - if (in.size() == 1) { - String only = in.iterator().next(); - if ("all".equalsIgnoreCase(only.trim())) { - // 返回所有标签 - return new LinkedHashSet<>(AnimeLabelPalette.defaultLabels()); - } - } - // 直接返回 trim 后的集合 - Set out = new LinkedHashSet<>(); - for (String s : in) { - if (s != null) out.add(s.trim()); - } - return out; - } - /** * 专门提取眼睛的方法(在丢给底层 segmenter 前做预处理) */ @@ -246,44 +212,6 @@ public class AnimeModelWrapper implements AutoCloseable { } catch (Exception ignore) {} } - /** - * 存放结果文件路径 - */ - public static class ResultFiles { - private final File maskFile; - private final File overlayFile; - - public ResultFiles(File maskFile, File overlayFile) { - this.maskFile = maskFile; - this.overlayFile = overlayFile; - } - - public File getMaskFile() { - return maskFile; - } - - public File getOverlayFile() { - return overlayFile; - } - } - - /* ================= helper: 从 modelDir 读取 synset.txt ================= */ - private static Optional> loadLabelsFromSynset(Path modelDir) { - Path syn = modelDir.resolve("synset.txt"); - if (Files.exists(syn)) { - try { - List lines = Files.readAllLines(syn); - List cleaned = new ArrayList<>(); - for (String l : lines) { - String s = l.trim(); - if (!s.isEmpty()) cleaned.add(s); - } - if (!cleaned.isEmpty()) return Optional.of(cleaned); - } catch (IOException ignore) {} - } - return Optional.empty(); - } - // ========== 新增:预处理并保存到临时文件 ========== private File preprocessAndSave(File inputImage) throws IOException { BufferedImage img = ImageIO.read(inputImage); @@ -391,36 +319,4 @@ public class AnimeModelWrapper implements AutoCloseable { return new int[]{w, h}; } - - /* ================= convenience 主方法(快速测试) ================= */ - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.out.println("用法: AnimeModelWrapper "); - System.out.println("示例: AnimeModelWrapper ./anime_unet.pt input.jpg outDir eye,face"); - System.out.println("标签: " + AnimeLabelPalette.defaultLabels()); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - Path out = Path.of(args[2]); - String targetsArg = args[3]; - - Set targets; - if ("all".equalsIgnoreCase(targetsArg.trim())) { - targets = new LinkedHashSet<>(AnimeLabelPalette.defaultLabels()); - } else { - String[] parts = targetsArg.split(","); - targets = new LinkedHashSet<>(); - for (String p : parts) { - if (!p.trim().isEmpty()) targets.add(p.trim()); - } - } - - try (AnimeModelWrapper wrapper = AnimeModelWrapper.load(modelDir)) { - Map m = wrapper.segmentAndSave(input, targets, out); - m.forEach((k, v) -> { - System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); - }); - } - } } diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java index c354cf8..bb18069 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmentationResult.java @@ -1,12 +1,14 @@ package com.chuangzhou.vivid2D.ai.anime_face_segmentation; +import com.chuangzhou.vivid2D.ai.SegmentationResult; + import java.awt.image.BufferedImage; import java.util.Map; /** * 动漫分割结果容器 */ -public class AnimeSegmentationResult { +public class AnimeSegmentationResult extends SegmentationResult { // 分割掩码图(每个像素的颜色为对应类别颜色) private final BufferedImage maskImage; @@ -21,6 +23,7 @@ public class AnimeSegmentationResult { public AnimeSegmentationResult(BufferedImage maskImage, float[][][] probabilityMap, Map labels, Map palette) { + super(maskImage, labels, palette); this.maskImage = maskImage; this.probabilityMap = probabilityMap; this.labels = labels; diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java index 98d0c6c..ed773a2 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_face_segmentation/AnimeSegmenter.java @@ -16,6 +16,7 @@ import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import com.chuangzhou.vivid2D.ai.Segmenter; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; @@ -27,7 +28,7 @@ import java.util.*; /** * AnimeSegmenter: 专门为 Anime-Face-Segmentation UNet 模型设计的分割器 */ -public class AnimeSegmenter implements AutoCloseable { +public class AnimeSegmenter extends Segmenter { // 模型默认输入大小(与训练时一致)。若模型不同可以修改为实际值或让 caller 通过构造参数传入。 private static final int MODEL_INPUT_W = 512; @@ -48,11 +49,10 @@ public class AnimeSegmenter implements AutoCloseable { private final ZooModel modelWrapper; private final Predictor predictor; - private final List labels; private final Map palette; public AnimeSegmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { - this.labels = new ArrayList<>(labels); + super(modelDir, labels); this.palette = AnimeLabelPalette.defaultPalette(); Translator translator = new Translator() { @@ -137,6 +137,16 @@ public class AnimeSegmenter implements AutoCloseable { this.predictor = modelWrapper.newPredictor(); } + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + return null; + } + + @Override + public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) { + return null; + } + public AnimeSegmentationResult segment(File imgFile) throws TranslateException, IOException { Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); @@ -201,30 +211,4 @@ public class AnimeSegmenter implements AutoCloseable { } catch (Exception ignore) { } } - - // 测试主函数 - public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.out.println("用法: java AnimeSegmenter "); - System.out.println("示例: java AnimeSegmenter ./anime_unet.pt input.jpg output.png"); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - File out = new File(args[2]); - - List labels = AnimeLabelPalette.defaultLabels(); - - try (AnimeSegmenter segmenter = new AnimeSegmenter(modelDir, labels)) { - AnimeSegmentationResult res = segmenter.segment(input); - ImageIO.write(res.getMaskImage(), "png", out); - System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath()); - - // 额外保存眼睛分割结果 - BufferedImage eyes = segmenter.extractEyes(input); - File eyesOut = new File(out.getParent(), "eyes_" + out.getName()); - ImageIO.write(eyes, "png", eyesOut); - System.out.println("眼睛分割结果已保存到: " + eyesOut.getAbsolutePath()); - } - } } diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java deleted file mode 100644 index de57b2e..0000000 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2ModelWrapper.java +++ /dev/null @@ -1,244 +0,0 @@ -package com.chuangzhou.vivid2D.ai.anime_segmentation; - -import javax.imageio.ImageIO; -import java.awt.*; -import java.awt.image.BufferedImage; -import java.io.*; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.*; -import java.util.List; - -/** - * Anime2ModelWrapper - 对动漫分割模型的封装 - * - * 用法示例: - * Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(Paths.get("/path/to/modelDir")); - * Map out = wrapper.segmentAndSave( - * new File("input.jpg"), - * Set.of("foreground"), // 动漫分割主要关注前景 - * Paths.get("outDir") - * ); - * wrapper.close(); - */ -public class Anime2ModelWrapper implements AutoCloseable { - - private final Anime2Segmenter segmenter; - private final List labels; // index -> name - private final Map palette; // name -> ARGB - - private Anime2ModelWrapper(Anime2Segmenter segmenter, List labels, Map palette) { - this.segmenter = segmenter; - this.labels = labels; - this.palette = palette; - } - - /** - * 创建 Anime2Segmenter 实例 - */ - public static Anime2ModelWrapper load(Path modelDir) throws Exception { - List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); - Anime2Segmenter s = new Anime2Segmenter(modelDir, labels); - Map palette = Anime2LabelPalette.animeSegmentationPalette(); - return new Anime2ModelWrapper(s, labels, palette); - } - - public List getLabels() { - return Collections.unmodifiableList(labels); - } - - public Map getPalette() { - return Collections.unmodifiableMap(palette); - } - - /** - * 直接返回分割结果 - */ - public Anime2SegmentationResult segment(File inputImage) throws Exception { - return segmenter.segment(inputImage); - } - - /** - * 把指定 targets(标签名集合)从输入图片中分割并保存到 outDir - */ - public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { - if (!Files.exists(outDir)) { - Files.createDirectories(outDir); - } - - Anime2SegmentationResult res = segment(inputImage); - BufferedImage original = ImageIO.read(inputImage); - BufferedImage maskImage = res.getMaskImage(); - - int maskW = maskImage.getWidth(); - int maskH = maskImage.getHeight(); - - // 解析 targets - Set realTargets = parseTargetsSet(targets); - Map saved = new LinkedHashMap<>(); - - for (String target : realTargets) { - if (!palette.containsKey(target)) { - System.err.println("Warning: unknown label '" + target + "' - skip."); - continue; - } - - int targetColor = palette.get(target); - - // 1) 生成透明背景的二值掩码(只保留 target 像素) - BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); - for (int y = 0; y < maskH; y++) { - for (int x = 0; x < maskW; x++) { - int c = maskImage.getRGB(x, y); - if (c == targetColor) { - partMask.setRGB(x, y, targetColor | 0xFF000000); // 保证不透明 - } else { - partMask.setRGB(x, y, 0x00000000); - } - } - } - - // 2) 将 mask 缩放到与原图一致 - BufferedImage maskResized = partMask; - if (original.getWidth() != maskW || original.getHeight() != maskH) { - maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); - Graphics2D g = maskResized.createGraphics(); - g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); - g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); - g.dispose(); - } - - // 3) 生成叠加图 - BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); - Graphics2D g2 = overlay.createGraphics(); - g2.drawImage(original, 0, 0, null); - - int rgbOnly = (targetColor & 0x00FFFFFF); - int translucent = (0x88 << 24) | rgbOnly; - BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); - for (int y = 0; y < colorOverlay.getHeight(); y++) { - for (int x = 0; x < colorOverlay.getWidth(); x++) { - int mc = maskResized.getRGB(x, y); - if ((mc & 0x00FFFFFF) == (targetColor & 0x00FFFFFF) && ((mc >>> 24) != 0)) { - colorOverlay.setRGB(x, y, translucent); - } else { - colorOverlay.setRGB(x, y, 0x00000000); - } - } - } - g2.drawImage(colorOverlay, 0, 0, null); - g2.dispose(); - - // 保存 - String safe = safeFileName(target); - File maskOut = outDir.resolve(safe + "_mask.png").toFile(); - File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); - - ImageIO.write(maskResized, "png", maskOut); - ImageIO.write(overlay, "png", overlayOut); - - saved.put(target, new ResultFiles(maskOut, overlayOut)); - } - - return saved; - } - - private static String safeFileName(String s) { - return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); - } - - private static Set parseTargetsSet(Set in) { - if (in == null || in.isEmpty()) return Collections.emptySet(); - if (in.size() == 1) { - String only = in.iterator().next(); - if ("all".equalsIgnoreCase(only.trim())) { - return Set.of("foreground"); // 动漫分割主要关注前景 - } - } - Set out = new LinkedHashSet<>(); - for (String s : in) { - if (s != null) out.add(s.trim()); - } - return out; - } - - /** - * 关闭底层资源 - */ - @Override - public void close() { - try { - segmenter.close(); - } catch (Exception ignore) {} - } - - /** - * 存放结果文件路径 - */ - public static class ResultFiles { - private final File maskFile; - private final File overlayFile; - - public ResultFiles(File maskFile, File overlayFile) { - this.maskFile = maskFile; - this.overlayFile = overlayFile; - } - - public File getMaskFile() { - return maskFile; - } - - public File getOverlayFile() { - return overlayFile; - } - } - - /* ================= helper: 从 modelDir 读取 synset.txt ================= */ - private static Optional> loadLabelsFromSynset(Path modelDir) { - Path syn = modelDir.resolve("synset.txt"); - if (Files.exists(syn)) { - try { - List lines = Files.readAllLines(syn); - List cleaned = new ArrayList<>(); - for (String l : lines) { - String s = l.trim(); - if (!s.isEmpty()) cleaned.add(s); - } - if (!cleaned.isEmpty()) return Optional.of(cleaned); - } catch (IOException ignore) {} - } - return Optional.empty(); - } - - /* ================= convenience 主方法(快速测试) ================= */ - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.out.println("用法: Anime2ModelWrapper "); - System.out.println("示例: Anime2ModelWrapper /models/anime_seg /images/in.jpg outDir foreground"); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - Path out = Path.of(args[2]); - String targetsArg = args[3]; - - List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); - Set targets; - if ("all".equalsIgnoreCase(targetsArg.trim())) { - targets = new LinkedHashSet<>(labels); - } else { - String[] parts = targetsArg.split(","); - targets = new LinkedHashSet<>(); - for (String p : parts) { - if (!p.trim().isEmpty()) targets.add(p.trim()); - } - } - - try (Anime2ModelWrapper wrapper = Anime2ModelWrapper.load(modelDir)) { - Map m = wrapper.segmentAndSave(input, targets, out); - m.forEach((k, v) -> { - System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); - }); - } - } -} \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java index d638f29..6b36eba 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2SegmentationResult.java @@ -1,36 +1,15 @@ package com.chuangzhou.vivid2D.ai.anime_segmentation; +import com.chuangzhou.vivid2D.ai.SegmentationResult; + import java.awt.image.BufferedImage; import java.util.Map; /** * 动漫分割结果容器 */ -public class Anime2SegmentationResult { - // 分割掩码图(每个像素的颜色为对应类别颜色) - private final BufferedImage maskImage; - - // 类别索引 -> 类别名称 - private final Map labels; - - // 类别名称 -> ARGB 颜色 - private final Map palette; - +public class Anime2SegmentationResult extends SegmentationResult { public Anime2SegmentationResult(BufferedImage maskImage, Map labels, Map palette) { - this.maskImage = maskImage; - this.labels = labels; - this.palette = palette; - } - - public BufferedImage getMaskImage() { - return maskImage; - } - - public Map getLabels() { - return labels; - } - - public Map getPalette() { - return palette; + super(maskImage, labels, palette); } } \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java index 2482b92..10a969b 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2Segmenter.java @@ -1,21 +1,17 @@ package com.chuangzhou.vivid2D.ai.anime_segmentation; import ai.djl.MalformedModelException; -import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; -import ai.djl.ndarray.types.Shape; -import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.translate.Batchifier; import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import com.chuangzhou.vivid2D.ai.SegmentationResult; +import com.chuangzhou.vivid2D.ai.Segmenter; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; @@ -28,94 +24,51 @@ import java.util.*; * Anime2Segmenter: 专门用于动漫分割模型 * 处理 anime-segmentation 模型的二值分割输出 */ -public class Anime2Segmenter implements AutoCloseable { - - public static class SegmentationData { - final int[] indices; - final long[] shape; - - public SegmentationData(int[] indices, long[] shape) { - this.indices = indices; - this.shape = shape; - } - } - - private final ZooModel modelWrapper; - private final Predictor predictor; - private final List labels; - private final Map palette; - +public class Anime2Segmenter extends Segmenter { public Anime2Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { - this.labels = new ArrayList<>(labels); - this.palette = Anime2LabelPalette.animeSegmentationPalette(); - - Translator translator = new Translator() { - @Override - public NDList processInput(TranslatorContext ctx, Image input) { - NDManager manager = ctx.getNDManager(); - - // 调整输入图像尺寸到模型期望的大小 (1024x1024) - Image resized = input.resize(1024, 1024, true); - NDArray array = resized.toNDArray(manager); - - // 转换为 CHW 格式并归一化 - array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); - array = array.div(255f); - array = array.expandDims(0); // 添加batch维度 - - return new NDList(array); - } - - @Override - public SegmentationData processOutput(TranslatorContext ctx, NDList list) { - if (list == null || list.isEmpty()) { - throw new IllegalStateException("Model did not return any output."); - } - - NDArray out = list.get(0); - - // 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图 - // 应用sigmoid并二值化 - NDArray probabilities = out.div(out.neg().exp().add(1)); - NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false); - - // 移除batch和channel维度 - if (binaryMask.getShape().dimension() == 4) { - binaryMask = binaryMask.squeeze(0).squeeze(0); - } - - // 转换为Java数组 - long[] finalShape = binaryMask.getShape().getShape(); - int[] indices = binaryMask.toIntArray(); - - return new SegmentationData(indices, finalShape); - } - - @Override - public Batchifier getBatchifier() { - return null; - } - }; - - Criteria criteria = Criteria.builder() - .setTypes(Image.class, SegmentationData.class) - .optModelPath(modelDir) - .optEngine("PyTorch") - .optTranslator(translator) - .build(); - - this.modelWrapper = criteria.loadModel(); - this.predictor = modelWrapper.newPredictor(); + super(modelDir, labels); } - public Anime2SegmentationResult segment(File imgFile) throws TranslateException, IOException { + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + + // 调整输入图像尺寸到模型期望的大小 (1024x1024) + Image resized = input.resize(1024, 1024, true); + NDArray array = resized.toNDArray(manager); + + // 转换为 CHW 格式并归一化 + array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); + array = array.div(255f); + array = array.expandDims(0); // 添加batch维度 + + return new NDList(array); + } + + @Override + public SegmentationData processOutput(TranslatorContext ctx, NDList list) { + if (list == null || list.isEmpty()) { + throw new IllegalStateException("Model did not return any output."); + } + NDArray out = list.get(0); + // 动漫分割模型输出形状: [1, 1, H, W] - 单通道概率图 + // 应用sigmoid并二值化 + NDArray probabilities = out.div(out.neg().exp().add(1)); + NDArray binaryMask = probabilities.gt(0.5).toType(DataType.INT32, false); + if (binaryMask.getShape().dimension() == 4) { + binaryMask = binaryMask.squeeze(0).squeeze(0); + } + long[] finalShape = binaryMask.getShape().getShape(); + int[] indices = binaryMask.toIntArray(); + return new SegmentationData(indices, finalShape); + } + + @Override + public SegmentationResult segment(File imgFile) throws TranslateException, IOException { Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); - - SegmentationData data = predictor.predict(img); - + Segmenter.SegmentationData data = predictor.predict(img); long[] shp = data.shape; int[] indices = data.indices; - int height, width; if (shp.length == 2) { height = (int) shp[0]; @@ -123,14 +76,11 @@ public class Anime2Segmenter implements AutoCloseable { } else { throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); } - - // 创建分割掩码 BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); Map labelsMap = new HashMap<>(); for (int i = 0; i < labels.size(); i++) { labelsMap.put(i, labels.get(i)); } - for (int y = 0; y < height; y++) { for (int x = 0; x < width; x++) { int idx = indices[y * width + x]; @@ -139,8 +89,7 @@ public class Anime2Segmenter implements AutoCloseable { mask.setRGB(x, y, argb); } } - - return new Anime2SegmentationResult(mask, labelsMap, palette); + return new SegmentationResult(mask, labelsMap, palette); } @Override @@ -154,22 +103,4 @@ public class Anime2Segmenter implements AutoCloseable { } catch (Exception ignore) { } } - - public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.out.println("用法: java Anime2Segmenter "); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - File out = new File(args[2]); - - List labels = Anime2LabelPalette.defaultLabels(); - - try (Anime2Segmenter s = new Anime2Segmenter(modelDir, labels)) { - Anime2SegmentationResult res = s.segment(input); - ImageIO.write(res.getMaskImage(), "png", out); - System.out.println("动漫分割掩码已保存到: " + out.getAbsolutePath()); - } - } } \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java index 3643dc8..786f6e2 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/anime_segmentation/Anime2VividModelWrapper.java @@ -1,5 +1,8 @@ package com.chuangzhou.vivid2D.ai.anime_segmentation; +import com.chuangzhou.vivid2D.ai.SegmentationResult; +import com.chuangzhou.vivid2D.ai.VividModelWrapper; + import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; @@ -22,16 +25,11 @@ import java.util.List; * // out contains 每个目标标签对应的 mask+overlay 文件路径 * wrapper.close(); */ -public class Anime2VividModelWrapper implements AutoCloseable { - - private final Anime2Segmenter segmenter; - private final List labels; // index -> name - private final Map palette; // name -> ARGB +public class Anime2VividModelWrapper extends VividModelWrapper { private Anime2VividModelWrapper(Anime2Segmenter segmenter, List labels, Map palette) { - this.segmenter = segmenter; - this.labels = labels; - this.palette = palette; + super(segmenter, labels, palette); + } /** @@ -56,7 +54,7 @@ public class Anime2VividModelWrapper implements AutoCloseable { /** * 直接返回分割结果(Anime2SegmentationResult) */ - public Anime2SegmentationResult segment(File inputImage) throws Exception { + public SegmentationResult segment(File inputImage) throws Exception { return segmenter.segment(inputImage); } @@ -66,12 +64,12 @@ public class Anime2VividModelWrapper implements AutoCloseable { *

* 返回值:Map,ResultFiles 包含 maskFile、overlayFile(两个 PNG) */ - public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { if (!Files.exists(outDir)) { Files.createDirectories(outDir); } - Anime2SegmentationResult res = segment(inputImage); + SegmentationResult res = segment(inputImage); BufferedImage original = ImageIO.read(inputImage); BufferedImage maskImage = res.getMaskImage(); @@ -84,7 +82,6 @@ public class Anime2VividModelWrapper implements AutoCloseable { for (String target : realTargets) { if (!palette.containsKey(target)) { - // 尝试忽略大小写匹配 String finalTarget = target; Optional matched = palette.keySet().stream() .filter(k -> k.equalsIgnoreCase(finalTarget)) @@ -154,109 +151,4 @@ public class Anime2VividModelWrapper implements AutoCloseable { return saved; } - - private static String safeFileName(String s) { - return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); - } - - private static Set parseTargetsSet(Set in) { - if (in == null || in.isEmpty()) return Collections.emptySet(); - // 若包含单个 "all" - if (in.size() == 1) { - String only = in.iterator().next(); - if ("all".equalsIgnoreCase(only.trim())) { - // 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette) - // For convenience, return a set containing "all" and let caller logic handle it earlier. - return Set.of("all"); - } - } - // 直接返回 trim 后的小写不变集合(保持用户传入的名字) - Set out = new LinkedHashSet<>(); - for (String s : in) { - if (s != null) out.add(s.trim()); - } - return out; - } - - /** - * 关闭底层资源 - */ - @Override - public void close() { - try { - segmenter.close(); - } catch (Exception ignore) {} - } - - /** - * 存放结果文件路径 - */ - public static class ResultFiles { - private final File maskFile; - private final File overlayFile; - - public ResultFiles(File maskFile, File overlayFile) { - this.maskFile = maskFile; - this.overlayFile = overlayFile; - } - - public File getMaskFile() { - return maskFile; - } - - public File getOverlayFile() { - return overlayFile; - } - } - - /* ================= helper: 从 modelDir 读取 synset.txt ================= */ - - private static Optional> loadLabelsFromSynset(Path modelDir) { - Path syn = modelDir.resolve("synset.txt"); - if (Files.exists(syn)) { - try { - List lines = Files.readAllLines(syn); - List cleaned = new ArrayList<>(); - for (String l : lines) { - String s = l.trim(); - if (!s.isEmpty()) cleaned.add(s); - } - if (!cleaned.isEmpty()) return Optional.of(cleaned); - } catch (IOException ignore) {} - } - return Optional.empty(); - } - - /* ================= convenience 主方法(快速测试) ================= */ - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.out.println("用法: Anime2VividModelWrapper "); - System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir foreground"); - System.out.println("示例: Anime2VividModelWrapper /models/anime_seg /images/in.jpg outDir all"); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - Path out = Path.of(args[2]); - String targetsArg = args[3]; - - List labels = loadLabelsFromSynset(modelDir).orElseGet(Anime2LabelPalette::defaultLabels); - Set targets; - if ("all".equalsIgnoreCase(targetsArg.trim())) { - targets = new LinkedHashSet<>(labels); - } else { - String[] parts = targetsArg.split(","); - targets = new LinkedHashSet<>(); - for (String p : parts) { - if (!p.trim().isEmpty()) targets.add(p.trim()); - } - } - - try (Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(modelDir)) { - Map m = wrapper.segmentAndSave(input, targets, out); - m.forEach((k, v) -> { - System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); - }); - } - } } \ No newline at end of file diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetLabelPalette.java similarity index 98% rename from src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java rename to src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetLabelPalette.java index 12f33b8..c4c25f5 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/LabelPalette.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetLabelPalette.java @@ -7,7 +7,7 @@ import java.util.*; * 颜色值基于 zllrunning/face-parsing.PyTorch 仓库的 test.py 文件。 * 标签索引必须与模型输出索引一致(0-18)。 */ -public class LabelPalette { +public class BiSeNetLabelPalette { /** * BiSeNet 人脸解析模型的标准标签(19个类别,索引 0-18) diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmentationResult.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmentationResult.java new file mode 100644 index 0000000..5e634ac --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmentationResult.java @@ -0,0 +1,15 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import com.chuangzhou.vivid2D.ai.SegmentationResult; + +import java.awt.image.BufferedImage; +import java.util.Map; + +/** + * 分割结果容器 + */ +public class BiSeNetSegmentationResult extends SegmentationResult { + public BiSeNetSegmentationResult(BufferedImage maskImage, Map labels, Map palette) { + super(maskImage, labels, palette); + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmenter.java new file mode 100644 index 0000000..30c71e3 --- /dev/null +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetSegmenter.java @@ -0,0 +1,101 @@ +package com.chuangzhou.vivid2D.ai.face_parsing; + +import ai.djl.MalformedModelException; +import ai.djl.inference.Predictor; +import ai.djl.modality.cv.Image; +import ai.djl.modality.cv.ImageFactory; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorContext; +import com.chuangzhou.vivid2D.ai.SegmentationResult; +import com.chuangzhou.vivid2D.ai.Segmenter; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +/** + * Segmenter: 加载模型并对图片做语义分割 + * + * 说明: + * - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray, + * 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。 + * - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()), + * 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。 + */ +public class BiSeNetSegmenter extends Segmenter { + + public BiSeNetSegmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { + super(modelDir, labels); + } + + @Override + public NDList processInput(TranslatorContext ctx, Image input) { + NDManager manager = ctx.getNDManager(); + NDArray array = input.toNDArray(manager); + array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); + array = array.div(255f); + array = array.expandDims(0); + return new NDList(array); + } + + @Override + public Segmenter.SegmentationData processOutput(TranslatorContext ctx, NDList list) { + if (list == null || list.isEmpty()) { + throw new IllegalStateException("Model did not return any output."); + } + + NDArray out = list.get(0); + NDArray classMap; + + // 1. 解析模型输出,得到类别图谱 (classMap) + long[] shape = out.getShape().getShape(); + if (shape.length == 4 && shape[1] > 1) { + classMap = out.argMax(1); + } else if (shape.length == 3) { + classMap = (shape[0] == 1) ? out : out.argMax(0); + } else if (shape.length == 2) { + classMap = out; + } else { + throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape)); + } + + if (classMap.getShape().dimension() == 3) { + classMap = classMap.squeeze(0); + } + + // 2. *** 关键步骤 *** + // 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型 + + // 首先,确保数据类型是 INT32 + NDArray int32ClassMap = classMap.toType(DataType.INT32, false); + + // 然后,获取形状和 int[] 数组 + long[] finalShape = int32ClassMap.getShape().getShape(); + int[] indices = int32ClassMap.toIntArray(); + + // 3. 将 Java 对象封装并返回 + return new SegmentationData(indices, finalShape); + } + + @Override + public SegmentationResult segment(File imgFile) throws TranslateException, IOException { + return super.segment(imgFile); + } + + @Override + public void close() { + super.close(); + } +} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetVividModelWrapper.java similarity index 52% rename from src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java rename to src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetVividModelWrapper.java index 641e0df..7f2e393 100644 --- a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/VividModelWrapper.java +++ b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/BiSeNetVividModelWrapper.java @@ -1,5 +1,8 @@ package com.chuangzhou.vivid2D.ai.face_parsing; +import com.chuangzhou.vivid2D.ai.SegmentationResult; +import com.chuangzhou.vivid2D.ai.VividModelWrapper; + import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; @@ -22,35 +25,29 @@ import java.util.List; * // out contains 每个目标标签对应的 mask+overlay 文件路径 * wrapper.close(); */ -public class VividModelWrapper implements AutoCloseable { +public class BiSeNetVividModelWrapper extends VividModelWrapper { - private final Segmenter segmenter; - private final List labels; // index -> name - private final Map palette; // name -> ARGB - - private VividModelWrapper(Segmenter segmenter, List labels, Map palette) { - this.segmenter = segmenter; - this.labels = labels; - this.palette = palette; + private BiSeNetVividModelWrapper(BiSeNetSegmenter segmenter, List labels, Map palette) { + super(segmenter, labels, palette); } /** * 读取 modelDir/synset.txt(每行一个标签),若不存在则使用 LabelPalette.defaultLabels() * 并创建 Segmenter 实例。 */ - public static VividModelWrapper load(Path modelDir) throws Exception { - List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); - Segmenter s = new Segmenter(modelDir, labels); - Map palette = LabelPalette.defaultPalette(); - return new VividModelWrapper(s, labels, palette); + public static BiSeNetVividModelWrapper load(Path modelDir) throws Exception { + List labels = loadLabelsFromSynset(modelDir).orElseGet(BiSeNetLabelPalette::defaultLabels); + BiSeNetSegmenter s = new BiSeNetSegmenter(modelDir, labels); + Map palette = BiSeNetLabelPalette.defaultPalette(); + return new BiSeNetVividModelWrapper(s, labels, palette); } public List getLabels() { - return Collections.unmodifiableList(labels); + return super.getLabels(); } public Map getPalette() { - return Collections.unmodifiableMap(palette); + return super.getPalette(); } /** @@ -66,25 +63,19 @@ public class VividModelWrapper implements AutoCloseable { *

* 返回值:Map,ResultFiles 包含 maskFile、overlayFile(两个 PNG) */ - public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { + public Map segmentAndSave(File inputImage, Set targets, Path outDir) throws Exception { if (!Files.exists(outDir)) { Files.createDirectories(outDir); } - SegmentationResult res = segment(inputImage); BufferedImage original = ImageIO.read(inputImage); BufferedImage maskImage = res.getMaskImage(); - int maskW = maskImage.getWidth(); int maskH = maskImage.getHeight(); - - // 解析 targets Set realTargets = parseTargetsSet(targets); Map saved = new LinkedHashMap<>(); - for (String target : realTargets) { if (!palette.containsKey(target)) { - // 尝试忽略大小写匹配 String finalTarget = target; Optional matched = palette.keySet().stream() .filter(k -> k.equalsIgnoreCase(finalTarget)) @@ -95,10 +86,7 @@ public class VividModelWrapper implements AutoCloseable { continue; } } - int targetColor = palette.get(target); - - // 1) 生成透明背景的二值掩码(只保留 target 像素) BufferedImage partMask = new BufferedImage(maskW, maskH, BufferedImage.TYPE_INT_ARGB); for (int y = 0; y < maskH; y++) { for (int x = 0; x < maskW; x++) { @@ -110,8 +98,6 @@ public class VividModelWrapper implements AutoCloseable { } } } - - // 2) 将 mask 缩放到与原图一致(如果需要),并生成 overlay(半透明) BufferedImage maskResized = partMask; if (original.getWidth() != maskW || original.getHeight() != maskH) { maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); @@ -120,11 +106,9 @@ public class VividModelWrapper implements AutoCloseable { g.drawImage(partMask, 0, 0, original.getWidth(), original.getHeight(), null); g.dispose(); } - BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); Graphics2D g2 = overlay.createGraphics(); g2.drawImage(original, 0, 0, null); - // 半透明颜色(alpha = 0x88) int rgbOnly = (targetColor & 0x00FFFFFF); int translucent = (0x88 << 24) | rgbOnly; BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); @@ -140,44 +124,17 @@ public class VividModelWrapper implements AutoCloseable { } g2.drawImage(colorOverlay, 0, 0, null); g2.dispose(); - - // 保存 String safe = safeFileName(target); File maskOut = outDir.resolve(safe + "_mask.png").toFile(); File overlayOut = outDir.resolve(safe + "_overlay.png").toFile(); - ImageIO.write(maskResized, "png", maskOut); ImageIO.write(overlay, "png", overlayOut); - saved.put(target, new ResultFiles(maskOut, overlayOut)); } return saved; } - private static String safeFileName(String s) { - return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); - } - - private static Set parseTargetsSet(Set in) { - if (in == null || in.isEmpty()) return Collections.emptySet(); - // 若包含单个 "all" - if (in.size() == 1) { - String only = in.iterator().next(); - if ("all".equalsIgnoreCase(only.trim())) { - // 由调用方自行取 labels(这里返回 sentinel, but caller already checks palette) - // For convenience, return a set containing "all" and let caller logic handle it earlier. - return Set.of("all"); - } - } - // 直接返回 trim 后的小写不变集合(保持用户传入的名字) - Set out = new LinkedHashSet<>(); - for (String s : in) { - if (s != null) out.add(s.trim()); - } - return out; - } - /** * 关闭底层资源 */ @@ -187,75 +144,4 @@ public class VividModelWrapper implements AutoCloseable { segmenter.close(); } catch (Exception ignore) {} } - - /** - * 存放结果文件路径 - */ - public static class ResultFiles { - private final File maskFile; - private final File overlayFile; - - public ResultFiles(File maskFile, File overlayFile) { - this.maskFile = maskFile; - this.overlayFile = overlayFile; - } - - public File getMaskFile() { - return maskFile; - } - - public File getOverlayFile() { - return overlayFile; - } - } - - /* ================= helper: 从 modelDir 读取 synset.txt ================= */ - - private static Optional> loadLabelsFromSynset(Path modelDir) { - Path syn = modelDir.resolve("synset.txt"); - if (Files.exists(syn)) { - try { - List lines = Files.readAllLines(syn); - List cleaned = new ArrayList<>(); - for (String l : lines) { - String s = l.trim(); - if (!s.isEmpty()) cleaned.add(s); - } - if (!cleaned.isEmpty()) return Optional.of(cleaned); - } catch (IOException ignore) {} - } - return Optional.empty(); - } - - /* ================= convenience 主方法(快速测试) ================= */ - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.out.println("用法: VividModelWrapper "); - System.out.println("示例: VividModelWrapper /models/bisenet /images/in.jpg outDir eye,face"); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - Path out = Path.of(args[2]); - String targetsArg = args[3]; - - List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); - Set targets; - if ("all".equalsIgnoreCase(targetsArg.trim())) { - targets = new LinkedHashSet<>(labels); - } else { - String[] parts = targetsArg.split(","); - targets = new LinkedHashSet<>(); - for (String p : parts) { - if (!p.trim().isEmpty()) targets.add(p.trim()); - } - } - - try (VividModelWrapper wrapper = VividModelWrapper.load(modelDir)) { - Map m = wrapper.segmentAndSave(input, targets, out); - m.forEach((k, v) -> { - System.out.println(String.format("Label=%s, mask=%s, overlay=%s", k, v.getMaskFile().getAbsolutePath(), v.getOverlayFile().getAbsolutePath())); - }); - } - } } diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java deleted file mode 100644 index b4f0717..0000000 --- a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/Segmenter.java +++ /dev/null @@ -1,193 +0,0 @@ -package com.chuangzhou.vivid2D.ai.face_parsing; - -import ai.djl.MalformedModelException; -import ai.djl.inference.Predictor; -import ai.djl.modality.cv.Image; -import ai.djl.modality.cv.ImageFactory; -import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.NDList; -import ai.djl.ndarray.NDManager; -import ai.djl.ndarray.types.DataType; -import ai.djl.repository.zoo.Criteria; -import ai.djl.repository.zoo.ModelNotFoundException; -import ai.djl.repository.zoo.ZooModel; -import ai.djl.translate.Batchifier; -import ai.djl.translate.TranslateException; -import ai.djl.translate.Translator; -import ai.djl.translate.TranslatorContext; - -import javax.imageio.ImageIO; -import java.awt.image.BufferedImage; -import java.io.File; -import java.io.IOException; -import java.nio.file.Path; -import java.util.*; - -/** - * Segmenter: 加载模型并对图片做语义分割 - * - * 说明: - * - Translator.processOutput 在翻译器层就把模型输出处理成 (H, W) 的类别索引 NDArray, - * 并把该 NDArray 拷贝到 persistentManager 中返回,从而避免后续 native 资源被释放的问题。 - * - 这里改为在 Translator 内部把 classMap 转为 Java int[](通过 classMap.toIntArray()), - * 再用 persistentManager.create(int[], shape) 创建新的 NDArray 返回,确保安全。 - */ -public class Segmenter implements AutoCloseable { - - // 内部类,用于从Translator安全地传出数据 - public static class SegmentationData { - final int[] indices; - final long[] shape; - - public SegmentationData(int[] indices, long[] shape) { - this.indices = indices; - this.shape = shape; - } - } - - private final ZooModel modelWrapper; - private final Predictor predictor; - private final List labels; - private final Map palette; - - public Segmenter(Path modelDir, List labels) throws IOException, MalformedModelException, ModelNotFoundException { - this.labels = new ArrayList<>(labels); - this.palette = LabelPalette.defaultPalette(); - - // Translator 的输出类型现在是 SegmentationData - Translator translator = new Translator() { - @Override - public NDList processInput(TranslatorContext ctx, Image input) { - NDManager manager = ctx.getNDManager(); - NDArray array = input.toNDArray(manager); - array = array.transpose(2, 0, 1).toType(DataType.FLOAT32, false); - array = array.div(255f); - array = array.expandDims(0); - return new NDList(array); - } - - @Override - public SegmentationData processOutput(TranslatorContext ctx, NDList list) { - if (list == null || list.isEmpty()) { - throw new IllegalStateException("Model did not return any output."); - } - - NDArray out = list.get(0); - NDArray classMap; - - // 1. 解析模型输出,得到类别图谱 (classMap) - long[] shape = out.getShape().getShape(); - if (shape.length == 4 && shape[1] > 1) { - classMap = out.argMax(1); - } else if (shape.length == 3) { - classMap = (shape[0] == 1) ? out : out.argMax(0); - } else if (shape.length == 2) { - classMap = out; - } else { - throw new IllegalStateException("Unexpected output shape: " + Arrays.toString(shape)); - } - - if (classMap.getShape().dimension() == 3) { - classMap = classMap.squeeze(0); - } - - // 2. *** 关键步骤 *** - // 在 NDArray 仍然有效的上下文中,将其转换为 Java 原生类型 - - // 首先,确保数据类型是 INT32 - NDArray int32ClassMap = classMap.toType(DataType.INT32, false); - - // 然后,获取形状和 int[] 数组 - long[] finalShape = int32ClassMap.getShape().getShape(); - int[] indices = int32ClassMap.toIntArray(); - - // 3. 将 Java 对象封装并返回 - return new SegmentationData(indices, finalShape); - } - - @Override - public Batchifier getBatchifier() { - return null; // 或者根据需要使用 Batchifier.STACK - } - }; - - // Criteria 的类型也需要更新 - Criteria criteria = Criteria.builder() - .setTypes(Image.class, SegmentationData.class) - .optModelPath(modelDir) - .optEngine("PyTorch") - .optTranslator(translator) - .build(); - - this.modelWrapper = criteria.loadModel(); - this.predictor = modelWrapper.newPredictor(); - } - - public SegmentationResult segment(File imgFile) throws TranslateException, IOException { - Image img = ImageFactory.getInstance().fromFile(imgFile.toPath()); - - // predict 方法现在直接返回安全的 Java 对象 - SegmentationData data = predictor.predict(img); - - long[] shp = data.shape; - int[] indices = data.indices; - - int height, width; - if (shp.length == 2) { - height = (int) shp[0]; - width = (int) shp[1]; - } else { - throw new RuntimeException("Unexpected classMap shape from SegmentationData: " + Arrays.toString(shp)); - } - - // 后续处理完全基于 Java 对象,不再有 Native resource 问题 - BufferedImage mask = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); - Map labelsMap = new HashMap<>(); - for (int i = 0; i < labels.size(); i++) { - labelsMap.put(i, labels.get(i)); - } - - for (int y = 0; y < height; y++) { - for (int x = 0; x < width; x++) { - int idx = indices[y * width + x]; - String label = labelsMap.getOrDefault(idx, "unknown"); - int argb = palette.getOrDefault(label, 0xFF00FF00); - mask.setRGB(x, y, argb); - } - } - - return new SegmentationResult(mask, labelsMap, palette); - } - - - @Override - public void close() { - try { - predictor.close(); - } catch (Exception ignore) { - } - try { - modelWrapper.close(); - } catch (Exception ignore) { - } - } - - // 小测试主函数(示例) - public static void main(String[] args) throws Exception { - if (args.length < 3) { - System.out.println("用法: java Segmenter "); - return; - } - Path modelDir = Path.of(args[0]); - File input = new File(args[1]); - File out = new File(args[2]); - - List labels = LabelPalette.defaultLabels(); - - try (Segmenter s = new Segmenter(modelDir, labels)) { - SegmentationResult res = s.segment(input); - ImageIO.write(res.getMaskImage(), "png", out); - System.out.println("分割掩码已保存到: " + out.getAbsolutePath()); - } - } -} diff --git a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java b/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java deleted file mode 100644 index 0a1ad11..0000000 --- a/src/main/java/com/chuangzhou/vivid2D/ai/face_parsing/SegmenterExample.java +++ /dev/null @@ -1,184 +0,0 @@ -package com.chuangzhou.vivid2D.ai.face_parsing; - -import javax.imageio.ImageIO; -import java.awt.image.BufferedImage; -import java.awt.*; -import java.io.*; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.*; -import java.util.List; - -/** - * SegmenterExample - * - * 使用说明(命令行): - * java -cp com.chuangzhou.vivid2D.ai.face_parsing.SegmenterExample \ - * - * - * 示例: - * java ... SegmenterExample /models/face_bisent /images/in.jpg /out "eye,face" - * java ... SegmenterExample /models/face_bisent /images/in.jpg /out all - */ -public class SegmenterExample { - - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.err.println("用法: SegmenterExample "); - System.err.println("例如: SegmenterExample /models/face_bisent input.jpg outDir eye,face"); - return; - } - - Path modelDir = Path.of(args[0]); - File inputImage = new File(args[1]); - Path outDir = Path.of(args[2]); - String targetsArg = args[3]; - - if (!Files.exists(modelDir)) { - System.err.println("modelDir 不存在: " + modelDir); - return; - } - if (!inputImage.exists()) { - System.err.println("输入图片不存在: " + inputImage.getAbsolutePath()); - return; - } - if (!Files.exists(outDir)) { - Files.createDirectories(outDir); - } - - // 读取 synset.txt(如果有),否则使用默认 LabelPalette - List labels = loadLabelsFromSynset(modelDir).orElseGet(LabelPalette::defaultLabels); - - // 打开 Segmenter - try (Segmenter segmenter = new Segmenter(modelDir, labels)) { - SegmentationResult res = segmenter.segment(inputImage); - - // 原始图片 - BufferedImage original = ImageIO.read(inputImage); - - // palette: labelName -> ARGB int - Map palette = res.getPalette(); - Map labelsMap = res.getLabels(); // index -> name - - // 解析目标 labels 列表 - Set targets = parseTargets(targetsArg, labels); - - System.out.println("Will export targets: " + targets); - - // maskImage: 每像素是类别颜色(ARGB) - BufferedImage maskImage = res.getMaskImage(); - int w = maskImage.getWidth(); - int h = maskImage.getHeight(); - - // 为快速查 color -> labelName - Map colorToLabel = new HashMap<>(); - for (Map.Entry e : palette.entrySet()) { - colorToLabel.put(e.getValue(), e.getKey()); - } - - // 对每个 target 生成单独的 mask 和 overlay - for (String target : targets) { - if (!palette.containsKey(target)) { - System.err.println("警告:模型 palette 中没有标签 '" + target + "',跳过。"); - continue; - } - int targetColor = palette.get(target); - - // 1) 生成透明背景的二值掩码(只保留 target 像素) - BufferedImage partMask = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB); - for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) { - int c = maskImage.getRGB(x, y); - if (c == targetColor) { - // 保留为不透明(使用原始颜色) - partMask.setRGB(x, y, targetColor); - } else { - // 透明 - partMask.setRGB(x, y, 0x00000000); - } - } - } - - // 2) 生成 overlay:在原图上叠加半透明的 targetColor 区域 - BufferedImage overlay = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); - // 若分辨率不同,先缩放 mask 到原图大小(简单粗暴地按尺寸相同假设) - BufferedImage maskResized = maskImage; - if (original.getWidth() != w || original.getHeight() != h) { - // 简单缩放 mask 到原图尺寸 - maskResized = new BufferedImage(original.getWidth(), original.getHeight(), BufferedImage.TYPE_INT_ARGB); - Graphics2D g = maskResized.createGraphics(); - g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); - g.drawImage(maskImage, 0, 0, original.getWidth(), original.getHeight(), null); - g.dispose(); - } - - // 在 overlay 上先画原图 - Graphics2D g2 = overlay.createGraphics(); - g2.drawImage(original, 0, 0, null); - - // 创建半透明颜色(将 targetColor 的 alpha 设为 0x88) - int rgbOnly = (targetColor & 0x00FFFFFF); - int translucent = (0x88 << 24) | rgbOnly; - // 创建一个图像,把 mask 中对应像素设置为 translucent,否则透明 - BufferedImage colorOverlay = new BufferedImage(overlay.getWidth(), overlay.getHeight(), BufferedImage.TYPE_INT_ARGB); - for (int y = 0; y < colorOverlay.getHeight(); y++) { - for (int x = 0; x < colorOverlay.getWidth(); x++) { - int mc = maskResized.getRGB(x, y); - if (mc == targetColor) { - colorOverlay.setRGB(x, y, translucent); - } else { - colorOverlay.setRGB(x, y, 0x00000000); - } - } - } - // 将 colorOverlay 画到 overlay 上 - g2.drawImage(colorOverlay, 0, 0, null); - g2.dispose(); - - // 保存文件 - File maskOut = outDir.resolve(safeFileName(target) + "_mask.png").toFile(); - File overlayOut = outDir.resolve(safeFileName(target) + "_overlay.png").toFile(); - - ImageIO.write(partMask, "png", maskOut); - ImageIO.write(overlay, "png", overlayOut); - - System.out.println("Saved mask: " + maskOut.getAbsolutePath()); - System.out.println("Saved overlay: " + overlayOut.getAbsolutePath()); - } - } - } - - private static Optional> loadLabelsFromSynset(Path modelDir) { - Path syn = modelDir.resolve("synset.txt"); - if (Files.exists(syn)) { - try { - List lines = Files.readAllLines(syn); - List cleaned = new ArrayList<>(); - for (String l : lines) { - String s = l.trim(); - if (!s.isEmpty()) cleaned.add(s); - } - if (!cleaned.isEmpty()) return Optional.of(cleaned); - } catch (IOException ignore) {} - } - return Optional.empty(); - } - - private static Set parseTargets(String arg, List availableLabels) { - String a = arg.trim(); - if (a.equalsIgnoreCase("all")) { - return new LinkedHashSet<>(availableLabels); - } - String[] parts = a.split(","); - Set out = new LinkedHashSet<>(); - for (String p : parts) { - String t = p.trim(); - if (!t.isEmpty()) out.add(t); - } - return out; - } - - private static String safeFileName(String s) { - return s.replaceAll("[^a-zA-Z0-9_\\-\\.]", "_"); - } -} diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java b/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java index ee60a24..6097ca9 100644 --- a/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java +++ b/src/main/java/com/chuangzhou/vivid2D/test/AI2Test.java @@ -16,7 +16,7 @@ public class AI2Test { System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); // 使用 AnimeModelWrapper 而不是 VividModelWrapper - AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\anime_unet.pt")); + AnimeModelWrapper wrapper = AnimeModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\Anime-Face-Segmentation\\Anime-Face-Segmentation.pt")); // 使用 Anime-Face-Segmentation 的 7 个标签 Set animeLabels = Set.of( diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java b/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java index 0a3956e..73ec954 100644 --- a/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java +++ b/src/main/java/com/chuangzhou/vivid2D/test/AI3Test.java @@ -1,7 +1,6 @@ package com.chuangzhou.vivid2D.test; import com.chuangzhou.vivid2D.ai.anime_segmentation.Anime2VividModelWrapper; -import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper; import java.io.PrintStream; import java.nio.charset.StandardCharsets; @@ -15,7 +14,7 @@ public class AI3Test { public static void main(String[] args) throws Exception { System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8)); System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); - Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\isnetis_traced.pt")); + Anime2VividModelWrapper wrapper = Anime2VividModelWrapper.load(Paths.get("C:\\Users\\Administrator\\Desktop\\model\\anime-segmentation-main\\anime-segmentation.pt")); Set faceLabels = Set.of("foreground"); diff --git a/src/main/java/com/chuangzhou/vivid2D/test/AITest.java b/src/main/java/com/chuangzhou/vivid2D/test/AITest.java index c79036d..b89ff83 100644 --- a/src/main/java/com/chuangzhou/vivid2D/test/AITest.java +++ b/src/main/java/com/chuangzhou/vivid2D/test/AITest.java @@ -1,6 +1,6 @@ package com.chuangzhou.vivid2D.test; -import com.chuangzhou.vivid2D.ai.face_parsing.VividModelWrapper; +import com.chuangzhou.vivid2D.ai.face_parsing.BiSeNetVividModelWrapper; import java.io.PrintStream; import java.nio.charset.StandardCharsets; @@ -14,7 +14,7 @@ public class AITest { public static void main(String[] args) throws Exception { System.setOut(new PrintStream(System.out, true, StandardCharsets.UTF_8)); System.setErr(new PrintStream(System.err, true, StandardCharsets.UTF_8)); - VividModelWrapper wrapper = VividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt")); + BiSeNetVividModelWrapper wrapper = BiSeNetVividModelWrapper.load(Paths.get("C:\\models\\bisenet_face_parsing.pt")); // 使用 BiSeNet 人脸解析模型的 18 个非背景标签 Set faceLabels = Set.of(