Skip to content

MMLU

Bases: DatasetBuilder

MMLU (Massive Multitask Language Understanding) is a massive multitask test consisting of multiple-choice questions from various domains.

We are interested in the tasks that could be related to the medical domain, so use the following subsets: - Clinical knowledge - Medical genetics - Anatomy - Professional medicine - College biology - College medicine

Paper: MMLU: Massive Multitask Language Understanding

7 Sep 2020 · Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt https://arxiv.org/abs/2009.03300

We use the version on the hugging face datasets: https://huggingface.co/datasets/lukaemon/mmlu

Source code in medplexity/benchmarks/mmlu/mmlu_dataset_builder.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class MMLUDatasetBuilder(DatasetBuilder):
    """MMLU (Massive Multitask Language Understanding) is a massive multitask test consisting of multiple-choice questions from various domains.

    We are interested in the tasks that could be related to the medical domain, so use the following subsets:
    - Clinical knowledge
    - Medical genetics
    - Anatomy
    - Professional medicine
    - College biology
    - College medicine

    Paper: MMLU: Massive Multitask Language Understanding

    7 Sep 2020  ·  Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
    <https://arxiv.org/abs/2009.03300>

    We use the version on the hugging face datasets: <https://huggingface.co/datasets/lukaemon/mmlu>
    """

    EXAMPLE_QUESTIONS_PATH = Path(__file__).resolve().parent / "examples.json"

    def build_dataset(
        self,
        split_type: MMLUQADatasetSplitType = "train",
        config=None,
    ) -> Dataset[MMLUDataPoint]:
        if config is None:
            config = {"subset": MMLUSubsetConfig.clinical_knowledge}

        dataset = self.loader.load("lukaemon/mmlu", config["subset"], split=split_type)

        questions = [MMLUQuestion(**row) for row in dataset]

        data_points = [
            MMLUDataPoint(
                id=f"{split_type}-{i}",
                input=MultipleChoiceInput(
                    question=question.input,
                    options=[question.A, question.B, question.C, question.D],
                ),
                expected_output=f"({question.target})",
                metadata=None,
            )
            for i, question in enumerate(questions)
        ]

        return Dataset[MMLUDataPoint](data_points=data_points, description=self.__doc__)