feat(main): 新增 FridaNative、LM 和 RegisterTray 模块
- 添加 FridaNative模块,用于 Frida 相关的 native 代码 - 添加 LM 模块,用于 llama模型相关的 native 代码 - 添加 RegisterTray 模块,用于注册系统托盘图标和相关操作 - 新建对应的头文件、源文件和项目配置文件
This commit is contained in:
165
src/main/Cpp/LM/LM.vcxproj
Normal file
165
src/main/Cpp/LM/LM.vcxproj
Normal file
@@ -0,0 +1,165 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup Label="ProjectConfigurations">
|
||||
<ProjectConfiguration Include="Debug|Win32">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>Win32</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|Win32">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>Win32</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Debug|x64">
|
||||
<Configuration>Debug</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
<ProjectConfiguration Include="Release|x64">
|
||||
<Configuration>Release</Configuration>
|
||||
<Platform>x64</Platform>
|
||||
</ProjectConfiguration>
|
||||
</ItemGroup>
|
||||
<PropertyGroup Label="Globals">
|
||||
<VCProjectVersion>17.0</VCProjectVersion>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
<ProjectGuid>{a3131b71-dd4e-41c6-927a-20b8b287fd6c}</ProjectGuid>
|
||||
<RootNamespace>LM</RootNamespace>
|
||||
<WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v143</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v143</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>true</UseDebugLibraries>
|
||||
<PlatformToolset>v143</PlatformToolset>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
|
||||
<ConfigurationType>DynamicLibrary</ConfigurationType>
|
||||
<UseDebugLibraries>false</UseDebugLibraries>
|
||||
<PlatformToolset>v143</PlatformToolset>
|
||||
<WholeProgramOptimization>true</WholeProgramOptimization>
|
||||
<CharacterSet>Unicode</CharacterSet>
|
||||
</PropertyGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
|
||||
<ImportGroup Label="ExtensionSettings">
|
||||
</ImportGroup>
|
||||
<ImportGroup Label="Shared">
|
||||
</ImportGroup>
|
||||
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
|
||||
</ImportGroup>
|
||||
<PropertyGroup Label="UserMacros" />
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<IncludePath>C:\Users\Administrator\Desktop\llama资源\include;C:\Users\Administrator\.jdks\corretto-20.0.2.1\include\win32;C:\Users\Administrator\.jdks\corretto-20.0.2.1\include;$(IncludePath)</IncludePath>
|
||||
</PropertyGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<PreprocessorDefinitions>WIN32;_DEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ConformanceMode>true</ConformanceMode>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableUAC>false</EnableUAC>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<PreprocessorDefinitions>WIN32;NDEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ConformanceMode>true</ConformanceMode>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableUAC>false</EnableUAC>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<PreprocessorDefinitions>_DEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ConformanceMode>true</ConformanceMode>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableUAC>false</EnableUAC>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<ClCompile>
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<SDLCheck>true</SDLCheck>
|
||||
<PreprocessorDefinitions>NDEBUG;LM_EXPORTS;_WINDOWS;_USRDLL;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<ConformanceMode>true</ConformanceMode>
|
||||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
|
||||
<LanguageStandard>stdcpp17</LanguageStandard>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Windows</SubSystem>
|
||||
<EnableCOMDATFolding>true</EnableCOMDATFolding>
|
||||
<OptimizeReferences>true</OptimizeReferences>
|
||||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
<EnableUAC>false</EnableUAC>
|
||||
<AdditionalLibraryDirectories>C:\Users\Administrator\Desktop\llama资源\lib;%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
|
||||
<AdditionalDependencies>llama.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="framework.h" />
|
||||
<ClInclude Include="org_tzd_lm_LM.h" />
|
||||
<ClInclude Include="pch.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="dllmain.cpp" />
|
||||
<ClCompile Include="org_tzd_lm_LM.cpp" />
|
||||
<ClCompile Include="pch.cpp">
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>
|
||||
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
</ImportGroup>
|
||||
</Project>
|
||||
39
src/main/Cpp/LM/LM.vcxproj.filters
Normal file
39
src/main/Cpp/LM/LM.vcxproj.filters
Normal file
@@ -0,0 +1,39 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<ItemGroup>
|
||||
<Filter Include="源文件">
|
||||
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
|
||||
<Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
|
||||
</Filter>
|
||||
<Filter Include="头文件">
|
||||
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
|
||||
<Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
|
||||
</Filter>
|
||||
<Filter Include="资源文件">
|
||||
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
|
||||
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="framework.h">
|
||||
<Filter>头文件</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="pch.h">
|
||||
<Filter>头文件</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="org_tzd_lm_LM.h">
|
||||
<Filter>头文件</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="dllmain.cpp">
|
||||
<Filter>源文件</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="pch.cpp">
|
||||
<Filter>源文件</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="org_tzd_lm_LM.cpp">
|
||||
<Filter>源文件</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
9
src/main/Cpp/LM/LM.vcxproj.user
Normal file
9
src/main/Cpp/LM/LM.vcxproj.user
Normal file
@@ -0,0 +1,9 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project ToolsVersion="Current" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<PropertyGroup>
|
||||
<ShowAllFiles>false</ShowAllFiles>
|
||||
</PropertyGroup>
|
||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
|
||||
<DebuggerFlavor>WindowsLocalDebugger</DebuggerFlavor>
|
||||
</PropertyGroup>
|
||||
</Project>
|
||||
21
src/main/Cpp/LM/dllmain.cpp
Normal file
21
src/main/Cpp/LM/dllmain.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// dllmain.cpp : 定义 DLL 应用程序的入口点。
|
||||
#include "pch.h"
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
BOOL APIENTRY DllMain( HMODULE hModule,
|
||||
DWORD ul_reason_for_call,
|
||||
LPVOID lpReserved
|
||||
)
|
||||
{
|
||||
switch (ul_reason_for_call)
|
||||
{
|
||||
case DLL_PROCESS_ATTACH:
|
||||
case DLL_THREAD_ATTACH:
|
||||
case DLL_THREAD_DETACH:
|
||||
case DLL_PROCESS_DETACH:
|
||||
break;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
5
src/main/Cpp/LM/framework.h
Normal file
5
src/main/Cpp/LM/framework.h
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#define WIN32_LEAN_AND_MEAN // 从 Windows 头文件中排除极少使用的内容
|
||||
// Windows 头文件
|
||||
#include <windows.h>
|
||||
419
src/main/Cpp/LM/org_tzd_lm_LM.cpp
Normal file
419
src/main/Cpp/LM/org_tzd_lm_LM.cpp
Normal file
@@ -0,0 +1,419 @@
|
||||
#include "pch.h"
|
||||
#include "org_tzd_lm_LM.h"
|
||||
|
||||
#include <ctime>
|
||||
#include <list>
|
||||
|
||||
#include "llama.h"
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
|
||||
// <20>ص<EFBFBD><D8B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȫ<EFBFBD><C8AB><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
static jmethodID gMessageCallbackMethodID = nullptr;
|
||||
static jmethodID gProgressCallbackMethodID = nullptr;
|
||||
JNIEnv* env_;
|
||||
jobject messageCallback_;
|
||||
jobject progressCallback__;
|
||||
|
||||
static bool isRun = true;
|
||||
|
||||
bool ToCppBool(jboolean value) {
|
||||
return value == JNI_TRUE;
|
||||
}
|
||||
|
||||
bool llamaProgressCallback(float progress, void* user_data) {
|
||||
JNIEnv* env = (JNIEnv*)user_data;
|
||||
jint j_progress = progress;
|
||||
jboolean ret = env->CallBooleanMethod(progressCallback__, gProgressCallbackMethodID, j_progress);
|
||||
return ToCppBool(ret);
|
||||
}
|
||||
|
||||
//--------------------------------------------------
|
||||
// ģ<>ͼ<EFBFBD><CDBC>غ<EFBFBD><D8BA>ͷ<EFBFBD>
|
||||
//--------------------------------------------------
|
||||
|
||||
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_llamaLoadModelFromFile
|
||||
(JNIEnv* env, jclass clazz, jstring pathModel,
|
||||
jboolean vocab_only_jboolean,
|
||||
jboolean use_mmap_jboolean, jboolean use_mlock_jboolean,jboolean check_tensors_jboolean, jobject progressCallback) {
|
||||
const char* path = env->GetStringUTFChars(pathModel, nullptr);
|
||||
if (!path) {
|
||||
return 0;
|
||||
}
|
||||
progressCallback__ = progressCallback;
|
||||
llama_model_params params = llama_model_default_params();
|
||||
if (progressCallback && !gProgressCallbackMethodID) {
|
||||
jclass callbackClass = env->GetObjectClass(progressCallback);
|
||||
if (!callbackClass) {
|
||||
return 0;
|
||||
}
|
||||
gProgressCallbackMethodID = env->GetMethodID(callbackClass, "onModelLoad", "(F)Z");
|
||||
if (!gProgressCallbackMethodID) {
|
||||
return 0;
|
||||
}
|
||||
params.progress_callback = llamaProgressCallback;
|
||||
params.progress_callback_user_data = env;
|
||||
}
|
||||
|
||||
bool vocab_only = ToCppBool(vocab_only_jboolean);
|
||||
bool use_mmap = ToCppBool(use_mmap_jboolean);
|
||||
bool use_mlock = ToCppBool(use_mlock_jboolean);
|
||||
bool check_tensors = ToCppBool(check_tensors_jboolean);
|
||||
|
||||
params.vocab_only = static_cast<bool>(vocab_only);
|
||||
params.use_mmap = static_cast<bool>(use_mmap);
|
||||
params.use_mlock = static_cast<bool>(use_mlock);
|
||||
params.check_tensors = static_cast<bool>(check_tensors);
|
||||
|
||||
llama_model* model = llama_model_load_from_file(path, params);
|
||||
env->ReleaseStringUTFChars(pathModel, path);
|
||||
|
||||
if (!model) {
|
||||
jclass exClass = env->FindClass("java/io/IOException");
|
||||
if (exClass) {
|
||||
env->ThrowNew(exClass, "Failed to load model: check path and parameters");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(model);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeModel
|
||||
(JNIEnv* env, jclass clazz, jlong modelHandle) {
|
||||
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
|
||||
llama_model_free(model); // ʹ<><CAB9><EFBFBD>µ<EFBFBD> API
|
||||
}
|
||||
|
||||
//--------------------------------------------------
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD>Ĵ<EFBFBD><C4B4><EFBFBD>
|
||||
//--------------------------------------------------
|
||||
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_createContext
|
||||
(JNIEnv* env, jclass clazz, jlong modelHandle, jint nCtx,
|
||||
jint nBatch,
|
||||
jint nSeqMax,
|
||||
jint nThreads,
|
||||
jint nThreadsBatch,
|
||||
jboolean logitsAll,
|
||||
jboolean embeddings,
|
||||
jboolean offloadKqv,
|
||||
jboolean flashAttn,
|
||||
jboolean noPerf
|
||||
) {
|
||||
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
|
||||
if (!model) {
|
||||
jclass exClass = env->FindClass("java/lang/IllegalArgumentException");
|
||||
if (exClass) {
|
||||
env->ThrowNew(exClass, "Invalid model handle");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
if (nCtx != 0) {
|
||||
ctx_params.n_ctx = nCtx;
|
||||
}
|
||||
if (nBatch != 0) {
|
||||
ctx_params.n_batch = nBatch;
|
||||
}
|
||||
if (nSeqMax != 0) {
|
||||
ctx_params.n_seq_max = nSeqMax;
|
||||
}
|
||||
if (nThreads != 0) {
|
||||
ctx_params.n_threads = nThreads;
|
||||
}
|
||||
if (nThreadsBatch != 0) {
|
||||
ctx_params.n_threads_batch = nThreadsBatch;
|
||||
}
|
||||
|
||||
ctx_params.logits_all = static_cast<bool>(logitsAll);
|
||||
ctx_params.embeddings = static_cast<bool>(embeddings);
|
||||
ctx_params.offload_kqv = static_cast<bool>(offloadKqv);
|
||||
ctx_params.flash_attn = static_cast<bool>(flashAttn);
|
||||
ctx_params.no_perf = static_cast<bool>(noPerf);
|
||||
|
||||
llama_context* ctx = llama_init_from_model(model, ctx_params);
|
||||
if (!ctx) {
|
||||
jclass exClass = env->FindClass("java/io/IOException");
|
||||
if (exClass) {
|
||||
env->ThrowNew(exClass, "Failed to create context");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
return reinterpret_cast<jlong>(ctx);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeContext
|
||||
(JNIEnv* env, jclass clazz, jlong ctxHandle)
|
||||
{
|
||||
llama_context* ctx = reinterpret_cast<llama_context*>(ctxHandle);
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_free(ctx);
|
||||
isRun = false;
|
||||
}
|
||||
|
||||
static int tokenize_prompt(
|
||||
const llama_vocab* vocab,
|
||||
const std::string& prompt,
|
||||
std::vector<llama_token>& prompt_tokens,
|
||||
llama_context* context
|
||||
) {
|
||||
const bool is_first = llama_get_kv_cache_used_cells(context) == 0;
|
||||
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab,
|
||||
prompt.c_str(),
|
||||
prompt.size(),
|
||||
NULL,
|
||||
0,
|
||||
is_first,
|
||||
true);
|
||||
prompt_tokens.resize(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
||||
true) < 0) {
|
||||
printf("failed to tokenize the prompt\n");
|
||||
return -1;
|
||||
}
|
||||
return n_prompt_tokens;
|
||||
}
|
||||
|
||||
static int check_context_size(
|
||||
const llama_context* ctx,
|
||||
const llama_batch& batch
|
||||
) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
|
||||
if (n_ctx_used + batch.n_tokens > n_ctx) {
|
||||
printf("context size exceeded\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int convert_token_to_string(const llama_vocab* vocab, const llama_token token_id, std::string& piece) {
|
||||
char buf[256];
|
||||
int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
|
||||
if (n < 0) {
|
||||
printf("failed to convert token to piece\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
piece = std::string(buf, n);
|
||||
return 0;
|
||||
}
|
||||
static void print_word_and_concatenate_to_response(const std::string& piece, std::string& response) {
|
||||
jstring message = env_->NewStringUTF(piece.c_str());
|
||||
if (message) {
|
||||
env_->CallVoidMethod(messageCallback_, gMessageCallbackMethodID, message);
|
||||
env_->DeleteLocalRef(message);
|
||||
}
|
||||
fflush(stdout);
|
||||
response += piece;
|
||||
}
|
||||
|
||||
static int apply_chat_template_with_error_handling(const bool append, std::string response, int& output_length) {
|
||||
if (!append)
|
||||
{
|
||||
const int new_len = response.length();
|
||||
if (new_len < 0) {
|
||||
printf("failed to apply the chat template\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
output_length = new_len;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
static int generate(
|
||||
llama_model* llama_data,
|
||||
llama_context* context,
|
||||
llama_sampler* smpl,
|
||||
const std::string& prompt,
|
||||
std::string& response
|
||||
) {
|
||||
//llama_kv_cache_clear(context);
|
||||
|
||||
const llama_vocab* vocab = llama_model_get_vocab(llama_data);
|
||||
isRun = true;
|
||||
if (tokenize_prompt(vocab, prompt, tokens, context) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
|
||||
llama_token new_token_id;
|
||||
while (true) {
|
||||
check_context_size(context, batch);
|
||||
if (llama_decode(context, batch)) {
|
||||
printf("\nfailed to decode\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
new_token_id = llama_sampler_sample(smpl, context, -1);
|
||||
if (llama_vocab_is_eog(vocab, new_token_id)) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::string piece;
|
||||
if (convert_token_to_string(vocab, new_token_id, piece)) {
|
||||
return 1;
|
||||
}
|
||||
print_word_and_concatenate_to_response(piece, response);
|
||||
batch = llama_batch_get_one(&new_token_id, 1);
|
||||
if (!isRun) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_sampler* initialize_sampler(float temperature,
|
||||
float min_p,
|
||||
float top_k,
|
||||
float top_p,
|
||||
float dist,
|
||||
int32_t penalty_last_n,
|
||||
float penalty_repeat,
|
||||
float penalty_freq,
|
||||
float penalty_present
|
||||
) {
|
||||
if (!dist) {
|
||||
dist = LLAMA_DEFAULT_SEED;
|
||||
}
|
||||
llama_sampler_chain_params params = llama_sampler_chain_default_params();
|
||||
llama_sampler* sampler = llama_sampler_chain_init(params);
|
||||
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>˳<EFBFBD><CBB3><EFBFBD><EFBFBD>ʾ<EFBFBD><CABE>˳<EFBFBD><EFBFBD><F2A3ACB8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_penalties(penalty_last_n, penalty_repeat, penalty_freq, penalty_present));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(top_k));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(top_p, 1));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(temperature));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_min_p(min_p, 1));
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(dist));
|
||||
|
||||
// <20>Ƴ<EFBFBD><C6B3>DZ<EFBFBD>Ҫ<EFBFBD><D2AA>̰<EFBFBD><CCB0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
// llama_sampler_chain_add(sampler, llama_sampler_init_greedy());
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> UTF-16 jstring ת<><D7AA>Ϊ UTF-8 std::string
|
||||
std::string jstringToUTF8(JNIEnv* env, jstring jstr) {
|
||||
if (!jstr) return "";
|
||||
|
||||
const jchar* raw = env->GetStringChars(jstr, nullptr);
|
||||
if (!raw) return "";
|
||||
|
||||
jsize len = env->GetStringLength(jstr);
|
||||
|
||||
// <20><> UTF-16 ת<><D7AA>Ϊ UTF-8
|
||||
int utf8Size = WideCharToMultiByte(CP_UTF8, 0, reinterpret_cast<const wchar_t*>(raw), len, nullptr, 0, nullptr, nullptr);
|
||||
std::string utf8(utf8Size, 0);
|
||||
WideCharToMultiByte(CP_UTF8, 0, reinterpret_cast<const wchar_t*>(raw), len, &utf8[0], utf8Size, nullptr, nullptr);
|
||||
|
||||
env->ReleaseStringChars(jstr, raw);
|
||||
return utf8;
|
||||
}
|
||||
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> UTF-8 std::string ת<><D7AA>Ϊ jstring
|
||||
jstring utf8ToJstring(JNIEnv* env, const std::string& utf8) {
|
||||
// <20><> UTF-8 ת<><D7AA>Ϊ UTF-16
|
||||
int utf16Size = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, nullptr, 0);
|
||||
std::wstring utf16(utf16Size, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &utf16[0], utf16Size);
|
||||
|
||||
return env->NewString(reinterpret_cast<const jchar*>(utf16.c_str()), utf16Size - 1);
|
||||
}
|
||||
|
||||
//--------------------------------------------------
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʵ<EFBFBD><CAB5>
|
||||
//--------------------------------------------------
|
||||
JNIEXPORT jstring JNICALL Java_org_tzd_lm_LM_inference
|
||||
(JNIEnv* env, jclass clazz, jlong modelHandle, jlong ctxHandle, jfloat temperature, jfloat minP,
|
||||
jfloat topK, jfloat topP, jfloat dist, jint penaltyLastN, jfloat penaltyRepeat, jfloat penaltyFreq,
|
||||
jfloat penaltyPresent, jstring prompt, jobject messageCallback) {
|
||||
llama_context* ctx = reinterpret_cast<llama_context*>(ctxHandle);
|
||||
llama_model* model = reinterpret_cast<llama_model*>(modelHandle);
|
||||
env_ = env;
|
||||
messageCallback_ = messageCallback;
|
||||
|
||||
// <20><><EFBFBD><EFBFBD> ctx <20>Ƿ<EFBFBD><C7B7><EFBFBD>Ч
|
||||
if (!ctx) {
|
||||
jclass exClass = env->FindClass("java/lang/IllegalArgumentException");
|
||||
if (exClass) env->ThrowNew(exClass, "Invalid context handle");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
// <20><>ʼ<EFBFBD><CABC><EFBFBD>ص<EFBFBD><D8B5><EFBFBD><EFBFBD><EFBFBD>ID
|
||||
if (!gMessageCallbackMethodID) {
|
||||
jclass callbackClass = env->GetObjectClass(messageCallback);
|
||||
if (!callbackClass) {
|
||||
return nullptr;
|
||||
}
|
||||
gMessageCallbackMethodID = env->GetMethodID(callbackClass, "onMessage", "(Ljava/lang/String;)V");
|
||||
if (!gMessageCallbackMethodID) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ӧ<EFBFBD>ַ<EFBFBD><D6B7><EFBFBD>
|
||||
std::string response;
|
||||
std::string prompt_(jstringToUTF8(env, prompt));
|
||||
// ʹ<>ó<EFBFBD>ʼ<EFBFBD><CABC><EFBFBD>IJ<EFBFBD><C4B2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ɽ<EFBFBD><C9BD><EFBFBD>
|
||||
llama_sampler* sampler = initialize_sampler(temperature, minP, topK,
|
||||
topP, dist, penaltyLastN, penaltyRepeat,
|
||||
penaltyFreq, penaltyPresent);
|
||||
|
||||
if (generate(model, ctx, sampler, prompt_, response) != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return utf8ToJstring(env, response);
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateLoadFile(
|
||||
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
|
||||
jlongArray tokensOut, jint nTokenCapacity, jintArray nTokenCountOut
|
||||
) {
|
||||
const char* path = env->GetStringUTFChars(pathSession, NULL);
|
||||
|
||||
jlong* tokens_out = env->GetLongArrayElements(tokensOut, NULL);
|
||||
jint* n_token_count_out = env->GetIntArrayElements(nTokenCountOut, NULL);
|
||||
|
||||
bool result = llama_state_load_file((struct llama_context*)ctx, path,
|
||||
(llama_token*)tokens_out,
|
||||
nTokenCapacity,
|
||||
(size_t*)n_token_count_out);
|
||||
|
||||
env->ReleaseStringUTFChars(pathSession, path);
|
||||
env->ReleaseLongArrayElements(tokensOut, tokens_out, 0);
|
||||
env->ReleaseIntArrayElements(nTokenCountOut, n_token_count_out, 0);
|
||||
|
||||
return result ? JNI_TRUE : JNI_FALSE;
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateSaveFile(
|
||||
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
|
||||
jlongArray tokens, jint nTokenCount
|
||||
) {
|
||||
const char* path = env->GetStringUTFChars(pathSession, NULL);
|
||||
|
||||
jlong* tokens_array = env->GetLongArrayElements(tokens, NULL);
|
||||
|
||||
bool result = llama_state_save_file((struct llama_context*)ctx, path,
|
||||
(const llama_token*)tokens_array,
|
||||
nTokenCount);
|
||||
|
||||
env->ReleaseStringUTFChars(pathSession, path);
|
||||
env->ReleaseLongArrayElements(tokens, tokens_array, 0);
|
||||
return result ? JNI_TRUE : JNI_FALSE;
|
||||
}
|
||||
67
src/main/Cpp/LM/org_tzd_lm_LM.h
Normal file
67
src/main/Cpp/LM/org_tzd_lm_LM.h
Normal file
@@ -0,0 +1,67 @@
|
||||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class org_tzd_lm_LM */
|
||||
|
||||
#ifndef _Included_org_tzd_lm_LM
|
||||
#define _Included_org_tzd_lm_LM
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: org_tzd_lm_LM
|
||||
* Method: llamaLoadModelFromFile
|
||||
* Signature: (Ljava/lang/String;)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_llamaLoadModelFromFile
|
||||
(JNIEnv*, jclass, jstring, jboolean , jboolean, jboolean, jboolean, jobject);
|
||||
|
||||
/*
|
||||
* Class: org_tzd_lm_LM
|
||||
* Method: llamaFreeModel
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeModel
|
||||
(JNIEnv*, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tzd_lm_LM
|
||||
* Method: createContext
|
||||
* Signature: (J)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_tzd_lm_LM_createContext
|
||||
(JNIEnv* env, jclass clazz, jlong modelHandle, jint nCtx,
|
||||
jint nBatch,
|
||||
jint nSeqMax,
|
||||
jint nThreads,
|
||||
jint nThreadsBatch,
|
||||
jboolean logitsAll,
|
||||
jboolean embeddings,
|
||||
jboolean offloadKqv,
|
||||
jboolean flashAttn,
|
||||
jboolean noPerf);
|
||||
|
||||
JNIEXPORT void JNICALL Java_org_tzd_lm_LM_llamaFreeContext
|
||||
(JNIEnv*, jclass, jlong);
|
||||
|
||||
/*
|
||||
* Class: org_tzd_lm_LM
|
||||
* Method: inference
|
||||
* Signature: (JLjava/lang/String;Lorg/tzd/lm/LM/MessageCallback;)Ljava/lang/String;
|
||||
*/
|
||||
JNIEXPORT jstring JNICALL Java_org_tzd_lm_LM_inference
|
||||
(JNIEnv* env, jclass clazz, jlong modelHandle, jlong ctxHandle, jfloat temperature, jfloat minP, jfloat topK, jfloat topP, jfloat dist, jint penaltyLastN, jfloat penaltyRepeat, jfloat penaltyFreq, jfloat penaltyPresent, jstring prompt, jobject messageCallback);
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateSaveFile(
|
||||
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
|
||||
jlongArray tokens, jint nTokenCount
|
||||
);
|
||||
|
||||
JNIEXPORT jboolean JNICALL Java_org_tzd_lm_LM_llamaStateLoadFile(
|
||||
JNIEnv* env, jobject obj, jlong ctx, jstring pathSession,
|
||||
jlongArray tokensOut, jint nTokenCapacity, jintArray nTokenCountOut
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
5
src/main/Cpp/LM/pch.cpp
Normal file
5
src/main/Cpp/LM/pch.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// pch.cpp: 与预编译标头对应的源文件
|
||||
|
||||
#include "pch.h"
|
||||
|
||||
// 当使用预编译的头时,需要使用此源文件,编译才能成功。
|
||||
13
src/main/Cpp/LM/pch.h
Normal file
13
src/main/Cpp/LM/pch.h
Normal file
@@ -0,0 +1,13 @@
|
||||
// pch.h: 这是预编译标头文件。
|
||||
// 下方列出的文件仅编译一次,提高了将来生成的生成性能。
|
||||
// 这还将影响 IntelliSense 性能,包括代码完成和许多代码浏览功能。
|
||||
// 但是,如果此处列出的文件中的任何一个在生成之间有更新,它们全部都将被重新编译。
|
||||
// 请勿在此处添加要频繁更新的文件,这将使得性能优势无效。
|
||||
|
||||
#ifndef PCH_H
|
||||
#define PCH_H
|
||||
|
||||
// 添加要在此处预编译的标头
|
||||
#include "framework.h"
|
||||
|
||||
#endif //PCH_H
|
||||
269
src/main/Cpp/LM/tiktoken.h
Normal file
269
src/main/Cpp/LM/tiktoken.h
Normal file
@@ -0,0 +1,269 @@
|
||||
#pragma once
|
||||
|
||||
#include <re2/re2.h>
|
||||
#include "unordered_dense.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace tiktoken {
|
||||
|
||||
static auto _byte_pair_merge(
|
||||
const std::string &piece,
|
||||
const ankerl::unordered_dense::map<std::string, int> &ranks,
|
||||
std::function<int (int, int)> func
|
||||
) -> std::vector<int> {
|
||||
std::vector<std::pair<int, int>> parts;
|
||||
parts.reserve(piece.size() + 1);
|
||||
for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
|
||||
parts.emplace_back(idx, std::numeric_limits<int>::max());
|
||||
}
|
||||
|
||||
auto get_rank = [&piece, &ranks](
|
||||
const std::vector<std::pair<int, int>> &parts,
|
||||
int start_idx,
|
||||
int skip
|
||||
) -> std::optional<int> {
|
||||
if (start_idx + skip + 2 < parts.size()) {
|
||||
auto s = parts[start_idx].first;
|
||||
auto e = parts[start_idx + skip + 2].first;
|
||||
auto key = piece.substr(s, e - s);
|
||||
auto iter = ranks.find(key);
|
||||
if (iter != ranks.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
for (auto i = 0U; i < parts.size() - 2; ++i) {
|
||||
auto rank = get_rank(parts, i, 0);
|
||||
if (rank) {
|
||||
assert(*rank != std::numeric_limits<int>::max());
|
||||
parts[i].second = *rank;
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (parts.size() == 1) break;
|
||||
|
||||
auto min_rank = std::make_pair<int, int>(std::numeric_limits<int>::max(), 0);
|
||||
for (auto i = 0U; i < parts.size() - 1; ++i) {
|
||||
auto rank = parts[i].second;
|
||||
if (rank < min_rank.first) {
|
||||
min_rank = { rank, i };
|
||||
}
|
||||
}
|
||||
|
||||
if (min_rank.first != std::numeric_limits<int>::max()) {
|
||||
auto i = min_rank.second;
|
||||
auto rank = get_rank(parts, i, 1);
|
||||
if (rank) {
|
||||
parts[i].second = *rank;
|
||||
} else {
|
||||
parts[i].second = std::numeric_limits<int>::max();
|
||||
}
|
||||
if (i > 0) {
|
||||
auto rank = get_rank(parts, i - 1, 1);
|
||||
if (rank) {
|
||||
parts[i - 1].second = *rank;
|
||||
} else {
|
||||
parts[i - 1].second = std::numeric_limits<int>::max();
|
||||
}
|
||||
}
|
||||
|
||||
parts.erase(parts.begin() + (i + 1));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<int> out;
|
||||
out.reserve(parts.size() - 1);
|
||||
for (auto i = 0U; i < parts.size() - 1; ++i) {
|
||||
out.push_back(func(parts[i].first, parts[i + 1].first));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static auto byte_pair_encode(
|
||||
const std::string &piece,
|
||||
const ankerl::unordered_dense::map<std::string, int> &ranks
|
||||
) -> std::vector<int> {
|
||||
if (piece.size() == 1) {
|
||||
return {ranks.at(piece)};
|
||||
}
|
||||
|
||||
auto func = [&piece, &ranks](int start, int stop) -> int {
|
||||
std::string key = piece.substr(start, stop - start);
|
||||
return ranks.at(key);
|
||||
};
|
||||
|
||||
return _byte_pair_merge(piece, ranks, func);
|
||||
}
|
||||
|
||||
class tiktoken {
|
||||
public:
|
||||
tiktoken() = default;
|
||||
tiktoken(
|
||||
ankerl::unordered_dense::map<std::string, int> encoder,
|
||||
ankerl::unordered_dense::map<std::string, int> special_encoder,
|
||||
const std::string &pattern
|
||||
) {
|
||||
regex_ = std::make_unique<re2::RE2>("(" + pattern + ")");
|
||||
|
||||
std::string special_pattern;
|
||||
for (const auto &item : special_encoder) {
|
||||
if (!special_pattern.empty()) {
|
||||
special_pattern += "|";
|
||||
}
|
||||
special_pattern += re2::RE2::QuoteMeta(item.first);
|
||||
}
|
||||
if (special_pattern.empty()) {
|
||||
special_regex_ = nullptr;
|
||||
} else {
|
||||
special_regex_ = std::make_unique<re2::RE2>("(" + special_pattern + ")");
|
||||
}
|
||||
|
||||
encoder_ = std::move(encoder);
|
||||
special_tokens_encoder = std::move(special_encoder);
|
||||
|
||||
for (const auto &[k, v] : encoder_) {
|
||||
decoder_.emplace(v, k);
|
||||
}
|
||||
assert(encoder_.size() != decoder_.size() && "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?");
|
||||
|
||||
for (const auto &[k, v] : special_tokens_encoder) {
|
||||
special_tokens_decoder.emplace(v, k);
|
||||
}
|
||||
}
|
||||
|
||||
auto encode_ordinary(const std::string &text) const -> std::vector<int> {
|
||||
return _encode_ordinary_native(text);
|
||||
}
|
||||
|
||||
auto encode(const std::string &text) const -> std::vector<int> {
|
||||
return _encode_native(text, special_tokens_encoder).first;
|
||||
}
|
||||
|
||||
auto encode_single_piece(const std::string &text) const -> std::vector<int> {
|
||||
auto iter = encoder_.find(text);
|
||||
if (iter != encoder_.end()) {
|
||||
return {iter->second};
|
||||
}
|
||||
return byte_pair_encode(text, encoder_);
|
||||
}
|
||||
|
||||
auto decode(const std::vector<int> &tokens) const -> std::string {
|
||||
return _decode_native(tokens);
|
||||
}
|
||||
|
||||
private:
|
||||
auto split_with_allowed_special_token(
|
||||
re2::StringPiece &input,
|
||||
const ankerl::unordered_dense::map<std::string, int> &allowed_special
|
||||
) const -> std::pair<std::optional<std::string>, re2::StringPiece> {
|
||||
if (special_regex_ == nullptr) return { std::nullopt, input };
|
||||
|
||||
auto start = input.begin();
|
||||
std::string special;
|
||||
while (true) {
|
||||
if (!re2::RE2::FindAndConsume(&input, *special_regex_, &special)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (allowed_special.count(special) == 1) {
|
||||
return { std::move(special), re2::StringPiece(start, input.begin() - start - special.size()) };
|
||||
}
|
||||
}
|
||||
|
||||
return { std::nullopt, input };
|
||||
}
|
||||
|
||||
auto _encode_ordinary_native(const std::string &text) const -> std::vector<int> {
|
||||
std::vector<int> ret;
|
||||
re2::StringPiece input(text);
|
||||
|
||||
std::string piece;
|
||||
while (re2::RE2::FindAndConsume(&input, *regex_, &piece)) {
|
||||
auto iter = encoder_.find(piece);
|
||||
if (iter != encoder_.end()) {
|
||||
ret.push_back(iter->second);
|
||||
continue;
|
||||
}
|
||||
auto tokens = byte_pair_encode(piece, encoder_);
|
||||
ret.insert(ret.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto _encode_native(
|
||||
const std::string &text,
|
||||
const ankerl::unordered_dense::map<std::string, int> &allowed_special
|
||||
) const -> std::pair<std::vector<int>, int> {
|
||||
std::vector<int> ret;
|
||||
int last_piece_token_len = 0;
|
||||
re2::StringPiece input(text);
|
||||
|
||||
while (true) {
|
||||
auto [special, sub_input] = split_with_allowed_special_token(input, allowed_special);
|
||||
std::string piece;
|
||||
while (re2::RE2::FindAndConsume(&sub_input, *regex_, &piece)) {
|
||||
auto iter = encoder_.find(piece);
|
||||
if (iter != encoder_.end()) {
|
||||
last_piece_token_len = 1;
|
||||
ret.push_back(iter->second);
|
||||
continue;
|
||||
}
|
||||
auto tokens = byte_pair_encode(piece, encoder_);
|
||||
last_piece_token_len = tokens.size();
|
||||
ret.insert(ret.end(), tokens.begin(), tokens.end());
|
||||
}
|
||||
|
||||
if (special) {
|
||||
int token = special_tokens_encoder.at(*special);
|
||||
ret.push_back(token);
|
||||
last_piece_token_len = 0;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return { ret, last_piece_token_len };
|
||||
}
|
||||
|
||||
auto _decode_native(const std::vector<int> &tokens) const -> std::string {
|
||||
std::string ret;
|
||||
ret.reserve(tokens.size() * 2);
|
||||
for (auto token : tokens) {
|
||||
std::string token_bytes;
|
||||
auto iter = decoder_.find(token);
|
||||
if (iter != decoder_.end()) {
|
||||
token_bytes = iter->second;
|
||||
} else {
|
||||
iter = special_tokens_decoder.find(token);
|
||||
if (iter != special_tokens_decoder.end()) {
|
||||
token_bytes = iter->second;
|
||||
} else {
|
||||
throw std::runtime_error("unknown token: " + std::to_string(token));
|
||||
}
|
||||
}
|
||||
ret += token_bytes;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ankerl::unordered_dense::map<std::string, int> encoder_;
|
||||
ankerl::unordered_dense::map<std::string, int> special_tokens_encoder;
|
||||
ankerl::unordered_dense::map<int, std::string> decoder_;
|
||||
ankerl::unordered_dense::map<int, std::string> special_tokens_decoder;
|
||||
std::unique_ptr<re2::RE2> regex_;
|
||||
std::unique_ptr<re2::RE2> special_regex_;
|
||||
};
|
||||
|
||||
} // namespace tiktoken
|
||||
Reference in New Issue
Block a user