HugginFace 使用数据集(学习笔记) 电脑版发表于:2023/10/20 17:23 ![](https://img.tnblog.net/arcimg/hb/782c293bc3904ab0bb30af5ff454beae.png) >#HugginFace 使用数据集(学习笔记) [TOC] ## 数据集工具介绍 tn2>HuggingFace 提供了统一的数据集处理工具,让不同的数据集通过统一的API处理。 访问HuggingFace官网,点击顶部的Datasets,可以看到它提供的所有数据集。 ![](https://img.tnblog.net/arcimg/hb/c71bd3aa2d4142cf885e31099164936e.png) tn2>左边可以通过不同的任务类型、语言、体积许可来筛选数据集,右侧为具体的数据集列表,其中有glue、super_glue数据集,问答数据集squad,情感分类数据集imdb,纯文本数据集wikitext。 点开其中一个可以查看到其中的子集数据。 ![](https://img.tnblog.net/arcimg/hb/876a95ec38cf41119191316d7a9a5471.png) ## 使用数据集工具 ### 数据集加载和保存 tn2>使用HuggingFace数据集工具加载数据只需要简单的一行代码就可以进行加载了。 以加载名为`lansinuote/ChnSentiCorp`数据集为例。 ```python #第3章/加载数据集 from datasets import load_dataset dataset = load_dataset(path='lansinuote/ChnSentiCorp') dataset ``` ![](https://img.tnblog.net/arcimg/hb/7df91e604753473c9fc8dccf9a7b8690.png) tn>由于HuggingFace把数据集存储在谷歌云盘上,从国内加载就需要梯子,所以可以进行特殊方式下载后通过`load_from_disk()`函数进行本地加载。 ```python #第3章/从磁盘加载数据集 from datasets import load_from_disk dataset = load_from_disk('./data/ChnSentiCorp') dataset ``` tn2>我们可以看到`lansinuote/ChnSentiCorp`分为3个部分,`train`、`test`、`validation`,分别表示训练集、验证集和测试集,并且每条数据都有`text`和`label`,分别代表文本和标签。 `load_dataset()`函数还有一些其他参数,可以通过下面的一个例子来举例说明。 ```python load_dataset(path='glue', name='sst2', split='train') ``` tn2>这里加载了`glue`数据集,通过`name`参数指定加载`sst2`数据子集,还通过`split`指定要加载`train`部分的。 ### 将数据集保存到本地磁盘 tn2>加载了数据集后,可以使用`save_to_disk()`函数将数据集保存到本地磁盘,代码如下: ```python dataset.save_to_disk(dataset_dict_path="./ChnSentiCorp") ``` ## 数据基本操作 ### 取出数据部分 tn2>为了便于做后续实验,首先取出train部分的数据集。 ```python dataset = dataset['train'] ``` tn2>查看数据取样。 ```python for i in [12, 17, 20, 26, 56]: print(dataset[i]) ``` ![](https://img.tnblog.net/arcimg/hb/c71db61978d0413f8986089621f07459.png) tn2>可以看出来这是一份消费评论数据,字段text表示消费者的评论,这段lael表明这是一段好评还是差评。 ### 数据排序 tn2>可以使用`sort()`函数让数据按照某个字段排序,代码如下: ```python #第3章/排序数据 #数据中的label是无序的 print(dataset['label'][:10]) #让数据按照label排序 sorted_dataset = dataset.sort('label') print(sorted_dataset['label'][:10]) print(sorted_dataset['label'][-10:]) ``` ![](https://img.tnblog.net/arcimg/hb/4692790a19af4afc8224871d06f20b7d.png) tn2>可以看到,初始数据是乱序的,使用`sort()`函数后,数据按照`label`排列为有序的了。 ### 打乱数据 tn2>我们还可以使用`shuffle()`函数再次打乱数据,代码如下: ```python shuffled_dataset = sorted_dataset.shuffle(seed=42) shuffled_dataset['label'][:10] ``` ![](https://img.tnblog.net/arcimg/hb/1d1457f258f64ed087d0d6a56aed010c.png) tn2>可以看到再次打乱为无序。 ### 数据抽样 tn2>可以使用`select()`函数从数据集中选择某些数据,得到一个新的数据集,代码如下: ```python dataset.select([0, 10, 20, 30, 40, 50]) ``` tn2>运行结果如下: ![](https://img.tnblog.net/arcimg/hb/9e564cab72824d6e8033998391d66e62.png) ### 过滤数据 tn2>使用`filter()`函数可以按照自定义的规则过滤数据,代码如下: ```python def f(data): return data['text'].startswith('非常不错') dataset.filter(f) ``` tn2>`filter()`函数接受一个函数作为参数,在该函数中确定过滤数据的条件,在上面的例子中数据过滤的条件是评价以`非常不错`开头,代码如下: ![](https://img.tnblog.net/arcimg/hb/c1188f4a2c2c4a16b4cdbdaa85f8f67d.png) tn2>可以看到,满足评价以`非常不错`开头的数据共有13条。 ### 训练测试集拆分 tn2>可以使用`train_test_split()`函数将数据集切分为训练集和测试集,代码如下: ```python dataset.train_test_split(test_size=0.1) ``` ![](https://img.tnblog.net/arcimg/hb/36e6e7fc1bab4c619e9c2fa81861ea2d.png) tn2>参数`test_size`表明测试集占数据总体的比例,例子中占10%,训练集占90%。 ### 数据分桶 tn2>可以使用`shared()`函数吧数据均匀地分为n部分,代码如下: ```python dataset.shard(num_shards=4, index=0) ``` | 参数 | 描述 | | ------------ | ------------ | | `num_shards` | 表明要把数据均匀地分为几部分,例子中为4部分。 | | `index` | 表示要取出第几份数据,例子中为取出第`0`份 | ![](https://img.tnblog.net/arcimg/hb/cc47d5763cef461cb1e98e8feab3d481.png) tn2>因为原数据集数据为9600条,均匀地分为4份后每一份是2400条,和上面输出的一致 ### 重名名字段 tn2>使用`rename_column()`函数可以重命名字段,代码如下: ```python dataset.rename_column('text', 'text_rename') ``` ![](https://img.tnblog.net/arcimg/hb/b26dad097fa8446688cfbded57a0e1ab.png) tn2>原始字段`text`现在已经被重命名为`text_rename`。 ### 删除字段 tn2>使用`remove_columns()`函数可以删除字段,代码如下: ```python dataset.remove_columns(['text']) ``` ![](https://img.tnblog.net/arcimg/hb/4d8e1a6aeb884b14bed627ef0b823710.png) tn2>可以看到text字段已经被删除了 ### 映射函数 tn2>`map()`函数,可以对每一条数据进行一定的修改,代码如下: ```python def f(data): data['text'] = 'My sentence: ' + data['text'] return data maped_datatset = dataset.map(f) print(dataset['text'][20]) print(maped_datatset['text'][20]) ``` ![](https://img.tnblog.net/arcimg/hb/97ff234e3f1f41a69e72711eb786fcb1.png) tn2>增删改查都可以。 ### 使用批处理加速 tn2>在使用过滤和映射这一类需要使用一个函数遍历数据集的方法时,可以使用批处理减少函数的调用次数。一批次处理多个数据,代码如下: ```python def f(data): text = data['text'] text = ['My sentence: ' + i for i in text] data['text'] = text return data maped_datatset = dataset.map(function=f, batched=True, batch_size=1000, num_proc=4) print(dataset['text'][20]) print(maped_datatset['text'][20]) ``` | 参数 | 描述 | | ------------ | ------------ | | `batched` | 当为True时,表示进行批次处理。 | | `batch_size` | 以`1000`条数据为一个批次进行一次处理,这样大大提高了性能。 | | `num_proc` | 表示在4条线程上执行该任务,一般也为设置为CPU核心数量。 | ![](https://img.tnblog.net/arcimg/hb/3ebb88da02604763ad5df371ed92fba6.png) ### 设置数据格式 tn2>使用`set_format()`函数修改数据格式,代码如下: ```python dataset.set_format(type='torch', columns=['label'], output_all_columns=True) dataset[20] ``` | 参数 | 描述 | | ------------ | ------------ | | `type` | 表明要修改的数据类型,常用的取值有numpy、torch、tensorflow、pandas等。 | | `columns` | 要修改格式的字段 | | `cutput_all_columns` | 表明是否要保留其他字段,设置为True表明要保留的。 | ![](https://img.tnblog.net/arcimg/hb/f750133517f04a82836926bc5ed48bbf.png) tn2>此时字段label已经被修改为PyTorch的Tensor格式了。 ## 将数据保存为其他格式 ### CSV格式 tn2> 将数据集保存为CSV格式,便于分享,同时数据集工具也有加载CSV格式数据方法,代码如下: ```python dataset = load_dataset(path='lansinuote/ChnSentiCorp', split='train') dataset.to_csv(path_or_buf='./ChnSentiCorp.csv') #加载csv格式数据 csv_dataset = load_dataset(path='csv', data_files='./ChnSentiCorp.csv', split='train') csv_dataset[20] ``` ![](https://img.tnblog.net/arcimg/hb/664cf849699f48cb9cd6c5d391d2b20a.png) ![](https://img.tnblog.net/arcimg/hb/fdff59d8977c4a2d8fb195abfb8dd6b9.png) tn2>可以看到,保存为CSV格式后再加载,多了一个Unnamed字段,在这一列中实际保存的是数据的序列,这和保存的CSV文件内容有关系。如果不想要这一列,则可以直接到CSV文件去删除1列,删除时可以使用数据集的删除列功能,在此不再叙述 ## 保存数据为JSON格式 tn2>除了可以保存CSV格式以外还可以保存为JSON格式,方法和CSV差不多,代码如下: ```python dataset = load_dataset(path='lansinuote/ChnSentiCorp', split='train') dataset.to_json(path_or_buf='./ChnSentiCorp.json') #加载json格式数据 json_dataset = load_dataset(path='json', data_files='./ChnSentiCorp.json', split='train') json_dataset[20] ``` ![](https://img.tnblog.net/arcimg/hb/062d9bb9bb284502bf42549c4cbdbf61.png) ![](https://img.tnblog.net/arcimg/hb/a19468f4d6914c4588bef73f1d906d71.png)