WordCount
WordCount基础
首先来看一个WordCount的例子
假设有个words.txt
文件,内容如下
hello world
hello friend
hello my friend
hello
要用spark统计一下每个单词的数量
那么可以分为如下步骤:
- 把每行的单词全部拿出来,拉平成为一个数组
- 利用
map
方法,把每个单词转换为key-value
的格式,例如[(hello, 1), (world, 1), (hello, 1)...]
- 利用
reduce
方法,根据key
单词,把value
加起来
下面是操作步骤:
- 首先,用
flatMap
分割文件,flatMap
的作用是把整个数组拉平,返回的是一个向量
words = lines.flatMap(lambda x: x.split())
words.collect()
['hello', 'world', 'hello', 'friend', 'hello', 'my', 'friend', 'hello']
为什么不使用
map
呢?假如使用的是map
,得到的结果是什么样?words = lines.map(lambda x: x.split()) words.collect()
[['hello', 'world'], ['hello', 'friend'], ['hello', 'my', 'friend'], ['hello']]
- 在拉平数组之后,可以先把单词转换为
key-value
对的形式
words.map(lambda x: (x, 1)).collect()
[('hello', 1),
('world', 1),
('hello', 1),
('friend', 1),
('hello', 1),
('my', 1),
('friend', 1),
('hello', 1)]
- 根据
key-value
对做reduce,利用lambda x,y:x+y
把每个value
加起来
words.map(lambda x: (x, 1)).reduceByKey(lambda x,y:x+y).collect()
[('world', 1), ('hello', 4), ('friend', 2), ('my', 1)]
最终的完整操作如下
import findspark
findspark.init()
import pyspark
sc = pyspark.SparkContext(appName="WordCount")
lines = sc.textFile('file:///home/ubuntu/words.txt')
words = lines.flatMap(lambda x: x.split())
wordcounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x,y:x+y)
wordcounts.collect()
WordCount排序
假如我们要把WordCount的结果,根据数量排序,那么如何做呢?可以使用sortByKey
来完成,但需要把key
和value
调换一下,让数量变成key,才可以根据数量排序。
这是上一步得到的结果
wordcounts.collect()
[('world', 1), ('hello', 4), ('friend', 2), ('my', 1)]
那么需要先做一下key-value
的转换
wordcounts.map(lambda x:(x[1],x[0])).collect()
[(1, 'world'), (1, 'my'), (4, 'hello'), (2, 'friend')]
转换后,就可以排序了,其中False
的意思是倒序
wordcounts = sc.textFile('file:///home/ubuntu/words.txt') \
.flatMap(lambda x: x.split()) \
.map(lambda x: (x, 1)) \
.reduceByKey(lambda x,y:x+y) \
.map(lambda x:(x[1],x[0])) \
.sortByKey(False)
[(4, 'hello'), (2, 'friend'), (1, 'world'), (1, 'my')]
查看RDD分区
Spark提供高效的并行计算,在上面这个例子里,Spark会自动进行分区。可以用下面代码查看每个RDD包含的object数量
def countPartitions(id,iterator):
c = 0
for _ in iterator:
c += 1
yield (id,c)
lines.mapPartitionsWithSplit(countPartitions).collectAsMap()
查看spark默认的并行数量
sc.defaultParallelism
8
spark会把一个RDD复制到多个worker节点上,所以一个节点上的任务挂了,不会影响到其他节点上的任务,只需要在其他节点上重启任务就OK了。
wordcounts
PySpark读取csv
首先下载用来读取的txt文件
! wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/datasets/Performance_2015Q1.txt .
之后加载环境变量SparkContext
和SparkSession
import findspark
findspark.init()
import pyspark
import random
sc = pyspark.SparkContext(appName="WordCount")
spark = pyspark.sql.session.SparkSession(sc)
用SparkSession读csv
inferSchema
如果为True,就不会一下加载整个数据集
df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')
获取前3条数据
df.take(3)
[Row(_c0=100002091588, _c1='01/01/2015', _c2='OTHER', _c3=4.125, _c4=None, _c5=0, _c6=360, _c7=360, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None),
Row(_c0=100002091588, _c1='02/01/2015', _c2=None, _c3=4.125, _c4=None, _c5=1, _c6=359, _c7=359, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None),
Row(_c0=100002091588, _c1='03/01/2015', _c2=None, _c3=4.125, _c4=None, _c5=2, _c6=358, _c7=358, _c8='01/2045', _c9=16740, _c10='0', _c11='N', _c12=None, _c13=None, _c14=None, _c15=None, _c16=None, _c17=None, _c18=None, _c19=None, _c20=None, _c21=None, _c22=None, _c23=None, _c24=None, _c25=None, _c26=None, _c27=None)]
另一种读取方法,可以指定需要的列
from pyspark.sql.types import DateType, TimestampType, IntegerType, FloatType, LongType, DoubleType, StringType
from pyspark.sql.types import StructType, StructField
custom_schema = StructType([StructField('_c0', DateType(), True),
StructField('_c1', StringType(), True),
StructField('_c2', DoubleType(), True),
StructField('_c3', DoubleType(), True)])
df_part = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, schema=custom_schema, sep='|')
df_part.take(3)
[Row(_c0=None, _c1='01/01/2015', _c2=None, _c3=4.125),
Row(_c0=None, _c1='02/01/2015', _c2=None, _c3=4.125),
Row(_c0=None, _c1='03/01/2015', _c2=None, _c3=4.125)]
计算数据的数量
df.count()
3526154
列的选择与重命名
可以使用df.select('col1', 'col2', 'col3')
来选择
df_lim = df.select('_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13')
df_lim.take(1)
[Row(_c0=100002091588, _c1=u'01/01/2015', _c2=u'OTHER', _c3=4.125, _c4=None, _c5=0, _c6=360, _c7=360, _c8=u'01/2045', _c9=16740, _c10=u'0', _c11=u'N', _c12=None, _c13=None)]
列的重命名
df_lim = df_lim.withColumnRenamed('_c0','loan_id').withColumnRenamed('_c1','period')
df_lim
DataFrame[loan_id: bigint, period: string, _c2: string, _c3: double, _c4: double, _c5: int, _c6: int, _c7: int, _c8: string, _c9: int, _c10: string, _c11: string, _c12: int, _c13: string]
使用for循环来批量改名
old_names = ['_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13']
new_names = ['servicer_name', 'new_int_rt', 'act_endg_upb', 'loan_age', 'mths_remng', 'aj_mths_remng', 'dt_matr', 'cd_msa', 'delq_sts', 'flag_mod', 'cd_zero_bal', 'dt_zero_bal']
for old, new in zip(old_names, new_names):
df_lim = df_lim.withColumnRenamed(old, new)
df_lim
DataFrame[loan_id: bigint, period: string, servicer_name: string, new_int_rt: double, act_endg_upb: double, loan_age: int, mths_remng: int, aj_mths_remng: int, dt_matr: string, cd_msa: int, delq_sts: string, flag_mod: string, cd_zero_bal: int, dt_zero_bal: string]
获取一行试试
df_lim.take(1)
[Row(loan_id=100002091588, period='01/01/2015', servicer_name='OTHER', new_int_rt=4.125, act_endg_upb=None, loan_age=0, mths_remng=360, aj_mths_remng=360, dt_matr='01/2045', cd_msa=16740, delq_sts='0', flag_mod='N', cd_zero_bal=None, dt_zero_bal=None)]
describe
describe
和pandas
的命令类似,如果不输入列名,就会显示全部列的统计(如果列太多,会比较混乱)
df_described = df_lim.describe('servicer_name', 'new_int_rt', 'loan_age')
df_described.show()
+-------+--------------------+-------------------+------------------+
|summary| servicer_name| new_int_rt| loan_age|
+-------+--------------------+-------------------+------------------+
| count| 382039| 3526154| 3526154|
| mean| null| 4.178168090219519| 5.134865351881966|
| stddev| null|0.34382335723646673|3.3833930336063465|
| min| CITIMORTGAGE, INC.| 2.75| -1|
| max|WELLS FARGO BANK,...| 6.125| 34|
+-------+--------------------+-------------------+------------------+
保存文件
df_described.write.format('com.databricks.spark.csv').option("header","true").save('file:///home/ubuntu/result.csv')
基础操作
添加列
我们只需要部分列
df_lim = df.select('_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13')
old_names = ['_c0','_c1','_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13']
new_names = ['loan_id','period','servicer_name', 'new_int_rt', 'act_endg_upb', 'loan_age', 'mths_remng', 'aj_mths_remng', 'dt_matr', 'cd_msa', 'delq_sts', 'flag_mod', 'cd_zero_bal', 'dt_zero_bal']
for old, new in zip(old_names, new_names):
df_lim = df_lim.withColumnRenamed(old, new)
可以使用withColumn(新列名,列的值)
来添加一列
df_lim = df_lim.withColumn('loan_length', df_lim['loan_age'] + df_lim['mths_remng'])
group by
## 根据servicer name来group by:
df_grp = df_lim.groupBy('servicer_name')
## 计算每个servicer的loan age, months remaining 和 loan length 的平均值:
df_avg = df_grp.avg('loan_age', 'mths_remng', 'loan_length')
展示一下,show
会花较多的时间,因为spark有一个特性叫做lazy operation
,也就是transformation
操作不会计算,知道出现action
操作时,才会一起计算。
df_avg.show()
+--------------------+--------------------+------------------+------------------+
| servicer_name| avg(loan_age)| avg(mths_remng)| avg(loan_length)|
+--------------------+--------------------+------------------+------------------+
| QUICKEN LOANS INC.|-0.08899247348614438| 358.5689787889155|358.47998631542936|
|NATIONSTAR MORTGA...| 0.39047125841532887| 359.5821853961678| 359.9726566545831|
| null| 5.6264681794400015|354.21486809483747| 359.8413362742775|
|WELLS FARGO BANK,...| 0.6704475572258285|359.25937820293814|359.92982576016396|
|FANNIE MAE/SETERU...| 9.333333333333334| 350.6666666666667| 360.0|
|DITECH FINANCIAL LLC| 5.147629653197582| 354.7811008590519|359.92873051224944|
|SENECA MORTGAGE S...| -0.2048814025438295|360.20075627363354| 359.9958748710897|
|SUNTRUST MORTGAGE...| 0.8241234756097561| 359.1453887195122| 359.969512195122|
|ROUNDPOINT MORTGA...| 5.153408024034549| 354.8269387244163|359.98034674845087|
| PENNYMAC CORP.| 0.14966740576496673| 359.8470066518847|359.99667405764967|
|PHH MORTGAGE CORP...| 0.9780420860018298|359.02195791399816| 360.0|
|MATRIX FINANCIAL ...| 6.566794707639778| 353.4229620145113|359.98975672215107|
| OTHER| 0.11480465916297489| 359.8345750772193|359.94937973638224|
| CITIMORTGAGE, INC.| 0.338498789346247|359.41670702179175| 359.755205811138|
|PINGORA LOAN SERV...| 7.573573382530696|352.40886824861633| 359.982441631147|
|JP MORGAN CHASE B...| 1.6553418987669224| 358.3384495990342|359.99379149780117|
| PNC BANK, N.A.| 1.1707779886148009|358.78747628083494| 359.9582542694497|
|FREEDOM MORTGAGE ...| 8.56265812109968|351.29583403609377|359.85849215719344|
+--------------------+--------------------+------------------+------------------+
可以使用persist
来保存中间结果,告诉spark后面我还会重新用到这个步骤
df_keep = df_lim.withColumn('loan_length', df_lim['loan_age'] + df_lim['mths_remng'])
df_keep.persist()
df_grp = df_keep.groupBy('servicer_name')
df_avg = df_grp.avg('loan_age', 'mths_remng', 'loan_length')
这是第一次计算的速度
%%time
df_avg.show()
+--------------------+--------------------+------------------+------------------+
| servicer_name| avg(loan_age)| avg(mths_remng)| avg(loan_length)|
+--------------------+--------------------+------------------+------------------+
| QUICKEN LOANS INC.|-0.08899247348614438| 358.5689787889155|358.47998631542936|
|NATIONSTAR MORTGA...| 0.39047125841532887| 359.5821853961678| 359.9726566545831|
| null| 5.6264681794400015|354.21486809483747| 359.8413362742775|
|WELLS FARGO BANK,...| 0.6704475572258285|359.25937820293814|359.92982576016396|
|FANNIE MAE/SETERU...| 9.333333333333334| 350.6666666666667| 360.0|
|DITECH FINANCIAL LLC| 5.147629653197582| 354.7811008590519|359.92873051224944|
|SENECA MORTGAGE S...| -0.2048814025438295|360.20075627363354| 359.9958748710897|
|SUNTRUST MORTGAGE...| 0.8241234756097561| 359.1453887195122| 359.969512195122|
|ROUNDPOINT MORTGA...| 5.153408024034549| 354.8269387244163|359.98034674845087|
| PENNYMAC CORP.| 0.14966740576496673| 359.8470066518847|359.99667405764967|
|PHH MORTGAGE CORP...| 0.9780420860018298|359.02195791399816| 360.0|
|MATRIX FINANCIAL ...| 6.566794707639778| 353.4229620145113|359.98975672215107|
| OTHER| 0.11480465916297489| 359.8345750772193|359.94937973638224|
| CITIMORTGAGE, INC.| 0.338498789346247|359.41670702179175| 359.755205811138|
|PINGORA LOAN SERV...| 7.573573382530696|352.40886824861633| 359.982441631147|
|JP MORGAN CHASE B...| 1.6553418987669224| 358.3384495990342|359.99379149780117|
| PNC BANK, N.A.| 1.1707779886148009|358.78747628083494| 359.9582542694497|
|FREEDOM MORTGAGE ...| 8.56265812109968|351.29583403609377|359.85849215719344|
+--------------------+--------------------+------------------+------------------+
CPU times: user 2.67 ms, sys: 0 ns, total: 2.67 ms
Wall time: 5.3 s
再算一次,可以发现速度变快了很多倍,因为利用了persist
保存了中间步骤
%%time
df_sum = df_grp.sum('new_int_rt', 'loan_age', 'mths_remng', 'cd_zero_bal', 'loan_length')
df_sum.show()
+--------------------+--------------------+-------------+---------------+----------------+----------------+
| servicer_name| sum(new_int_rt)|sum(loan_age)|sum(mths_remng)|sum(cd_zero_bal)|sum(loan_length)|
+--------------------+--------------------+-------------+---------------+----------------+----------------+
| QUICKEN LOANS INC.| 101801.764999999| -2081| 8384777| null| 8382696|
|NATIONSTAR MORTGA...| 40287.497999999956| 3770| 3471766| 2| 3475536|
| null|1.3139130894999936E7| 17690263| 1113692280| 16932| 1131382543|
|WELLS FARGO BANK,...| 187326.36500000005| 29436| 15773283| null| 15802719|
|FANNIE MAE/SETERU...| 26.6| 56| 2104| null| 2160|
|DITECH FINANCIAL LLC| 39531.709999999934| 48537| 3345231| 41| 3393768|
|SENECA MORTGAGE S...| 24093.559999999987| -1192| 2095648| null| 2094456|
|SUNTRUST MORTGAGE...| 21530.767999999953| 4325| 1884795| null| 1889120|
|ROUNDPOINT MORTGA...| 67708.25999999992| 82336| 5669070| 74| 5751406|
| PENNYMAC CORP.| 15209.140000000001| 540| 1298328| null| 1298868|
|PHH MORTGAGE CORP...| 9086.066| 2138| 784822| null| 786960|
|MATRIX FINANCIAL ...| 19212.932999999997| 30772| 1656140| 16| 1686912|
| OTHER| 904855.0440000098| 25163| 78868902| 21| 78894065|
| CITIMORTGAGE, INC.| 16939.33| 1398| 1484391| null| 1485789|
|PINGORA LOAN SERV...| 64224.70499999986| 119049| 5539515| 111| 5658564|
|JP MORGAN CHASE B...| 50187.15499999998| 19197| 4155651| null| 4174848|
| PNC BANK, N.A.| 6911.724999999999| 1851| 567243| 1| 569094|
|FREEDOM MORTGAGE ...| 24800.604999999992| 50768| 2082833| 60| 2133601|
+--------------------+--------------------+-------------+---------------+----------------+----------------+
CPU times: user 1.48 ms, sys: 930 µs, total: 2.41 ms
Wall time: 882 ms
merging操作
首先还是初始化一下
import findspark
findspark.init()
import pyspark
import random
sc = pyspark.SparkContext(appName="Test")
spark = pyspark.sql.session.SparkSession(sc)
union
构建DataFrame,语法如下:
from pyspark.sql import Row
row = Row("name", "pet", "count")
df1 = sc.parallelize([
row("Sue", "cat", 16),
row("Kim", "dog", 1),
row("Bob", "fish", 5)
]).toDF()
df2 = sc.parallelize([
row("Fred", "cat", 2),
row("Kate", "ant", 179),
row("Marc", "lizard", 5)
]).toDF()
df3 = sc.parallelize([
row("Sarah", "shark", 3),
row("Jason", "kids", 2),
row("Scott", "squirrel", 1)
]).toDF()
直接union两个表
df_union = df1.unionAll(df2)
df_union.show()
+----+------+-----+
|name| pet|count|
+----+------+-----+
| Sue| cat| 16|
| Kim| dog| 1|
| Bob| fish| 5|
|Fred| cat| 2|
|Kate| ant| 179|
|Marc|lizard| 5|
+----+------+-----+
union多个表
from pyspark.sql import DataFrame
from functools import reduce
def union_many(*dfs):
return reduce(DataFrame.unionAll, dfs)
df_union = union_many(df1, df2, df3)
df_union.show()
+-----+--------+-----+
| name| pet|count|
+-----+--------+-----+
| Sue| cat| 16|
| Kim| dog| 1|
| Bob| fish| 5|
| Fred| cat| 2|
| Kate| ant| 179|
| Marc| lizard| 5|
|Sarah| shark| 3|
|Jason| kids| 2|
|Scott|squirrel| 1|
join
重新建立两个表
row1 = Row("name", "pet", "count")
row2 = Row("name", "pet2", "count2")
df1 = sc.parallelize([
row1("Sue", "cat", 16),
row1("Kim", "dog", 1),
row1("Bob", "fish", 5),
row1("Libuse", "horse", 1)
]).toDF()
df2 = sc.parallelize([
row2("Sue", "eagle", 2),
row2("Kim", "ant", 179),
row2("Bob", "lizard", 5),
row2("Ferdinand", "bees", 23)
]).toDF()
inner join
df1.join(df2, 'name', how='inner').show()
+----+----+-----+------+------+
|name| pet|count| pet2|count2|
+----+----+-----+------+------+
| Sue| cat| 16| eagle| 2|
| Bob|fish| 5|lizard| 5|
| Kim| dog| 1| ant| 179|
+----+----+-----+------+------+
outer join
df1.join(df2, 'name', how='outer').show()
+----+----+-----+------+------+
|name| pet|count| pet2|count2|
+----+----+-----+------+------+
| Sue| cat| 16| eagle| 2|
| Bob|fish| 5|lizard| 5|
| Kim| dog| 1| ant| 179|
+----+----+-----+------+------+
left join
df1.join(df2, 'name', how='left').show()
+------+-----+-----+------+------+
| name| pet|count| pet2|count2|
+------+-----+-----+------+------+
| Sue| cat| 16| eagle| 2|
| Bob| fish| 5|lizard| 5|
| Kim| dog| 1| ant| 179|
|Libuse|horse| 1| null| null|
+------+-----+-----+------+------+
missing data
先看一下全部数据集的样本数
df.count()
3526154
统计某列的缺失值数量
df.where( df['_c12'].isNull() ).count()
3510294
统计全部列的缺失值数量,spark不会给string类型插入Null值,所以可以跳过
def count_nulls(df):
null_counts = [] # 空数组,用来存结果
for col in df.dtypes: #迭代读取列的列表 ('C0', 'bigint')
cname = col[0] # 列名,
ctype = col[1] # 列的类型
if ctype != 'string': # 如果是string,就跳过
nulls = df.where( df[cname].isNull() ).count()
result = tuple([cname, nulls]) # (列名,缺失值数量)
null_counts.append(result) # 放到数组里
return null_counts
null_counts = count_nulls(df)
null_counts
[('_c0', 0),
('_c3', 0),
('_c4', 1945752),
('_c5', 0),
('_c6', 0),
('_c7', 1),
('_c9', 0),
('_c12', 3510294),
('_c26', 3526153)]
删除缺失值drop
删除缺失值可以选择all
,那么只有在全部列都缺失的时候才删除
df_drops = df.dropna(how='all', subset=['_c4', '_c12', '_c26'])
df_drops.count()
1580403
也可以选择一个阈值,代表至少要有多少个列非空,才可以不被删除,假设下面就是至少有2个列非空才行(如果thresh=1
,那么等同于how='all'
)
df_drops2 = df.dropna(thresh=2, subset=['_c4', '_c12', '_c26'])
df_drops2.count()
15860
填补缺失值
df_fill = df.fillna(0, subset=['_c12'])
查看是否还有缺失值
df_fill.where( df_fill['_c12'].isNull() ).count()
0
moving average填补缺失值
首先下载一个新数据集
! wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/datasets/diamonds_nulls.csv
读取
df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv',
inferSchema=True, header=True, sep=',', nullValue='')
df.show()
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat| cut|color|clarity|depth|table|price| x| y| z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.23| Ideal| E| SI2| 61.5| 55.0| 326|3.95|3.98|2.43|
| 0.21| Premium| E| SI1| 59.8| 61.0| 326|3.89|3.84|2.31|
| 0.23| Good| E| VS1| 56.9| 65.0| 327|4.05|4.07|2.31|
| 0.29| Premium| I| VS2| 62.4| 58.0| 334| 4.2|4.23|2.63|
| 0.31| Good| J| SI2| 63.3| 58.0| 335|4.34|4.35|2.75|
| 0.24|Very Good| J| VVS2| 62.8| 57.0| 336|3.94|3.96|2.48|
| 0.24|Very Good| I| VVS1| 62.3| 57.0| 336|3.95|3.98|2.47|
| 0.26|Very Good| H| SI1| 61.9| 55.0| 337|4.07|4.11|2.53|
| 0.22| Fair| E| VS2| 65.1| 61.0| 337|3.87|3.78|2.49|
| 0.23|Very Good| H| VS1| 59.4| 61.0| 338| 4.0|4.05|2.39|
| 0.3| Good| J| SI1| 64.0| 55.0| 339|4.25|4.28|2.73|
| 0.23| Ideal| J| VS1| 62.8| 56.0| 340|3.93| 3.9|2.46|
| 0.22| Premium| F| SI1| 60.4| 61.0| 342|3.88|3.84|2.33|
| 0.31| Ideal| J| SI2| 62.2| 54.0| 344|4.35|4.37|2.71|
| 0.2| Premium| E| SI2| 60.2| 62.0| 345|3.79|3.75|2.27|
| 0.32| Premium| E| I1| 60.9| 58.0| 345|4.38|4.42|2.68|
| 0.3| Ideal| I| SI2| 62.0| 54.0| 348|4.31|4.34|2.68|
| 0.3| Good| J| SI1| 63.4| 54.0| 351|4.23|4.29| 2.7|
| 0.3| Good| J| SI1| 63.8| 56.0| 351|4.23|4.26|2.71|
| 0.3|Very Good| J| SI1| 62.7| 59.0| 351|4.21|4.27|2.66|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
only showing top 20 rows
看一下price
缺失的样本,选前50个
df.where(df['price'].isNull()).show(50)
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
|carat| cut|color|clarity|depth|table|price| x| y| z|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
| 0.24| Premium| I| VS1| 62.5| 57.0| null|3.97|3.94|2.47|
| 0.35| Ideal| I| VS1| 60.9| 57.0| null|4.54|4.59|2.78|
| 0.7| Good| F| VS1| 59.4| 62.0| null|5.71|5.76| 3.4|
| 0.7| Fair| F| VS2| 64.5| 57.0| null|5.57|5.53|3.58|
| 0.7| Premium| E| SI1| 61.2| 57.0| null|5.73|5.68|3.49|
| 0.73| Premium| F| VS2| 62.5| 57.0| null|5.75| 5.7|3.58|
| 1.01| Ideal| F| SI1| 62.7| 55.0| null|6.45| 6.4|4.03|
| 1.03| Ideal| H| SI1| 61.1| 56.0| null| 6.5|6.53|3.98|
| 1.28| Ideal| I| SI2| 61.7| 59.0| null|6.96|6.92|4.28|
| 0.37| Premium| D| SI1| 60.4| 59.0| null|4.68|4.62|2.81|
| 0.5| Ideal| J| VS2| 61.7| 57.0| null|5.09|5.12|3.15|
| 0.34| Ideal| E| VS1| 61.2| 55.0| null|4.52|4.56|2.77|
| 0.52| Ideal| D| VS2| 61.8| 55.0| null|5.19|5.23|3.22|
| 0.71|Very Good| J| VVS2| 61.1| 58.0| null| 5.7|5.75| 3.5|
| 0.76| Premium| H| SI1| 59.8| 57.0| null|5.93|5.91|3.54|
| 0.58| Ideal| F| VS1| 60.3| 57.0| null|5.47|5.44|3.29|
| 0.7|Very Good| E| VS1| 63.4| 62.0| null|5.64|5.56|3.55|
| 0.92| Premium| D| I1| 63.0| 58.0| null|6.18|6.13|3.88|
| 0.88|Very Good| I| SI1| 62.5| 56.0| null|6.06|6.19|3.83|
| 0.7| Good| H| VVS2| 58.9| 61.5| null|5.77|5.84|3.42|
| 0.7|Very Good| D| SI1| 62.8| 60.0| null|5.66|5.68|3.56|
+-----+---------+-----+-------+-----+-----+-----+----+----+----+
首先建立一个窗口函数,类似SQL,有partition by
和order by
from pyspark.sql import Window
window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
根据这个窗口函数 计算price
的平均值
from pyspark.sql.functions import mean
moving_avg = mean(df['price']).over(window)
建立一个新列,moving average
df = df.withColumn('moving_avg', moving_avg)
df.show()
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
|carat| cut|color|clarity|depth|table|price| x| y| z| moving_avg|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
| 0.73|Premium| F| VS2| 62.5| 57.0| null|5.75| 5.7|3.58| 356.0|
| 0.29|Premium| I| VS2| 62.4| 58.0| 334| 4.2|4.23|2.63| 358.75|
| 0.2|Premium| E| VS2| 59.8| 62.0| 367|3.79|3.77|2.26| 360.4|
| 0.2|Premium| E| VS2| 59.0| 60.0| 367|3.81|3.78|2.24| 361.5|
| 0.2|Premium| E| VS2| 61.1| 59.0| 367|3.81|3.78|2.32| 362.2857142857143|
| 0.2|Premium| E| VS2| 59.7| 62.0| 367|3.84| 3.8|2.28| 367.0|
| 0.2|Premium| F| VS2| 62.6| 59.0| 367|3.73|3.71|2.33|367.14285714285717|
| 0.2|Premium| D| VS2| 62.3| 60.0| 367|3.73|3.68|2.31| 367.2857142857143|
| 0.2|Premium| D| VS2| 61.7| 60.0| 367|3.77|3.72|2.31|369.14285714285717|
| 0.3|Premium| J| VS2| 62.2| 58.0| 368|4.28| 4.3|2.67| 371.0|
| 0.3|Premium| J| VS2| 60.6| 59.0| 368|4.34|4.38|2.64| 373.7142857142857|
| 0.31|Premium| J| VS2| 62.5| 60.0| 380|4.31|4.36|2.71|376.42857142857144|
| 0.31|Premium| J| VS2| 62.4| 60.0| 380|4.29|4.33|2.69|379.14285714285717|
| 0.21|Premium| E| VS2| 60.5| 59.0| 386|3.87|3.83|2.33| 381.7142857142857|
| 0.21|Premium| E| VS2| 59.6| 56.0| 386|3.93|3.89|2.33| 384.2857142857143|
| 0.21|Premium| D| VS2| 61.6| 59.0| 386|3.82|3.78|2.34|385.14285714285717|
| 0.21|Premium| D| VS2| 60.6| 60.0| 386|3.85|3.81|2.32| 387.0|
| 0.21|Premium| D| VS2| 59.1| 62.0| 386|3.89|3.86|2.29| 388.0|
| 0.21|Premium| D| VS2| 58.3| 59.0| 386|3.96|3.93| 2.3|389.57142857142856|
| 0.32|Premium| J| VS2| 61.9| 58.0| 393|4.35|4.38| 2.7|392.14285714285717|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+
only showing top 20 rows
用如下方法,使用replace_null
来根据是列否为空,决定是使用原始值还是moving average
填充
from pyspark.sql.functions import when, col
def replace_null(orig, ma):
return when(orig.isNull(), ma).otherwise(orig)
df_new = df.withColumn('imputed',
replace_null(col('price'), col('moving_avg'))
)
展示一下,可以看到第一行还是空值,因为moving average在一开始无法计算
df_new.show()
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
|carat| cut|color|clarity|depth|table|price| x| y| z| moving_avg|imputed|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
| 0.73|Premium| F| VS2| 62.5| 57.0| null|5.75| 5.7|3.58| 356.0| 356.0|
| 0.29|Premium| I| VS2| 62.4| 58.0| 334| 4.2|4.23|2.63| 358.75| 334.0|
| 0.2|Premium| E| VS2| 59.8| 62.0| 367|3.79|3.77|2.26| 360.4| 367.0|
| 0.2|Premium| E| VS2| 59.0| 60.0| 367|3.81|3.78|2.24| 361.5| 367.0|
| 0.2|Premium| E| VS2| 61.1| 59.0| 367|3.81|3.78|2.32| 362.2857142857143| 367.0|
| 0.2|Premium| E| VS2| 59.7| 62.0| 367|3.84| 3.8|2.28| 367.0| 367.0|
| 0.2|Premium| F| VS2| 62.6| 59.0| 367|3.73|3.71|2.33|367.14285714285717| 367.0|
| 0.2|Premium| D| VS2| 62.3| 60.0| 367|3.73|3.68|2.31| 367.2857142857143| 367.0|
| 0.2|Premium| D| VS2| 61.7| 60.0| 367|3.77|3.72|2.31|369.14285714285717| 367.0|
| 0.3|Premium| J| VS2| 62.2| 58.0| 368|4.28| 4.3|2.67| 371.0| 368.0|
| 0.3|Premium| J| VS2| 60.6| 59.0| 368|4.34|4.38|2.64| 373.7142857142857| 368.0|
| 0.31|Premium| J| VS2| 62.5| 60.0| 380|4.31|4.36|2.71|376.42857142857144| 380.0|
| 0.31|Premium| J| VS2| 62.4| 60.0| 380|4.29|4.33|2.69|379.14285714285717| 380.0|
| 0.21|Premium| E| VS2| 60.5| 59.0| 386|3.87|3.83|2.33| 381.7142857142857| 386.0|
| 0.21|Premium| E| VS2| 59.6| 56.0| 386|3.93|3.89|2.33| 384.2857142857143| 386.0|
| 0.21|Premium| D| VS2| 61.6| 59.0| 386|3.82|3.78|2.34|385.14285714285717| 386.0|
| 0.21|Premium| D| VS2| 60.6| 60.0| 386|3.85|3.81|2.32| 387.0| 386.0|
| 0.21|Premium| D| VS2| 59.1| 62.0| 386|3.89|3.86|2.29| 388.0| 386.0|
| 0.21|Premium| D| VS2| 58.3| 59.0| 386|3.96|3.93| 2.3|389.57142857142856| 386.0|
| 0.32|Premium| J| VS2| 61.9| 58.0| 393|4.35|4.38| 2.7|392.14285714285717| 393.0|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+------------------+-------+
only showing top 20 rows
pivot table
建立一个表
from pyspark.sql import Row
row = Row('state', 'industry', 'hq', 'jobs')
df = sc.parallelize([
row('MI', 'auto', 'domestic', 716),
row('MI', 'auto', 'foreign', 123),
row('MI', 'auto', 'domestic', 1340),
row('MI', 'retail', 'foreign', 12),
row('MI', 'retail', 'foreign', 33),
row('OH', 'auto', 'domestic', 349),
row('OH', 'auto', 'foreign', 101),
row('OH', 'auto', 'foreign', 77),
row('OH', 'retail', 'domestic', 45),
row('OH', 'retail', 'foreign', 12)
]).toDF()
df.show()
+-----+--------+--------+----+
|state|industry| hq|jobs|
+-----+--------+--------+----+
| MI| auto|domestic| 716|
| MI| auto| foreign| 123|
| MI| auto|domestic|1340|
| MI| retail| foreign| 12|
| MI| retail| foreign| 33|
| OH| auto|domestic| 349|
| OH| auto| foreign| 101|
| OH| auto| foreign| 77|
| OH| retail|domestic| 45|
| OH| retail| foreign| 12|
+-----+--------+--------+----+
先根据state
分组(每个row为一个组),再根据hq
来作为列,值为jobs
的和
df_pivot1 = df.groupby('state').pivot('hq', values=['domestic', 'foreign']).sum('jobs')
df_pivot1.show()
+-----+--------+-------+
|state|domestic|foreign|
+-----+--------+-------+
| MI| 2056| 168|
| OH| 394| 190|
+-----+--------+-------+
可以让行有多个值state+industry
df_pivot = df.groupBy('state', 'industry').pivot('hq', values=['domestic', 'foreign']).sum('jobs')
df_pivot.show()
+-----+--------+--------+-------+
|state|industry|domestic|foreign|
+-----+--------+--------+-------+
| OH| retail| 45| 12|
| MI| auto| 2056| 123|
| OH| auto| 349| 178|
| MI| retail| null| 45|
+-----+--------+--------+-------+
resampling
建立一个表
import datetime
from pyspark.sql import Row
from pyspark.sql.functions import col
row = Row("date", "name", "production")
df = sc.parallelize([
row("08/01/2014", "Kim", 5),
row("08/02/2014", "Kim", 14),
row("08/01/2014", "Bob", 6),
row("08/02/2014", "Bob", 3),
row("08/01/2014", "Sue", 0),
row("08/02/2014", "Sue", 22),
row("08/01/2014", "Dan", 4),
row("08/02/2014", "Dan", 4),
row("08/01/2014", "Joe", 37),
row("09/01/2014", "Kim", 6),
row("09/02/2014", "Kim", 6),
row("09/01/2014", "Bob", 4),
row("09/02/2014", "Bob", 20),
row("09/01/2014", "Sue", 11),
row("09/02/2014", "Sue", 2),
row("09/01/2014", "Dan", 1),
row("09/02/2014", "Dan", 3),
row("09/02/2014", "Joe", 29)
]).toDF()
df.show()
+----------+----+----------+
| date|name|production|
+----------+----+----------+
|08/01/2014| Kim| 5|
|08/02/2014| Kim| 14|
|08/01/2014| Bob| 6|
|08/02/2014| Bob| 3|
|08/01/2014| Sue| 0|
|08/02/2014| Sue| 22|
|08/01/2014| Dan| 4|
|08/02/2014| Dan| 4|
|08/01/2014| Joe| 37|
|09/01/2014| Kim| 6|
|09/02/2014| Kim| 6|
|09/01/2014| Bob| 4|
|09/02/2014| Bob| 20|
|09/01/2014| Sue| 11|
|09/02/2014| Sue| 2|
|09/01/2014| Dan| 1|
|09/02/2014| Dan| 3|
|09/02/2014| Joe| 29|
+----------+----+----------+
df.dtypes
[('date', 'string'), ('name', 'string'), ('production', 'bigint')]
利用自定义函数来把data
列的MM/DD/YYYY
处理为MM/YYYY
#'udf' 代表 'user defined function'用来对列操作,是一个wrapper
from pyspark.sql.functions import udf
# 定义函数来solit MM/DD/YYYY字符串,返回MM/YYYY字符串
def split_date(whole_date):
try:
mo, day, yr = whole_date.split('/')
except ValueError:
return 'error'
return mo + '/' + yr
# wrapper
udf_split_date = udf(split_date)
# 创建新的列
df_new = df.withColumn('month_year', udf_split_date('date'))
df_new.show()
+----------+----+----------+----------+
| date|name|production|month_year|
+----------+----+----------+----------+
|08/01/2014| Kim| 5| 08/2014|
|08/02/2014| Kim| 14| 08/2014|
|08/01/2014| Bob| 6| 08/2014|
|08/02/2014| Bob| 3| 08/2014|
|08/01/2014| Sue| 0| 08/2014|
|08/02/2014| Sue| 22| 08/2014|
|08/01/2014| Dan| 4| 08/2014|
|08/02/2014| Dan| 4| 08/2014|
|08/01/2014| Joe| 37| 08/2014|
|09/01/2014| Kim| 6| 09/2014|
|09/02/2014| Kim| 6| 09/2014|
|09/01/2014| Bob| 4| 09/2014|
|09/02/2014| Bob| 20| 09/2014|
|09/01/2014| Sue| 11| 09/2014|
|09/02/2014| Sue| 2| 09/2014|
|09/01/2014| Dan| 1| 09/2014|
|09/02/2014| Dan| 3| 09/2014|
|09/02/2014| Joe| 29| 09/2014|
+----------+----+----------+----------+
删除data
列,根据month_year+name
resampling
df_new = df_new.drop('date')
df_agg = df_new.groupBy('month_year', 'name').agg({'production' : 'sum'})
df_agg.show()
+----------+----+---------------+
|month_year|name|sum(production)|
+----------+----+---------------+
| 09/2014| Sue| 13|
| 09/2014| Kim| 12|
| 09/2014| Bob| 24|
| 09/2014| Joe| 29|
| 09/2014| Dan| 4|
| 08/2014| Kim| 19|
| 08/2014| Joe| 37|
| 08/2014| Dan| 8|
| 08/2014| Sue| 22|
| 08/2014| Bob| 9|
+----------+----+---------------+
也可以把日期读取为Spark支持的格式
from pyspark.sql.functions import udf
from pyspark.sql.types import DateType
from datetime import datetime
dateFormat = udf(lambda x: datetime.strptime(x, '%M/%d/%Y'), DateType())
df_d = df.withColumn('new_date', dateFormat(col('date')))
df_d.select('new_date').take(1)
[Row(new_date=datetime.date(2014, 1, 1))]
子集subset
读取数据集
df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')
df.dtypes
[('_c0', 'bigint'),
('_c1', 'string'),
('_c2', 'string'),
('_c3', 'double'),
('_c4', 'double'),
('_c5', 'int'),
('_c6', 'int'),
('_c7', 'int'),
('_c8', 'string'),
('_c9', 'int'),
('_c10', 'string'),
('_c11', 'string'),
('_c12', 'int'),
('_c13', 'string'),
('_c14', 'string'),
('_c15', 'string'),
('_c16', 'string'),
('_c17', 'string'),
('_c18', 'string'),
('_c19', 'string'),
('_c20', 'string'),
('_c21', 'string'),
('_c22', 'string'),
('_c23', 'string'),
('_c24', 'string'),
('_c25', 'string'),
('_c26', 'int'),
('_c27', 'string')]
列选择
选择列,方式1
from pyspark.sql.functions import col
df_select = df.select(col('_c0'), col('_c1'), col('_c3'), col('_c9'))
df_select.show(5)
+------------+----------+-----+-----+
| _c0| _c1| _c3| _c9|
+------------+----------+-----+-----+
|100002091588|01/01/2015|4.125|16740|
|100002091588|02/01/2015|4.125|16740|
|100002091588|03/01/2015|4.125|16740|
|100002091588|04/01/2015|4.125|16740|
|100002091588|05/01/2015|4.125|16740|
+------------+----------+-----+-----+
only showing top 5 rows
选择列,方式2
df_select = df[['_c0', '_c1', '_c3', '_c9']]
df_select.show(5)
删除列
df_drop = df_select.drop(col('_c3'))
df_drop.show(5)
+------------+----------+-----+
| _c0| _c1| _c9|
+------------+----------+-----+
|100002091588|01/01/2015|16740|
|100002091588|02/01/2015|16740|
|100002091588|03/01/2015|16740|
|100002091588|04/01/2015|16740|
|100002091588|05/01/2015|16740|
+------------+----------+-----+
only showing top 5 rows
行选择
假设我们要选择_c6
df.describe('_c6').show()
+-------+-----------------+
|summary| _c6|
+-------+-----------------+
| count| 3526154|
| mean|354.7084951479714|
| stddev| 4.01181251079202|
| min| 292|
| max| 480|
+-------+-----------------+
where
条件选择_c6<358
的行
df_sub = df.where(df['_c6'] < 358)
df_sub.describe('_c6').show()
+-------+------------------+
|summary| _c6|
+-------+------------------+
| count| 2598037|
| mean|353.15604897081914|
| stddev|3.5170213056883988|
| min| 292|
| max| 357|
+-------+------------------+
where
多个条件
df_filter = df.where((df['_c6'] > 340) & (df['_c5'] < 4))
df_filter.describe('_c6', '_c5').show()
+-------+------------------+------------------+
|summary| _c6| _c5|
+-------+------------------+------------------+
| count| 1254131| 1254131|
| mean|358.48713810598736| 1.474693632483369|
| stddev| 1.378961910349754|1.2067831502138422|
| min| 341| -1|
| max| 361| 3|
+-------+------------------+------------------+
+-------+------------------+------------------+
|summary| _c6| _c5|
+-------+------------------+------------------+
| count| 1254131| 1254131|
| mean|358.48713810598736| 1.474693632483369|
| stddev| 1.378961910349754|1.2067831502138422|
| min| 341| -1|
| max| 361| 3|
+-------+------------------+------------------+
采样sample
第1个参数是是否有放回
,第2个参数是采样百分比,第3个参数是可选的random_seed
df_sample = df.sample(False, 0.05, 99)
df_sample.describe('_c6').show()
+-------+-----------------+
|summary| _c6|
+-------+-----------------+
| count| 176316|
| mean|354.7118072097824|
| stddev| 3.99195164816948|
| min| 299|
| max| 361|
+-------+-----------------+
基础统计分析
读取数据集
df = spark.read.csv('file:///home/ubuntu/SageMaker/Performance_2015Q1.txt', header=False, inferSchema=True, sep='|')
基础的describe
df_described = df.describe()
df_described.show()
加载峰度偏度等等
from pyspark.sql.functions import skewness, kurtosis
from pyspark.sql.functions import var_pop, var_samp, stddev, stddev_pop, sumDistinct, ntile
df.select(skewness('_c3')).show()
构建一个函数,来生成峰度偏度的统计表,用来和基础的describe表union
from pyspark.sql import Row
columns = df_described.columns # 列名list: ['summary', '_c0', '_c3', '_c4', '_c5', '_c6']
funcs = [skewness, kurtosis] # function的列表
fnames = ['skew', 'kurtosis'] # function名字,用来显示
def new_item(func, column):
"""
获取一个aggregation function和一个列名,
之后对这个列执行aggregation,
为了与describe的输出匹配,
返回的是一个string而不是数字
"""
return str(df.select(func(column)).collect()[0][0])
new_data = []
for func, fname in zip(funcs, fnames):
row_dict = {'summary':fname} # 每行以summary开头,确定是什么函数
for column in columns[1:]:
row_dict[column] = new_item(func, column)
new_data.append(Row(**row_dict))
print(new_data)
[Row(summary='skew', _c0='-0.00183847089866041', _c2='None', _c3='0.5197993394959906', _c4='0.7584115767562998', _c5='0.2864801560838491', _c6='-2.6976520156650614'), Row(summary='kurtosis', _c0='-1.1990072635132925', _c2='None', _c3='0.1260577268465326', _c4='0.5760856026559504', _c5='0.1951877800894679', _c6='24.723785894417404')]
可以看到,格式和基础describe表一样
df_described.collect()
[Row(summary='count', _c0='3526154', _c2='382039', _c3='3526154', _c4='1580402', _c5='3526154', _c6='3526154'),
Row(summary='mean', _c0='5.503885995001908E11', _c2=None, _c3='4.178168090219519', _c4='234846.78065481762', _c5='5.134865351881966', _c6='354.7084951479714'),
Row(summary='stddev', _c0='2.596112361975215E11', _c2=None, _c3='0.34382335723646673', _c4='118170.6859226166', _c5='3.3833930336063465', _c6='4.01181251079202'),
Row(summary='min', _c0='100002091588', _c2='CITIMORTGAGE, INC.', _c3='2.75', _c4='0.85', _c5='-1', _c6='292'),
Row(summary='max', _c0='999995696635', _c2='WELLS FARGO BANK, N.A.', _c3='6.125', _c4='1193544.39', _c5='34', _c6='480')]
union到一起
new_describe = sc.parallelize(new_data).toDF() #turns the results from our loop into a dataframe
new_describe = new_describe.select(df_described.columns) #forces the columns into the same order
expanded_describe = df_described.unionAll(new_describe) #merges the new stats with the original describe
expanded_describe.show()
可视化
首先读入数据集
import pandas as pd
import matplotlib.pyplot as plt
spark_df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv',
inferSchema=True, header=True, sep=',', nullValue='')
spark_df.show(5)
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
|carat| cut|color|clarity|depth|table|price| x| y| z|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
| 0.23| Ideal| E| SI2| 61.5| 55.0| 326|3.95|3.98|2.43|
| 0.21|Premium| E| SI1| 59.8| 61.0| 326|3.89|3.84|2.31|
| 0.23| Good| E| VS1| 56.9| 65.0| 327|4.05|4.07|2.31|
| 0.29|Premium| I| VS2| 62.4| 58.0| 334| 4.2|4.23|2.63|
| 0.31| Good| J| SI2| 63.3| 58.0| 335|4.34|4.35|2.75|
+-----+-------+-----+-------+-----+-----+-----+----+----+----+
only showing top 5 rows
spark_df.dtypes
[('carat', 'double'),
('cut', 'string'),
('color', 'string'),
('clarity', 'string'),
('depth', 'double'),
('table', 'double'),
('price', 'int'),
('x', 'double'),
('y', 'double'),
('z', 'double')]
spark_df.describe(['carat', 'depth', 'table', 'price']).show()
+-------+------------------+------------------+------------------+------------------+
|summary| carat| depth| table| price|
+-------+------------------+------------------+------------------+------------------+
| count| 53940| 53940| 53940| 53919|
| mean|0.7979397478679852| 61.74940489432624| 57.45718390804603|3933.3421799365715|
| stddev|0.4740112444054196|1.4326213188336525|2.2344905628213247| 3990.022722699714|
| min| 0.2| 43.0| 43.0| 326|
| max| 5.01| 79.0| 95.0| 18823|
+-------+------------------+------------------+------------------+------------------+
collect一下数据集
carat = spark_df[['carat']].collect()
price = spark_df[['price']].collect()
print(carat[:5])
print(price[:5])
[Row(carat=0.23), Row(carat=0.21), Row(carat=0.23), Row(carat=0.29), Row(carat=0.31)]
[Row(price=326), Row(price=326), Row(price=327), Row(price=334), Row(price=335)]
直接可视化
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.plot(carat, price, 'go', alpha=0.1)
ax.set_xlabel('Carat')
ax.set_ylabel('Price')
ax.set_title('Diamonds')
也可以把DataFrame转换为Pandas的DataFrame
pandas_df = spark_df.toPandas()
pandas_df.describe()
线性回归GLM
首先加载数据集,并选择其中几个列,我们需要预测的是log(price)
from pyspark.sql.functions import log
df = spark.read.csv('file:///home/ubuntu/SageMaker/diamonds_nulls.csv',
inferSchema=True, header=True, sep=',', nullValue='')
df = df[['carat', 'clarity', 'price']]
df = df.withColumn('lprice', log('price'))
df.show(5)
+-----+-------+-----+------------------+
|carat|clarity|price| lprice|
+-----+-------+-----+------------------+
| 0.23| SI2| 326| 5.786897381366708|
| 0.21| SI1| 326| 5.786897381366708|
| 0.23| VS1| 327|5.7899601708972535|
| 0.29| VS2| 334| 5.8111409929767|
| 0.31| SI2| 335| 5.814130531825066|
+-----+-------+-----+------------------+
only showing top 5 rows
这里有个辅助函数,用来生成feature,就不详细讲了
"""
Program written by Jeff Levy (jlevy@urban.org) for the Urban Institute, last revised 8/24/2016.
Note that this is intended as a temporary work-around until pySpark improves its ML package.
Tested in pySpark 2.0.
"""
def build_indep_vars(df, independent_vars, categorical_vars=None, keep_intermediate=False, summarizer=True):
"""
Data verification
df : DataFrame
independent_vars : List of column names
categorical_vars : None or list of column names, e.g. ['col1', 'col2']
"""
assert(type(df) is pyspark.sql.dataframe.DataFrame), 'pypark_glm: A pySpark dataframe is required as the first argument.'
assert(type(independent_vars) is list), 'pyspark_glm: List of independent variable column names must be the third argument.'
for iv in independent_vars:
assert(type(iv) is str), 'pyspark_glm: Independent variables must be column name strings.'
assert(iv in df.columns), 'pyspark_glm: Independent variable name is not a dataframe column.'
if categorical_vars:
for cv in categorical_vars:
assert(type(cv) is str), 'pyspark_glm: Categorical variables must be column name strings.'
assert(cv in df.columns), 'pyspark_glm: Categorical variable name is not a dataframe column.'
assert(cv in independent_vars), 'pyspark_glm: Categorical variables must be independent variables.'
"""
Code
"""
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.regression import GeneralizedLinearRegression
if categorical_vars:
string_indexer = [StringIndexer(inputCol=x,
outputCol='{}_index'.format(x))
for x in categorical_vars]
encoder = [OneHotEncoder(dropLast=True,
inputCol ='{}_index' .format(x),
outputCol='{}_vector'.format(x))
for x in categorical_vars]
independent_vars = ['{}_vector'.format(x) if x in categorical_vars else x for x in independent_vars]
else:
string_indexer, encoder = [], []
assembler = VectorAssembler(inputCols=independent_vars,
outputCol='indep_vars')
pipeline = Pipeline(stages=string_indexer+encoder+[assembler])
model = pipeline.fit(df)
df = model.transform(df)
#for building the crosswalk between indicies and column names
if summarizer:
param_crosswalk = {}
i = 0
for x in independent_vars:
if '_vector' in x[-7:]:
xrs = x.rstrip('_vector')
dst = df[[xrs, '{}_index'.format(xrs)]].distinct().collect()
for row in dst:
param_crosswalk[int(row['{}_index'.format(xrs)]+i)] = row[xrs]
maxind = max(param_crosswalk.keys())
del param_crosswalk[maxind] #for droplast
i += len(dst)
elif '_index' in x[:-6]:
pass
else:
param_crosswalk[i] = x
i += 1
"""
{0: 'carat',
1: u'SI1',
2: u'VS2',
3: u'SI2',
4: u'VS1',
5: u'VVS2',
6: u'VVS1',
7: u'IF'}
"""
make_summary = Summarizer(param_crosswalk)
if not keep_intermediate:
fcols = [c for c in df.columns if '_index' not in c[-6:] and '_vector' not in c[-7:]]
df = df[fcols]
if summarizer:
return df, make_summary
else:
return df
class Summarizer(object):
def __init__(self, param_crosswalk):
self.param_crosswalk = param_crosswalk
self.precision = 4
self.screen_width = 57
self.hsep = '-'
self.vsep = '|'
def summarize(self, model, show=True, return_str=False):
coefs = list(model.coefficients)
inter = model.intercept
tstat = model.summary.tValues
stder = model.summary.coefficientStandardErrors
pvals = model.summary.pValues
#if model includes an intercept:
if len(coefs) == len(tstat)-1:
coefs.insert(0, inter)
x = {0:'intercept'}
for k, v in self.param_crosswalk.items():
x[k+1] = v
else:
x = self.param_crosswalk
assert(len(coefs) == len(tstat) == len(stder) == len(pvals))
p = self.precision
h = self.hsep
v = self.vsep
w = self.screen_width
coefs = [str(round(e, p)).center(10) for e in coefs]
tstat = [str(round(e, p)).center(10) for e in tstat]
stder = [str(round(e, p)).center(10) for e in stder]
pvals = [str(round(e, p)).center(10) for e in pvals]
lines = ''
for i in range(len(coefs)):
lines += str(x[i]).rjust(15) + v + coefs[i] + stder[i] + tstat[i] + pvals[i] + '\n'
labels = ''.rjust(15) + v + 'Coef'.center(10) + 'Std Err'.center(10) + 'T Stat'.center(10) + 'P Val'.center(10)
pad = ''.rjust(15) + v
output = """{hline}\n{labels}\n{hline}\n{lines}{hline}""".format(
hline=h*w,
labels=labels,
lines=lines)
if show:
print(output)
if return_str:
return output
下面生成一下
df, summarizer = build_indep_vars(df,
['carat', 'clarity'],
categorical_vars=['clarity'],
keep_intermediate=False,
summarizer=True)
df.show(5)
+-----+-------+-----+------------------+--------------------+
|carat|clarity|price| lprice| indep_vars|
+-----+-------+-----+------------------+--------------------+
| 0.23| SI2| 326| 5.786897381366708|(8,[0,3],[0.23,1.0])|
| 0.21| SI1| 326| 5.786897381366708|(8,[0,1],[0.21,1.0])|
| 0.23| VS1| 327|5.7899601708972535|(8,[0,4],[0.23,1.0])|
| 0.29| VS2| 334| 5.8111409929767|(8,[0,2],[0.29,1.0])|
| 0.31| SI2| 335| 5.814130531825066|(8,[0,3],[0.31,1.0])|
+-----+-------+-----+------------------+--------------------+
only showing top 5 rows
这里先删除有空值的列
df.where( df['lprice'].isNull()) .count()
21
df_drops = df.dropna(how='all', subset=['lprice'])
模型fit
from pyspark.ml.regression import GeneralizedLinearRegression
glm = GeneralizedLinearRegression(family='gaussian',
link='identity',
labelCol='lprice',
featuresCol='indep_vars',
fitIntercept=True)
model = glm.fit(df_drops)
查看模型结果
model.coefficients
DenseVector([2.0808, 0.7228, 0.818, 0.5691, 0.8564, 0.9349, 0.9203, 0.9988])
model.intercept
5.3553700003910265
model.summary.tValues
[573.770045512725,
51.26011561154599,
57.75466173814217,
40.10391980395556,
59.5623952901403,
63.14279213211864,
60.48868627589056,
60.711257997408495,
371.7558057494966]
model.summary.coefficientStandardErrors
[0.0036265388065588383,
0.014100833468172893,
0.014163700186194172,
0.01418970865492269,
0.014378787192379332,
0.014805773323652266,
0.015214235864428029,
0.016451002067601285,
0.014405612279797675]
summarizer.summarize(model)
---------------------------------------------------------
| Coef Std Err T Stat P Val
---------------------------------------------------------
intercept| 5.3554 0.0036 573.77 0.0
carat| 2.0808 0.0141 51.2601 0.0
SI1| 0.7228 0.0142 57.7547 0.0
VS2| 0.818 0.0142 40.1039 0.0
SI2| 0.5691 0.0144 59.5624 0.0
VS1| 0.8564 0.0148 63.1428 0.0
VVS2| 0.9349 0.0152 60.4887 0.0
VVS1| 0.9203 0.0165 60.7113 0.0
IF| 0.9988 0.0144 371.7558 0.0
---------------------------------------------------------
summarizer.param_crosswalk
{0: 'carat',
1: 'SI1',
2: 'VS2',
6: 'VVS1',
5: 'VVS2',
7: 'IF',
4: 'VS1',
3: 'SI2'}