WordCount

WordCount基础

首先来看一个WordCount的例子

假设有个words.txt文件,内容如下

hello world
hello friend
hello my friend
hello

要用spark统计一下每个单词的数量

那么可以分为如下步骤:

  • 把每行的单词全部拿出来,拉平成为一个数组
  • 利用map方法,把每个单词转换为key-value的格式,例如[(hello, 1), (world, 1), (hello, 1)...]
  • 利用reduce方法,根据key单词,把value加起来

下面是操作步骤:

  1. 首先,用flatMap分割文件,flatMap的作用是把整个数组拉平,返回的是一个向量
words = lines.flatMap(lambda x: x.split())
words.collect()
['hello', 'world', 'hello', 'friend', 'hello', 'my', 'friend', 'hello']

为什么不使用map呢?假如使用的是map,得到的结果是什么样?

words = lines.map(lambda x: x.split())
words.collect()
[['hello', 'world'], ['hello', 'friend'], ['hello', 'my', 'friend'], ['hello']]
  1. 在拉平数组之后,可以先把单词转换为key-value对的形式
words.map(lambda x: (x, 1)).collect()
[('hello', 1),
 ('world', 1),
 ('hello', 1),
 ('friend', 1),
 ('hello', 1),
 ('my', 1),
 ('friend', 1),
 ('hello', 1)]
  1. 根据key-value对做reduce,利用lambda x,y:x+y把每个value加起来
words.map(lambda x: (x, 1)).reduceByKey(lambda x,y:x+y).collect()
[('world', 1), ('hello', 4), ('friend', 2), ('my', 1)]

最终的完整操作如下

import findspark
findspark.init()

import pyspark

sc = pyspark.SparkContext(appName="WordCount")

lines = sc.textFile('file:///home/ubuntu/words.txt')

words = lines.flatMap(lambda x: x.split())

wordcounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x,y:x+y)
wordcounts.collect()

WordCount排序

假如我们要把WordCount的结果,根据数量排序,那么如何做呢?可以使用sortByKey来完成,但需要把keyvalue调换一下,让数量变成key,才可以根据数量排序。

这是上一步得到的结果

wordcounts.collect()
[('world', 1), ('hello', 4), ('friend', 2), ('my', 1)]

那么需要先做一下key-value的转换

wordcounts.map(lambda x:(x[1],x[0])).collect()
[(1, 'world'), (1, 'my'), (4, 'hello'), (2, 'friend')]

转换后,就可以排序了,其中False的意思是倒序

wordcounts = sc.textFile('file:///home/ubuntu/words.txt') \
        .flatMap(lambda x: x.split()) \
        .map(lambda x: (x, 1)) \
        .reduceByKey(lambda x,y:x+y) \
        .map(lambda x:(x[1],x[0])) \
        .sortByKey(False) 
[(4, 'hello'), (2, 'friend'), (1, 'world'), (1, 'my')]

查看RDD分区

Spark提供高效的并行计算,在上面这个例子里,Spark会自动进行分区。可以用下面代码查看每个RDD包含的object数量

def countPartitions(id,iterator): 
    c = 0 
    for _ in iterator: 
        c += 1 
    yield (id,c) 
lines.mapPartitionsWithSplit(countPartitions).collectAsMap()

查看spark默认的并行数量

sc.defaultParallelism
8

spark会把一个RDD复制到多个worker节点上,所以一个节点上的任务挂了,不会影响到其他节点上的任务,只需要在其他节点上重启任务就OK了。

wordcounts

PySpark读取csv

首先下载用来读取的txt文件

! wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/datasets/Performance_2015Q1.txt .

之后加载环境变量SparkContextSparkSession

import findspark
findspark.init()

import pyspark
import random

sc = pyspark.SparkContext(appName="WordCount")
spark = pyspark.sql.session.SparkSession(sc)

用SparkSession读csv

  • inferSchema如果为True,就不会一下加载整个数据集
df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')

获取前3条数据

df.take(3)
[Row(_c0=100002091588, _c1='01/01/2015', _c2='OTHER', _c3=4.125, _c4=None, _c5=0, _c6=360, _c7=360, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None),
 Row(_c0=100002091588, _c1='02/01/2015', _c2=None, _c3=4.125, _c4=None, _c5=1, _c6=359, _c7=359, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None),
 Row(_c0=100002091588, _c1='03/01/2015', _c2=None, _c3=4.125, _c4=None, _c5=2, _c6=358, _c7=358, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None)]

另一种读取方法,可以指定需要的列

from pyspark.sql.types import DateType, TimestampType, IntegerType, FloatType, LongType, DoubleType, StringType
from pyspark.sql.types import StructType, StructField

custom_schema = StructType([StructField('_c0', DateType(), True),
                           StructField('_c1', StringType(), True),
                           StructField('_c2', DoubleType(), True),
                           StructField('_c3', DoubleType(), True)])
                           
df_part = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, schema=custom_schema, sep='|')
df_part.take(3)
[Row(_c0=None, _c1='01/01/2015', _c2=None, _c3=4.125),
 Row(_c0=None, _c1='02/01/2015', _c2=None, _c3=4.125),
 Row(_c0=None, _c1='03/01/2015', _c2=None, _c3=4.125)]

计算数据的数量

df.count()
3526154

列的选择与重命名

可以使用df.select('col1', 'col2', 'col3')来选择

df_lim = df.select('_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13')
df_lim.take(1)
[Row(_c0=100002091588, _c1=u'01/01/2015', _c2=u'OTHER', _c3=4.125, _c4=None, _c5=0, _c6=360, _c7=360, _c8=u'01/2045', _c9=16740, _c10=u'0', _c11=u'N', _c12=None, _c13=None)]

列的重命名

df_lim = df_lim.withColumnRenamed('_c0','loan_id').withColumnRenamed('_c1','period')
df_lim
DataFrame[loan_id: bigint, period: string, _c2: string, _c3: double, _c4: double, _c5: int, _c6: int, _c7: int, _c8: string, _c9: int, _c10: string, _c11: string, _c12: int, _c13: string]

使用for循环来批量改名

old_names = ['_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13']
new_names = ['servicer_name', 'new_int_rt', 'act_endg_upb', 'loan_age', 'mths_remng', 'aj_mths_remng', 'dt_matr', 'cd_msa', 'delq_sts', 'flag_mod', 'cd_zero_bal', 'dt_zero_bal']

for old, new in zip(old_names, new_names):
    df_lim = df_lim.withColumnRenamed(old, new)
df_lim
DataFrame[loan_id: bigint, period: string, servicer_name: string, new_int_rt: double, act_endg_upb: double, loan_age: int, mths_remng: int, aj_mths_remng: int, dt_matr: string, cd_msa: int, delq_sts: string, flag_mod: string, cd_zero_bal: int, dt_zero_bal: string]

获取一行试试

df_lim.take(1)
[Row(loan_id=100002091588, period='01/01/2015', servicer_name='OTHER', new_int_rt=4.125, act_endg_upb=None, loan_age=0, mths_remng=360, aj_mths_remng=360, dt_matr='01/2045', cd_msa=16740, delq_sts='0', flag_mod='N', cd_zero_bal=None, dt_zero_bal=None)]

describe

describepandas的命令类似,如果不输入列名,就会显示全部列的统计(如果列太多,会比较混乱)

df_described = df_lim.describe('servicer_name', 'new_int_rt', 'loan_age')
df_described.show()
+-------+--------------------+-------------------+------------------+
|summary|       servicer_name|         new_int_rt|          loan_age|
+-------+--------------------+-------------------+------------------+
|  count|              382039|            3526154|           3526154|
|   mean|                null|  4.178168090219519| 5.134865351881966|
| stddev|                null|0.34382335723646673|3.3833930336063465|
|    min|  CITIMORTGAGE, INC.|               2.75|                -1|
|    max|WELLS FARGO BANK,...|              6.125|                34|
+-------+--------------------+-------------------+------------------+

保存文件

df_described.write.format('com.databricks.spark.csv').option("header","true").save('file:///home/ubuntu/result.csv')

基础操作

添加列

我们只需要部分列

df_lim = df.select('_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13')

old_names = ['_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13']
new_names = ['loan_id','period','servicer_name', 'new_int_rt', 'act_endg_upb', 'loan_age', 'mths_remng', 'aj_mths_remng', 'dt_matr', 'cd_msa', 'delq_sts', 'flag_mod', 'cd_zero_bal', 'dt_zero_bal']
for old, new in zip(old_names, new_names):
    df_lim = df_lim.withColumnRenamed(old, new)

可以使用withColumn(新列名,列的值)来添加一列

df_lim = df_lim.withColumn('loan_length', df_lim['loan_age'] + df_lim['mths_remng'])

group by

## 根据servicer name来group by:
df_grp = df_lim.groupBy('servicer_name')

## 计算每个servicer的loan age, months remaining 和 loan length 的平均值:
df_avg = df_grp.avg('loan_age', 'mths_remng', 'loan_length')

展示一下,show会花较多的时间,因为spark有一个特性叫做lazy operation,也就是transformation操作不会计算,知道出现action操作时,才会一起计算。

df_avg.show()
+--------------------+--------------------+------------------+------------------+
|       servicer_name|       avg(loan_age)|   avg(mths_remng)|  avg(loan_length)|
+--------------------+--------------------+------------------+------------------+
|  QUICKEN LOANS INC.|-0.08899247348614438| 358.5689787889155|358.47998631542936|
|NATIONSTAR MORTGA...| 0.39047125841532887| 359.5821853961678| 359.9726566545831|
|                null|  5.6264681794400015|354.21486809483747| 359.8413362742775|
|WELLS FARGO BANK,...|  0.6704475572258285|359.25937820293814|359.92982576016396|
|FANNIE MAE/SETERU...|   9.333333333333334| 350.6666666666667|             360.0|
|DITECH FINANCIAL LLC|   5.147629653197582| 354.7811008590519|359.92873051224944|
|SENECA MORTGAGE S...| -0.2048814025438295|360.20075627363354| 359.9958748710897|
|SUNTRUST MORTGAGE...|  0.8241234756097561| 359.1453887195122|  359.969512195122|
|ROUNDPOINT MORTGA...|   5.153408024034549| 354.8269387244163|359.98034674845087|
|      PENNYMAC CORP.| 0.14966740576496673| 359.8470066518847|359.99667405764967|
|PHH MORTGAGE CORP...|  0.9780420860018298|359.02195791399816|             360.0|
|MATRIX FINANCIAL ...|   6.566794707639778| 353.4229620145113|359.98975672215107|
|               OTHER| 0.11480465916297489| 359.8345750772193|359.94937973638224|
|  CITIMORTGAGE, INC.|   0.338498789346247|359.41670702179175|  359.755205811138|
|PINGORA LOAN SERV...|   7.573573382530696|352.40886824861633|  359.982441631147|
|JP MORGAN CHASE B...|  1.6553418987669224| 358.3384495990342|359.99379149780117|
|      PNC BANK, N.A.|  1.1707779886148009|358.78747628083494| 359.9582542694497|
|FREEDOM MORTGAGE ...|    8.56265812109968|351.29583403609377|359.85849215719344|
+--------------------+--------------------+------------------+------------------+

可以使用persist来保存中间结果,告诉spark后面我还会重新用到这个步骤

df_keep = df_lim.withColumn('loan_length', df_lim['loan_age'] + df_lim['mths_remng'])

df_keep.persist()

df_grp = df_keep.groupBy('servicer_name')
df_avg = df_grp.avg('loan_age', 'mths_remng', 'loan_length')

这是第一次计算的速度

%%time
df_avg.show()
+--------------------+--------------------+------------------+------------------+
|       servicer_name|       avg(loan_age)|   avg(mths_remng)|  avg(loan_length)|
+--------------------+--------------------+------------------+------------------+
|  QUICKEN LOANS INC.|-0.08899247348614438| 358.5689787889155|358.47998631542936|
|NATIONSTAR MORTGA...| 0.39047125841532887| 359.5821853961678| 359.9726566545831|
|                null|  5.6264681794400015|354.21486809483747| 359.8413362742775|
|WELLS FARGO BANK,...|  0.6704475572258285|359.25937820293814|359.92982576016396|
|FANNIE MAE/SETERU...|   9.333333333333334| 350.6666666666667|             360.0|
|DITECH FINANCIAL LLC|   5.147629653197582| 354.7811008590519|359.92873051224944|
|SENECA MORTGAGE S...| -0.2048814025438295|360.20075627363354| 359.9958748710897|
|SUNTRUST MORTGAGE...|  0.8241234756097561| 359.1453887195122|  359.969512195122|
|ROUNDPOINT MORTGA...|   5.153408024034549| 354.8269387244163|359.98034674845087|
|      PENNYMAC CORP.| 0.14966740576496673| 359.8470066518847|359.99667405764967|
|PHH MORTGAGE CORP...|  0.9780420860018298|359.02195791399816|             360.0|
|MATRIX FINANCIAL ...|   6.566794707639778| 353.4229620145113|359.98975672215107|
|               OTHER| 0.11480465916297489| 359.8345750772193|359.94937973638224|
|  CITIMORTGAGE, INC.|   0.338498789346247|359.41670702179175|  359.755205811138|
|PINGORA LOAN SERV...|   7.573573382530696|352.40886824861633|  359.982441631147|
|JP MORGAN CHASE B...|  1.6553418987669224| 358.3384495990342|359.99379149780117|
|      PNC BANK, N.A.|  1.1707779886148009|358.78747628083494| 359.9582542694497|
|FREEDOM MORTGAGE ...|    8.56265812109968|351.29583403609377|359.85849215719344|
+--------------------+--------------------+------------------+------------------+

CPU times: user 2.67 ms, sys: 0 ns, total: 2.67 ms
Wall time: 5.3 s

再算一次,可以发现速度变快了很多倍,因为利用了persist保存了中间步骤

%%time
df_sum = df_grp.sum('new_int_rt', 'loan_age', 'mths_remng', 'cd_zero_bal', 'loan_length')
df_sum.show()
+--------------------+--------------------+-------------+---------------+----------------+----------------+
|       servicer_name|     sum(new_int_rt)|sum(loan_age)|sum(mths_remng)|sum(cd_zero_bal)|sum(loan_length)|
+--------------------+--------------------+-------------+---------------+----------------+----------------+
|  QUICKEN LOANS INC.|    101801.764999999|        -2081|        8384777|            null|         8382696|
|NATIONSTAR MORTGA...|  40287.497999999956|         3770|        3471766|               2|         3475536|
|                null|1.3139130894999936E7|     17690263|     1113692280|           16932|      1131382543|
|WELLS FARGO BANK,...|  187326.36500000005|        29436|       15773283|            null|        15802719|
|FANNIE MAE/SETERU...|                26.6|           56|           2104|            null|            2160|
|DITECH FINANCIAL LLC|  39531.709999999934|        48537|        3345231|              41|         3393768|
|SENECA MORTGAGE S...|  24093.559999999987|        -1192|        2095648|            null|         2094456|
|SUNTRUST MORTGAGE...|  21530.767999999953|         4325|        1884795|            null|         1889120|
|ROUNDPOINT MORTGA...|   67708.25999999992|        82336|        5669070|              74|         5751406|
|      PENNYMAC CORP.|  15209.140000000001|          540|        1298328|            null|         1298868|
|PHH MORTGAGE CORP...|            9086.066|         2138|         784822|            null|          786960|
|MATRIX FINANCIAL ...|  19212.932999999997|        30772|        1656140|              16|         1686912|
|               OTHER|   904855.0440000098|        25163|       78868902|              21|        78894065|
|  CITIMORTGAGE, INC.|            16939.33|         1398|        1484391|            null|         1485789|
|PINGORA LOAN SERV...|   64224.70499999986|       119049|        5539515|             111|         5658564|
|JP MORGAN CHASE B...|   50187.15499999998|        19197|        4155651|            null|         4174848|
|      PNC BANK, N.A.|   6911.724999999999|         1851|         567243|               1|          569094|
|FREEDOM MORTGAGE ...|  24800.604999999992|        50768|        2082833|              60|         2133601|
+--------------------+--------------------+-------------+---------------+----------------+----------------+

CPU times: user 1.48 ms, sys: 930 µs, total: 2.41 ms
Wall time: 882 ms

merging操作

首先还是初始化一下

import findspark
findspark.init()

import pyspark
import random

sc = pyspark.SparkContext(appName="Test")
spark = pyspark.sql.session.SparkSession(sc)

union

构建DataFrame,语法如下:

from pyspark.sql import Row

row = Row("name", "pet", "count")

df1 = sc.parallelize([
    row("Sue", "cat", 16),
    row("Kim", "dog", 1),    
    row("Bob", "fish", 5)
    ]).toDF()

df2 = sc.parallelize([
    row("Fred", "cat", 2),
    row("Kate", "ant", 179),    
    row("Marc", "lizard", 5)
    ]).toDF()

df3 = sc.parallelize([
    row("Sarah", "shark", 3),
    row("Jason", "kids", 2),    
    row("Scott", "squirrel", 1)
    ]).toDF()

直接union两个表

df_union = df1.unionAll(df2)
df_union.show()
+----+------+-----+
|name|   pet|count|
+----+------+-----+
| Sue|   cat|   16|
| Kim|   dog|    1|
| Bob|  fish|    5|
|Fred|   cat|    2|
|Kate|   ant|  179|
|Marc|lizard|    5|
+----+------+-----+

union多个表

from pyspark.sql import DataFrame
from functools import reduce

def union_many(*dfs):
    return reduce(DataFrame.unionAll, dfs)

df_union = union_many(df1, df2, df3)

df_union.show()
+-----+--------+-----+
| name|     pet|count|
+-----+--------+-----+
|  Sue|     cat|   16|
|  Kim|     dog|    1|
|  Bob|    fish|    5|
| Fred|     cat|    2|
| Kate|     ant|  179|
| Marc|  lizard|    5|
|Sarah|   shark|    3|
|Jason|    kids|    2|
|Scott|squirrel|    1|

join

重新建立两个表

row1 = Row("name", "pet", "count")
row2 = Row("name", "pet2", "count2")

df1 = sc.parallelize([
    row1("Sue", "cat", 16),
    row1("Kim", "dog", 1),    
    row1("Bob", "fish", 5),
    row1("Libuse", "horse", 1)
    ]).toDF()

df2 = sc.parallelize([
    row2("Sue", "eagle", 2),
    row2("Kim", "ant", 179),    
    row2("Bob", "lizard", 5),
    row2("Ferdinand", "bees", 23)
    ]).toDF()

inner join

df1.join(df2, 'name', how='inner').show()
+----+----+-----+------+------+
|name| pet|count|  pet2|count2|
+----+----+-----+------+------+
| Sue| cat|   16| eagle|     2|
| Bob|fish|    5|lizard|     5|
| Kim| dog|    1|   ant|   179|
+----+----+-----+------+------+

outer join

df1.join(df2, 'name', how='outer').show()
+----+----+-----+------+------+
|name| pet|count|  pet2|count2|
+----+----+-----+------+------+
| Sue| cat|   16| eagle|     2|
| Bob|fish|    5|lizard|     5|
| Kim| dog|    1|   ant|   179|
+----+----+-----+------+------+

left join

df1.join(df2, 'name', how='left').show()
+------+-----+-----+------+------+
|  name|  pet|count|  pet2|count2|
+------+-----+-----+------+------+
|   Sue|  cat|   16| eagle|     2|
|   Bob| fish|    5|lizard|     5|
|   Kim|  dog|    1|   ant|   179|
|Libuse|horse|    1|  null|  null|
+------+-----+-----+------+------+

missing data

先看一下全部数据集的样本数

df.count()
3526154

统计某列的缺失值数量

df.where( df['_c12'].isNull() ).count()
3510294

统计全部列的缺失值数量,spark不会给string类型插入Null值,所以可以跳过

def count_nulls(df):
    null_counts = []          # 空数组,用来存结果
    for col in df.dtypes:     #迭代读取列的列表 ('C0', 'bigint')
        cname = col[0]        # 列名,
        ctype = col[1]        # 列的类型
        if ctype != 'string': # 如果是string,就跳过
            nulls = df.where( df[cname].isNull() ).count()
            result = tuple([cname, nulls])  # (列名,缺失值数量)
            null_counts.append(result)      # 放到数组里
    return null_counts

null_counts = count_nulls(df)
null_counts
[('_c0', 0),
 ('_c3', 0),
 ('_c4', 1945752),
 ('_c5', 0),
 ('_c6', 0),
 ('_c7', 1),
 ('_c9', 0),
 ('_c12', 3510294),
 ('_c26', 3526153)]

删除缺失值drop

删除缺失值可以选择all,那么只有在全部列都缺失的时候才删除

df_drops = df.dropna(how='all', subset=['_c4', '_c12', '_c26'])
df_drops.count()
1580403

也可以选择一个阈值,代表至少要有多少个列非空,才可以不被删除,假设下面就是至少有2个列非空才行(如果thresh=1,那么等同于how='all')

df_drops2 = df.dropna(thresh=2, subset=['_c4', '_c12', '_c26'])
df_drops2.count()
15860

填补缺失值

df_fill = df.fillna(0, subset=['_c12'])

查看是否还有缺失值

df_fill.where( df_fill['_c12'].isNull() ).count()
0

moving average填补缺失值

首先下载一个新数据集

! wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/datasets/diamonds_nulls.csv

读取

df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv', 
                    inferSchema=True, header=True, sep=',', nullValue='')
df.show()
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|    Ideal|    E|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|
| 0.21|  Premium|    E|    SI1| 59.8| 61.0|  326|3.89|3.84|2.31|
| 0.23|     Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|
| 0.29|  Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|
| 0.31|     Good|    J|    SI2| 63.3| 58.0|  335|4.34|4.35|2.75|
| 0.24|Very Good|    J|   VVS2| 62.8| 57.0|  336|3.94|3.96|2.48|
| 0.24|Very Good|    I|   VVS1| 62.3| 57.0|  336|3.95|3.98|2.47|
| 0.26|Very Good|    H|    SI1| 61.9| 55.0|  337|4.07|4.11|2.53|
| 0.22|     Fair|    E|    VS2| 65.1| 61.0|  337|3.87|3.78|2.49|
| 0.23|Very Good|    H|    VS1| 59.4| 61.0|  338| 4.0|4.05|2.39|
|  0.3|     Good|    J|    SI1| 64.0| 55.0|  339|4.25|4.28|2.73|
| 0.23|    Ideal|    J|    VS1| 62.8| 56.0|  340|3.93| 3.9|2.46|
| 0.22|  Premium|    F|    SI1| 60.4| 61.0|  342|3.88|3.84|2.33|
| 0.31|    Ideal|    J|    SI2| 62.2| 54.0|  344|4.35|4.37|2.71|
|  0.2|  Premium|    E|    SI2| 60.2| 62.0|  345|3.79|3.75|2.27|
| 0.32|  Premium|    E|     I1| 60.9| 58.0|  345|4.38|4.42|2.68|
|  0.3|    Ideal|    I|    SI2| 62.0| 54.0|  348|4.31|4.34|2.68|
|  0.3|     Good|    J|    SI1| 63.4| 54.0|  351|4.23|4.29| 2.7|
|  0.3|     Good|    J|    SI1| 63.8| 56.0|  351|4.23|4.26|2.71|
|  0.3|Very Good|    J|    SI1| 62.7| 59.0|  351|4.21|4.27|2.66|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 20 rows

看一下price缺失的样本,选前50个

df.where(df['price'].isNull()).show(50)
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat|      cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.24|  Premium|    I|    VS1| 62.5| 57.0| null|3.97|3.94|2.47|
| 0.35|    Ideal|    I|    VS1| 60.9| 57.0| null|4.54|4.59|2.78|
|  0.7|     Good|    F|    VS1| 59.4| 62.0| null|5.71|5.76| 3.4|
|  0.7|     Fair|    F|    VS2| 64.5| 57.0| null|5.57|5.53|3.58|
|  0.7|  Premium|    E|    SI1| 61.2| 57.0| null|5.73|5.68|3.49|
| 0.73|  Premium|    F|    VS2| 62.5| 57.0| null|5.75| 5.7|3.58|
| 1.01|    Ideal|    F|    SI1| 62.7| 55.0| null|6.45| 6.4|4.03|
| 1.03|    Ideal|    H|    SI1| 61.1| 56.0| null| 6.5|6.53|3.98|
| 1.28|    Ideal|    I|    SI2| 61.7| 59.0| null|6.96|6.92|4.28|
| 0.37|  Premium|    D|    SI1| 60.4| 59.0| null|4.68|4.62|2.81|
|  0.5|    Ideal|    J|    VS2| 61.7| 57.0| null|5.09|5.12|3.15|
| 0.34|    Ideal|    E|    VS1| 61.2| 55.0| null|4.52|4.56|2.77|
| 0.52|    Ideal|    D|    VS2| 61.8| 55.0| null|5.19|5.23|3.22|
| 0.71|Very Good|    J|   VVS2| 61.1| 58.0| null| 5.7|5.75| 3.5|
| 0.76|  Premium|    H|    SI1| 59.8| 57.0| null|5.93|5.91|3.54|
| 0.58|    Ideal|    F|    VS1| 60.3| 57.0| null|5.47|5.44|3.29|
|  0.7|Very Good|    E|    VS1| 63.4| 62.0| null|5.64|5.56|3.55|
| 0.92|  Premium|    D|     I1| 63.0| 58.0| null|6.18|6.13|3.88|
| 0.88|Very Good|    I|    SI1| 62.5| 56.0| null|6.06|6.19|3.83|
|  0.7|     Good|    H|   VVS2| 58.9| 61.5| null|5.77|5.84|3.42|
|  0.7|Very Good|    D|    SI1| 62.8| 60.0| null|5.66|5.68|3.56|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+

首先建立一个窗口函数,类似SQL,有partition byorder by

from pyspark.sql import Window

window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)

根据这个窗口函数 计算price的平均值

from pyspark.sql.functions import mean

moving_avg = mean(df['price']).over(window)

建立一个新列,moving average

df = df.withColumn('moving_avg', moving_avg)
df.show()
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|        moving_avg|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
| 0.73|Premium|    F|    VS2| 62.5| 57.0| null|5.75| 5.7|3.58|             356.0|
| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|            358.75|
|  0.2|Premium|    E|    VS2| 59.8| 62.0|  367|3.79|3.77|2.26|             360.4|
|  0.2|Premium|    E|    VS2| 59.0| 60.0|  367|3.81|3.78|2.24|             361.5|
|  0.2|Premium|    E|    VS2| 61.1| 59.0|  367|3.81|3.78|2.32| 362.2857142857143|
|  0.2|Premium|    E|    VS2| 59.7| 62.0|  367|3.84| 3.8|2.28|             367.0|
|  0.2|Premium|    F|    VS2| 62.6| 59.0|  367|3.73|3.71|2.33|367.14285714285717|
|  0.2|Premium|    D|    VS2| 62.3| 60.0|  367|3.73|3.68|2.31| 367.2857142857143|
|  0.2|Premium|    D|    VS2| 61.7| 60.0|  367|3.77|3.72|2.31|369.14285714285717|
|  0.3|Premium|    J|    VS2| 62.2| 58.0|  368|4.28| 4.3|2.67|             371.0|
|  0.3|Premium|    J|    VS2| 60.6| 59.0|  368|4.34|4.38|2.64| 373.7142857142857|
| 0.31|Premium|    J|    VS2| 62.5| 60.0|  380|4.31|4.36|2.71|376.42857142857144|
| 0.31|Premium|    J|    VS2| 62.4| 60.0|  380|4.29|4.33|2.69|379.14285714285717|
| 0.21|Premium|    E|    VS2| 60.5| 59.0|  386|3.87|3.83|2.33| 381.7142857142857|
| 0.21|Premium|    E|    VS2| 59.6| 56.0|  386|3.93|3.89|2.33| 384.2857142857143|
| 0.21|Premium|    D|    VS2| 61.6| 59.0|  386|3.82|3.78|2.34|385.14285714285717|
| 0.21|Premium|    D|    VS2| 60.6| 60.0|  386|3.85|3.81|2.32|             387.0|
| 0.21|Premium|    D|    VS2| 59.1| 62.0|  386|3.89|3.86|2.29|             388.0|
| 0.21|Premium|    D|    VS2| 58.3| 59.0|  386|3.96|3.93| 2.3|389.57142857142856|
| 0.32|Premium|    J|    VS2| 61.9| 58.0|  393|4.35|4.38| 2.7|392.14285714285717|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
only showing top 20 rows

用如下方法,使用replace_null来根据是列否为空,决定是使用原始值还是moving average填充

from pyspark.sql.functions import when, col

def replace_null(orig, ma):
    return when(orig.isNull(), ma).otherwise(orig)

df_new = df.withColumn('imputed', 
                       replace_null(col('price'), col('moving_avg')) 
                      )

展示一下,可以看到第一行还是空值,因为moving average在一开始无法计算

df_new.show()
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|        moving_avg|imputed|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
| 0.73|Premium|    F|    VS2| 62.5| 57.0| null|5.75| 5.7|3.58|             356.0|  356.0|
| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|            358.75|  334.0|
|  0.2|Premium|    E|    VS2| 59.8| 62.0|  367|3.79|3.77|2.26|             360.4|  367.0|
|  0.2|Premium|    E|    VS2| 59.0| 60.0|  367|3.81|3.78|2.24|             361.5|  367.0|
|  0.2|Premium|    E|    VS2| 61.1| 59.0|  367|3.81|3.78|2.32| 362.2857142857143|  367.0|
|  0.2|Premium|    E|    VS2| 59.7| 62.0|  367|3.84| 3.8|2.28|             367.0|  367.0|
|  0.2|Premium|    F|    VS2| 62.6| 59.0|  367|3.73|3.71|2.33|367.14285714285717|  367.0|
|  0.2|Premium|    D|    VS2| 62.3| 60.0|  367|3.73|3.68|2.31| 367.2857142857143|  367.0|
|  0.2|Premium|    D|    VS2| 61.7| 60.0|  367|3.77|3.72|2.31|369.14285714285717|  367.0|
|  0.3|Premium|    J|    VS2| 62.2| 58.0|  368|4.28| 4.3|2.67|             371.0|  368.0|
|  0.3|Premium|    J|    VS2| 60.6| 59.0|  368|4.34|4.38|2.64| 373.7142857142857|  368.0|
| 0.31|Premium|    J|    VS2| 62.5| 60.0|  380|4.31|4.36|2.71|376.42857142857144|  380.0|
| 0.31|Premium|    J|    VS2| 62.4| 60.0|  380|4.29|4.33|2.69|379.14285714285717|  380.0|
| 0.21|Premium|    E|    VS2| 60.5| 59.0|  386|3.87|3.83|2.33| 381.7142857142857|  386.0|
| 0.21|Premium|    E|    VS2| 59.6| 56.0|  386|3.93|3.89|2.33| 384.2857142857143|  386.0|
| 0.21|Premium|    D|    VS2| 61.6| 59.0|  386|3.82|3.78|2.34|385.14285714285717|  386.0|
| 0.21|Premium|    D|    VS2| 60.6| 60.0|  386|3.85|3.81|2.32|             387.0|  386.0|
| 0.21|Premium|    D|    VS2| 59.1| 62.0|  386|3.89|3.86|2.29|             388.0|  386.0|
| 0.21|Premium|    D|    VS2| 58.3| 59.0|  386|3.96|3.93| 2.3|389.57142857142856|  386.0|
| 0.32|Premium|    J|    VS2| 61.9| 58.0|  393|4.35|4.38| 2.7|392.14285714285717|  393.0|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
only showing top 20 rows

pivot table

建立一个表

from pyspark.sql import Row

row = Row('state', 'industry', 'hq', 'jobs')

df = sc.parallelize([
    row('MI', 'auto', 'domestic', 716),
    row('MI', 'auto', 'foreign', 123),
    row('MI', 'auto', 'domestic', 1340),
    row('MI', 'retail', 'foreign', 12),
    row('MI', 'retail', 'foreign', 33),
    row('OH', 'auto', 'domestic', 349),
    row('OH', 'auto', 'foreign', 101),
    row('OH', 'auto', 'foreign', 77),
    row('OH', 'retail', 'domestic', 45),
    row('OH', 'retail', 'foreign', 12)
    ]).toDF()

df.show()
+-----+--------+--------+----+
|state|industry|      hq|jobs|
+-----+--------+--------+----+
|   MI|    auto|domestic| 716|
|   MI|    auto| foreign| 123|
|   MI|    auto|domestic|1340|
|   MI|  retail| foreign|  12|
|   MI|  retail| foreign|  33|
|   OH|    auto|domestic| 349|
|   OH|    auto| foreign| 101|
|   OH|    auto| foreign|  77|
|   OH|  retail|domestic|  45|
|   OH|  retail| foreign|  12|
+-----+--------+--------+----+

先根据state分组(每个row为一个组),再根据hq来作为列,值为jobs的和

df_pivot1 = df.groupby('state').pivot('hq', values=['domestic', 'foreign']).sum('jobs')
df_pivot1.show()
+-----+--------+-------+
|state|domestic|foreign|
+-----+--------+-------+
|   MI|    2056|    168|
|   OH|     394|    190|
+-----+--------+-------+

可以让行有多个值state+industry

df_pivot = df.groupBy('state', 'industry').pivot('hq', values=['domestic', 'foreign']).sum('jobs')
df_pivot.show()
+-----+--------+--------+-------+
|state|industry|domestic|foreign|
+-----+--------+--------+-------+
|   OH|  retail|      45|     12|
|   MI|    auto|    2056|    123|
|   OH|    auto|     349|    178|
|   MI|  retail|    null|     45|
+-----+--------+--------+-------+

resampling

建立一个表

import datetime
from pyspark.sql import Row
from pyspark.sql.functions import col

row = Row("date", "name", "production")

df = sc.parallelize([
    row("08/01/2014", "Kim", 5),
    row("08/02/2014", "Kim", 14),
    row("08/01/2014", "Bob", 6),
    row("08/02/2014", "Bob", 3),
    row("08/01/2014", "Sue", 0),
    row("08/02/2014", "Sue", 22),
    row("08/01/2014", "Dan", 4),
    row("08/02/2014", "Dan", 4),
    row("08/01/2014", "Joe", 37),
    row("09/01/2014", "Kim", 6),
    row("09/02/2014", "Kim", 6),
    row("09/01/2014", "Bob", 4),
    row("09/02/2014", "Bob", 20),
    row("09/01/2014", "Sue", 11),
    row("09/02/2014", "Sue", 2),
    row("09/01/2014", "Dan", 1),
    row("09/02/2014", "Dan", 3),
    row("09/02/2014", "Joe", 29)
    ]).toDF()
df.show()
+----------+----+----------+
|      date|name|production|
+----------+----+----------+
|08/01/2014| Kim|         5|
|08/02/2014| Kim|        14|
|08/01/2014| Bob|         6|
|08/02/2014| Bob|         3|
|08/01/2014| Sue|         0|
|08/02/2014| Sue|        22|
|08/01/2014| Dan|         4|
|08/02/2014| Dan|         4|
|08/01/2014| Joe|        37|
|09/01/2014| Kim|         6|
|09/02/2014| Kim|         6|
|09/01/2014| Bob|         4|
|09/02/2014| Bob|        20|
|09/01/2014| Sue|        11|
|09/02/2014| Sue|         2|
|09/01/2014| Dan|         1|
|09/02/2014| Dan|         3|
|09/02/2014| Joe|        29|
+----------+----+----------+
df.dtypes
[('date', 'string'), ('name', 'string'), ('production', 'bigint')]

利用自定义函数来把data列的MM/DD/YYYY处理为MM/YYYY

#'udf' 代表 'user defined function'用来对列操作,是一个wrapper
from pyspark.sql.functions import udf
# 定义函数来solit MM/DD/YYYY字符串,返回MM/YYYY字符串
def split_date(whole_date):
    
    try:
        mo, day, yr = whole_date.split('/')
    except ValueError:
        return 'error'
    
    return mo + '/' + yr
# wrapper
udf_split_date = udf(split_date)

# 创建新的列
df_new = df.withColumn('month_year', udf_split_date('date'))
df_new.show()
+----------+----+----------+----------+
|      date|name|production|month_year|
+----------+----+----------+----------+
|08/01/2014| Kim|         5|   08/2014|
|08/02/2014| Kim|        14|   08/2014|
|08/01/2014| Bob|         6|   08/2014|
|08/02/2014| Bob|         3|   08/2014|
|08/01/2014| Sue|         0|   08/2014|
|08/02/2014| Sue|        22|   08/2014|
|08/01/2014| Dan|         4|   08/2014|
|08/02/2014| Dan|         4|   08/2014|
|08/01/2014| Joe|        37|   08/2014|
|09/01/2014| Kim|         6|   09/2014|
|09/02/2014| Kim|         6|   09/2014|
|09/01/2014| Bob|         4|   09/2014|
|09/02/2014| Bob|        20|   09/2014|
|09/01/2014| Sue|        11|   09/2014|
|09/02/2014| Sue|         2|   09/2014|
|09/01/2014| Dan|         1|   09/2014|
|09/02/2014| Dan|         3|   09/2014|
|09/02/2014| Joe|        29|   09/2014|
+----------+----+----------+----------+

删除data列,根据month_year+nameresampling

df_new = df_new.drop('date')
df_agg = df_new.groupBy('month_year', 'name').agg({'production' : 'sum'})
df_agg.show()
+----------+----+---------------+
|month_year|name|sum(production)|
+----------+----+---------------+
|   09/2014| Sue|             13|
|   09/2014| Kim|             12|
|   09/2014| Bob|             24|
|   09/2014| Joe|             29|
|   09/2014| Dan|              4|
|   08/2014| Kim|             19|
|   08/2014| Joe|             37|
|   08/2014| Dan|              8|
|   08/2014| Sue|             22|
|   08/2014| Bob|              9|
+----------+----+---------------+

也可以把日期读取为Spark支持的格式

from pyspark.sql.functions import udf
from pyspark.sql.types import DateType
from datetime import datetime

dateFormat = udf(lambda x: datetime.strptime(x, '%M/%d/%Y'), DateType())
    
df_d = df.withColumn('new_date', dateFormat(col('date')))
df_d.select('new_date').take(1)
[Row(new_date=datetime.date(2014, 1, 1))]

子集subset

读取数据集

df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')
df.dtypes
[('_c0', 'bigint'),
 ('_c1', 'string'),
 ('_c2', 'string'),
 ('_c3', 'double'),
 ('_c4', 'double'),
 ('_c5', 'int'),
 ('_c6', 'int'),
 ('_c7', 'int'),
 ('_c8', 'string'),
 ('_c9', 'int'),
 ('_c10', 'string'),
 ('_c11', 'string'),
 ('_c12', 'int'),
 ('_c13', 'string'),
 ('_c14', 'string'),
 ('_c15', 'string'),
 ('_c16', 'string'),
 ('_c17', 'string'),
 ('_c18', 'string'),
 ('_c19', 'string'),
 ('_c20', 'string'),
 ('_c21', 'string'),
 ('_c22', 'string'),
 ('_c23', 'string'),
 ('_c24', 'string'),
 ('_c25', 'string'),
 ('_c26', 'int'),
 ('_c27', 'string')]

列选择

选择列,方式1

from pyspark.sql.functions import col

df_select = df.select(col('_c0'), col('_c1'), col('_c3'), col('_c9'))
df_select.show(5)
+------------+----------+-----+-----+
|         _c0|       _c1|  _c3|  _c9|
+------------+----------+-----+-----+
|100002091588|01/01/2015|4.125|16740|
|100002091588|02/01/2015|4.125|16740|
|100002091588|03/01/2015|4.125|16740|
|100002091588|04/01/2015|4.125|16740|
|100002091588|05/01/2015|4.125|16740|
+------------+----------+-----+-----+
only showing top 5 rows

选择列,方式2

df_select = df[['_c0', '_c1', '_c3', '_c9']]
df_select.show(5)

删除列

df_drop = df_select.drop(col('_c3'))
df_drop.show(5)
+------------+----------+-----+
|         _c0|       _c1|  _c9|
+------------+----------+-----+
|100002091588|01/01/2015|16740|
|100002091588|02/01/2015|16740|
|100002091588|03/01/2015|16740|
|100002091588|04/01/2015|16740|
|100002091588|05/01/2015|16740|
+------------+----------+-----+
only showing top 5 rows

行选择

假设我们要选择_c6

df.describe('_c6').show()
+-------+-----------------+
|summary|              _c6|
+-------+-----------------+
|  count|          3526154|
|   mean|354.7084951479714|
| stddev| 4.01181251079202|
|    min|              292|
|    max|              480|
+-------+-----------------+

where条件选择_c6<358的行

df_sub = df.where(df['_c6'] < 358)
df_sub.describe('_c6').show()
+-------+------------------+
|summary|               _c6|
+-------+------------------+
|  count|           2598037|
|   mean|353.15604897081914|
| stddev|3.5170213056883988|
|    min|               292|
|    max|               357|
+-------+------------------+

where多个条件

df_filter = df.where((df['_c6'] > 340) & (df['_c5'] < 4))
df_filter.describe('_c6', '_c5').show()
+-------+------------------+------------------+
|summary|               _c6|               _c5|
+-------+------------------+------------------+
|  count|           1254131|           1254131|
|   mean|358.48713810598736| 1.474693632483369|
| stddev| 1.378961910349754|1.2067831502138422|
|    min|               341|                -1|
|    max|               361|                 3|
+-------+------------------+------------------+
+-------+------------------+------------------+
|summary|               _c6|               _c5|
+-------+------------------+------------------+
|  count|           1254131|           1254131|
|   mean|358.48713810598736| 1.474693632483369|
| stddev| 1.378961910349754|1.2067831502138422|
|    min|               341|                -1|
|    max|               361|                 3|
+-------+------------------+------------------+

采样sample

第1个参数是是否有放回,第2个参数是采样百分比,第3个参数是可选的random_seed

df_sample = df.sample(False, 0.05, 99)
df_sample.describe('_c6').show()
+-------+-----------------+
|summary|              _c6|
+-------+-----------------+
|  count|           176316|
|   mean|354.7118072097824|
| stddev| 3.99195164816948|
|    min|              299|
|    max|              361|
+-------+-----------------+

基础统计分析

读取数据集

df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')

基础的describe

df_described = df.describe()
df_described.show()

image-20210714135920886

加载峰度偏度等等

from pyspark.sql.functions import skewness, kurtosis
from pyspark.sql.functions import var_pop, var_samp, stddev, stddev_pop, sumDistinct, ntile
df.select(skewness('_c3')).show()

构建一个函数,来生成峰度偏度的统计表,用来和基础的describe表union

from pyspark.sql import Row

columns = df_described.columns  # 列名list: ['summary', '_c0', '_c3', '_c4', '_c5', '_c6']
funcs   = [skewness, kurtosis]  # function的列表
fnames  = ['skew', 'kurtosis']  # function名字,用来显示

def new_item(func, column):
    """
    获取一个aggregation function和一个列名,
    之后对这个列执行aggregation,
    为了与describe的输出匹配,
    返回的是一个string而不是数字
    """
    return str(df.select(func(column)).collect()[0][0])

new_data = []
for func, fname in zip(funcs, fnames):
    row_dict = {'summary':fname}  # 每行以summary开头,确定是什么函数
    for column in columns[1:]:
        row_dict[column] = new_item(func, column)
    new_data.append(Row(**row_dict))  
    
print(new_data)
[Row(summary='skew', _c0='-0.00183847089866041', _c2='None', _c3='0.5197993394959906', _c4='0.7584115767562998', _c5='0.2864801560838491', _c6='-2.6976520156650614'), Row(summary='kurtosis', _c0='-1.1990072635132925', _c2='None', _c3='0.1260577268465326', _c4='0.5760856026559504', _c5='0.1951877800894679', _c6='24.723785894417404')]

可以看到,格式和基础describe表一样

df_described.collect()
[Row(summary='count', _c0='3526154', _c2='382039', _c3='3526154', _c4='1580402', _c5='3526154', _c6='3526154'),
 Row(summary='mean', _c0='5.503885995001908E11', _c2=None, _c3='4.178168090219519', _c4='234846.78065481762', _c5='5.134865351881966', _c6='354.7084951479714'),
 Row(summary='stddev', _c0='2.596112361975215E11', _c2=None, _c3='0.34382335723646673', _c4='118170.6859226166', _c5='3.3833930336063465', _c6='4.01181251079202'),
 Row(summary='min', _c0='100002091588', _c2='CITIMORTGAGE, INC.', _c3='2.75', _c4='0.85', _c5='-1', _c6='292'),
 Row(summary='max', _c0='999995696635', _c2='WELLS FARGO BANK, N.A.', _c3='6.125', _c4='1193544.39', _c5='34', _c6='480')]

union到一起

new_describe = sc.parallelize(new_data).toDF()           #turns the results from our loop into a dataframe
new_describe = new_describe.select(df_described.columns) #forces the columns into the same order

expanded_describe = df_described.unionAll(new_describe)  #merges the new stats with the original describe
expanded_describe.show()

可视化

首先读入数据集

import pandas as pd
import matplotlib.pyplot as plt

spark_df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv', 
                    inferSchema=True, header=True, sep=',', nullValue='')
spark_df.show(5)
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
|carat|    cut|color|clarity|depth|table|price|   x|   y|   z|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
| 0.23|  Ideal|    E|    SI2| 61.5| 55.0|  326|3.95|3.98|2.43|
| 0.21|Premium|    E|    SI1| 59.8| 61.0|  326|3.89|3.84|2.31|
| 0.23|   Good|    E|    VS1| 56.9| 65.0|  327|4.05|4.07|2.31|
| 0.29|Premium|    I|    VS2| 62.4| 58.0|  334| 4.2|4.23|2.63|
| 0.31|   Good|    J|    SI2| 63.3| 58.0|  335|4.34|4.35|2.75|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
only showing top 5 rows
spark_df.dtypes
[('carat', 'double'),
 ('cut', 'string'),
 ('color', 'string'),
 ('clarity', 'string'),
 ('depth', 'double'),
 ('table', 'double'),
 ('price', 'int'),
 ('x', 'double'),
 ('y', 'double'),
 ('z', 'double')]
spark_df.describe(['carat', 'depth', 'table', 'price']).show()
+-------+------------------+------------------+------------------+------------------+
|summary|             carat|             depth|             table|             price|
+-------+------------------+------------------+------------------+------------------+
|  count|             53940|             53940|             53940|             53919|
|   mean|0.7979397478679852| 61.74940489432624| 57.45718390804603|3933.3421799365715|
| stddev|0.4740112444054196|1.4326213188336525|2.2344905628213247| 3990.022722699714|
|    min|               0.2|              43.0|              43.0|               326|
|    max|              5.01|              79.0|              95.0|             18823|
+-------+------------------+------------------+------------------+------------------+

collect一下数据集

carat = spark_df[['carat']].collect()
price = spark_df[['price']].collect()
print(carat[:5])
print(price[:5])
[Row(carat=0.23), Row(carat=0.21), Row(carat=0.23), Row(carat=0.29), Row(carat=0.31)]
[Row(price=326), Row(price=326), Row(price=327), Row(price=334), Row(price=335)]

直接可视化

fig = plt.figure()
ax = fig.add_subplot(1,1,1)

ax.plot(carat, price, 'go', alpha=0.1)
ax.set_xlabel('Carat')
ax.set_ylabel('Price')
ax.set_title('Diamonds')

image-20210714140610640

也可以把DataFrame转换为Pandas的DataFrame

pandas_df = spark_df.toPandas()
pandas_df.describe()

image-20210714140636940

线性回归GLM

首先加载数据集,并选择其中几个列,我们需要预测的是log(price)

from pyspark.sql.functions import log

df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv', 
                    inferSchema=True, header=True, sep=',', nullValue='')
df = df[['carat', 'clarity', 'price']]
df = df.withColumn('lprice', log('price'))

df.show(5)
+-----+-------+-----+------------------+
|carat|clarity|price|            lprice|
+-----+-------+-----+------------------+
| 0.23|    SI2|  326| 5.786897381366708|
| 0.21|    SI1|  326| 5.786897381366708|
| 0.23|    VS1|  327|5.7899601708972535|
| 0.29|    VS2|  334|   5.8111409929767|
| 0.31|    SI2|  335| 5.814130531825066|
+-----+-------+-----+------------------+
only showing top 5 rows

这里有个辅助函数,用来生成feature,就不详细讲了

  
"""
Program written by Jeff Levy (jlevy@urban.org) for the Urban Institute, last revised 8/24/2016.
Note that this is intended as a temporary work-around until pySpark improves its ML package.
Tested in pySpark 2.0.
"""

def build_indep_vars(df, independent_vars, categorical_vars=None, keep_intermediate=False, summarizer=True):

    """
    Data verification
    df               : DataFrame
    independent_vars : List of column names
    categorical_vars : None or list of column names, e.g. ['col1', 'col2']
    """
    assert(type(df) is pyspark.sql.dataframe.DataFrame), 'pypark_glm: A pySpark dataframe is required as the first argument.'
    assert(type(independent_vars) is list), 'pyspark_glm: List of independent variable column names must be the third argument.'
    for iv in independent_vars:
        assert(type(iv) is str), 'pyspark_glm: Independent variables must be column name strings.'
        assert(iv in df.columns), 'pyspark_glm: Independent variable name is not a dataframe column.'
    if categorical_vars:
        for cv in categorical_vars:
            assert(type(cv) is str), 'pyspark_glm: Categorical variables must be column name strings.'
            assert(cv in df.columns), 'pyspark_glm: Categorical variable name is not a dataframe column.'
            assert(cv in independent_vars), 'pyspark_glm: Categorical variables must be independent variables.'

    """
    Code
    """
    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    from pyspark.ml.regression import GeneralizedLinearRegression

    if categorical_vars:
        string_indexer = [StringIndexer(inputCol=x, 
                                        outputCol='{}_index'.format(x))
                          for x in categorical_vars]

        encoder        = [OneHotEncoder(dropLast=True, 
                                        inputCol ='{}_index' .format(x), 
                                        outputCol='{}_vector'.format(x))
                          for x in categorical_vars]

        independent_vars = ['{}_vector'.format(x) if x in categorical_vars else x for x in independent_vars]
    else:
        string_indexer, encoder = [], []

    assembler = VectorAssembler(inputCols=independent_vars, 
                                outputCol='indep_vars')
    pipeline  = Pipeline(stages=string_indexer+encoder+[assembler])
    model = pipeline.fit(df)
    df = model.transform(df)

    #for building the crosswalk between indicies and column names
    if summarizer:
        param_crosswalk = {}

        i = 0
        for x in independent_vars:
            if '_vector' in x[-7:]:
                xrs = x.rstrip('_vector')
                dst = df[[xrs, '{}_index'.format(xrs)]].distinct().collect()

                for row in dst:
                    param_crosswalk[int(row['{}_index'.format(xrs)]+i)] = row[xrs]
                maxind = max(param_crosswalk.keys())
                del param_crosswalk[maxind] #for droplast
                i += len(dst)
            elif '_index' in x[:-6]:
                pass
            else:
                param_crosswalk[i] = x
                i += 1
        """
        {0: 'carat',
         1: u'SI1',
         2: u'VS2',
         3: u'SI2',
         4: u'VS1',
         5: u'VVS2',
         6: u'VVS1',
         7: u'IF'}
        """
        make_summary = Summarizer(param_crosswalk)


    if not keep_intermediate:
        fcols = [c for c in df.columns if '_index' not in c[-6:] and '_vector' not in c[-7:]]
        df = df[fcols]

    if summarizer:
        return df, make_summary
    else:
        return df

class Summarizer(object):
    def __init__(self, param_crosswalk):
        self.param_crosswalk = param_crosswalk
        self.precision = 4
        self.screen_width = 57
        self.hsep = '-'
        self.vsep = '|'

    def summarize(self, model, show=True, return_str=False):
        coefs = list(model.coefficients)
        inter = model.intercept
        tstat = model.summary.tValues
        stder = model.summary.coefficientStandardErrors
        pvals = model.summary.pValues

        #if model includes an intercept:
        if len(coefs) == len(tstat)-1:
            coefs.insert(0, inter)
            x = {0:'intercept'}
            for k, v in self.param_crosswalk.items():
                x[k+1] = v
        else:
            x = self.param_crosswalk

        assert(len(coefs) == len(tstat) == len(stder) == len(pvals))

        p = self.precision
        h = self.hsep
        v = self.vsep
        w = self.screen_width

        coefs = [str(round(e, p)).center(10) for e in coefs]
        tstat = [str(round(e, p)).center(10) for e in tstat]
        stder = [str(round(e, p)).center(10) for e in stder]
        pvals = [str(round(e, p)).center(10) for e in pvals]

        lines = ''
        for i in range(len(coefs)):
            lines += str(x[i]).rjust(15) + v + coefs[i] + stder[i] + tstat[i] + pvals[i] + '\n'

        labels = ''.rjust(15) + v + 'Coef'.center(10) + 'Std Err'.center(10) + 'T Stat'.center(10) + 'P Val'.center(10)
        pad    = ''.rjust(15) + v

        output = """{hline}\n{labels}\n{hline}\n{lines}{hline}""".format(
                    hline=h*w, 
                    labels=labels,
                    lines=lines)
        if show:
            print(output)
        if return_str:
            return output

下面生成一下

df, summarizer = build_indep_vars(df, 
                                  ['carat', 'clarity'], 
                                  categorical_vars=['clarity'], 
                                  keep_intermediate=False, 
                                  summarizer=True)
df.show(5)
+-----+-------+-----+------------------+--------------------+
|carat|clarity|price|            lprice|          indep_vars|
+-----+-------+-----+------------------+--------------------+
| 0.23|    SI2|  326| 5.786897381366708|(8,[0,3],[0.23,1.0])|
| 0.21|    SI1|  326| 5.786897381366708|(8,[0,1],[0.21,1.0])|
| 0.23|    VS1|  327|5.7899601708972535|(8,[0,4],[0.23,1.0])|
| 0.29|    VS2|  334|   5.8111409929767|(8,[0,2],[0.29,1.0])|
| 0.31|    SI2|  335| 5.814130531825066|(8,[0,3],[0.31,1.0])|
+-----+-------+-----+------------------+--------------------+
only showing top 5 rows

这里先删除有空值的列

df.where( df['lprice'].isNull()) .count()
21
df_drops = df.dropna(how='all', subset=['lprice'])

模型fit

from pyspark.ml.regression import GeneralizedLinearRegression

glm = GeneralizedLinearRegression(family='gaussian', 
                                  link='identity', 
                                  labelCol='lprice', 
                                  featuresCol='indep_vars', 
                                  fitIntercept=True)

model = glm.fit(df_drops)

查看模型结果

model.coefficients
DenseVector([2.0808, 0.7228, 0.818, 0.5691, 0.8564, 0.9349, 0.9203, 0.9988])
model.intercept
5.3553700003910265
model.summary.tValues
[573.770045512725,
 51.26011561154599,
 57.75466173814217,
 40.10391980395556,
 59.5623952901403,
 63.14279213211864,
 60.48868627589056,
 60.711257997408495,
 371.7558057494966]
model.summary.coefficientStandardErrors
[0.0036265388065588383,
 0.014100833468172893,
 0.014163700186194172,
 0.01418970865492269,
 0.014378787192379332,
 0.014805773323652266,
 0.015214235864428029,
 0.016451002067601285,
 0.014405612279797675]
summarizer.summarize(model)
---------------------------------------------------------
               |   Coef    Std Err    T Stat    P Val   
---------------------------------------------------------
      intercept|  5.3554    0.0036    573.77     0.0    
          carat|  2.0808    0.0141   51.2601     0.0    
            SI1|  0.7228    0.0142   57.7547     0.0    
            VS2|  0.818     0.0142   40.1039     0.0    
            SI2|  0.5691    0.0144   59.5624     0.0    
            VS1|  0.8564    0.0148   63.1428     0.0    
           VVS2|  0.9349    0.0152   60.4887     0.0    
           VVS1|  0.9203    0.0165   60.7113     0.0    
             IF|  0.9988    0.0144   371.7558    0.0    
---------------------------------------------------------
summarizer.param_crosswalk
{0: 'carat',
 1: 'SI1',
 2: 'VS2',
 6: 'VVS1',
 5: 'VVS2',
 7: 'IF',
 4: 'VS1',
 3: 'SI2'}
最后修改:2021 年 07 月 14 日 02 : 22 PM
如果觉得我的文章对你有用,请随意赞赏