Rust + Wasm + AI(二):让浏览器开始思考 —— 基于 Candle 的端侧情感引擎
上一回我们聊了 Rust + Wasm + AI 的宏大愿景,这次真的把 BERT 模型塞进了浏览器。想想看:当用户在输入框敲下"这服务太棒了"的瞬间,模型已经完成推理,满屏粒子爆发出青色光晕。全程零服务器请求,数据不出本地,甚至断网都能用。
1. 引言:打破 请求-响应 的旧枷锁
算力的南水北调
传统 AI 部署就像南水北调——把用户数据千里迢迢送到 GPU 集群,再把结果运回来。这个模式有三个硬伤:
- 延迟陷阱:网络抖动 100ms 就能毁掉输入流畅感,复杂推理直奔秒级。
- 隐私裸奔:每一句私密输入都在公网裸奔,数据必须出域。
- 成本黑洞:简单分类任务也在消耗昂贵 GPU 显存,断网即服务死亡。
但算力格局正在发生剧变。M2 Max 的神经网络引擎已达 15.8 TOPS,高端安卓机的 NPU 也轻松突破 5 TOPS。Rust + Wasm 的出现,让我们能把推理任务下放到用户的 CPU/GPU。这不仅是成本节约,更是人机交互体验的质变。
端侧 AI 的杀手锏
就拿本次情感分析引擎来说,在 MacBook Pro M1 的浏览器环境中,可以实现即时响应:
- 零延迟:用户松开键盘的瞬间,情感分数已出现在屏幕。
- 隐私设计:数据不出内存,连本地存储都不沾。
- 离线优先:一次加载,永久可用。
本篇核心
深度拆解如何利用 Rust 生态,让 uer/roberta-base-finetuned-jd-binary-chinese 模型在浏览器里实现毫秒级读心术。
2. 演示效果
先展示一下在浏览器中的运行效果:
3. 技术选型:为什么是 Candle?
Candle 的极致主义
Candle 是 HuggingFace 出品的纯 Rust 框架,专为轻量化推理而生。关键优势在于:
- 真正的按需加载:模型结构代码编译进 Wasm,权重按需 fetch,无冗余运行时。
- 零拷贝架构:通过
Safetensors格式,Rust 可以直接将 Wasm 内存映射为张量,无需在 JS 和 Rust 之间进行昂贵的序列化。 - Wasm 友好:纯 Rust 实现,无 C++ FFI,编译产物干净利落。
Safetensors:Wasm 时代的权重协议
传统 PyTorch 的 .bin 格式,本质是 Pickle——可以执行任意代码,在浏览器里加载等于引狼入室。Safetensors 是新标准,核心优势在于:
- 零拷贝加载:内存映射后直接算,无需反序列化。
- 安全:纯数据,无代码执行风险。
- 自描述:JSON 头信息让浏览器提前知道内存布局。
3. 架构设计:四层流水线
整个引擎分为四层,每层都是性能战场:
┌─────────────┐
│资源层(fetch) │ ← 模型/分词器加载
├─────────────┤
│转换层(Wasm) │ ← 二进制流注入内存
├─────────────┤
│计算层(Candle)│ ← 动态图构建与推理
├─────────────┤
│交互层(Canvas)│ ← 粒子渲染与反馈
└─────────────┘
关键设计决策有三:
- Tokenizer 预处理:将
tokenizer.json提前序列化为静态数组,避免运行时 JSON 解析开销。 - 零拷贝张量映射 (Zero-copy Mapping):利用
Safetensors内存对齐特性,将模型权重直接从ArrayBuffer映射为Candle张量,实现首屏启动零内存拷贝。 - 内存池复用:推理中间结果复用同一块 Wasm 内存,避免 GC 压力。
4. 工程实战:从零构建 Wasm 推理核
Python 端:模型选择与转换
最初选的是 jackietung/bert-base-chinese-finetuned-sentiment 模型,因为它支持 .safetensors 格式且支持中文环境。但试用后发现效果不行——比如输入 "难过",结果推理出 "正向"。后来换上 uer/roberta-base-finetuned-jd-binary-chinese 模型,它对中文情感分析能力更强,且在京东真实评论数据上微调过,电商场景下很准。不过该模型好久没更新,没提供 .safetensors 格式,只能自己动手转换。
转换代码在 candle-senti-pulse/model2safetensor 目录中,使用 uv 管理的 Python 脚本,负责将 uer/roberta-base-finetuned-jd-binary-chinese 模型转换为 Safetensors 格式。实现如下:
MODEL_NAME = "uer/roberta-base-finetuned-jd-binary-chinese"
SA VE_DIR = "./converted_model"
# 强制使用镜像
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
def convert():
if not os.path.exists(SA VE_DIR):
os.makedirs(SA VE_DIR)
try:
# 加载
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# 1. 导出 config
print("? 正在生成 config.json...")
config = model.config.to_dict()
with open(os.path.join(SA VE_DIR, "config.json"), "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
# 2. 导出 tokenizer
print("? 正在生成 tokenizer.json...")
tokenizer.sa ve_pretrained(SA VE_DIR)
# 3. 导出权重
print("? 正在生成 model.safetensors...")
state_dict = model.state_dict()
# 移除可能存在的 _orig_mod 等前缀(如果使用了 torch.compile)
clean_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
sa ve_file(clean_state_dict, os.path.join(SA VE_DIR, "model.safetensors"))
except Exception as e:
print(f"n❌ 下载失败: {e}")
转换完成后,将模型复制到 www/models 目录下即可。
Rust 侧:SentiPulseEngine 设计
通过 Candle 加载模型权重,使用 VarBuilder 构建计算图。核心代码如下:
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
#[derive(Debug)]
pub struct SentiPulseResult {
negative: f32,
positive: f32,
neutral: f32,
}
#[wasm_bindgen]
impl SentiPulseResult {
#[wasm_bindgen(getter)]
pub fn negative(&self) -> f32 { self.negative }
#[wasm_bindgen(getter)]
pub fn positive(&self) -> f32 { self.positive }
#[wasm_bindgen(getter)]
pub fn neutral(&self) -> f32 { self.neutral }
}
#[wasm_bindgen]
pub struct SentiPulseEngine {
model: BertModel,
tokenizer: Tokenizer,
// 分类头
w_out: Tensor,
b_out: Tensor,
// 新增:Pooler 层 (用于处理 CLS 向量)
w_pooler: Option,
b_pooler: Option,
}
#[wasm_bindgen]
impl SentiPulseEngine {
#[wasm_bindgen(constructor)]
pub fn new(
weights: &[u8],
tokenizer_data: &[u8],
config_str: &str,
) -> Result {
console_error_panic_hook::set_once();
let device = &Device::Cpu;
let tokenizer = Tokenizer::from_bytes(tokenizer_data)
.map_err(|e| JsError::new(&e.to_string()))?;
let config: Config = serde_json::from_str(config_str)
.map_err(|e| JsError::new(&e.to_string()))?;
let vb = VarBuilder::from_buffered_safetensors(
weights.to_vec(), DType::F32, device)?;
// 1. 加载 BERT
let model = BertModel::load(vb.pp("bert"), &config)?;
let w_pooler = vb.pp("bert")
.get((config.hidden_size, config.hidden_size), "pooler.dense.weight")
.ok();
let b_pooler = vb.pp("bert")
.get(config.hidden_size, "pooler.dense.bias")
.ok();
// 3. 加载 Classifier (带兼容逻辑)
let num_labels = 2;
let w_out = vb.get((num_labels, config.hidden_size), "classifier.weight")
.or_else(|_| vb.get((num_labels, config.hidden_size), "classifier.out_proj.weight"))
.or_else(|_| vb.get((num_labels, config.hidden_size), "classifier.dense.weight"))
.map_err(|_| JsError::new("权重文件中缺少分类层 (classifier weight)"))?;
let b_out = vb.get(num_labels, "classifier.bias")
.or_else(|_| vb.get(num_labels, "classifier.out_proj.bias"))
.or_else(|_| vb.get(num_labels, "classifier.dense.bias"))
.map_err(|_| JsError::new("权重文件中缺少分类层偏置 (classifier bias)"))?;
Ok(Self { model, tokenizer, w_out, b_out, w_pooler, b_pooler })
}
pub fn predict(&self, text: &str) -> Result {
let device = &Device::Cpu;
let tokens = self.tokenizer.encode(text, true)
.map_err(|e| JsError::new(&e.to_string()))?;
let input_ids = Tensor::new(tokens.get_ids(), device)?.unsqueeze(0)?;
let token_type_ids = Tensor::new(tokens.get_type_ids(), device)?.unsqueeze(0)?;
let enc = self.model.forward(&input_ids, &token_type_ids, None)?;
let mut cls_token = enc.get(0)?.get(0)?.unsqueeze(0)?;
if let (Some(w), Some(b)) = (&self.w_pooler, &self.b_pooler) {
cls_token = cls_token.matmul(&w.t()?)?.broadcast_add(b)?.tanh()?;
}
let logits = cls_token.matmul(&self.w_out.t()?)?.broadcast_add(&self.b_out)?;
let scale_factor = 1.0;
let scaled_logits = (logits * scale_factor as f64)?;
let pr = candle_nn::ops::softmax(&scaled_logits.flatten_all()?, 0)?;
let scores = pr.to_vec1::<f32>()?;
let (neg, pos, neu) = if scores.len() >= 3 {
(scores[0], scores[1], scores[2])
} else {
let mut n = scores[0];
let mut p = scores[1];
let diff = (n - p).abs();
let mut m = if diff < 0.2 { 0.8 } else if diff < 0.4 { 0.3 } else { 0.0 };
let total = n + p + m;
if total > 0.0 {
n = n / total; p = p / total; m = m / total;
(n, p, m)
} else {
(0.33, 0.33, 0.34)
}
};
web_sys::console::log_1(&format!("Raw Text: {}, Raw Scores: {:?}", text, scores).into());
let result = SentiPulseResult { negative: neg, positive: pos, neutral: neu };
web_sys::console::log_1(&format!("Raw Text: {}, result: {:?}", text, result).into());
Ok(result)
}
}
几个关键点:VarBuilder::from_buffered_safetensors 避免了在内存中反复拷贝大文件,直接在内存池中构建权重;Result 是 Rust 与 JS 交互的最佳实践,让 JS 端的 try-catch 能捕获详细错误;unsqueeze(0) 将一维 Token 序列升维为模型需要的 Batch 张量。
JS 侧:模型推理与粒子风暴
粒子风暴系统根据情感分析分数动态渲染不同的粒子颜色和速度:
// =========================================
// PART 1: 粒子风暴系统 (Particle System)
// =========================================
const canvas = document.getElementById("particle-canvas");
const ctx = canvas.getContext("2d");
// 设置画布大小
function resizeCanvas() {
canvas.width = window.innerWidth;
canvas.height = window.innerHeight;
}
window.addEventListener("resize", resizeCanvas);
resizeCanvas();
// 粒子参数全局状态 (受 AI 情绪驱动)
let globalMood = {
neg: 0.1, // 初始平静状态
pos: 0.9,
neu: 0.1,
targetSpeed: 0.5,
currentSpeed: 0.5,
chaos: 0.2, // 混乱度
};
class Particle {
constructor() {
this.reset();
this.y = Math.random() * canvas.height; // 初始随机分布
}
reset() {
this.x = Math.random() * canvas.width;
this.y = canvas.height + Math.random() * 100; // 从底部生成
this.size = Math.random() * 2 + 1;
this.baseSpeedY = Math.random() * 1 + 0.5;
this.vx = (Math.random() - 0.5) * 0.5;
this.vy = -this.baseSpeedY;
this.alpha = Math.random() * 0.5 + 0.2;
}
update() {
globalMood.currentSpeed += (globalMood.targetSpeed - globalMood.currentSpeed) * 0.05;
this.x += this.vx * (1 + globalMood.chaos * 5);
this.y += this.vy * globalMood.currentSpeed;
if (this.y < -10) this.reset();
}
draw() {
const r = Math.floor(globalMood.neg * 255 + globalMood.neu * 168);
const g = Math.floor(globalMood.pos * 242 + globalMood.neu * 85);
const b = Math.floor(globalMood.pos * 255 + globalMood.neu * 247);
const dynamicSize = this.size * (1 + globalMood.neg * 1.5);
ctx.fillStyle = `rgba(${r}, ${g}, ${b}, ${this.alpha + globalMood.neg * 0.3})`;
ctx.beginPath();
ctx.arc(this.x, this.y, dynamicSize, 0, Math.PI * 2);
ctx.fill();
}
}
const particles = Array.from({ length: 150 }, () => new Particle());
function animateParticles() {
ctx.fillStyle = "rgba(10, 11, 16, 0.2)";
ctx.fillRect(0, 0, canvas.width, canvas.height);
particles.forEach((p) => { p.update(); p.draw(); });
requestAnimationFrame(animateParticles);
}
animateParticles();
模型初始化通过 fetch 并行加载模型资源后实例化:
// 中文模型资源
const baseUrl = "./model/uer/roberta-base-finetuned-jd-binary-chinese/";
const [weights, tokenizer, config] = await Promise.all([
fetch(baseUrl + "model.safetensors").then((r) => r.arrayBuffer()),
fetch(baseUrl + "tokenizer.json").then((r) => r.arrayBuffer()),
fetch(baseUrl + "config.json").then((r) => r.text()),
]);
const engine = new SentiPulseEngine(
new Uint8Array(weights),
new Uint8Array(tokenizer),
config,
);
模型推理
输入中文评论,调用 Rust Wasm 模块进行情感分析,返回情感分数:
const t0 = performance.now();
const result = engine.predict(text);
const { negative: neg, positive: pos, neutral: neu } = result;
const t1 = performance.now();
5. 运行情感分析推理
# 1. 模型下载及转换
cd model2safetensor && uv run main.py
# 2. 构建 Wasm 模块
cargo build --target web --release
# 3. 启动本地服务器
miniserve .
访问 https://127.0.0.1:8080/www/index.html,在输入框输入中文评论,就能看到情感分析分数和情绪粒子风暴的变化了。
6. 总结:开启 Web 推理的新纪元
从 调包侠 到 推理架构师,这一步的跨越在于:我们已可以掌控算力的分配权。通过 Rust + Wasm,我们证明了即使是复杂的 Transformer 模型,也能在用户的指尖轻盈跃动。下一步,我们将引入 WebGPU,探索如何在浏览器里运行 3B 参数量级的端侧大模型(LLM)。
