在数据分析与机器学习任务中,随机抽样是一项基础但至关重要的操作。然而,许多初学者在尝试实现“分层抽样”时,容易陷入一个常见的误区:直接寻找一个名为“分层抽样”的特定函数。实际上,Python标准库中的random.sample()函数仅支持简单随机抽样,并不具备分层功能。这揭示了数据处理中的一个关键认知:我们需要清晰区分“业务目标”与“技术实现路径”。

理解random.sample()的局限性
首先需要明确:random.sample()是一个纯粹的简单随机抽样工具。它作用于一个扁平的数据序列(如列表或元组),并从中无差别地抽取指定数量的元素。该函数无法识别数据内部的结构化分组信息(如类别、部门、地区等)。若将混合了不同类别的数据直接传入,它只会进行整体随机抓取,这很可能导致抽样结果中某些类别代表性过强或过弱,甚至完全缺失。这并非函数设计的缺陷,而是方法选择不当的结果。
分层抽样的标准流程应该是:
- 第一步:分组。根据一个关键字段(例如“用户等级”、“产品类型”、“地区代码”)将原始数据集划分为多个互斥的子组。
- 第二步:组内抽样。在每个独立的子组内,分别执行随机抽样。可以设定固定的样本数量,也可以按照比例抽取。
- 第三步:合并结果。将所有子组抽取出的样本合并,形成最终的分层样本集合。
基于pandas的groupby与sample实现分层抽样
对于结构化的表格数据,使用pandas库是最为高效和直观的方法。假设你有一个DataFramedf,其中“category”列是用于分层的依据字段。
import pandas as pd
# 方法一:按 category 列分层,每组固定抽取5行数据
stratified_sample = df.groupby("category", group_keys=False).apply(lambda x: x.sample(n=5))
# 方法二:按比例抽取,例如每组抽取20%的数据
stratified_sample = df.groupby("category", group_keys=False).apply(lambda x: x.sample(frac=0.2))
在实施过程中,有几个技术细节需要特别注意:
- 参数
group_keys=False:此参数至关重要。若保持默认值True,结果会包含分组键构成的多级索引,可能给后续的数据处理带来不必要的麻烦。 - 比例与数量的权衡:参数
frac(抽样比例)和n(固定样本数)不可同时使用。使用frac时需注意,pandas会将其计算结果向下取整。如果某个分组的样本量极小(例如仅3行),设置frac=0.1意味着抽取0.3行,取整后为0,程序将抛出ValueError: Cannot take a larger sample than population when 'replace=False'错误。 - 抽样方式选择:默认情况下,
sample执行的是无放回抽样。如果某个分组内的行数少于你指定的抽样数量,同样会触发上述错误。虽然可以通过设置replace=True启用有放回抽样,但这通常违背了分层抽样的设计初衷。更稳妥的做法是在抽样前,使用df["category"].value_counts()预先检查各分组的基数。
利用scikit-learn的train_test_split进行分层数据划分
如果你的核心目标是为机器学习模型准备训练集和测试集,并希望保持数据集的类别分布一致,那么scikit-learn库中的train_test_split函数提供了更便捷的解决方案。
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.3,
stratify=y, # 核心参数:依据标签y进行分层划分
random_state=42
)
该方法在底层自动完成了分层逻辑,能确保训练集和测试集中各个类别的比例与原始数据集高度一致。然而,它主要服务于建模流程,输出的是特征矩阵和标签数组,原始的DataFrame索引和结构信息不会保留。因此,它不适用于需要自定义每层样本数量或需要保留完整行信息的通用分层抽样场景。
使用时请注意:
stratify参数:需要传入一个一维数组(通常是标签列y),其长度必须与特征数据X的行数严格相等。- 处理稀有类别:如果数据集中存在样本量极少的类别,而划分比例又可能导致该类别在某个子集中数量为0,函数会报错
ValueError: The least populated class in y has only 1 member。此时需要考虑合并稀有类别或采用过采样/欠采样等策略进行调整。
手动实现时需警惕的索引问题
使用groupby().apply(sample)得到的结果,默认会沿用原始数据的索引。这在需要回溯原始记录时是个优点,但若将结果作为独立的新数据集进行后续分析,可能会遇到索引重复或排序混乱的问题。
- 重置索引:一个简单的解决方案是在抽样后立即执行
stratified_sample.reset_index(drop=True)来重置索引。 - 保留业务索引:但如果原始索引本身具有重要的业务含义(如订单ID、时间戳),则不应直接丢弃,而应将其作为一列显式地保留在新的DataFrame中。
- 处理空分组:在某些pandas版本中,如果某个分组为空,或者
sample操作意外返回了一个空的DataFrame,整个apply过程可能会失败。建议在分组前检查并过滤掉空组,或者在apply函数内部增加异常处理逻辑。
总而言之,分层抽样的代码实现并不复杂,真正的挑战在于前期的业务分析与设计:选择哪个字段作为分层依据?每层是抽取固定数量还是按比例抽取?如何应对数据量不足的分组?将这些业务逻辑梳理清晰后,技术实现就变成了在pandas.groupby与sklearn.train_test_split之间做出合适选择的问题。其核心原则始终是:先深入理解你的数据结构和业务目标,再选择恰当的工具进行操作。
