scikit-learn: Pipeline

Yao Yao on September 15, 2016

Pipeline 的作用，见名知意，就是把多个 processor units chain up 起来。Pipeline 要求前 $N-1$ 个 processor units 是 Transformer，最后一个 processor unit 是 Estimator。我们举个例子看看就很好理解了：

from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline
from sklearn.pipeline import make_pipeline

class DullTransformer(TransformerMixin):
def __init__(self, number):
self.number = number

def fit(self, X, y=None):
print("Dull Transformer No.{}: fit X={}".format(self.number, X))
return self

def transform(self, X):
print("Dull Transformer No.{}: transform X={} => X={}".format(self.number, X, X+10))
return X+10

class DullEstimator(BaseEstimator):
def __init__(self, number):
self.number = number

def fit(self, X, y=None):
print("Dull Estimator No.{}: fit X={}".format(self.number, X))
return self

def predict(self, X):
print("Dull Estimator No.{}: predict X={} is...".format(self.number, X))
return 0

if __name__ == '__main__':
p1 = Pipeline(steps=[("Trans1", DullTransformer(1)),
("Trans2", DullTransformer(2)),
("Estmt1", DullEstimator(1))])
print("#===== Pipeline 1 repr =====#")
print(p1)
print("#===== Pipeline 1 fitting =====#")
p1.fit(X=100)
print("#===== Pipeline 1 predicting =====#")
pred = p1.predict(X=100)
print(pred)

p2 = make_pipeline(DullTransformer(1), DullTransformer(2), DullEstimator(1))
print("#===== Pipeline 2 repr =====#")
print(p2)

# output:
#    #===== Pipeline 1 repr =====#
#    Pipeline(steps=[('Trans1', <__main__.DullTransformer object at 0x7f7008051ad0>), ('Trans2', <__main__.DullTransformer object at 0x7f7008061590>), ('Estmt1', DullEstimator(number=1))])
#    #===== Pipeline 1 fitting =====#
#    Dull Transformer No.1: fit X=100
#    Dull Transformer No.1: transform X=100 => X=110
#    Dull Transformer No.2: fit X=110
#    Dull Transformer No.2: transform X=110 => X=120
#    Dull Estimator No.1: fit X=120
#    #===== Pipeline 1 predicting =====#
#    Dull Transformer No.1: transform X=100 => X=110
#    Dull Transformer No.2: transform X=110 => X=120
#    Dull Estimator No.1: predict X=120 is...
#    0
#    #===== Pipeline 2 repr =====#
#    Pipeline(steps=[('dulltransformer-1', <__main__.DullTransformer object at 0x7f7011c8f750>), ('dulltransformer-2', <__main__.DullTransformer object at 0x7f7008061890>), ('dullestimator', DullEstimator(number=1))])


Pipeline fit() 的逻辑是：

1. DullTransformer(1).fit(100).transform(100)，输出 110 给 DullTransformer(2) 作输入
2. DullTransformer(2).fit(110).transform(110)，输出 120 给 DullEstimator(1) 作输入
3. DullEstimator(1).fit(120)

predict() 的逻辑类似。此外还有 fit_transform()fit_predict() 等方法可以连环组织你的处理步骤。

1. 我可以在 GridSearchCV() 里尝试一个 Transformer 的不同参数。比如，是先 $\log_2(X)$ 再 predict 还是先 $\ln(X)$ 再 predict。
2. 我可以设计一个 Transformer dispatcher，根据参数来指定具体使用哪个 Transformer。比如接收到 type="log2"，我就 delegate 一个 $\log_2(X)$ 的 transformer； 接收到 type="abs"，就 delegate 一个 $\vert X \vert$ 的 transformer。