机器学习中的数据集切分(代码片段)

LaurenceGeng LaurenceGeng     2023-03-02     631

关键词:

文章目录

1. 常规切分思路

应用有监督的机器学习算法时,需要将数据集切分成训练数据集和测试数据集两部分。在《Handson ML》一书中,使用了numpy.random.permutation( https://numpy.org/doc/stable/reference/random/generated/numpy.random.permutation.html#numpy-random-permutation)对数据集进行了切分。其思路是:利用permutation生成shuffle后记录索引(打乱顺序的索引集合),然后按比例提取训练数据集的索引子集和测试数据集的测试数据集,最后通过DataFrame的iloc方法按索引完成了数据集的切割,操作示例代码如下:

import numpy as np
# to make this notebook's output identical at every run
np.random.seed(42)
# For illustration only. Sklearn has train_test_split()
def split_train_test(data, test_ratio):
    # permutation就是生成随机数组,参数是n就生成从0到n-1的n个元素组成的数组,但是顺序是打乱是,不是0,1,2...这样的
    shuffled_indices = np.random.permutation(len(data))
    test_set_size = int(len(data) * test_ratio)
    test_indices = shuffled_indices[:test_set_size]
    train_indices = shuffled_indices[test_set_size:]
    # 根据切分好的索引,分别提取对应记录,形成训练集和测试集
    return data.iloc[train_indices], data.iloc[test_indices]

该切分方法有一个问题:当多次执行时,每次切分出的数据集都是不固定的,这不是算法期望的行为,有两种方法可以修正这一问题:

  • 将切分好的索引集合保存下来,这样每次提取的就是固定的数据集了
  • 在调用permutation之前,显式地设置随机数发生器的seed,这样就能保证每次生成的随机索引数组都是一致的

对于第二种方式,可以使用如下代码进行验证:

import numpy as np
# 设置seed,可以保证每次生成的随机数组是固定的!这对于切分数据集来说是必要的
np.random.seed(42)
shuffled_indices = np.random.permutation(5)
print(shuffled_indices)

但是,当数据集发生变更后,上述基于固定索引的切分方法就不再可靠了(原始记录的条数和位置都有可能发生变化),较为稳妥的切分方法应该是基于数据自身特征进行切分,这样无论这条记录出现在什么位置上,都可以被精准划拨到指定的数据集上。

最简单的做法是基于记录的ID进行切分,我们可以对ID使用crc32循环冗余编码,然后将求出的hash值和按期望切分比例算出的hash值进行比较,以决定数据是进入训练集还是测试集,以下是参考代码:

from zlib import crc32
def test_set_check(identifier, test_ratio):
    return crc32(np.int64(identifier)) & 0xffffffff < test_ratio * 2**32
def split_train_test_by_id(data, test_ratio, id_column):
    ids = data[id_column]
    # 针对每一个id计算crc32并判定是否落在测试数据集区间中,是则放入in_test_set
    in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio))
    return data.loc[~in_test_set], data.loc[in_test_set]

但是,并不是所有的记录都有ID,此时就需要根据一些相对稳定的属性列来构建ID,例如,在《Handson ML》一书使用的房地产数据中,根据经纬度可以构建中相对稳定的ID值:ID = 经度*1000 + 维度

2. Scikit-Learn内置的切分函数

鉴于切分数据集是如些基础和必要,Scikit-Learn也就直接内置了现成的函数,其功能和做法与最开始时介绍的手动切分方法是一致的:

from sklearn.model_selection import train_test_split
train_set, test_set = train_test_split(housing, test_size=0.2, random_state=42)

参数的参数都可以直接猜出其作用,test_size=0.2表示期望测试数据集占20%,random_state=42就是此前说的设置固定的seed:np.random.seed(42),以确保每次执行时切分的记录是一致的。

机器学习新手项目之n-gram分词(代码片段)

概述对机器学习感兴趣的小伙伴,可以借助python,实现一个N-gram分词中的Unigram和Bigram分词器,来进行入门,github地址此项目并将前向最大切词FMM和后向最大切词的结果作为Baseline,对比分析N-gram分词器在词语切分正确率、词义消... 查看详情

机器学习中的数据预处理方法与步骤(代码片段)

数据预处理是准备原始数据并使其适用于机器学习模型的过程。这是创建机器学习模型的第一步,也是至关重要的一步。在创建机器学习项目时,我们并不总是遇到干净且格式化的数据。并且在对数据进行任何操作时ÿ... 查看详情

机器学习irisdataset(鸢尾属植物数据集)(代码片段)

注:数据是机器学习模型的原材料,当下机器学习的热潮离不开大数据的支撑。在机器学习领域,有大量的公开数据集可以使用,从几百个样本到几十万个样本的数据集都有。有些数据集被用来教学,有些被当做机器学习模型性... 查看详情

基于sklearn和keras的数据切分与交叉验证(代码片段)

在训练深度学习模型的时候,通常将数据集切分为训练集和验证集.Keras提供了两种评估模型性能的方法:使用自动切分的验证集使用手动切分的验证集 一.自动切分在Keras中,可以从数据集中切分出一部分作为验证集,并... 查看详情

机器学习sklearn数据集(代码片段)

目录1数据集1.1可用数据集1.1.1Scikit-learn工具介绍1.1.2安装1.1.3Scikit-learn包含的内容1.2sklearn数据集1.2.1scikit-learn数据集API介绍1.2.2sklearn小数据集1.2.3sklearn大数据集1.2.4sklearn数据集的使用1.3数据集的划分2pip/pip3下载慢解决1数据集知道... 查看详情

机器学习项目流程(代码片段)

机器学习项目流程在这我们会从头开始做一个机器学习项目,向大家展示一个机器学习项目的一个基本流程与方法。一个机器学习主要分为以下几个步骤:从整体上了解项目获取数据发现并可视化数据,以深入了解数据为机器学... 查看详情

机器学习实战基础(二十七):sklearn中的降维算法pca和svdpca对手写数字数据集的降维(代码片段)

PCA对手写数字数据集的降维1.导入需要的模块和库fromsklearn.decompositionimportPCAfromsklearn.ensembleimportRandomForestClassifierasRFCfromsklearn.model_selectionimportcross_val_scoreimportmatplotlib.pyplotaspltimportpandas 查看详情

机器学习sklearn的快速使用--周振洋(代码片段)

ML神器:sklearn的快速使用传统的机器学习任务从开始到建模的一般流程是:获取数据->数据预处理->训练建模->模型评估->预测,分类。本文我们将依据传统机器学习的流程,看看在每一步流程中都有哪些常用的函数以及... 查看详情

机器学习算法的sklearn实现(代码片段)

...据集  sklearn中包含了大量的优质的数据集,在你学习机器学习的过程中,你可以通过使用这些数据集实现出不同的模型,从而提高你的动手实践能力,同时这个过程也可以加深你对理论知识的理解和把握。(这一步我也亟需... 查看详情

机器学习基础(代码片段)

...arn的数据集数据集划分数据集接口介绍数据集划分前提:机器学习就是从数据中自动分析获得规律,并利用规律对未知数据进行预测。换句话说,我们的模型一定是要经过样本数据对其进行训练,才可以对未知数据进行预测的。... 查看详情

机器学习项目实战识别mnist数据集识别图片数字(代码片段)

目录1机器学习分析步骤2问题定义2数据的收集和预处理3选择机器学习模型4训练机器,确定参数5超参数调试和性能优化5.1训练集、验证集和测试集5.2K折验证5.3模型的优化和泛化5.4验证预测结果1机器学习分析步骤机器学习项... 查看详情

数据集切分(代码片段)

...Ytrain,Ytest=train_test_split(X,Y,test_size=0.3,random_state=420)#切分前数据标签的分布情况train_data.SeriousDlqin2yrs.value_counts()#切分后,训练集数据标签的分布情况pd.Series(Ytrain).value_counts()#如果切分后的数据集分布不均衡,则需要对Xtrain,Ytrain和Xtest... 查看详情

从iris数据集开始---机器学习入门(代码片段)

...特征,数据和目标结果之间的关系是什么?而且这可能是机器学习过程中最重要的部分。在开始使用机器学习实际应用时,有必要先回答下面几个问题:解决的问题是什么?现在收集的数据能够解决目前的问题吗?该 查看详情

机器学习100天:002数据预处理之导入数据集(代码片段)

机器学习100天,今天讲的是:数据预处理之导入数据集。首先,我们打开spyder。新建一个load_data.py脚本。第一步,导入标准库机器学习常用的标准库有3个:第一个:numpy,用于数据处理。第二个:matplotlib.pyplot,用于画图。第三... 查看详情

机器学习100天:002数据预处理之导入数据集(代码片段)

机器学习100天,今天讲的是:数据预处理之导入数据集。首先,我们打开spyder。新建一个load_data.py脚本。第一步,导入标准库机器学习常用的标准库有3个:第一个:numpy,用于数据处理。第二个:matplotlib.pyplot,用于画图。第三... 查看详情

机器学习特征工程(代码片段)

目录数据集可用数据集sklearn数据集特征提取字典文本特征预处理无量纲化归一化标准化特征降维特征选择主成分分析(PCA降维)数据集下面列举了一些示例来说明哪些内容能算作数据集:包含某些数据的表格或CSV文... 查看详情

fashionmnist数据集简要分析---深度学习&机器学习第五天(代码片段)

图像分类数据集—FashionMNIST数据集①简介:fashionmnist数据集中共有10种类别的服饰,分别为:['t-shirt','toruser','pullover','dress','coat','sandal','shirt',& 查看详情

强烈推荐机器学习之算法篇(代码片段)

机器学习算法机器学习算法数据类型:可用数据集:监督学习和无监督学习:算法分类:scikit-learn数据集获取数据集:获取数据集方式:数据集的划分:本地数据集:分类数据集:回归数据集&#x... 查看详情