Skip to content

Dataset

The Dataset class is the core component that represents an individual dataset with methods for loading, transforming, and analyzing.

Class Documentation

fairml_datasets.dataset.Dataset

Main class representing a fairness dataset with methods to download, load, transform and split data.

Provides interfaces to common fairness-related operations like identifying sensitive columns, preparing data for analysis, and generating metadata.

Source code in fairml_datasets/dataset.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
class Dataset:
    """
    Main class representing a fairness dataset with methods to download, load, transform and split data.

    Provides interfaces to common fairness-related operations like identifying sensitive columns,
    preparing data for analysis, and generating metadata.
    """

    info: pd.Series

    def __init__(self, info: pd.Series):
        """
        Initialize a Dataset, one will usually use Dataset.from_id() instead,
        unless you have a custom collection of dataset annotations.

        Args:
            info: A pandas Series containing annotations for the dataset
        """
        assert isinstance(info, pd.Series), "info must be a pandas.Series"
        self.info = info.replace({np.nan: None})
        self._sensitive_columns = None

    @staticmethod
    def from_id(id: str) -> "Dataset":
        """
        Create a Dataset object using the dataset identifier.

        Args:
            id: String identifier of the dataset

        Returns:
            Dataset: A Dataset object
        """
        df_info = annotations.load()
        return Dataset(df_info.loc[id])

    def __repr__(self) -> str:
        """
        Returns a string representation of the Dataset object (for printing)

        Returns:
            str: A string representation of the Dataset object
        """
        return f"fairml_datasets.Dataset(id={self.dataset_id})"

    @property
    def dataset_id(self) -> str:
        """
        Get the dataset identifier.

        Returns:
            str: The dataset identifier
        """
        return self.info.name

    @property
    def name(self) -> str:
        """
        Get the human-readable name of the dataset.

        Returns:
            str: The name of the dataset
        """
        return self.info["dataset_name"]

    def get_processing_script(
        self, processing_options: Dict[str, Any] = None
    ) -> ProcessingScript:
        """
        Get the ProcessingScript for this dataset with optional configuration.

        Args:
            processing_options: Dictionary of options to pass to the processing script

        Returns:
            ProcessingScript: The processing script for this dataset, or None if not available
        """
        ScriptClass = get_processing_script(self.dataset_id)
        if ScriptClass is not None:
            return ScriptClass(processing_options=processing_options)
        else:
            if processing_options is not None:
                warnings.warn(
                    "Processing_options provided, but no processing script available. Ignoring processing_options."
                )
            return None

    def get_urls(self) -> List[str]:
        """
        Get the download URLs for this dataset.

        Returns:
            List[str]: List of download URLs, empty if none available
        """
        if self.info["download_url"] is not None:
            return self.info["download_url"].split(";")
        else:
            return []

    def get_filenames(self) -> List[str]:
        """
        Get the expected filenames for the raw dataset files.

        Returns:
            List[str]: List of expected filenames
        """
        urls = self.get_urls()
        if self.info["filename_raw"] is not None:
            return self.info["filename_raw"].split(";")
        elif urls:
            return [url.rsplit("/", 1)[-1] for url in urls]
        else:
            return [self.dataset_id]

    def _download(self, directory: Path, read_cache: bool = True) -> List[Path]:
        """
        Download the dataset files to a specified directory.

        Args:
            directory: Directory to store downloaded files
            read_cache: Whether to use cached files if available

        Returns:
            List[Path]: Paths to downloaded files
        """
        target_dir = directory / str(self.dataset_id)
        target_dir.mkdir(parents=True, exist_ok=True)

        return download_dataset(
            urls=self.get_urls(),
            filenames=self.get_filenames(),
            target_directory=target_dir,
            is_zip=self.info["is_zip"],
            read_cache=read_cache,
        )

    def load(
        self,
        stage: Literal[
            "downloaded", "loaded", "prepared", "binarized", "transformed", "split"
        ] = "prepared",
        cache_at: Literal["downloaded", "prepared"] = "prepared",
        check_cache: bool = True,
        processing_options: Optional[Dict[str, Any]] = None,
    ) -> pd.DataFrame:
        """
        Load the dataset at a specific processing stage.

        Args:
            stage: Processing stage at which to return the dataset
            cache_at: Stage at which to cache the dataset (downloaded or prepared)
            check_cache: Whether to check for cached data
            processing_options: Options to pass to the dataset's processing script (if available; optional)

        Returns:
            pd.DataFrame: The pandas dataframe with data at the specified stage of processing
        """
        if stage == "split":
            logger.info(
                "Please use .split_dataset(), .train_test_split() or .train_test_val_split to split the dataset. 'Binarized' data will be returned from this function."
            )
            stage = "binarized"

        # Load the custom processing script (if it exists)
        script = self.get_processing_script(processing_options=processing_options)
        has_script = script is not None

        # Check the "prepared" cache
        # Add a hash to the cache name if necessary, as the script might have options
        opt_hash_postfix = (
            ("-" + script.get_options_hash())
            if has_script and script.has_options
            else ""
        )
        cached_filename = f"{self.dataset_id}{opt_hash_postfix}.parquet"
        cached_filepath = DATASET_CACHE_DIR / cached_filename
        if check_cache and cached_filepath.exists():
            logger.info(f"Loading cached dataset from {cached_filepath}")
            return pd.read_parquet(cached_filepath, engine="fastparquet")

        if not self.info["custom_download"]:
            assert (
                self.info["download_url"] is not None
            ), "Dataset is missing download URL."

            with make_temp_directory() as temp_dir:
                cache_download = cache_at == "downloaded"
                download_dir = temp_dir if not cache_download else DOWNLOAD_CACHE_DIR

                # Step 1: Download the dataset
                locations = self._download(
                    directory=download_dir, read_cache=cache_download
                )
                if stage == "downloaded":
                    # Cast to DataFrame so we can stick to one return type
                    return pd.DataFrame({"file_paths": [str(loc) for loc in locations]})

                # Step 2: Load the dataset
                if has_script and isinstance(script, LoadingScript):
                    dataset = script.load(locations)
                    assert (
                        dataset is not None
                    ), "LoadingScript returned None, maybe a return was missing?"
                else:
                    assert (
                        len(locations) == 1
                    ), "Multiple files, but no custom script to handle them."

                    dataset = load_dataset(
                        location=locations[0],
                        data_format=self.info["format"],
                        colnames=self.info["colnames"],
                    )
        else:
            if stage == "downloaded":
                warnings.warn(
                    "Use stage == downloaded for a dataset that uses a custom download (most likely synthetic data). Returning an empty Dataframe."
                )
                return pd.DataFrame()

            # Alternative Step 1 & 2: Using a custom download & loading script
            if has_script and isinstance(script, LoadingScript):
                logger.debug(f"Detected LoadingScript for {self.dataset_id}.")
                dataset = script.load(locations=[])
            else:
                raise ValueError(
                    f"Dataset {self.dataset_id} flagged as custom download, "
                    "but is missing a loading script."
                )
        if stage == "loaded":
            return dataset

        # Step 3: Prepare the dataset (if applicable)
        if has_script and isinstance(script, PreparationScript):
            logger.debug(f"Detected PreparationScript for {self.dataset_id}.")
            dataset = script.prepare(dataset)
        # Cache the dataset (optional)
        if cache_at == "prepared":
            DATASET_CACHE_DIR.mkdir(exist_ok=True, parents=True)
            dataset.to_parquet(cached_filepath, index=False, engine="fastparquet")

        if stage == "prepared":
            return dataset
        if stage == "transformed":
            df, _ = self.transform(dataset)
            return df
        elif stage == "binarized":
            df, _ = self.binarize(
                dataset,
                transform_sensitive_columns="intersection_binary",
                transform_sensitive_values="majority_minority",
            )
            return df
        else:
            raise ValueError(f"Unsupported stage: {stage}")

    def to_pandas(self) -> pd.DataFrame:
        """
        Load the dataset as a pandas DataFrame.

        Use .load() if you want more control.

        Returns:
            pd.DataFrame: The dataset as a pandas DataFrame
        """
        return self.load()

    def to_numpy(self) -> np.ndarray:
        """
        Load the dataset as a numpy array.

        Use .load() if you want more control.

        Returns:
            np.ndarray: The dataset as a numpy array
        """
        return self.load().to_numpy()

    def binarize(
        self,
        df: Optional[pd.DataFrame] = None,
        sensitive_columns: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> tuple[pd.DataFrame, PreprocessingInfo]:
        """
        Apply binarization transformations to the dataset for fairness analysis.

        This is a convenience method that calls transform() with specific parameters
        to convert categorical sensitive attributes to binary format using
        'intersection_binary' and 'majority_minority' strategies.

        Args:
            df: Optional DataFrame to transform, if None the dataset is loaded
            sensitive_columns: List of sensitive attribute column names
            **kwargs: Additional arguments for transformation, passed to transform()

        Returns:
            tuple[pd.DataFrame, PreprocessingInfo]: Binarized DataFrame and preprocessing information
        """
        return self.transform(
            df=df,
            sensitive_columns=sensitive_columns,
            transform_sensitive_columns="intersection_binary",
            transform_sensitive_values="majority_minority",
            **kwargs,
        )

    def transform(
        self,
        df: Optional[pd.DataFrame] = None,
        sensitive_columns: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> tuple[pd.DataFrame, PreprocessingInfo]:
        """
        Apply default transformations to the dataset.

        Args:
            df: Optional DataFrame to transform, if None the dataset is loaded
            sensitive_columns: List of sensitive attribute column names
            **kwargs: Additional arguments for transformation, passed to transform()

        Returns:
            tuple[pd.DataFrame, PreprocessingInfo]: Transformed DataFrame and preprocessing information
        """
        logger.debug(f"Transforming: {self.dataset_id}")

        if df is None:
            df = self.load()
        target_column = self.get_target_column()
        if sensitive_columns is None:
            sensitive_columns = self.sensitive_columns

        feature_columns = self.get_feature_columns(df=df)
        target_lvl_good_bad = self.get_target_lvl_good_value()

        return transform(
            df=df,
            sensitive_columns=sensitive_columns,
            feature_columns=feature_columns,
            target_column=target_column,
            target_lvl_good_bad=target_lvl_good_bad,
            **kwargs,
        )

    def get_feature_columns(self, df: pd.DataFrame) -> List[str]:
        """
        Get the feature columns for the dataset.

        Args:
            df: DataFrame containing the dataset

        Returns:
            List[str]: List of feature column names
        """
        target_column = set([self.get_target_column()])
        sensitive_columns = set(self.sensitive_columns)
        typical_col_features = self.get_typical_col_features()
        colnames = set(df.columns)
        feature_columns = parse_feature_column_filter(
            colnames, target_column, sensitive_columns, typical_col_features
        )
        return feature_columns

    def filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Filter the dataset to include only the required columns.

        Args:
            df: DataFrame to filter

        Returns:
            pd.DataFrame: Filtered DataFrame
        """
        sensitive_columns = set(self.sensitive_columns)
        feature_columns = self.get_feature_columns()
        target_column = set([self.get_target_column()])

        return filter_columns(
            sensitive_columns=sensitive_columns,
            feature_columns=feature_columns,
            target_column=target_column,
        )

    def to_aif360_BinaryLabelDataset(
        self,
        sensitive_columns: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> "BinaryLabelDataset":  # noqa: F821
        """
        Convert the dataset to AIF360 BinaryLabelDataset format.

        Args:
            sensitive_columns: List of sensitive attribute column names
            **kwargs: Additional arguments for transformation passed to .binarize() and then .transform()

        Returns:
            BinaryLabelDataset: The dataset in AIF360 format
        """
        df = self.load()

        # Preprocessing (this can affect columns / column names)
        df, info = self.binarize(df=df, sensitive_columns=sensitive_columns, **kwargs)
        sensitive_columns = info.sensitive_columns

        # Create the AIF360 dataset
        from aif360.datasets import BinaryLabelDataset

        dataset = BinaryLabelDataset(
            favorable_label=1,
            unfavorable_label=0,
            df=df,
            label_names=[self.get_target_column()],
            protected_attribute_names=sensitive_columns,
        )

        return dataset

    def generate_metadata(self) -> Dict[str, Union[str, int, float]]:
        """
        Generate metadata about the dataset for fairness analysis.

        Returns:
            Dict[str, Union[str, int, float]]: Dictionary of metadata
        """
        try:
            df_raw = self.load()
        except Exception as e:
            logger.exception(
                f"Data loading / preparation error in {self.dataset_id}: {e}"
            )
            return {
                "id": self.dataset_id,
                "debug_meta_status": "LOAD_ERROR",
                "debug_meta_error_message": str(e),
            }

        try:
            sensitive_columns = self.sensitive_columns
            target_column = self.get_target_column()
            target_lvl_good_value = self.get_target_lvl_good_value()
            feature_columns = self.get_feature_columns(df=df_raw)

            general_descriptives = generate_general_descriptives(
                df_raw=df_raw,
                sensitive_columns=sensitive_columns,
                target_column=target_column,
                target_lvl_good_value=target_lvl_good_value,
                feature_columns=list(feature_columns),
            )

            try:
                df, info = self.binarize(df=df_raw, sensitive_columns=sensitive_columns)
                sensitive_columns = info.sensitive_columns
                col_na_indicator = info.col_na_indicator
            except Exception as e:
                logger.exception(f"Transformation Error in {self.dataset_id}: {e}")

            binarized_descriptives = generate_binarized_descriptives(
                df=df,
                sensitive_columns=sensitive_columns,
                target_column=target_column,
                col_na_indicator=col_na_indicator,
            )

            meta = {
                "id": self.dataset_id,
                # Metadata used for debugging / development
                "debug_meta_status": "OK",
                "debug_meta_colnames": ";".join(df.columns.tolist()),
                "debug_meta_coltypes": ";".join(
                    [str(df[col].dtype) for col in df.columns]
                ),
            }
            meta.update(general_descriptives)
            meta.update(binarized_descriptives)

            return meta
        except Exception as e:
            logger.exception(f"Metadata error in {self.dataset_id}: {e}")
            return {
                "id": self.dataset_id,
                "debug_meta_status": "METADATA_ERROR",
                "debug_meta_error_message": str(e),
            }

    def _get_sensitive_parsed(self) -> Optional[Dict[str, str]]:
        """
        Get the parsed sensitive columns information.

        Returns:
            Optional[Dict[str, str]]: Dictionary mapping sensitive column names to descriptions
        """
        return (
            json.loads(self.info["typical_col_sensitive"])
            if self.info["typical_col_sensitive"] is not None
            else None
        )

    def get_all_sensitive_columns(self) -> List[str]:
        """
        Get all available (typically used) sensitive columns for this dataset.

        Returns:
            List[str]: List of all sensitive column names
        """
        sensitive_parsed = self._get_sensitive_parsed()
        return list(sensitive_parsed.keys()) if sensitive_parsed is not None else []

    def _get_default_scenario_sensitive_columns(self) -> List[str]:
        """
        Get the default sensitive columns for the standard Scenario.

        Returns:
            List[str]: List of sensitive column names
        """
        sensitive_cols = self.info["default_scenario_sensitive_cols"].split(";")
        if len(sensitive_cols) > 0:
            return sensitive_cols
        else:
            raise ValueError

    @property
    def sensitive_columns(self) -> List[str]:
        """
        Get the sensitive columns for this dataset.

        Returns:
            List[str]: List of sensitive column names
        """
        if self._sensitive_columns is not None:
            return self._sensitive_columns
        return self._get_default_scenario_sensitive_columns()

    def generate_sensitive_intersections(self) -> List[List[str]]:
        """
        Generate all possible intersections of sensitive attributes.

        Returns:
            List[List[str]]: List of all possible combinations of sensitive attributes
        """
        if self._sensitive_columns is not None:
            warnings.warn(
                "Generating sensitive intersections on a scenario. You will usually want to do this on the dataset itself, as they will be generated for ALL available sensitive attributes, not the ones in the scenario."
            )
        sensitive_columns = self.get_all_sensitive_columns()
        all_combinations = list(
            chain.from_iterable(
                combinations(sensitive_columns, r)
                for r in range(1, len(sensitive_columns) + 1)
            )
        )
        return [list(combo) for combo in all_combinations]

    def get_target_column(self) -> str:
        """
        Get the name of the target column for this dataset.

        Returns:
            str: Name of the target column
        """
        target_col = self.info["typical_col_target"]
        # Strip leading question mark
        if target_col.startswith("?"):
            target_col = target_col[1:]
        # Separate multiple columns by semicolon
        if ";" in target_col:
            target_cols = target_col.split(";")
            # The first column is the "most" typical, so use this one for now
            # This is currently only an issue for the Drug dataset
            target_col = target_cols[0]
        return target_col

    def get_target_lvl_good_value(self) -> Optional[str]:
        """
        Get the value in the target column that represents a favorable outcome.

        Returns:
            Optional[str]: The value representing a favorable outcome
        """
        target_lvl_good = str(self.info["target_lvl_good"])
        # Strip leading question mark
        if target_lvl_good.startswith("?"):
            target_lvl_good = target_lvl_good[1:]
        # Encode empty string as None
        if target_lvl_good == "":
            return None
        return target_lvl_good

    def citation(self) -> Optional[str]:
        """
        Get the citation for this dataset in BibTeX format.

        Returns:
            Optional[str]: The citation in BibTeX format, or None if not available
        """
        citation_text = self.info.get("citation")
        if citation_text and isinstance(citation_text, str):
            return citation_text.strip()
        return None

    def get_typical_col_features(self) -> str | None:
        """
        Get information about typical feature columns.

        Returns:
            str | None: Information about typical feature columns
        """
        typical_col_features = self.info["typical_col_features"]
        if typical_col_features.startswith("?"):
            typical_col_features = typical_col_features[1:]
        if typical_col_features == "":
            return None
        return typical_col_features

    def split_dataset(
        self,
        df: Union[pd.DataFrame, BinaryLabelDataset],
        splits: Tuple[float, ...],
        seed: int = DEFAULT_SEED,
        stratify: bool = True,
        stratify_manual: Optional[str] = None,
        sensitive_columns: Optional[List[str]] = None,
    ) -> Tuple[Union[pd.DataFrame, BinaryLabelDataset], ...]:
        """
        Split a dataset into multiple partitions based on specified ratios.

        Args:
            df: DataFrame or AIF360 BinaryLabelDataset to split
            splits: Tuple of fractions that sum to 1.0. For example, (0.6, 0.2, 0.2) for train/val/test
            seed: Random seed for reproducibility
            stratify: Whether to stratify the split by target column and sensitive attributes
            stratify_manual: Column to stratify by (defaults to combined target+sensitive if None and stratify=True)
            sensitive_columns: List of sensitive attribute column names

        Returns:
            Tuple of datasets with the same type as df
        """
        # Validate splits
        if not isinstance(splits, tuple) or len(splits) < 2:
            raise ValueError("splits must be a tuple with at least 2 elements")
        if round(sum(splits), 10) != 1.0:
            raise ValueError(f"Split values must sum to 1.0, got {sum(splits)}")

        if stratify_manual is not None and not stratify:
            logger.error(
                "stratify_manual is set but stratify is False. Setting stratify to True."
            )
            stratify = True

        # If AIF360 dataset, convert to pandas for splitting
        is_aif360 = isinstance(df, BinaryLabelDataset)
        pandas_df = df if not is_aif360 else df.convert_to_dataframe()[0]

        # Determine column to use for stratification
        stratify_col = None
        if stratify:
            # Manual column provided
            if stratify_manual is not None:
                if stratify_manual in pandas_df.columns:
                    stratify_col = pandas_df[stratify_manual]
                else:
                    raise ValueError(
                        f"Stratify column {stratify_manual} not found in dataframe."
                    )
            else:
                # Get target and sensitive columns for combined stratification
                target_column = self.get_target_column()
                if sensitive_columns is None:
                    sensitive_columns = self.sensitive_columns

                # Handle the case where sensitive columns were transformed
                # Check if original sensitive columns exist, if not try sensitive_intersection
                available_sensitive_columns = []
                for col in sensitive_columns:
                    if col in pandas_df.columns:
                        available_sensitive_columns.append(col)

                # If no original sensitive columns found, check for sensitive_intersection
                if (
                    not available_sensitive_columns
                    and "sensitive_intersection" in pandas_df.columns
                ):
                    sensitive_columns = ["sensitive_intersection"]

                strat_columns = [target_column] + sensitive_columns

                # Check if columns exist in the dataframe
                if not set(strat_columns).issubset(set(pandas_df.columns)):
                    missing_columns = set(strat_columns) - set(pandas_df.columns)
                    raise ValueError(
                        f"Target or sensitive columns not found in dataframe: {missing_columns}"
                    )

                logger.info(
                    f"Stratifying split by columns: {strat_columns}"
                )

                # Combine target and sensitive columns for stratification
                if len(strat_columns) == 1:
                    stratify_col = pandas_df[strat_columns[0]]
                else:
                    stratify_col = (
                        pandas_df[strat_columns]
                        .astype(str)
                        .apply(lambda x: "_".join(x.values), axis=1)
                    )

        result = []
        remaining_df = pandas_df.copy()
        remaining_prob = 1.0

        # Iteratively split dataset
        for i, split_ratio in enumerate(splits[:-1]):
            if len(remaining_df) == 0:
                # Handle edge case of empty dataframe
                result.append(remaining_df.copy())
                continue

            # Calculate the proportion for this split from the remaining data
            proportion = split_ratio / remaining_prob

            # Get stratification data for current split
            current_stratify = (
                stratify_col.loc[remaining_df.index]
                if stratify_col is not None
                else None
            )

            # If there's only one class in the stratification column, don't stratify
            if current_stratify is not None and len(current_stratify.unique()) < 2:
                logger.warning(
                    f"Only one class present in stratification column for split {i}. Not stratifying this split."
                )
                current_stratify = None

            # Split the data
            new_split, remaining_df = sk_train_test_split(
                remaining_df,
                test_size=1 - proportion,
                random_state=seed,
                stratify=current_stratify,
            )

            result.append(new_split)
            remaining_prob -= split_ratio

        # Add the final piece
        result.append(remaining_df)

        # Convert back to AIF360 if needed
        if is_aif360:
            return tuple(
                BinaryLabelDataset(
                    df=split_df,
                    label_names=df.label_names,
                    protected_attribute_names=df.protected_attribute_names,
                    favorable_label=df.favorable_label,
                    unfavorable_label=df.unfavorable_label,
                )
                for split_df in result
            )

        return tuple(result)

    def train_test_split(
        self,
        df: Union[pd.DataFrame, BinaryLabelDataset],
        test_size: float = 0.3,
        **kwargs: Any,
    ) -> Tuple[
        Union[pd.DataFrame, BinaryLabelDataset], Union[pd.DataFrame, BinaryLabelDataset]
    ]:
        """
        Split a dataset into train and test sets.

        Args:
            df: DataFrame or BinaryLabelDataset to split (if None, dataset is loaded)
            test_size: Fraction of data to use for testing
            **kwargs: Additional arguments to pass to split_dataset

        Returns:
            Tuple of (train_data, test_data)
        """
        train_size = 1.0 - test_size
        return self.split_dataset(df, splits=(train_size, test_size), **kwargs)

    def train_test_val_split(
        self,
        df: Union[pd.DataFrame, BinaryLabelDataset],
        test_size: float = 0.2,
        val_size: float = 0.2,
        **kwargs: Any,
    ) -> Tuple[
        Union[pd.DataFrame, BinaryLabelDataset],
        Union[pd.DataFrame, BinaryLabelDataset],
        Union[pd.DataFrame, BinaryLabelDataset],
    ]:
        """
        Split a dataset into train, validation, and test sets.

        Args:
            df: DataFrame or BinaryLabelDataset to split
            test_size: Fraction of data to use for testing
            val_size: Fraction of data to use for validation
            **kwargs: Additional arguments to pass to split_dataset

        Returns:
            Tuple of (train_data, val_data, test_data)
        """
        # Calculate train size based on test and validation sizes
        train_size = 1.0 - test_size - val_size

        if train_size <= 0:
            raise ValueError(
                f"Invalid split sizes: test_size ({test_size}) + val_size ({val_size}) must be less than 1.0"
            )

        return self.split_dataset(
            df, splits=(train_size, val_size, test_size), **kwargs
        )

Attributes

dataset_id property

Get the dataset identifier.

Returns:

Name Type Description
str str

The dataset identifier

name property

Get the human-readable name of the dataset.

Returns:

Name Type Description
str str

The name of the dataset

sensitive_columns property

Get the sensitive columns for this dataset.

Returns:

Type Description
List[str]

List[str]: List of sensitive column names

Functions

__init__(info)

Initialize a Dataset, one will usually use Dataset.from_id() instead, unless you have a custom collection of dataset annotations.

Parameters:

Name Type Description Default
info Series

A pandas Series containing annotations for the dataset

required
Source code in fairml_datasets/dataset.py
def __init__(self, info: pd.Series):
    """
    Initialize a Dataset, one will usually use Dataset.from_id() instead,
    unless you have a custom collection of dataset annotations.

    Args:
        info: A pandas Series containing annotations for the dataset
    """
    assert isinstance(info, pd.Series), "info must be a pandas.Series"
    self.info = info.replace({np.nan: None})
    self._sensitive_columns = None

__repr__()

Returns a string representation of the Dataset object (for printing)

Returns:

Name Type Description
str str

A string representation of the Dataset object

Source code in fairml_datasets/dataset.py
def __repr__(self) -> str:
    """
    Returns a string representation of the Dataset object (for printing)

    Returns:
        str: A string representation of the Dataset object
    """
    return f"fairml_datasets.Dataset(id={self.dataset_id})"

binarize(df=None, sensitive_columns=None, **kwargs)

Apply binarization transformations to the dataset for fairness analysis.

This is a convenience method that calls transform() with specific parameters to convert categorical sensitive attributes to binary format using 'intersection_binary' and 'majority_minority' strategies.

Parameters:

Name Type Description Default
df Optional[DataFrame]

Optional DataFrame to transform, if None the dataset is loaded

None
sensitive_columns Optional[List[str]]

List of sensitive attribute column names

None
**kwargs Any

Additional arguments for transformation, passed to transform()

{}

Returns:

Type Description
tuple[DataFrame, PreprocessingInfo]

tuple[pd.DataFrame, PreprocessingInfo]: Binarized DataFrame and preprocessing information

Source code in fairml_datasets/dataset.py
def binarize(
    self,
    df: Optional[pd.DataFrame] = None,
    sensitive_columns: Optional[List[str]] = None,
    **kwargs: Any,
) -> tuple[pd.DataFrame, PreprocessingInfo]:
    """
    Apply binarization transformations to the dataset for fairness analysis.

    This is a convenience method that calls transform() with specific parameters
    to convert categorical sensitive attributes to binary format using
    'intersection_binary' and 'majority_minority' strategies.

    Args:
        df: Optional DataFrame to transform, if None the dataset is loaded
        sensitive_columns: List of sensitive attribute column names
        **kwargs: Additional arguments for transformation, passed to transform()

    Returns:
        tuple[pd.DataFrame, PreprocessingInfo]: Binarized DataFrame and preprocessing information
    """
    return self.transform(
        df=df,
        sensitive_columns=sensitive_columns,
        transform_sensitive_columns="intersection_binary",
        transform_sensitive_values="majority_minority",
        **kwargs,
    )

citation()

Get the citation for this dataset in BibTeX format.

Returns:

Type Description
Optional[str]

Optional[str]: The citation in BibTeX format, or None if not available

Source code in fairml_datasets/dataset.py
def citation(self) -> Optional[str]:
    """
    Get the citation for this dataset in BibTeX format.

    Returns:
        Optional[str]: The citation in BibTeX format, or None if not available
    """
    citation_text = self.info.get("citation")
    if citation_text and isinstance(citation_text, str):
        return citation_text.strip()
    return None

filter_columns(df)

Filter the dataset to include only the required columns.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to filter

required

Returns:

Type Description
DataFrame

pd.DataFrame: Filtered DataFrame

Source code in fairml_datasets/dataset.py
def filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Filter the dataset to include only the required columns.

    Args:
        df: DataFrame to filter

    Returns:
        pd.DataFrame: Filtered DataFrame
    """
    sensitive_columns = set(self.sensitive_columns)
    feature_columns = self.get_feature_columns()
    target_column = set([self.get_target_column()])

    return filter_columns(
        sensitive_columns=sensitive_columns,
        feature_columns=feature_columns,
        target_column=target_column,
    )

from_id(id) staticmethod

Create a Dataset object using the dataset identifier.

Parameters:

Name Type Description Default
id str

String identifier of the dataset

required

Returns:

Name Type Description
Dataset Dataset

A Dataset object

Source code in fairml_datasets/dataset.py
@staticmethod
def from_id(id: str) -> "Dataset":
    """
    Create a Dataset object using the dataset identifier.

    Args:
        id: String identifier of the dataset

    Returns:
        Dataset: A Dataset object
    """
    df_info = annotations.load()
    return Dataset(df_info.loc[id])

generate_metadata()

Generate metadata about the dataset for fairness analysis.

Returns:

Type Description
Dict[str, Union[str, int, float]]

Dict[str, Union[str, int, float]]: Dictionary of metadata

Source code in fairml_datasets/dataset.py
def generate_metadata(self) -> Dict[str, Union[str, int, float]]:
    """
    Generate metadata about the dataset for fairness analysis.

    Returns:
        Dict[str, Union[str, int, float]]: Dictionary of metadata
    """
    try:
        df_raw = self.load()
    except Exception as e:
        logger.exception(
            f"Data loading / preparation error in {self.dataset_id}: {e}"
        )
        return {
            "id": self.dataset_id,
            "debug_meta_status": "LOAD_ERROR",
            "debug_meta_error_message": str(e),
        }

    try:
        sensitive_columns = self.sensitive_columns
        target_column = self.get_target_column()
        target_lvl_good_value = self.get_target_lvl_good_value()
        feature_columns = self.get_feature_columns(df=df_raw)

        general_descriptives = generate_general_descriptives(
            df_raw=df_raw,
            sensitive_columns=sensitive_columns,
            target_column=target_column,
            target_lvl_good_value=target_lvl_good_value,
            feature_columns=list(feature_columns),
        )

        try:
            df, info = self.binarize(df=df_raw, sensitive_columns=sensitive_columns)
            sensitive_columns = info.sensitive_columns
            col_na_indicator = info.col_na_indicator
        except Exception as e:
            logger.exception(f"Transformation Error in {self.dataset_id}: {e}")

        binarized_descriptives = generate_binarized_descriptives(
            df=df,
            sensitive_columns=sensitive_columns,
            target_column=target_column,
            col_na_indicator=col_na_indicator,
        )

        meta = {
            "id": self.dataset_id,
            # Metadata used for debugging / development
            "debug_meta_status": "OK",
            "debug_meta_colnames": ";".join(df.columns.tolist()),
            "debug_meta_coltypes": ";".join(
                [str(df[col].dtype) for col in df.columns]
            ),
        }
        meta.update(general_descriptives)
        meta.update(binarized_descriptives)

        return meta
    except Exception as e:
        logger.exception(f"Metadata error in {self.dataset_id}: {e}")
        return {
            "id": self.dataset_id,
            "debug_meta_status": "METADATA_ERROR",
            "debug_meta_error_message": str(e),
        }

generate_sensitive_intersections()

Generate all possible intersections of sensitive attributes.

Returns:

Type Description
List[List[str]]

List[List[str]]: List of all possible combinations of sensitive attributes

Source code in fairml_datasets/dataset.py
def generate_sensitive_intersections(self) -> List[List[str]]:
    """
    Generate all possible intersections of sensitive attributes.

    Returns:
        List[List[str]]: List of all possible combinations of sensitive attributes
    """
    if self._sensitive_columns is not None:
        warnings.warn(
            "Generating sensitive intersections on a scenario. You will usually want to do this on the dataset itself, as they will be generated for ALL available sensitive attributes, not the ones in the scenario."
        )
    sensitive_columns = self.get_all_sensitive_columns()
    all_combinations = list(
        chain.from_iterable(
            combinations(sensitive_columns, r)
            for r in range(1, len(sensitive_columns) + 1)
        )
    )
    return [list(combo) for combo in all_combinations]

get_all_sensitive_columns()

Get all available (typically used) sensitive columns for this dataset.

Returns:

Type Description
List[str]

List[str]: List of all sensitive column names

Source code in fairml_datasets/dataset.py
def get_all_sensitive_columns(self) -> List[str]:
    """
    Get all available (typically used) sensitive columns for this dataset.

    Returns:
        List[str]: List of all sensitive column names
    """
    sensitive_parsed = self._get_sensitive_parsed()
    return list(sensitive_parsed.keys()) if sensitive_parsed is not None else []

get_feature_columns(df)

Get the feature columns for the dataset.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the dataset

required

Returns:

Type Description
List[str]

List[str]: List of feature column names

Source code in fairml_datasets/dataset.py
def get_feature_columns(self, df: pd.DataFrame) -> List[str]:
    """
    Get the feature columns for the dataset.

    Args:
        df: DataFrame containing the dataset

    Returns:
        List[str]: List of feature column names
    """
    target_column = set([self.get_target_column()])
    sensitive_columns = set(self.sensitive_columns)
    typical_col_features = self.get_typical_col_features()
    colnames = set(df.columns)
    feature_columns = parse_feature_column_filter(
        colnames, target_column, sensitive_columns, typical_col_features
    )
    return feature_columns

get_filenames()

Get the expected filenames for the raw dataset files.

Returns:

Type Description
List[str]

List[str]: List of expected filenames

Source code in fairml_datasets/dataset.py
def get_filenames(self) -> List[str]:
    """
    Get the expected filenames for the raw dataset files.

    Returns:
        List[str]: List of expected filenames
    """
    urls = self.get_urls()
    if self.info["filename_raw"] is not None:
        return self.info["filename_raw"].split(";")
    elif urls:
        return [url.rsplit("/", 1)[-1] for url in urls]
    else:
        return [self.dataset_id]

get_processing_script(processing_options=None)

Get the ProcessingScript for this dataset with optional configuration.

Parameters:

Name Type Description Default
processing_options Dict[str, Any]

Dictionary of options to pass to the processing script

None

Returns:

Name Type Description
ProcessingScript ProcessingScript

The processing script for this dataset, or None if not available

Source code in fairml_datasets/dataset.py
def get_processing_script(
    self, processing_options: Dict[str, Any] = None
) -> ProcessingScript:
    """
    Get the ProcessingScript for this dataset with optional configuration.

    Args:
        processing_options: Dictionary of options to pass to the processing script

    Returns:
        ProcessingScript: The processing script for this dataset, or None if not available
    """
    ScriptClass = get_processing_script(self.dataset_id)
    if ScriptClass is not None:
        return ScriptClass(processing_options=processing_options)
    else:
        if processing_options is not None:
            warnings.warn(
                "Processing_options provided, but no processing script available. Ignoring processing_options."
            )
        return None

get_target_column()

Get the name of the target column for this dataset.

Returns:

Name Type Description
str str

Name of the target column

Source code in fairml_datasets/dataset.py
def get_target_column(self) -> str:
    """
    Get the name of the target column for this dataset.

    Returns:
        str: Name of the target column
    """
    target_col = self.info["typical_col_target"]
    # Strip leading question mark
    if target_col.startswith("?"):
        target_col = target_col[1:]
    # Separate multiple columns by semicolon
    if ";" in target_col:
        target_cols = target_col.split(";")
        # The first column is the "most" typical, so use this one for now
        # This is currently only an issue for the Drug dataset
        target_col = target_cols[0]
    return target_col

get_target_lvl_good_value()

Get the value in the target column that represents a favorable outcome.

Returns:

Type Description
Optional[str]

Optional[str]: The value representing a favorable outcome

Source code in fairml_datasets/dataset.py
def get_target_lvl_good_value(self) -> Optional[str]:
    """
    Get the value in the target column that represents a favorable outcome.

    Returns:
        Optional[str]: The value representing a favorable outcome
    """
    target_lvl_good = str(self.info["target_lvl_good"])
    # Strip leading question mark
    if target_lvl_good.startswith("?"):
        target_lvl_good = target_lvl_good[1:]
    # Encode empty string as None
    if target_lvl_good == "":
        return None
    return target_lvl_good

get_typical_col_features()

Get information about typical feature columns.

Returns:

Type Description
str | None

str | None: Information about typical feature columns

Source code in fairml_datasets/dataset.py
def get_typical_col_features(self) -> str | None:
    """
    Get information about typical feature columns.

    Returns:
        str | None: Information about typical feature columns
    """
    typical_col_features = self.info["typical_col_features"]
    if typical_col_features.startswith("?"):
        typical_col_features = typical_col_features[1:]
    if typical_col_features == "":
        return None
    return typical_col_features

get_urls()

Get the download URLs for this dataset.

Returns:

Type Description
List[str]

List[str]: List of download URLs, empty if none available

Source code in fairml_datasets/dataset.py
def get_urls(self) -> List[str]:
    """
    Get the download URLs for this dataset.

    Returns:
        List[str]: List of download URLs, empty if none available
    """
    if self.info["download_url"] is not None:
        return self.info["download_url"].split(";")
    else:
        return []

load(stage='prepared', cache_at='prepared', check_cache=True, processing_options=None)

Load the dataset at a specific processing stage.

Parameters:

Name Type Description Default
stage Literal['downloaded', 'loaded', 'prepared', 'binarized', 'transformed', 'split']

Processing stage at which to return the dataset

'prepared'
cache_at Literal['downloaded', 'prepared']

Stage at which to cache the dataset (downloaded or prepared)

'prepared'
check_cache bool

Whether to check for cached data

True
processing_options Optional[Dict[str, Any]]

Options to pass to the dataset's processing script (if available; optional)

None

Returns:

Type Description
DataFrame

pd.DataFrame: The pandas dataframe with data at the specified stage of processing

Source code in fairml_datasets/dataset.py
def load(
    self,
    stage: Literal[
        "downloaded", "loaded", "prepared", "binarized", "transformed", "split"
    ] = "prepared",
    cache_at: Literal["downloaded", "prepared"] = "prepared",
    check_cache: bool = True,
    processing_options: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
    """
    Load the dataset at a specific processing stage.

    Args:
        stage: Processing stage at which to return the dataset
        cache_at: Stage at which to cache the dataset (downloaded or prepared)
        check_cache: Whether to check for cached data
        processing_options: Options to pass to the dataset's processing script (if available; optional)

    Returns:
        pd.DataFrame: The pandas dataframe with data at the specified stage of processing
    """
    if stage == "split":
        logger.info(
            "Please use .split_dataset(), .train_test_split() or .train_test_val_split to split the dataset. 'Binarized' data will be returned from this function."
        )
        stage = "binarized"

    # Load the custom processing script (if it exists)
    script = self.get_processing_script(processing_options=processing_options)
    has_script = script is not None

    # Check the "prepared" cache
    # Add a hash to the cache name if necessary, as the script might have options
    opt_hash_postfix = (
        ("-" + script.get_options_hash())
        if has_script and script.has_options
        else ""
    )
    cached_filename = f"{self.dataset_id}{opt_hash_postfix}.parquet"
    cached_filepath = DATASET_CACHE_DIR / cached_filename
    if check_cache and cached_filepath.exists():
        logger.info(f"Loading cached dataset from {cached_filepath}")
        return pd.read_parquet(cached_filepath, engine="fastparquet")

    if not self.info["custom_download"]:
        assert (
            self.info["download_url"] is not None
        ), "Dataset is missing download URL."

        with make_temp_directory() as temp_dir:
            cache_download = cache_at == "downloaded"
            download_dir = temp_dir if not cache_download else DOWNLOAD_CACHE_DIR

            # Step 1: Download the dataset
            locations = self._download(
                directory=download_dir, read_cache=cache_download
            )
            if stage == "downloaded":
                # Cast to DataFrame so we can stick to one return type
                return pd.DataFrame({"file_paths": [str(loc) for loc in locations]})

            # Step 2: Load the dataset
            if has_script and isinstance(script, LoadingScript):
                dataset = script.load(locations)
                assert (
                    dataset is not None
                ), "LoadingScript returned None, maybe a return was missing?"
            else:
                assert (
                    len(locations) == 1
                ), "Multiple files, but no custom script to handle them."

                dataset = load_dataset(
                    location=locations[0],
                    data_format=self.info["format"],
                    colnames=self.info["colnames"],
                )
    else:
        if stage == "downloaded":
            warnings.warn(
                "Use stage == downloaded for a dataset that uses a custom download (most likely synthetic data). Returning an empty Dataframe."
            )
            return pd.DataFrame()

        # Alternative Step 1 & 2: Using a custom download & loading script
        if has_script and isinstance(script, LoadingScript):
            logger.debug(f"Detected LoadingScript for {self.dataset_id}.")
            dataset = script.load(locations=[])
        else:
            raise ValueError(
                f"Dataset {self.dataset_id} flagged as custom download, "
                "but is missing a loading script."
            )
    if stage == "loaded":
        return dataset

    # Step 3: Prepare the dataset (if applicable)
    if has_script and isinstance(script, PreparationScript):
        logger.debug(f"Detected PreparationScript for {self.dataset_id}.")
        dataset = script.prepare(dataset)
    # Cache the dataset (optional)
    if cache_at == "prepared":
        DATASET_CACHE_DIR.mkdir(exist_ok=True, parents=True)
        dataset.to_parquet(cached_filepath, index=False, engine="fastparquet")

    if stage == "prepared":
        return dataset
    if stage == "transformed":
        df, _ = self.transform(dataset)
        return df
    elif stage == "binarized":
        df, _ = self.binarize(
            dataset,
            transform_sensitive_columns="intersection_binary",
            transform_sensitive_values="majority_minority",
        )
        return df
    else:
        raise ValueError(f"Unsupported stage: {stage}")

split_dataset(df, splits, seed=DEFAULT_SEED, stratify=True, stratify_manual=None, sensitive_columns=None)

Split a dataset into multiple partitions based on specified ratios.

Parameters:

Name Type Description Default
df Union[DataFrame, BinaryLabelDataset]

DataFrame or AIF360 BinaryLabelDataset to split

required
splits Tuple[float, ...]

Tuple of fractions that sum to 1.0. For example, (0.6, 0.2, 0.2) for train/val/test

required
seed int

Random seed for reproducibility

DEFAULT_SEED
stratify bool

Whether to stratify the split by target column and sensitive attributes

True
stratify_manual Optional[str]

Column to stratify by (defaults to combined target+sensitive if None and stratify=True)

None
sensitive_columns Optional[List[str]]

List of sensitive attribute column names

None

Returns:

Type Description
Tuple[Union[DataFrame, BinaryLabelDataset], ...]

Tuple of datasets with the same type as df

Source code in fairml_datasets/dataset.py
def split_dataset(
    self,
    df: Union[pd.DataFrame, BinaryLabelDataset],
    splits: Tuple[float, ...],
    seed: int = DEFAULT_SEED,
    stratify: bool = True,
    stratify_manual: Optional[str] = None,
    sensitive_columns: Optional[List[str]] = None,
) -> Tuple[Union[pd.DataFrame, BinaryLabelDataset], ...]:
    """
    Split a dataset into multiple partitions based on specified ratios.

    Args:
        df: DataFrame or AIF360 BinaryLabelDataset to split
        splits: Tuple of fractions that sum to 1.0. For example, (0.6, 0.2, 0.2) for train/val/test
        seed: Random seed for reproducibility
        stratify: Whether to stratify the split by target column and sensitive attributes
        stratify_manual: Column to stratify by (defaults to combined target+sensitive if None and stratify=True)
        sensitive_columns: List of sensitive attribute column names

    Returns:
        Tuple of datasets with the same type as df
    """
    # Validate splits
    if not isinstance(splits, tuple) or len(splits) < 2:
        raise ValueError("splits must be a tuple with at least 2 elements")
    if round(sum(splits), 10) != 1.0:
        raise ValueError(f"Split values must sum to 1.0, got {sum(splits)}")

    if stratify_manual is not None and not stratify:
        logger.error(
            "stratify_manual is set but stratify is False. Setting stratify to True."
        )
        stratify = True

    # If AIF360 dataset, convert to pandas for splitting
    is_aif360 = isinstance(df, BinaryLabelDataset)
    pandas_df = df if not is_aif360 else df.convert_to_dataframe()[0]

    # Determine column to use for stratification
    stratify_col = None
    if stratify:
        # Manual column provided
        if stratify_manual is not None:
            if stratify_manual in pandas_df.columns:
                stratify_col = pandas_df[stratify_manual]
            else:
                raise ValueError(
                    f"Stratify column {stratify_manual} not found in dataframe."
                )
        else:
            # Get target and sensitive columns for combined stratification
            target_column = self.get_target_column()
            if sensitive_columns is None:
                sensitive_columns = self.sensitive_columns

            # Handle the case where sensitive columns were transformed
            # Check if original sensitive columns exist, if not try sensitive_intersection
            available_sensitive_columns = []
            for col in sensitive_columns:
                if col in pandas_df.columns:
                    available_sensitive_columns.append(col)

            # If no original sensitive columns found, check for sensitive_intersection
            if (
                not available_sensitive_columns
                and "sensitive_intersection" in pandas_df.columns
            ):
                sensitive_columns = ["sensitive_intersection"]

            strat_columns = [target_column] + sensitive_columns

            # Check if columns exist in the dataframe
            if not set(strat_columns).issubset(set(pandas_df.columns)):
                missing_columns = set(strat_columns) - set(pandas_df.columns)
                raise ValueError(
                    f"Target or sensitive columns not found in dataframe: {missing_columns}"
                )

            logger.info(
                f"Stratifying split by columns: {strat_columns}"
            )

            # Combine target and sensitive columns for stratification
            if len(strat_columns) == 1:
                stratify_col = pandas_df[strat_columns[0]]
            else:
                stratify_col = (
                    pandas_df[strat_columns]
                    .astype(str)
                    .apply(lambda x: "_".join(x.values), axis=1)
                )

    result = []
    remaining_df = pandas_df.copy()
    remaining_prob = 1.0

    # Iteratively split dataset
    for i, split_ratio in enumerate(splits[:-1]):
        if len(remaining_df) == 0:
            # Handle edge case of empty dataframe
            result.append(remaining_df.copy())
            continue

        # Calculate the proportion for this split from the remaining data
        proportion = split_ratio / remaining_prob

        # Get stratification data for current split
        current_stratify = (
            stratify_col.loc[remaining_df.index]
            if stratify_col is not None
            else None
        )

        # If there's only one class in the stratification column, don't stratify
        if current_stratify is not None and len(current_stratify.unique()) < 2:
            logger.warning(
                f"Only one class present in stratification column for split {i}. Not stratifying this split."
            )
            current_stratify = None

        # Split the data
        new_split, remaining_df = sk_train_test_split(
            remaining_df,
            test_size=1 - proportion,
            random_state=seed,
            stratify=current_stratify,
        )

        result.append(new_split)
        remaining_prob -= split_ratio

    # Add the final piece
    result.append(remaining_df)

    # Convert back to AIF360 if needed
    if is_aif360:
        return tuple(
            BinaryLabelDataset(
                df=split_df,
                label_names=df.label_names,
                protected_attribute_names=df.protected_attribute_names,
                favorable_label=df.favorable_label,
                unfavorable_label=df.unfavorable_label,
            )
            for split_df in result
        )

    return tuple(result)

to_aif360_BinaryLabelDataset(sensitive_columns=None, **kwargs)

Convert the dataset to AIF360 BinaryLabelDataset format.

Parameters:

Name Type Description Default
sensitive_columns Optional[List[str]]

List of sensitive attribute column names

None
**kwargs Any

Additional arguments for transformation passed to .binarize() and then .transform()

{}

Returns:

Name Type Description
BinaryLabelDataset BinaryLabelDataset

The dataset in AIF360 format

Source code in fairml_datasets/dataset.py
def to_aif360_BinaryLabelDataset(
    self,
    sensitive_columns: Optional[List[str]] = None,
    **kwargs: Any,
) -> "BinaryLabelDataset":  # noqa: F821
    """
    Convert the dataset to AIF360 BinaryLabelDataset format.

    Args:
        sensitive_columns: List of sensitive attribute column names
        **kwargs: Additional arguments for transformation passed to .binarize() and then .transform()

    Returns:
        BinaryLabelDataset: The dataset in AIF360 format
    """
    df = self.load()

    # Preprocessing (this can affect columns / column names)
    df, info = self.binarize(df=df, sensitive_columns=sensitive_columns, **kwargs)
    sensitive_columns = info.sensitive_columns

    # Create the AIF360 dataset
    from aif360.datasets import BinaryLabelDataset

    dataset = BinaryLabelDataset(
        favorable_label=1,
        unfavorable_label=0,
        df=df,
        label_names=[self.get_target_column()],
        protected_attribute_names=sensitive_columns,
    )

    return dataset

to_numpy()

Load the dataset as a numpy array.

Use .load() if you want more control.

Returns:

Type Description
ndarray

np.ndarray: The dataset as a numpy array

Source code in fairml_datasets/dataset.py
def to_numpy(self) -> np.ndarray:
    """
    Load the dataset as a numpy array.

    Use .load() if you want more control.

    Returns:
        np.ndarray: The dataset as a numpy array
    """
    return self.load().to_numpy()

to_pandas()

Load the dataset as a pandas DataFrame.

Use .load() if you want more control.

Returns:

Type Description
DataFrame

pd.DataFrame: The dataset as a pandas DataFrame

Source code in fairml_datasets/dataset.py
def to_pandas(self) -> pd.DataFrame:
    """
    Load the dataset as a pandas DataFrame.

    Use .load() if you want more control.

    Returns:
        pd.DataFrame: The dataset as a pandas DataFrame
    """
    return self.load()

train_test_split(df, test_size=0.3, **kwargs)

Split a dataset into train and test sets.

Parameters:

Name Type Description Default
df Union[DataFrame, BinaryLabelDataset]

DataFrame or BinaryLabelDataset to split (if None, dataset is loaded)

required
test_size float

Fraction of data to use for testing

0.3
**kwargs Any

Additional arguments to pass to split_dataset

{}

Returns:

Type Description
Tuple[Union[DataFrame, BinaryLabelDataset], Union[DataFrame, BinaryLabelDataset]]

Tuple of (train_data, test_data)

Source code in fairml_datasets/dataset.py
def train_test_split(
    self,
    df: Union[pd.DataFrame, BinaryLabelDataset],
    test_size: float = 0.3,
    **kwargs: Any,
) -> Tuple[
    Union[pd.DataFrame, BinaryLabelDataset], Union[pd.DataFrame, BinaryLabelDataset]
]:
    """
    Split a dataset into train and test sets.

    Args:
        df: DataFrame or BinaryLabelDataset to split (if None, dataset is loaded)
        test_size: Fraction of data to use for testing
        **kwargs: Additional arguments to pass to split_dataset

    Returns:
        Tuple of (train_data, test_data)
    """
    train_size = 1.0 - test_size
    return self.split_dataset(df, splits=(train_size, test_size), **kwargs)

train_test_val_split(df, test_size=0.2, val_size=0.2, **kwargs)

Split a dataset into train, validation, and test sets.

Parameters:

Name Type Description Default
df Union[DataFrame, BinaryLabelDataset]

DataFrame or BinaryLabelDataset to split

required
test_size float

Fraction of data to use for testing

0.2
val_size float

Fraction of data to use for validation

0.2
**kwargs Any

Additional arguments to pass to split_dataset

{}

Returns:

Type Description
Tuple[Union[DataFrame, BinaryLabelDataset], Union[DataFrame, BinaryLabelDataset], Union[DataFrame, BinaryLabelDataset]]

Tuple of (train_data, val_data, test_data)

Source code in fairml_datasets/dataset.py
def train_test_val_split(
    self,
    df: Union[pd.DataFrame, BinaryLabelDataset],
    test_size: float = 0.2,
    val_size: float = 0.2,
    **kwargs: Any,
) -> Tuple[
    Union[pd.DataFrame, BinaryLabelDataset],
    Union[pd.DataFrame, BinaryLabelDataset],
    Union[pd.DataFrame, BinaryLabelDataset],
]:
    """
    Split a dataset into train, validation, and test sets.

    Args:
        df: DataFrame or BinaryLabelDataset to split
        test_size: Fraction of data to use for testing
        val_size: Fraction of data to use for validation
        **kwargs: Additional arguments to pass to split_dataset

    Returns:
        Tuple of (train_data, val_data, test_data)
    """
    # Calculate train size based on test and validation sizes
    train_size = 1.0 - test_size - val_size

    if train_size <= 0:
        raise ValueError(
            f"Invalid split sizes: test_size ({test_size}) + val_size ({val_size}) must be less than 1.0"
        )

    return self.split_dataset(
        df, splits=(train_size, val_size, test_size), **kwargs
    )

transform(df=None, sensitive_columns=None, **kwargs)

Apply default transformations to the dataset.

Parameters:

Name Type Description Default
df Optional[DataFrame]

Optional DataFrame to transform, if None the dataset is loaded

None
sensitive_columns Optional[List[str]]

List of sensitive attribute column names

None
**kwargs Any

Additional arguments for transformation, passed to transform()

{}

Returns:

Type Description
tuple[DataFrame, PreprocessingInfo]

tuple[pd.DataFrame, PreprocessingInfo]: Transformed DataFrame and preprocessing information

Source code in fairml_datasets/dataset.py
def transform(
    self,
    df: Optional[pd.DataFrame] = None,
    sensitive_columns: Optional[List[str]] = None,
    **kwargs: Any,
) -> tuple[pd.DataFrame, PreprocessingInfo]:
    """
    Apply default transformations to the dataset.

    Args:
        df: Optional DataFrame to transform, if None the dataset is loaded
        sensitive_columns: List of sensitive attribute column names
        **kwargs: Additional arguments for transformation, passed to transform()

    Returns:
        tuple[pd.DataFrame, PreprocessingInfo]: Transformed DataFrame and preprocessing information
    """
    logger.debug(f"Transforming: {self.dataset_id}")

    if df is None:
        df = self.load()
    target_column = self.get_target_column()
    if sensitive_columns is None:
        sensitive_columns = self.sensitive_columns

    feature_columns = self.get_feature_columns(df=df)
    target_lvl_good_bad = self.get_target_lvl_good_value()

    return transform(
        df=df,
        sensitive_columns=sensitive_columns,
        feature_columns=feature_columns,
        target_column=target_column,
        target_lvl_good_bad=target_lvl_good_bad,
        **kwargs,
    )

Usage Examples

Basic Usage

from fairml_datasets import Dataset

# Load a dataset directly using Dataset.from_id (recommended)
dataset = Dataset.from_id("folktables_acsincome_small")

# Load the dataset
df = dataset.load()

# Print basic information
print(f"Dataset ID: {dataset.dataset_id}")
print(f"Sensitive columns: {dataset.sensitive_columns}")
print(f"Target column: {dataset.get_target_column()}")

Data Transformation

# Load a dataset
dataset = Dataset.from_id("folktables_acsincome_small")
df = dataset.load()

# Apply standard transformations
df_transformed, info = dataset.transform(df)

# Check transformation info
print(f"Original shape: {df.shape}")
print(f"Transformed shape: {df_transformed.shape}")
print(f"Transformed sensitive columns: {info.sensitive_columns}")

Train/Test Splitting

# Load and transform a dataset
dataset = Dataset.from_id("folktables_acsincome_small")
df = dataset.load()
df_transformed, _ = dataset.transform(df)

# Create train/test/validation split
train, test, val = dataset.train_test_val_split(
    df_transformed, 
    test_size=0.2,
    val_size=0.1,
    random_state=42
)

print(f"Train set size: {len(train)}")
print(f"Test set size: {len(test)}")
print(f"Validation set size: {len(val)}")

Metadata Generation

# Generate metadata for a dataset
dataset = Dataset.from_id("folktables_acsincome_small")
metadata = dataset.generate_metadata()

print(metadata)