feat(main): 新增 FridaNative、LM 和 RegisterTray 模块

- 添加 FridaNative模块,用于 Frida 相关的 native 代码
- 添加 LM 模块,用于 llama模型相关的 native 代码
- 添加 RegisterTray 模块,用于注册系统托盘图标和相关操作
- 新建对应的头文件、源文件和项目配置文件
This commit is contained in:
tzdwindows 7
2025-05-02 19:16:14 +08:00
parent d8099c3489
commit 3253997641
34 changed files with 2476 additions and 0 deletions

165
src/main/Cpp/LM/LM.vcxproj Normal file
View File

@@ -0,0 +1,165 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<VCProjectVersion>17.0</VCProjectVersion>
<Keyword>Win32Proj</Keyword>
<ProjectGuid>{a3131b71-dd4e-41c6-927a-20b8b287fd6c}</ProjectGuid>
<RootNamespace>LM</RootNamespace>
<WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v143</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v143</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v143</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v143</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="Shared">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<IncludePath>C:\Users\Administrator\Desktop\llama资源\include;C:\Users\Administrator\.jdks\corretto-20.0.2.1\include\win32;C:\Users\Administrator\.jdks\corretto-20.0.2.1\include;$(IncludePath)</IncludePath>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;_DEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableUAC>false</EnableUAC>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;NDEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableUAC>false</EnableUAC>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>_DEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableUAC>false</EnableUAC>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>NDEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
<LanguageStandard>stdcpp17</LanguageStandard>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableUAC>false</EnableUAC>
<AdditionalLibraryDirectories>C:\Users\Administrator\Desktop\llama资源\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
<AdditionalDependencies>llama.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="framework.h" />
<ClInclude Include="org_tzd_lm_LM.h" />
<ClInclude Include="pch.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="dllmain.cpp" />
<ClCompile Include="org_tzd_lm_LM.cpp" />
<ClCompile Include="pch.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

View File

@@ -0,0 +1,39 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="源文件">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="头文件">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
</Filter>
<Filter Include="资源文件">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="framework.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="pch.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="org_tzd_lm_LM.h">
<Filter>头文件</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="dllmain.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="pch.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="org_tzd_lm_LM.cpp">
<Filter>源文件</Filter>
</ClCompile>
</ItemGroup>
</Project>

View File

@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="Current" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<PropertyGroup>
<ShowAllFiles>false</ShowAllFiles>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
</PropertyGroup>
</Project>

View File

@@ -0,0 +1,21 @@
// dllmain.cpp : 定义 DLL 应用程序的入口点。
#include "pch.h"
#include "llama.h"
BOOL APIENTRY DllMain( HMODULE hModule,
DWORD ul_reason_for_call,
LPVOID lpReserved
)
{
switch (ul_reason_for_call)
{
case DLL_PROCESS_ATTACH:
case DLL_THREAD_ATTACH:
case DLL_THREAD_DETACH:
case DLL_PROCESS_DETACH:
break;
}
return TRUE;
}

View File

@@ -0,0 +1,5 @@
#pragma once
#define WIN32_LEAN_AND_MEAN // 从 Windows 头文件中排除极少使用的内容
// Windows 头文件
#include <windows.h>

View File

@@ -0,0 +1,419 @@
#include "pch.h"
#include "org_tzd_lm_LM.h"
#include <ctime>
#include <list>
#include "llama.h"
#include <string>
#include <unordered_set>
#include <vector>
#include <locale>
#include <codecvt>
// <20>ص<EFBFBD><D8B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȫ<EFBFBD><C8AB><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
static jmethodID gMessageCallbackMethodID = nullptr;
static jmethodID gProgressCallbackMethodID = nullptr;
JNIEnv* env_;
jobject messageCallback_;
jobject progressCallback__;
static bool isRun = true;
bool ToCppBool(jboolean value) {
return value == JNI_TRUE;
}
bool llamaProgressCallback(float progress, void* user_data) {
JNIEnv* env = (JNIEnv*)user_data;
jint j_progress = progress;
jboolean ret = env->CallBooleanMethod(progressCallback__, gProgressCallbackMethodID, j_progress);
return ToCppBool(ret);
}
//--------------------------------------------------
// ģ<>ͼ<EFBFBD><CDBC>غ<EFBFBD><D8BA>ͷ<EFBFBD>
//--------------------------------------------------
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_llamaLoadModelFromFile
(JNIEnv* env, jclass clazz, jstring pathModel,
jboolean vocab_only_jboolean,
jboolean use_mmap_jboolean, jboolean use_mlock_jboolean,jboolean check_tensors_jboolean, jobject progressCallback) {
const char* path = env->GetStringUTFChars(pathModel, nullptr);
if (!path) {
return 0;
}
progressCallback__ = progressCallback;
llama_model_params params = llama_model_default_params();
if (progressCallback && !gProgressCallbackMethodID) {
jclass callbackClass = env->GetObjectClass(progressCallback);
if (!callbackClass) {
return 0;
}
gProgressCallbackMethodID = env->GetMethodID(callbackClass, "onModelLoad", "(F)Z");
if (!gProgressCallbackMethodID) {
return 0;
}
params.progress_callback = llamaProgressCallback;
params.progress_callback_user_data = env;
}
bool vocab_only = ToCppBool(vocab_only_jboolean);
bool use_mmap = ToCppBool(use_mmap_jboolean);
bool use_mlock = ToCppBool(use_mlock_jboolean);
bool check_tensors = ToCppBool(check_tensors_jboolean);
params.vocab_only = static_cast<bool>(vocab_only);
params.use_mmap = static_cast<bool>(use_mmap);
params.use_mlock = static_cast<bool>(use_mlock);
params.check_tensors = static_cast<bool>(check_tensors);
llama_model* model = llama_model_load_from_file(path, params);
env->ReleaseStringUTFChars(pathModel, path);
if (!model) {
jclass exClass = env->FindClass("java/io/IOException");
if (exClass) {
env->ThrowNew(exClass, "Failed to load model: check path and parameters");
}
return 0;
}
return reinterpret_cast<jlong>(model);
}
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeModel
(JNIEnv* env, jclass clazz, jlong modelHandle) {
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
llama_model_free(model); // ʹ<><CAB9><EFBFBD>µ<EFBFBD> API
}
//--------------------------------------------------
// <20><><EFBFBD><EFBFBD><EFBFBD>Ĵ<EFBFBD><C4B4><EFBFBD>
//--------------------------------------------------
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_createContext
(JNIEnv* env, jclass clazz, jlong modelHandle, jint nCtx,
jint nBatch,
jint nSeqMax,
jint nThreads,
jint nThreadsBatch,
jboolean logitsAll,
jboolean embeddings,
jboolean offloadKqv,
jboolean flashAttn,
jboolean noPerf
) {
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
if (!model) {
jclass exClass = env->FindClass("java/lang/IllegalArgumentException");
if (exClass) {
env->ThrowNew(exClass, "Invalid model handle");
}
return 0;
}
llama_context_params ctx_params = llama_context_default_params();
if (nCtx != 0) {
ctx_params.n_ctx = nCtx;
}
if (nBatch != 0) {
ctx_params.n_batch = nBatch;
}
if (nSeqMax != 0) {
ctx_params.n_seq_max = nSeqMax;
}
if (nThreads != 0) {
ctx_params.n_threads = nThreads;
}
if (nThreadsBatch != 0) {
ctx_params.n_threads_batch = nThreadsBatch;
}
ctx_params.logits_all = static_cast<bool>(logitsAll);
ctx_params.embeddings = static_cast<bool>(embeddings);
ctx_params.offload_kqv = static_cast<bool>(offloadKqv);
ctx_params.flash_attn = static_cast<bool>(flashAttn);
ctx_params.no_perf = static_cast<bool>(noPerf);
llama_context* ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
jclass exClass = env->FindClass("java/io/IOException");
if (exClass) {
env->ThrowNew(exClass, "Failed to create context");
}
return 0;
}
return reinterpret_cast<jlong>(ctx);
}
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeContext
(JNIEnv* env, jclass clazz, jlong ctxHandle)
{
llama_context* ctx = reinterpret_cast<llama_context*>(ctxHandle);
llama_kv_cache_clear(ctx);
llama_free(ctx);
isRun = false;
}
static int tokenize_prompt(
const llama_vocab* vocab,
const std::string& prompt,
std::vector<llama_token>& prompt_tokens,
llama_context* context
) {
const bool is_first = llama_get_kv_cache_used_cells(context) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab,
prompt.c_str(),
prompt.size(),
NULL,
0,
is_first,
true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
true) < 0) {
printf("failed to tokenize the prompt\n");
return -1;
}
return n_prompt_tokens;
}
static int check_context_size(
const llama_context* ctx,
const llama_batch& batch
) {
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("context size exceeded\n");
return 1;
}
return 0;
}
static int convert_token_to_string(const llama_vocab* vocab, const llama_token token_id, std::string& piece) {
char buf[256];
int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
if (n < 0) {
printf("failed to convert token to piece\n");
return 1;
}
piece = std::string(buf, n);
return 0;
}
static void print_word_and_concatenate_to_response(const std::string& piece, std::string& response) {
jstring message = env_->NewStringUTF(piece.c_str());
if (message) {
env_->CallVoidMethod(messageCallback_, gMessageCallbackMethodID, message);
env_->DeleteLocalRef(message);
}
fflush(stdout);
response += piece;
}
static int apply_chat_template_with_error_handling(const bool append, std::string response, int& output_length) {
if (!append)
{
const int new_len = response.length();
if (new_len < 0) {
printf("failed to apply the chat template\n");
return -1;
}
output_length = new_len;
}
return 0;
}
std::vector<llama_token> tokens;
static int generate(
llama_model* llama_data,
llama_context* context,
llama_sampler* smpl,
const std::string& prompt,
std::string& response
) {
//llama_kv_cache_clear(context);
const llama_vocab* vocab = llama_model_get_vocab(llama_data);
isRun = true;
if (tokenize_prompt(vocab, prompt, tokens, context) < 0) {
return 1;
}
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
llama_token new_token_id;
while (true) {
check_context_size(context, batch);
if (llama_decode(context, batch)) {
printf("\nfailed to decode\n");
return 1;
}
new_token_id = llama_sampler_sample(smpl, context, -1);
if (llama_vocab_is_eog(vocab, new_token_id)) {
break;
}
std::string piece;
if (convert_token_to_string(vocab, new_token_id, piece)) {
return 1;
}
print_word_and_concatenate_to_response(piece, response);
batch = llama_batch_get_one(&new_token_id, 1);
if (!isRun) {
return 0;
}
}
return 0;
}
llama_sampler* initialize_sampler(float temperature,
float min_p,
float top_k,
float top_p,
float dist,
int32_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present
) {
if (!dist) {
dist = LLAMA_DEFAULT_SEED;
}
llama_sampler_chain_params params = llama_sampler_chain_default_params();
llama_sampler* sampler = llama_sampler_chain_init(params);
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>˳<EFBFBD><CBB3><EFBFBD><EFBFBD>ʾ<EFBFBD><CABE>˳<EFBFBD>򣬸<EFBFBD><F2A3ACB8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
llama_sampler_chain_add(sampler, llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(top_k));
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(top_p, 1));
llama_sampler_chain_add(sampler, llama_sampler_init_temp(temperature));
llama_sampler_chain_add(sampler, llama_sampler_init_min_p(min_p, 1));
llama_sampler_chain_add(sampler, llama_sampler_init_dist(dist));
// <20>Ƴ<EFBFBD><C6B3>DZ<EFBFBD>Ҫ<EFBFBD><D2AA>̰<EFBFBD><CCB0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
// llama_sampler_chain_add(sampler, llama_sampler_init_greedy());
return sampler;
}
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> UTF-16 jstring ת<><D7AA>Ϊ UTF-8 std::string
std::string jstringToUTF8(JNIEnv* env, jstring jstr) {
if (!jstr) return "";
const jchar* raw = env->GetStringChars(jstr, nullptr);
if (!raw) return "";
jsize len = env->GetStringLength(jstr);
// <20><> UTF-16 ת<><D7AA>Ϊ UTF-8
int utf8Size = WideCharToMultiByte(CP_UTF8, 0, reinterpret_cast<const wchar_t*>(raw), len, nullptr, 0, nullptr, nullptr);
std::string utf8(utf8Size, 0);
WideCharToMultiByte(CP_UTF8, 0, reinterpret_cast<const wchar_t*>(raw), len, &utf8[0], utf8Size, nullptr, nullptr);
env->ReleaseStringChars(jstr, raw);
return utf8;
}
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> UTF-8 std::string ת<><D7AA>Ϊ jstring
jstring utf8ToJstring(JNIEnv* env, const std::string& utf8) {
// <20><> UTF-8 ת<><D7AA>Ϊ UTF-16
int utf16Size = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, nullptr, 0);
std::wstring utf16(utf16Size, 0);
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &utf16[0], utf16Size);
return env->NewString(reinterpret_cast<const jchar*>(utf16.c_str()), utf16Size - 1);
}
//--------------------------------------------------
// <20><><EFBFBD><EFBFBD><EFBFBD>߼<EFBFBD>ʵ<EFBFBD><CAB5>
//--------------------------------------------------
JNIEXPORT jstring JNICALL Java_org_tzd_lm_LM_inference
(JNIEnv* env, jclass clazz, jlong modelHandle, jlong ctxHandle, jfloat temperature, jfloat minP,
jfloat topK, jfloat topP, jfloat dist, jint penaltyLastN, jfloat penaltyRepeat, jfloat penaltyFreq,
jfloat penaltyPresent, jstring prompt, jobject messageCallback) {
llama_context* ctx = reinterpret_cast<llama_context*>(ctxHandle);
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
env_ = env;
messageCallback_ = messageCallback;
// <20><><EFBFBD><EFBFBD> ctx <20>Ƿ<EFBFBD><C7B7><EFBFBD>Ч
if (!ctx) {
jclass exClass = env->FindClass("java/lang/IllegalArgumentException");
if (exClass) env->ThrowNew(exClass, "Invalid context handle");
return nullptr;
}
// <20><>ʼ<EFBFBD><CABC><EFBFBD>ص<EFBFBD><D8B5><EFBFBD><EFBFBD><EFBFBD>ID
if (!gMessageCallbackMethodID) {
jclass callbackClass = env->GetObjectClass(messageCallback);
if (!callbackClass) {
return nullptr;
}
gMessageCallbackMethodID = env->GetMethodID(callbackClass, "onMessage", "(Ljava/lang/String;)V");
if (!gMessageCallbackMethodID) {
return nullptr;
}
}
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ӧ<EFBFBD>ַ<EFBFBD><D6B7><EFBFBD>
std::string response;
std::string prompt_(jstringToUTF8(env, prompt));
// ʹ<>ó<EFBFBD>ʼ<EFBFBD><CABC><EFBFBD>IJ<EFBFBD><C4B2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ɽ<EFBFBD><C9BD><EFBFBD>
llama_sampler* sampler = initialize_sampler(temperature, minP, topK,
topP, dist, penaltyLastN, penaltyRepeat,
penaltyFreq, penaltyPresent);
if (generate(model, ctx, sampler, prompt_, response) != 0) {
return nullptr;
}
return utf8ToJstring(env, response);
}
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateLoadFile(
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
jlongArray tokensOut, jint nTokenCapacity, jintArray nTokenCountOut
) {
const char* path = env->GetStringUTFChars(pathSession, NULL);
jlong* tokens_out = env->GetLongArrayElements(tokensOut, NULL);
jint* n_token_count_out = env->GetIntArrayElements(nTokenCountOut, NULL);
bool result = llama_state_load_file((struct llama_context*)ctx, path,
(llama_token*)tokens_out,
nTokenCapacity,
(size_t*)n_token_count_out);
env->ReleaseStringUTFChars(pathSession, path);
env->ReleaseLongArrayElements(tokensOut, tokens_out, 0);
env->ReleaseIntArrayElements(nTokenCountOut, n_token_count_out, 0);
return result ? JNI_TRUE : JNI_FALSE;
}
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateSaveFile(
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
jlongArray tokens, jint nTokenCount
) {
const char* path = env->GetStringUTFChars(pathSession, NULL);
jlong* tokens_array = env->GetLongArrayElements(tokens, NULL);
bool result = llama_state_save_file((struct llama_context*)ctx, path,
(const llama_token*)tokens_array,
nTokenCount);
env->ReleaseStringUTFChars(pathSession, path);
env->ReleaseLongArrayElements(tokens, tokens_array, 0);
return result ? JNI_TRUE : JNI_FALSE;
}

View File

@@ -0,0 +1,67 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_tzd_lm_LM */
#ifndef _Included_org_tzd_lm_LM
#define _Included_org_tzd_lm_LM
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_tzd_lm_LM
* Method: llamaLoadModelFromFile
* Signature: (Ljava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_llamaLoadModelFromFile
(JNIEnv*, jclass, jstring, jboolean , jboolean, jboolean, jboolean, jobject);
/*
* Class: org_tzd_lm_LM
* Method: llamaFreeModel
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeModel
(JNIEnv*, jclass, jlong);
/*
* Class: org_tzd_lm_LM
* Method: createContext
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_createContext
(JNIEnv* env, jclass clazz, jlong modelHandle, jint nCtx,
jint nBatch,
jint nSeqMax,
jint nThreads,
jint nThreadsBatch,
jboolean logitsAll,
jboolean embeddings,
jboolean offloadKqv,
jboolean flashAttn,
jboolean noPerf);
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeContext
(JNIEnv*, jclass, jlong);
/*
* Class: org_tzd_lm_LM
* Method: inference
* Signature: (JLjava/lang/String;Lorg/tzd/lm/LM/MessageCallback;)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_tzd_lm_LM_inference
(JNIEnv* env, jclass clazz, jlong modelHandle, jlong ctxHandle, jfloat temperature, jfloat minP, jfloat topK, jfloat topP, jfloat dist, jint penaltyLastN, jfloat penaltyRepeat, jfloat penaltyFreq, jfloat penaltyPresent, jstring prompt, jobject messageCallback);
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateSaveFile(
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
jlongArray tokens, jint nTokenCount
);
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateLoadFile(
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
jlongArray tokensOut, jint nTokenCapacity, jintArray nTokenCountOut
);
#ifdef __cplusplus
}
#endif
#endif

5
src/main/Cpp/LM/pch.cpp Normal file
View File

@@ -0,0 +1,5 @@
// pch.cpp: 与预编译标头对应的源文件
#include "pch.h"
// 当使用预编译的头时,需要使用此源文件,编译才能成功。

13
src/main/Cpp/LM/pch.h Normal file
View File

@@ -0,0 +1,13 @@
// pch.h: 这是预编译标头文件。
// 下方列出的文件仅编译一次,提高了将来生成的生成性能。
// 这还将影响 IntelliSense 性能,包括代码完成和许多代码浏览功能。
// 但是,如果此处列出的文件中的任何一个在生成之间有更新,它们全部都将被重新编译。
// 请勿在此处添加要频繁更新的文件,这将使得性能优势无效。
#ifndef PCH_H
#define PCH_H
// 添加要在此处预编译的标头
#include "framework.h"
#endif //PCH_H

269
src/main/Cpp/LM/tiktoken.h Normal file
View File

@@ -0,0 +1,269 @@
#pragma once
#include <re2/re2.h>
#include "unordered_dense.h"
#include <cassert>
#include <limits>
#include <optional>
#include <regex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace tiktoken {
static auto _byte_pair_merge(
const std::string &piece,
const ankerl::unordered_dense::map<std::string, int> &ranks,
std::function<int (int, int)> func
) -> std::vector<int> {
std::vector<std::pair<int, int>> parts;
parts.reserve(piece.size() + 1);
for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
parts.emplace_back(idx, std::numeric_limits<int>::max());
}
auto get_rank = [&piece, &ranks](
const std::vector<std::pair<int, int>> &parts,
int start_idx,
int skip
) -> std::optional<int> {
if (start_idx + skip + 2 < parts.size()) {
auto s = parts[start_idx].first;
auto e = parts[start_idx + skip + 2].first;
auto key = piece.substr(s, e - s);
auto iter = ranks.find(key);
if (iter != ranks.end()) {
return iter->second;
}
}
return std::nullopt;
};
for (auto i = 0U; i < parts.size() - 2; ++i) {
auto rank = get_rank(parts, i, 0);
if (rank) {
assert(*rank != std::numeric_limits<int>::max());
parts[i].second = *rank;
}
}
while (true) {
if (parts.size() == 1) break;
auto min_rank = std::make_pair<int, int>(std::numeric_limits<int>::max(), 0);
for (auto i = 0U; i < parts.size() - 1; ++i) {
auto rank = parts[i].second;
if (rank < min_rank.first) {
min_rank = { rank, i };
}
}
if (min_rank.first != std::numeric_limits<int>::max()) {
auto i = min_rank.second;
auto rank = get_rank(parts, i, 1);
if (rank) {
parts[i].second = *rank;
} else {
parts[i].second = std::numeric_limits<int>::max();
}
if (i > 0) {
auto rank = get_rank(parts, i - 1, 1);
if (rank) {
parts[i - 1].second = *rank;
} else {
parts[i - 1].second = std::numeric_limits<int>::max();
}
}
parts.erase(parts.begin() + (i + 1));
} else {
break;
}
}
std::vector<int> out;
out.reserve(parts.size() - 1);
for (auto i = 0U; i < parts.size() - 1; ++i) {
out.push_back(func(parts[i].first, parts[i + 1].first));
}
return out;
}
static auto byte_pair_encode(
const std::string &piece,
const ankerl::unordered_dense::map<std::string, int> &ranks
) -> std::vector<int> {
if (piece.size() == 1) {
return {ranks.at(piece)};
}
auto func = [&piece, &ranks](int start, int stop) -> int {
std::string key = piece.substr(start, stop - start);
return ranks.at(key);
};
return _byte_pair_merge(piece, ranks, func);
}
class tiktoken {
public:
tiktoken() = default;
tiktoken(
ankerl::unordered_dense::map<std::string, int> encoder,
ankerl::unordered_dense::map<std::string, int> special_encoder,
const std::string &pattern
) {
regex_ = std::make_unique<re2::RE2>("(" + pattern + ")");
std::string special_pattern;
for (const auto &item : special_encoder) {
if (!special_pattern.empty()) {
special_pattern += "|";
}
special_pattern += re2::RE2::QuoteMeta(item.first);
}
if (special_pattern.empty()) {
special_regex_ = nullptr;
} else {
special_regex_ = std::make_unique<re2::RE2>("(" + special_pattern + ")");
}
encoder_ = std::move(encoder);
special_tokens_encoder = std::move(special_encoder);
for (const auto &[k, v] : encoder_) {
decoder_.emplace(v, k);
}
assert(encoder_.size() != decoder_.size() && "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?");
for (const auto &[k, v] : special_tokens_encoder) {
special_tokens_decoder.emplace(v, k);
}
}
auto encode_ordinary(const std::string &text) const -> std::vector<int> {
return _encode_ordinary_native(text);
}
auto encode(const std::string &text) const -> std::vector<int> {
return _encode_native(text, special_tokens_encoder).first;
}
auto encode_single_piece(const std::string &text) const -> std::vector<int> {
auto iter = encoder_.find(text);
if (iter != encoder_.end()) {
return {iter->second};
}
return byte_pair_encode(text, encoder_);
}
auto decode(const std::vector<int> &tokens) const -> std::string {
return _decode_native(tokens);
}
private:
auto split_with_allowed_special_token(
re2::StringPiece &input,
const ankerl::unordered_dense::map<std::string, int> &allowed_special
) const -> std::pair<std::optional<std::string>, re2::StringPiece> {
if (special_regex_ == nullptr) return { std::nullopt, input };
auto start = input.begin();
std::string special;
while (true) {
if (!re2::RE2::FindAndConsume(&input, *special_regex_, &special)) {
break;
}
if (allowed_special.count(special) == 1) {
return { std::move(special), re2::StringPiece(start, input.begin() - start - special.size()) };
}
}
return { std::nullopt, input };
}
auto _encode_ordinary_native(const std::string &text) const -> std::vector<int> {
std::vector<int> ret;
re2::StringPiece input(text);
std::string piece;
while (re2::RE2::FindAndConsume(&input, *regex_, &piece)) {
auto iter = encoder_.find(piece);
if (iter != encoder_.end()) {
ret.push_back(iter->second);
continue;
}
auto tokens = byte_pair_encode(piece, encoder_);
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
return ret;
}
auto _encode_native(
const std::string &text,
const ankerl::unordered_dense::map<std::string, int> &allowed_special
) const -> std::pair<std::vector<int>, int> {
std::vector<int> ret;
int last_piece_token_len = 0;
re2::StringPiece input(text);
while (true) {
auto [special, sub_input] = split_with_allowed_special_token(input, allowed_special);
std::string piece;
while (re2::RE2::FindAndConsume(&sub_input, *regex_, &piece)) {
auto iter = encoder_.find(piece);
if (iter != encoder_.end()) {
last_piece_token_len = 1;
ret.push_back(iter->second);
continue;
}
auto tokens = byte_pair_encode(piece, encoder_);
last_piece_token_len = tokens.size();
ret.insert(ret.end(), tokens.begin(), tokens.end());
}
if (special) {
int token = special_tokens_encoder.at(*special);
ret.push_back(token);
last_piece_token_len = 0;
} else {
break;
}
}
return { ret, last_piece_token_len };
}
auto _decode_native(const std::vector<int> &tokens) const -> std::string {
std::string ret;
ret.reserve(tokens.size() * 2);
for (auto token : tokens) {
std::string token_bytes;
auto iter = decoder_.find(token);
if (iter != decoder_.end()) {
token_bytes = iter->second;
} else {
iter = special_tokens_decoder.find(token);
if (iter != special_tokens_decoder.end()) {
token_bytes = iter->second;
} else {
throw std::runtime_error("unknown token: " + std::to_string(token));
}
}
ret += token_bytes;
}
return ret;
}
ankerl::unordered_dense::map<std::string, int> encoder_;
ankerl::unordered_dense::map<std::string, int> special_tokens_encoder;
ankerl::unordered_dense::map<int, std::string> decoder_;
ankerl::unordered_dense::map<int, std::string> special_tokens_decoder;
std::unique_ptr<re2::RE2> regex_;
std::unique_ptr<re2::RE2> special_regex_;
};
} // namespace tiktoken