ColumnTransformer is a preprocessing tool in the Scikit-learn library in Python that allows you to apply different transformations to different columns of a dataset. It can be used to apply different preprocessing steps, such as:
- Scaling or Imputation, which is used for numerical columns.
- One-hot encoding, which is used for categorical columns.
- Feature selection, to select the most relevant and important features.
- The ColumnTransformer class allows you to specify a list of transformers, where each transformer is applied to a specific subset of the input features.
- The transformers can be any of the preprocessing tools provided by Scikit-learn, such as StandardScaler, OneHotEncoder, or PCA.
- The ColumnTransformer then applies these transformers to the relevant columns of the input data and concatenates the transformed data back together.
In this thread, you’ll learn how to use ColumnTransformer and apply preprocessing to categorical and numerical columns in the dataset.
Loading a sample dataframe:
To illustrate the working of ColumnTransformer
, a simple sample dataframe is imported and the output of it can be seen by running the code below along with the output of features that are selected for our transformation.
In this sample dataframe, columns Embarked
and Sex
are categorical and columns Fare
and Age
are numerical.
Applying ColumnTransformer to selected columns:
We will use Scikit-learn’s OneHotEncoder
function which is a transformer used to convert categorical variables into a set of binary dummy variables and the SimpleImputer()
transformer which is used to impute missing values in a dataset using different strategies.
-
OneHotEncoder()
andSimpleImputer()
objects are initialized after importing libraries and then passed as transformers namedcat
andnum
in theColumnTransformer
function. -
OneHotEncoder()
is applied to columnsEmbarked
andSex
whileSimpleImputer()
is applied toAge
. - The argument
remainder = "passthrough"
indicates that all other columns (Fare
) in the array would be passed as they are without change. - The results are then finally printed after applying the
ColumnTransformer
to arrayX
using thefit_transform()
function.
There are now more columns than before in the result:
- Starting three columns belong to the
Embarked
column since there are 3 unique values in this column. - Similarly, the next two columns belong to the
Sex
column since there are 2 unique values in this column. - The
Age
column previously had a missing value which is now imputed using themean
strategy which is the default strategy ofSimpleImputer()
. - The column
Fare
is not changed and was the same as before.