Duan Yao's Blog

思无涯,行无疆,言无忌,行无羁

Spring Boot 主要开发流程

  1. 创建实体类以及数据库中的表(sql/init.sql)
  2. 新增或修改数据库映射类,类似 Django 的模型类
    • main/此处省略/mapper/*Mapper.java 中写好接口
    • resources/mapper/*Mapper.xml 中根据各接口编写查询语句
    • 注意接口和查询语句必须一一对应(不一定按顺序)
  3. 新增或修改服务类
    • main/此处省略/service/*Service.java 中写好接口
    • main/此处省略/service/impl/*ServiceImpl.java 中实现类
  4. 新增或修改控制器,类似 Django 的视图函数
  5. 完成前端模板(采用 thymeleaf 模板引擎)

1018 tplink 算法二面 面试题

  1. 自我介绍
  2. CUB200项目中遇到了什么问题,欠拟合(预训练不能解决欠拟合,但是大模型可以)与过拟合怎么解决的
  3. 为什么你觉得模型攻击能解决过拟合问题,一般用于测试,但是加入到训练集中确实可以减弱过拟合。
  4. PGD的计算方法
  5. 解决过拟合的方法
  6. 个性化搜索的细节,如何实现的,如何将dm应用到这个任务上面,如何构造训练的数据对
  7. fixmatch 原理
  8. transformer 计算方法
  9. BN 层的计算方法,参数量,训练与推理时的区别。

重写答案

CUB200项目中遇到的问题与解决方法

主要问题是欠拟合。

解决欠拟合的方法: 1. 增加模型复杂度,加深加宽换架构 2. 增加训练时间 3. 减少正则化,降低正则化系数 4. 特征工程:提升数据质量和丰富度每增加有意义的特征,通过数据转换提取更有代表性的特征。使用一些特征提取方法,比如PCA等 5. (不确定)使用更好的优化算法,如 Adan RMSprop 等等 6. (不确定)减少噪声或者平滑数据

解决过拟合的方法: 1. 增加数据量,一般采用数据增广的方式 2. 使用正则化: 1. L1/L2正则化 2. Dropout随机丢弃神经元 3. 权重衰减(Weight Decay)在优化器中对权重加惩罚项减少模型复杂度的一种正则化方法 3. 简化模型架构 4. 早停法,验证集的误差开始上升时,使用早停 5. 使用K-Fold评估模型的泛化能力 6. 数据归一化和标准化,比如BN层 7. 添加噪声,或采用模型攻击的方式 8. 减小学习率以解决发散问题。

两类问题,基本上都可以通过更好地特征工程、合理的模型架构、数据处理解决。

注意,模型攻击一般不是解决过拟合的方法,但是如果将攻击的数据加入到训练集中,是可以的。也就是对抗训练,将对抗样本加入到训练集中,使得模型在标准样本上表现良好,也能在对抗样本上表现稳健。

PGD 的计算方法

PGD 攻击是 FGSM(Fast Gradient Sign Method)的扩展,FGSM 是一种一次性基于梯度的对抗攻击,而 PGD 则通过多次迭代更新对抗样本来增强攻击效果。

  1. 确认损失函数
  2. 初始化对抗样本
  3. 迭代更新样本

\[ x^{(k+1)}=\mathrm{Proj}_{\mathcal{B}_\epsilon(x_0)}\left(x^{(k)}+\alpha\cdot\mathrm{sign}\left(\nabla_xL(\theta,x^{(k)},y)\right)\right) \]

其中:

  • \(x^{(k)}\) 是第\(k\) 次迭代的对抗样本

  • \(\alpha\) 是每次迭代的步长,决定更新的幅度

  • \(\nabla_xL\) 是损失函数 对 输入\(x\) 的梯度

  • \(\mathrm{sign}\) 表示符号函数,取梯度的符号

  • \(\mathrm{Proj_{\mathcal{B}_\epsilon }}_{(x_0)}(x^{(k))}\) 表示将更新后的样本重新投影到以 \(x_0\) 为中心,半径为 \(\epsilon\)\(L_\infty\) 范围内,即: \[ \mathrm{Proj}_{\mathcal{B}_\epsilon(x_0)}(x^{(k)})=\min\left(\max(x^{(k)},x_0-\epsilon),x_0+\epsilon\right) \] 这个投影操作限制了对抗样本的扰动在合法的范围内,不会超过 \(\epsilon\) 限制,即最大扰动限制

  • 一般来讲,\(K,\alpha,\epsilon\) 是关键参数。

解决过拟合的方法

这个在面经二中说过啦。(对的我是后续才整理的

DiffuPPS 的细节

这里暂时不能公开。

主要问了,如何将 DM 应用到这个任务上,如何构造训练数据对。

FixMatch 原理

详见 半监督学习-分类

Transformer 的计算方法

Transformer模型详解(图解最完整版)-CSDN博客 参考

从每一个计算细节了解 transformer_transformer loss函数-CSDN博客 参考

我忘记了 \(\div \sqrt{d_k} 和 softmax\) 部分。之后有空详细整理一下。以前觉得,这玩意儿会用就行。。

BN 层的计算方法,参数量,在训练与推理的不同

计算方法

假设输入到 BN 层的数据维度是 \(X\in\mathbb{R}^{N\times C\times H\times W}\)

  1. 计算每个通道的均值与方差 \[ \mu_{c}=\frac1{N\times H\times W}\sum_{i=1}^N\sum_{h=1}^H\sum_{w=1}^WX_{i,c,h,w}\\\sigma_{c}^{2}=\frac1{N\times H\times W}\sum_{i=1}^N\sum_{h=1}^H\sum_{w=1}^W(X_{i,c,h,w}-\mu_c)^2 \]

  2. 标准化 \[ \hat{X}_{i,c,h,w}=\frac{X_{i,c,h,w}-\mu_c}{\sqrt{\sigma_c^2+\epsilon}} \] 其中,\(\epsilon\) 是一个小正数,防止处以0

  3. 缩放与平移 \[ Y_{i,c,h,w}=\gamma_c\hat{X}_{i,c,h,w}+\beta_c \] 其中,\(\gamma _c\)\(\beta _c\) 是每个通道独立的,每个通道各自有两个的,可以学习的参数。

BN 层的参数量

BN层中有两个主要的可学习参数:缩放系数 \(\gamma\)偏移量 \(\beta\)。它们的数量和通道数 \(C\)一致,因为每个通道有一个 \(\gamma_c\) 和一个 \(\beta_c\)

因此,总参数量是: \(2*C\),其中 \(C\) 表示输入通道数。

BN层在训练与推理阶段的不同

这个我记错了,记成了 Dropout 在两个阶段的不同。

1. Batch Normalization (BN)

训练阶段
  • 均值和方差计算:在训练阶段,BN层使用当前批次的均值和方差来标准化输入。也就是说,BN层根据当前小批量数据计算出每个通道的均值 μ和方差 σ2^2σ2,并使用这些统计量进行标准化。

  • 移动平均:为了在推理阶段使用,BN层在训练过程中对每个批次的均值和方差进行指数加权移动平均,逐步估计全局均值和方差。这些全局统计量将在推理阶段使用。

  • 标准化与学习:在训练中,输入会被标准化为零均值和单位方差,然后再应用可学习的缩放系数 γ和偏移量 β,即: \[ \hat{X}=\frac{X-\mu}{\sqrt{\sigma^2+\epsilon}},\quad Y=\gamma\hat{X}+\beta \]

推理阶段
  • 固定的均值和方差:在推理阶段,BN层使用训练时估计得到的全局均值和方差(而不是当前输入批次的统计量),来进行标准化。这确保了模型在推理时具有确定性和一致性。
  • 缩放和平移:推理阶段依然使用训练时学到的缩放系数 \(\gamma\) 和偏移量 \(\beta\),对标准化后的数据进行线性变换,但不更新这些参数。

总结:BN在训练时动态计算均值和方差并更新移动平均值,而推理时使用固定的全局统计量进行标准化。


2. Dropout

训练阶段
  • 随机丢弃神经元:在训练阶段,Dropout通过以一定的概率(如 0.5)随机丢弃神经元,即将它们的激活值置为0。被丢弃的神经元不参与当前批次的前向传播和反向传播。
  • 缩放激活值:为了保持输出的一致性,在训练中,未被丢弃的神经元的激活值会被按丢弃率的倒数进行缩放。例如,如果丢弃率是 0.5,未被丢弃的神经元输出会被除以 0.5(即乘以2),以弥补丢弃神经元带来的激活减少。
推理阶段
  • 所有神经元都工作:在推理阶段,Dropout 被关闭,所有的神经元都参与前向传播,不再有任何神经元被丢弃。
  • 不进行缩放:推理时,神经元的输出值不再被缩放,因为在训练时已经通过丢弃神经元和缩放保持了输出值的期望一致性。

总结:在训练时,Dropout随机丢弃一部分神经元来防止过拟合;在推理时,Dropout被关闭,所有神经元正常工作。


3. BN层与Dropout的训练和推理阶段对比总结

特性 训练阶段 推理阶段
BN层 1. 使用小批次的均值和方差进行标准化
2. 更新全局均值和方差
1. 使用全局均值和方差进行标准化
2. 固定缩放参数 \(\gamma\)\(\beta\)
Dropout 1. 以一定概率随机丢弃神经元
2. 对未丢弃神经元进行缩放
1. 不丢弃任何神经元
2. 不进行缩放处理

这两者的不同使得 BN 主要作用于加速收敛、提高稳定性,而 Dropout 则主要用于防止过拟合。

注意,BN 层与 Dropout 共同使用可能导致不好的结果。参考:BN和Dropout在训练和测试时的差别 - 知乎

线程状态与优先级

一图胜千言。

线程状态-菜鸟教程

其中,阻塞状态分为三种:

  • 等待阻塞:运行状态中的线程执行 wait() 方法,使线程进入到等待阻塞状态。
  • 同步阻塞:线程在获取 synchronized 同步锁失败(因为同步锁被其他线程占用)。
  • 其他阻塞:通过调用线程的 sleep()join() 发出了 I/O 请求时,线程就会进入到阻塞状态。当sleep() 状态超时,join() 等待线程终止或超时,或者 I/O 处理完毕,线程重新转入就绪状态。

线程优先级是整数:1 (Thread.MIN_PRIORITY ) - 10 (Thread.MAX_PRIORITY )。一般默认分配 NORM_PRIORITY(5)。优先级并不完全代表调度顺序。总不能把线程直接饿死吧。

创建线程

三种方法:

  • 通过实现 Runnable 接口;
  • 通过继承 Thread 类本身;(本质上也是实现了 Runable 接口的一个实例)
  • 通过 Callable 和 Future 创建线程。

Runnable

本文中主要说明文档注释。

注释类型

Java 本身包含三种注释

1
2
3
4
5
6
7
8
9
/**
文档注释,一般用在开头
*/

// 单行注释

/*
多行注释,似乎不太常用
*/

注意: 文档注释的结尾不是 */ ,而是 **/

Javadoc

Javadoc is a tool for generating API documentation in HTML format from doc comments in source code.

菜鸟的一个文档注释示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import java.io.*;

/**
* 这个类演示了文档注释
* @author Ayan Amhed
* @version 1.2
*/
public class SquareNum {
/**
* This method returns the square of num.
* This is a multiline description. You can use
* as many lines as you like.
* @param num The value to be squared.
* @return num squared.
*/
public double square(double num) {
return num * num;
}
/**
* This method inputs a number from the user.
* @return The value input as a double.
* @exception IOException On input error.
* @see IOException
*/
public double getNumber() throws IOException {
InputStreamReader isr = new InputStreamReader(System.in);
BufferedReader inData = new BufferedReader(isr);
String str;
str = inData.readLine();
return (new Double(str)).doubleValue();
}
/**
* This method demonstrates square().
* @param args Unused.
* @return Nothing.
* @exception IOException On input error.
* @see IOException
*/
public static void main(String args[]) throws IOException
{
SquareNum ob = new SquareNum();
double val;
System.out.println("Enter value to be squared: ");
val = ob.getNumber();
val = ob.square(val);
System.out.println("Squared value is " + val);
}
}

而后我们使用 javadoc 工具进行处理 javadoc <name>.java ,之后会生成包含程序注释的HTML文件。

一个详细的教程 小白都看得懂的Javadoc使用教程 - 说人话 - 博客园

如何在IDE中使用模板? 在 IDEA 上优雅地使用 Javadoc_javadoc 异常描述-CSDN博客

标签小全 Java 文档注释 | 菜鸟教程

序列化

序列化:一般等价于持久化,将对象转换为字节流存储在磁盘上、网络上、内存里。

通过实现 标记接口 java.io.Serializable

如果类中的一个变量不想被序列化,可以将其标记为短暂变量,使用 transient。静态变量是属于类的,而不是属于对象的,所以不会被序列化。

一般将序列化输出的文件名命名为 <name>.ser

一个简单的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// 首先定义一个可序列化的类
// Employee.java 文件
public class Employee implements java.io.Serializable
{
public String name;
public String address;
public transient int SSN; // 这个值是短暂的,不会被序列化存储,不会被发送到输出流,所以在反序列化时直接被初始化为0
public int number;
public void mailCheck()
{
System.out.println("Mailing a check to " + name
+ " " + address);
}
}

// 序列化这个类
// SerializeDemo.java 文件
import java.io.*;

public class SerializeDemo
{
public static void main(String [] args)
{
Employee e = new Employee();
e.name = "Reyan Ali";
e.address = "Phokka Kuan, Ambehta Peer";
e.SSN = 11122333;
e.number = 101;
try
{
FileOutputStream fileOut =
new FileOutputStream("/tmp/employee.ser"); // 指定输出文件流
ObjectOutputStream out = new ObjectOutputStream(fileOut); // 高层次的数据输出流
out.writeObject(e); // 序列化对象并 发送到输出流
out.close();
fileOut.close();
System.out.printf("Serialized data is saved in /tmp/employee.ser");
}catch(IOException i)
{
i.printStackTrace();
}
}
}


// 实现反序列化
// DeserializeDemo.java 文件

import java.io.*;

public class DeserializeDemo
{
public static void main(String [] args)
{
Employee e = null;
try
{
FileInputStream fileIn = new FileInputStream("/tmp/employee.ser"); // 输入文件流
ObjectInputStream in = new ObjectInputStream(fileIn); // 高层次的数据输入流
e = (Employee) in.readObject(); // 反序列化对象并发送到输入流
in.close();
fileIn.close();
}catch(IOException i)
{
i.printStackTrace();
return;
}catch(ClassNotFoundException c) // 注意这里,我们后面解释
{
System.out.println("Employee class not found");
c.printStackTrace();
return;
}
System.out.println("Deserialized Employee...");
System.out.println("Name: " + e.name);
System.out.println("Address: " + e.address);
System.out.println("SSN: " + e.SSN);
System.out.println("Number: " + e.number);
}
}

注意到其中我们用到了 :

1
2
3
4
5
public final void writeObject(Object x) throws IOException
// 用于序列化一个对象,并将其发送到输出流
public final Object readObject() throws IOException,
ClassNotFoundException
// 用于反序列化一个对象并发送到输入流

另外 ClassNotFoundException 异常:若JVM反序列化过程中,找不到该类的字节码,会抛出这个异常。

FixMatch 是一种半监督学习 (Semi-Supervised Learning, SSL) 方法,它通过极少量标注数据和大量未标注数据来提升模型的性能。本文将详细介绍 FixMatch 的工作原理、核心思想以及其在机器学习任务中的应用。

半监督学习背景

在监督学习中,模型通常依赖大量的标注数据来学习任务。然而,获取标注数据往往成本高昂。半监督学习通过结合少量的标注数据和大量的未标注数据,极大地减少了标注成本,同时在性能上接近全监督学习。

FixMatch 的核心思想(待补充详细细节)

FixMatch 的核心思想是利用 一致性正则化 (Consistency Regularization)伪标签 (Pseudo-Labeling) 技术来处理未标注数据。

  • 一致性正则化:假设模型对同一输入的不同扰动(增强)应保持一致的预测结果。
  • 伪标签:为未标注数据生成预测标签(即伪标签),并将高置信度的伪标签作为训练目标。

FixMatch 的整体架构

FixMatch 的训练过程包括以下步骤:

  1. 强弱数据增强:对输入数据进行弱增强 \(\mathcal{A}_{weak}\) 和强增强 \(\mathcal{A}_{strong}\),并分别得到增强后的数据。

  2. 伪标签生成:对弱增强数据 \(\mathcal{A}_{weak}(x)\) 通过模型预测生成伪标签 \(\hat{y}\)。若预测的置信度超过阈值 \(\tau\),则将该伪标签应用于强增强后的样本 \(\mathcal{A}_{strong}(x)\)

  3. 损失函数计算

    • 对于标注数据,采用常规的交叉熵损失 \(\mathcal{L}_{sup}\)
  • 对于未标注数据,若伪标签置信度足够高,则使用伪标签计算无监督损失 \(\mathcal{L}_{unsup}\)

    • 最终损失函数是标注数据损失和未标注数据损失的加权和:

    \[ \mathcal{L}=\mathcal{L}_{sup}+\lambda\cdot\mathcal{L}_{unsup} \]

    其中, \(\lambda\) 是调节两者权重的超参数。

FixMatch 的算法流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
FixMatch(模型 f, 标注数据集 D_L, 未标注数据集 D_U, 阈值 τ, 超参数 λ):
for mini-batch (x_L, y_L) ∈ D_L and x_U ∈ D_U do:
# 对标注数据进行训练
L_sup = CrossEntropy(f(x_L), y_L)

# 对未标注数据进行弱增强和强增强
x_U^weak = A_weak(x_U)
x_U^strong = A_strong(x_U)

# 生成伪标签并过滤置信度低的样本
with torch.no_grad():
p = f(x_U^weak)
p_hat = argmax(p)
mask = (max(p) >= τ)

# 计算无监督损失
L_unsup = CrossEntropy(f(x_U^strong), p_hat) * mask

# 计算总损失
L_total = L_sup + λ * L_unsup

# 反向传播和优化
optimizer.zero_grad()
L_total.backward()
optimizer.step()

FixMatch 的优点

  1. 高效利用未标注数据:通过强弱增强和伪标签机制,FixMatch 能有效地利用未标注数据,提升模型性能。
  2. 简单高效:与其他复杂的半监督学习方法相比,FixMatch 的实现较为简单,但性能却表现优异。
  3. 无需大量超参数调节:FixMatch 主要依赖一个阈值 τ来过滤伪标签,其他超参数较少,易于部署。

实验与结果

FixMatch 在多个标准的半监督学习数据集上展现了强大的性能,例如 CIFAR-10、CIFAR-100 和 SVHN 等。具体表现如下:

  • CIFAR-10:在仅使用 40 张标注样本的情况下,FixMatch 能达到超过 94% 的准确率。
  • SVHN:在仅使用 250 张标注样本的情况下,FixMatch 达到了 96.53% 的准确率。

实验结果表格

数据集 标注样本数量 准确率
CIFAR-10 40 94.10%
SVHN 250 96.53%
CIFAR-100 400 72.39%

代码实现

FixMatch 的 PyTorch 实现可以通过以下代码片段简单展示:

1
2
3
4
5
6
7
8
9
10
import torch
import torch.nn.functional as F

def fixmatch_loss(logits_weak, logits_strong, labels, threshold):
"""计算FixMatch损失"""
pseudo_labels = torch.argmax(logits_weak, dim=-1)
mask = (torch.max(F.softmax(logits_weak, dim=-1), dim=-1)[0] >= threshold)
loss = F.cross_entropy(logits_strong, pseudo_labels, reduction='none')
loss = torch.mean(loss * mask)
return loss

总结

FixMatch 是一种简单高效的半监督学习算法,它通过一致性正则化和伪标签的结合,显著提升了模型在仅有少量标注数据时的表现。由于其实现简单、性能强大,FixMatch 已经成为半监督学习领域的重要基准方法。

在实际应用中,FixMatch 适用于各类需要降低标注成本、利用大量未标注数据的任务,如图像分类、语义分割等。

半监督学习主要用于分类与检测两个任务。

分类

\(\pi\) model

参考文献

[1] https://mp.weixin.qq.com/s/kvqic9Qpvnz0BDvrIViZlg [2] https://mp.weixin.qq.com/s/z5-saMrt7m4goRp4DyChDQ

隐式扩散模型 (DENOISING DIFFUSION IMPLICIT MODELS)

DDIM图示

在 DDPM 中,我们似乎用到了后验概率 ,这个后验概率怎么来的?是通过假定马尔可夫过程、得到 、进一步采用贝叶斯公式求出来的。

那我们能不能跳过假定马尔可夫过程,直接推导出这个后验概率公式形式?这有什么好处?我们在推理时可以不用一步一步生成采样,可以实现跨步采样,DDIM 已经证明了。

DDIM 证明,后验概率如果满足以下条件: 之后 采用待定系数法求解后验概率密度公式。我们假设后验概率符合高斯分布,这一假设是合理的:

阅读全文 »