逻辑回归中的多分类问题:Softmax回归详解
描述
逻辑回归(Logistic Regression)本身是用于解决二分类问题的模型,但现实中的分类任务往往涉及多个类别(如手写数字识别、物体分类等)。Softmax回归(或称多项逻辑回归)是逻辑回归在多分类问题上的推广,它通过Softmax函数将多个线性输出转换为概率分布,从而实现对多个类别的分类。
解题过程
1. 多分类问题的数学表达
假设有 \(K\) 个类别(\(K \geq 3\)),每个样本的特征向量为 \(\mathbf{x} \in \mathbb{R}^n\),标签 \(y\) 取值为 \(\{1, 2, ..., K\}\)。Softmax回归需要为每个类别 \(k\) 学习一个参数向量 \(\mathbf{w}_k \in \mathbb{R}^n\),并计算样本属于每个类别的分数:
\[z_k = \mathbf{w}_k^\top \mathbf{x} \quad (\text{注意:通常省略偏置项,或已将偏置并入}\mathbf{w}_k) \]
2. Softmax函数:将分数转化为概率
Softmax函数将 \(K\) 个分数 \(z_1, z_2, ..., z_K\) 映射为概率分布:
\[P(y=k \mid \mathbf{x}) = \frac{e^{z_k}}{\sum_{j=1}^{K} e^{z_j}} = \frac{e^{\mathbf{w}_k^\top \mathbf{x}}}{\sum_{j=1}^{K} e^{\mathbf{w}_j^\top \mathbf{x}}} \]
特性:
- 所有概率之和为 1(\(\sum_{k=1}^{K} P(y=k \mid \mathbf{x}) = 1\))。
- 概率值非负,且受分数相对大小影响(分数越高,概率越大)。
3. 损失函数:交叉熵损失
对于真实标签 \(y=i\),我们希望模型预测的概率 \(P(y=i \mid \mathbf{x})\) 尽量接近 1。使用交叉熵损失衡量预测概率与真实分布的差距:
\[L(\mathbf{W}) = -\sum_{k=1}^{K} \mathbb{I}(y=k) \log P(y=k \mid \mathbf{x}) \]
其中 \(\mathbb{I}(y=k)\) 是指示函数(当 \(y=k\) 时为 1,否则为 0)。实际计算时,只需考虑真实类别 \(i\) 对应的项:
\[L(\mathbf{W}) = -\log P(y=i \mid \mathbf{x}) \]
示例:若真实类别为 \(i=2\),模型预测的概率为 \([0.1, 0.7, 0.2]\),则损失为 \(-\log(0.7) \approx 0.357\)。
4. 梯度下降优化
目标是最小化所有训练样本的损失之和。对参数 \(\mathbf{w}_k\) 求梯度(推导过程略):
\[\frac{\partial L}{\partial \mathbf{w}_k} = \left( P(y=k \mid \mathbf{x}) - \mathbb{I}(y=k) \right) \mathbf{x} \]
物理意义:
- 如果样本属于类别 \(k\)(即 \(y=k\)),梯度为 \((P(y=k \mid \mathbf{x}) - 1)\mathbf{x}\),模型会更新 \(\mathbf{w}_k\) 以提高 \(P(y=k \mid \mathbf{x})\)。
- 如果样本不属于类别 \(k\)(即 \(y \neq k\)),梯度为 \(P(y=k \mid \mathbf{x})\mathbf{x}\),模型会降低其他类别的概率。
通过迭代更新参数(\(\mathbf{w}_k \leftarrow \mathbf{w}_k - \eta \frac{\partial L}{\partial \mathbf{w}_k}\)),逐步优化模型。
5. 与二分类逻辑回归的关系
当 \(K=2\) 时,Softmax回归等价于二分类逻辑回归(其中一个参数向量可设为零向量)。但实际中,二分类问题通常直接使用Sigmoid函数,避免冗余参数。
总结
Softmax回归通过扩展逻辑回归的概率输出机制,利用交叉熵损失和梯度下降,实现了多分类问题的有效解决。其核心在于Softmax函数将分数转化为概率,以及梯度更新时对正确类别的强化和对错误类别的抑制。