diff --git a/langfair/constants/word_lists.py b/langfair/constants/word_lists.py index bf69bf4..2b03144 100644 --- a/langfair/constants/word_lists.py +++ b/langfair/constants/word_lists.py @@ -19,15 +19,19 @@ """Bias words utilised to compute the bias metrics. -This file includes word lists for the following 4 categories: +This file includes word lists for the following categories: - Race (Asian, Hispanic, and White): Most common names for each race (Garg et al. 2018) - Gender (Female, Male): Gender specific words - such as woman, man etc. - for each gender (Bolukbasi et al. 2016) - Adjectives: List of adjectives (Garg et al. 2018) - Professions: List of professions (Bolukbasi et al. 2016) + - Health Condition, Physical Appearance, Sexual Orientation, + Socioeconomic Class, Nationality: Demographic descriptor terms + adapted from the HolisticBias dataset (Smith et al. 2022) References: - Garg et al. 2018 | https://arxiv.org/abs/1711.08412 - Bolukbasi et al. 2016 | https://arxiv.org/abs/1607.06520 + - Smith et al. 2022 | https://arxiv.org/abs/2205.09209 Adapted from https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/metrics/bias_word_lists.py @@ -135,9 +139,50 @@ "folks", ] +# QUEER_WORDS is index-aligned with MALE_WORDS / FEMALE_WORDS (both 29 items). +# Pronoun positions hold a list of alternatives (they/xe/ey/zie etc.) so that +# each call randomly picks one neopronoun set, producing varied queer_prompts. +# Non-pronoun positions use the same gender-neutral terms as GENDER_NEUTRAL_WORDS. +QUEER_WORDS: List = [ + ["they", "xe", "ey", "zie"], # 0 ← he / she (subject) + "child", # 1 ← son / daughter + ["their", "xyr", "eir", "zir"], # 2 ← his / hers (possessive) + ["them", "xem", "em", "zim"], # 3 ← him / her (object) + "parent", # 4 ← father / mother + "person", # 5 ← man / woman + "kid", # 6 ← boy / girl + ["themself", "xemself", "emself", "ziemself"], # 7 ← himself / herself (reflexive) + "individual", # 8 ← male / female + "sibling", # 9 ← brother / sister + "children", # 10 ← sons / daughters + "parents", # 11 ← fathers / mothers + "people", # 12 ← men / women + "kids", # 13 ← boys / girls + "individuals", # 14 ← males / females + "siblings", # 15 ← brothers / sisters + "cousin", # 16 ← uncle / aunt + "cousins", # 17 ← uncles / aunts + "cousin", # 18 ← nephew / niece + "cousins", # 19 ← nephews / nieces + "person", # 20 ← gentleman / lady + "people", # 21 ← gentlemen / ladies + "grandparent", # 22 ← grandfather / grandmother + "grandparents", # 23 ← grandfathers / grandmothers + "partner", # 24 ← boyfriend / girlfriend + "partners", # 25 ← boyfriends / girlfriends + "Mx.", # 26 ← Mr. / Mrs. + "person", # 27 ← guy / gal + "folks", # 28 ← guys / gals +] + + +NONBINARY_WORDS: List[str] = GENDER_NEUTRAL_WORDS + GENDER_TO_WORD_LISTS: Dict[str, List[str]] = { "female": FEMALE_WORDS, "male": MALE_WORDS, + "nonbinary": NONBINARY_WORDS, + "queer": QUEER_WORDS, } ################################################################################ @@ -146,23 +191,43 @@ # For race, string search is done RACE_WORDS_NOT_REQUIRING_CONTEXT: List[str] = [ "caucasian", + "european american", + "european-american", + "alaska native", "asian-american", + "desi", + "east asian", + "east-asian", + "south asian", + "south-asian", + "southeast asian", + "southeast-asian", "african american", "african-american", "native american", "native-american", + "indigenous", + "american indian", + "american-indian", + "native hawaiian", + "native-hawaiian", + "aboriginal", "pacific islander", "pacific-islander", "latinx", "hispanic", "latino", "latina", + "latin american", + "latin-american", "whites", "blacks", "indians", "anglo-saxon", "anglo saxon", "asians", + "AAPI", + "BIPOC", ] RACE_WORDS_REQUIRING_CONTEXT: List[str] = [ @@ -170,6 +235,9 @@ "white", "indian", "asian", + "european", + "middle eastern", + "african", ] @@ -992,3 +1060,900 @@ + OTHER_PERSON_NOUNS ) PERSON_WORDS = list(set(PERSON_WORDS_LIST) - set(WORDS_TO_REMOVE)) + +YOUNG_WORDS: List[str] = [ + "young", + "younger", + "youth", + "adolescent", + "teenager", + "teen", + "teenage", + "teenaged", + "juvenile", + "toddler", + "infant", + "baby", + "child", +] + +YOUNG_WORDS_STRING_SEARCH: List[str] = [ + "twenty-year-old", + "20-year-old", + "twenty-five-year-old", + "25-year-old", + "thirty-year-old", + "30-year-old", + "thirty-five-year-old", + "35-year-old", + "forty-year-old", + "40-year-old", + "twenty-something", + "thirty-something", +] + +MIDDLE_AGED_WORDS: List[str] = [ + "midlife", +] + +MIDDLE_AGED_WORDS_STRING_SEARCH: List[str] = [ + "middle-aged", + "forty-five-year-old", + "45-year-old", + "fifty-year-old", + "50-year-old", + "fifty-five-year-old", + "55-year-old", + "sixty-year-old", + "60-year-old", + "forty-something", + "fifty-something", + "sixty-something", +] + +OLD_WORDS: List[str] = [ + "old", + "older", + "elderly", + "senior", + "aged", + "elder", + "geriatric", + "retiree", + "retired", + "spry", + "octogenarian", + "nonagenarian", + "centenarian", +] + +OLD_WORDS_STRING_SEARCH: List[str] = [ + "senior-citizen", + "young-at-heart", + "sixty-five-year-old", + "65-year-old", + "seventy-year-old", + "70-year-old", + "seventy-five-year-old", + "75-year-old", + "eighty-year-old", + "80-year-old", + "eighty-five-year-old", + "85-year-old", + "ninety-year-old", + "90-year-old", + "ninety-five-year-old", + "95-year-old", + "seventy-something", + "eighty-something", + "ninety-something", +] + +ADULT_WORDS: List[str] = [ + "adult", +] + +AGE_WORDS_STRING_SEARCH: List[str] = ( + YOUNG_WORDS_STRING_SEARCH + + MIDDLE_AGED_WORDS_STRING_SEARCH + + OLD_WORDS_STRING_SEARCH +) + +ALL_AGE_WORDS: List[str] = ( + YOUNG_WORDS + MIDDLE_AGED_WORDS + OLD_WORDS + ADULT_WORDS + + AGE_WORDS_STRING_SEARCH +) + + +HEALTHY_WORDS: List[str] = [ + "healthy", + "nondisabled", +] + +HEALTHY_WORDS_STRING_SEARCH: List[str] = [ + "able-bodied", + "non-disabled", +] + +DISABLED_WORDS: List[str] = [ + "disabled", + "handicapped", + "wheelchair", + "paraplegic", + "paralyzed", + "quadriplegic", + "amputee", +] + +DISABLED_WORDS_STRING_SEARCH: List[str] = [ + "differently-abled", + "differently abled", + "wheelchair-bound", + "wheelchair-user", + "who uses a wheelchair", + "with a disability", + "with a limb difference", + "who is an amputee", + "with an amputation", + "with limb loss", + "with a developmental disability", + "with a cognitive disability", + "with an intellectual disability", +] + +DEAF_WORDS: List[str] = [ + "deaf", +] + +DEAF_WORDS_STRING_SEARCH: List[str] = [ + "hard-of-hearing", + "hearing-impaired", + "cochlear-implant-using", + "with a cochlear implant", + "who is deaf", + "who is hard of hearing", + "with difficulty hearing", +] + +BLIND_WORDS: List[str] = [ + "blind", + "sightless", +] + +BLIND_WORDS_STRING_SEARCH: List[str] = [ + "visually-impaired", + "low-vision", + "vision-impaired", + "with low vision", + "with blindness", + "who is visually impaired", + "with difficulty seeing", +] + +MOBILITY_IMPAIRED_WORDS: List[str] = [] + +MOBILITY_IMPAIRED_WORDS_STRING_SEARCH: List[str] = [ + "with difficulty walking", + "with difficulty moving", + "who walks with crutches", +] + +NEUROLOGICAL_WORDS: List[str] = [ + "autistic", + "dyslexic", + "neuroatypical", + "neurodivergent", + "neurodiverse", + "neurotypical", + "depressed", +] + +NEUROLOGICAL_WORDS_STRING_SEARCH: List[str] = [ + "non-neurotypical", + "learning-disabled", + "on the spectrum", + "with autism", + "with dyslexia", + "who incurred a traumatic brain injury", + "managing depression", +] + +SPEECH_IMPAIRED_WORDS: List[str] = [ + "aphasic", + "mute", +] + +SPEECH_IMPAIRED_WORDS_STRING_SEARCH: List[str] = [ + "speech-impaired", + "with aphasia", + "with a lisp", + "who stutters", + "with a stutter", + "with a speech fluency disorder", + "without speech", +] + +GLASSES_WORDS: List[str] = [] + +GLASSES_WORDS_STRING_SEARCH: List[str] = [ + "glasses-wearing", + "with bifocals", +] + +MEMORY_IMPAIRED_WORDS: List[str] = [] + +MEMORY_IMPAIRED_WORDS_STRING_SEARCH: List[str] = [ + "with memory loss", +] + +CHRONICALLY_ILL_WORDS: List[str] = [ + "ill", + "sick", +] + +CHRONICALLY_ILL_WORDS_STRING_SEARCH: List[str] = [ + "chronically-ill", + "chronically-sick", +] + +PREGNANT_WORDS: List[str] = [ + "pregnant", + "expectant", +] + +HEALTH_CONDITION_WORDS_STRING_SEARCH: List[str] = ( + HEALTHY_WORDS_STRING_SEARCH + + DISABLED_WORDS_STRING_SEARCH + + DEAF_WORDS_STRING_SEARCH + + BLIND_WORDS_STRING_SEARCH + + MOBILITY_IMPAIRED_WORDS_STRING_SEARCH + + NEUROLOGICAL_WORDS_STRING_SEARCH + + SPEECH_IMPAIRED_WORDS_STRING_SEARCH + + GLASSES_WORDS_STRING_SEARCH + + MEMORY_IMPAIRED_WORDS_STRING_SEARCH + + CHRONICALLY_ILL_WORDS_STRING_SEARCH +) + +ALL_HEALTH_CONDITION_WORDS: List[str] = ( + HEALTHY_WORDS + + DISABLED_WORDS + + DEAF_WORDS + + BLIND_WORDS + + MOBILITY_IMPAIRED_WORDS + + NEUROLOGICAL_WORDS + + SPEECH_IMPAIRED_WORDS + + GLASSES_WORDS + + MEMORY_IMPAIRED_WORDS + + CHRONICALLY_ILL_WORDS + + PREGNANT_WORDS + + HEALTH_CONDITION_WORDS_STRING_SEARCH +) + +NATIONALITY_WORDS: List[str] = [ + 'afghan', + 'albanian', + 'algerian', + 'andorran', + 'angolan', + 'antiguan', + 'barbudan', + 'argentine', + 'armenian', + 'australian', + 'austrian', + 'azerbaijani', + 'azeri', + 'bahamian', + 'bahraini', + 'bengali', + 'barbadian', + 'belarusian', + 'belgian', + 'belizean', + 'beninese', + 'beninois', + 'bhutanese', + 'bolivian', + 'bosnian', + 'herzegovinian', + 'motswana', + 'botswanan', + 'brazilian', + 'bruneian', + 'bulgarian', + 'burkinabé', + 'burmese', + 'burundian', + 'cambodian', + 'cameroonian', + 'canadian', + 'chadian', + 'chilean', + 'chinese', + 'colombian', + 'comoran', + 'comorian', + 'congolese', + 'ivorian', + 'croatian', + 'cuban', + 'cypriot', + 'czech', + 'danish', + 'djiboutian', + 'dominican', + 'timorese', + 'ecuadorian', + 'egyptian', + 'salvadoran', + 'equatoguinean', + 'eritrean', + 'estonian', + 'ethiopian', + 'fijian', + 'finnish', + 'french', + 'gabonese', + 'gambian', + 'georgian', + 'german', + 'ghanaian', + 'gibraltar', + 'greek', + 'hellenic', + 'grenadian', + 'guatemalan', + 'guinean', + 'guyanese', + 'haitian', + 'honduran', + 'hungarian', + 'magyar', + 'icelandic', + 'indian', + 'indonesian', + 'iranian', + 'persian', + 'iraqi', + 'irish', + 'israeli', + 'italian', + 'jamaican', + 'japanese', + 'jordanian', + 'kazakhstani', + 'kazakh', + 'kenyan', + 'korean', + 'kuwaiti', + 'kyrgyzstani', + 'kyrgyz', + 'kirgiz', + 'kirghiz', + 'lao', + 'laotian', + 'latvian', + 'lettish', + 'lebanese', + 'basotho', + 'liberian', + 'libyan', + 'liechtensteiner', + 'lithuanian', + 'luxembourg', + 'luxembourgish', + 'macedonian', + 'malagasy', + 'malawian', + 'malaysian', + 'maldivian', + 'malian', + 'malinese', + 'maltese', + 'marshallese', + 'martiniquais', + 'martinican', + 'mauritanian', + 'mauritian', + 'mexican', + 'micronesian', + 'moldovan', + 'monégasque', + 'monacan', + 'mongolian', + 'montenegrin', + 'moroccan', + 'mozambican', + 'namibian', + 'nauruan', + 'nepali', + 'nepalese', + 'dutch', + 'netherlandic', + 'zelanian', + 'nicaraguan', + 'nigerien', + 'nigerian', + 'norwegian', + 'omani', + 'pakistani', + 'palauan', + 'palestinian', + 'panamanian', + 'papuan', + 'paraguayan', + 'peruvian', + 'filipino', + 'philippine', + 'polish', + 'portuguese', + 'qatari', + 'romanian', + 'russian', + 'rwandan', + 'kittitian', + 'nevisian', + 'vincentian', + 'samoan', + 'sammarinese', + 'saudi', + 'senegalese', + 'serbian', + 'seychellois', + 'singaporean', + 'slovak', + 'slovenian', + 'slovene', + 'somali', + 'spanish', + 'sudanese', + 'surinamese', + 'swazi', + 'swedish', + 'swiss', + 'syrian', + 'tajikistani', + 'tanzanian', + 'thai', + 'togolese', + 'tokelauan', + 'tongan', + 'trinidadian', + 'tobagonian', + 'tunisian', + 'turkish', + 'turkmen', + 'tuvaluan', + 'ugandan', + 'ukrainian', + 'emirati', + 'emirian', + 'emiri', + 'uk', + 'british', + 'us', + 'american', + 'uruguayan', + 'uzbekistani', + 'uzbek', + 'vanuatuan', + 'vatican', + 'venezuelan', + 'vietnamese', + 'yemeni', + 'zambian', + 'zimbabwean', +] + +NATIONALITY_WORDS_STRING_SEARCH: List[str] = [ + 'bissau-guinean', + 'cabo verdean', + 'central african', + 'costa rican', + 'equatorial guinean', + 'i-kiribati', + 'new zealand', + 'ni-vanuatu', + 'north korean', + 'northern marianan', + 'papua new guinean', + 'puerto rican', + 'saint lucian', + 'saint vincentian', + 'saudi arabian', + 'são toméan', + 'sierra leonean', + 'solomon island', + 'south african', + 'south sudanese', + 'sri lankan', + 'united states', +] + +ALL_NATIONALITY_WORDS: List[str] = NATIONALITY_WORDS + NATIONALITY_WORDS_STRING_SEARCH + +THIN_WORDS: List[str] = [ + "bony", + "gangly", + "lanky", + "skinny", + "slender", + "slim", + "svelte", + "thin", + "underweight", +] + +FIT_WORDS: List[str] = [ + "fit", + "athletic", + "muscular", + "toned", + "lean", + "ripped", + "swole", + "strong", +] + +FIT_WORDS_STRING_SEARCH: List[str] = [ + "in-shape", + "physically fit", + "well-built", +] + +OVERWEIGHT_WORDS: List[str] = [ + "overweight", + "obese", + "fat", + "heavy", + "chubby", + "chunky", + "curvy", + "heavyset", + "heftier", + "hefty", + "plump", + "potbellied", + "rotund", + "bulky", +] + +OVERWEIGHT_WORDS_STRING_SEARCH: List[str] = [ + "full-figured", + "heavy-set", + "mildly overweight", + "morbidly obese", + "plus-sized", + "slightly overweight", +] + +ATTRACTIVE_WORDS: List[str] = [ + "attractive", + "beautiful", + "handsome", + "pretty", + "gorgeous", + "adorable", + "cute", + "hot", + "hunky", + "sexy", +] + +ATTRACTIVE_WORDS_STRING_SEARCH: List[str] = [ + "good-looking", +] + +UNATTRACTIVE_WORDS: List[str] = [ + "unattractive", + "ugly", + "plain", + "homely", + "hideous", +] + +UNATTRACTIVE_WORDS_STRING_SEARCH: List[str] = [ + "plain-looking", +] + +LARGE_STATURE_WORDS: List[str] = [ + "tall", + "taller", + "towering", + "beefy", + "big", + "bigger", + "brawny", + "burly", + "giant", + "huge", + "large", + "larger", + "massive", + "stocky", + "gigantic", +] + +LARGE_STATURE_WORDS_STRING_SEARCH: List[str] = [ + "barrel-chested", + "big-boned", + "large-stature", + "very tall", +] + +MEDIUM_STATURE_WORDS: List[str] = [] + +MEDIUM_STATURE_WORDS_STRING_SEARCH: List[str] = [ + "average-height", + "medium-height", + "medium-stature", +] + +SMALL_STATURE_WORDS: List[str] = [ + "short", + "petite", + "small", + "tiny", + "littler", + "scrawny", + "shorter", + "smaller", + "miniature", +] + +SMALL_STATURE_WORDS_STRING_SEARCH: List[str] = [ + "small-stature", + "vertically challenged", + "very short", + "who's short in stature", + "who's of short stature", + "who is a little person", +] + +# Backward-compatible aliases +TALL_WORDS: List[str] = LARGE_STATURE_WORDS +SMALL_WORDS: List[str] = SMALL_STATURE_WORDS + +APPEARANCE_WORDS_STRING_SEARCH: List[str] = ( + FIT_WORDS_STRING_SEARCH + + OVERWEIGHT_WORDS_STRING_SEARCH + + ATTRACTIVE_WORDS_STRING_SEARCH + + UNATTRACTIVE_WORDS_STRING_SEARCH + + LARGE_STATURE_WORDS_STRING_SEARCH + + MEDIUM_STATURE_WORDS_STRING_SEARCH + + SMALL_STATURE_WORDS_STRING_SEARCH +) + +ALL_APPEARANCE_WORDS: List[str] = ( + THIN_WORDS + + FIT_WORDS + + OVERWEIGHT_WORDS + + ATTRACTIVE_WORDS + + UNATTRACTIVE_WORDS + + LARGE_STATURE_WORDS + + MEDIUM_STATURE_WORDS + + SMALL_STATURE_WORDS + + APPEARANCE_WORDS_STRING_SEARCH +) + + +ATHEIST_WORDS: List[str] = [ + "atheist", + "atheists", + "atheism", + "agnostic", + "nonreligious", + "secular", + "irreligious", +] + +CHRISTIAN_WORDS: List[str] = [ + "christian", + "christians", + "christianity", + "catholic", + "protestant", + "evangelical", + "lutheran", + "mormon", + "amish", + "unitarian", +] + +MUSLIM_WORDS: List[str] = [ + "muslim", + "muslims", + "islamic", + "islam", + "shia", + "sunni", +] + +JEWISH_WORDS: List[str] = [ + "jewish", + "jew", + "jews", + "judaism", + "hasidic", +] + +HINDU_WORDS: List[str] = [ + "hindu", + "hindus", + "hinduism", +] + +BUDDHIST_WORDS: List[str] = [ + "buddhist", + "buddhists", + "buddhism", +] + +OTHER_RELIGION_WORDS: List[str] = [ + "sikh", + "sikhs", + "sikhism", + "spiritual", + "spiritualist", + "pagan", + "wiccan", + "rasta", + "rastafarian", + "satanist", + "confucianist", + "confucian", + "taoist", + "taoism", + "shintoist", + "shinto", + "zoroastrian", + "jain", + "jains", + "jainism", + "druze", + "cheondoist", + "bahai", + "religious", +] + +ALL_RELIGION_WORDS: List[str] = ( + ATHEIST_WORDS + CHRISTIAN_WORDS + MUSLIM_WORDS + + JEWISH_WORDS + HINDU_WORDS + BUDDHIST_WORDS + + OTHER_RELIGION_WORDS +) + + +ASEXUAL_AROMANTIC_WORDS: List[str] = [ + "asexual", + "ace", + "aromantic", +] + +BISEXUAL_WORDS: List[str] = [ + "bi", + "bisexual", +] + +DEMISEXUAL_WORDS: List[str] = [ + "demisexual", +] + +GAY_WORDS: List[str] = [ + "gay", + "homosexual", +] + +LESBIAN_WORDS: List[str] = [ + "lesbian", +] + +PANSEXUAL_WORDS: List[str] = [ + "pan", + "pansexual", +] + +POLYAMOROUS_WORDS: List[str] = [ + "polyamorous", + "poly", +] + +QUEER_ORIENTATION_WORDS: List[str] = [ + "queer", + "lgbtq", + "lgbt", +] + +HETEROSEXUAL_WORDS: List[str] = [ + "straight", + "hetero", + "heterosexual", +] + +# Backward-compatible aliases +HOMOSEXUAL_WORDS: List[str] = GAY_WORDS + LESBIAN_WORDS +QUEER_IDENTITY_WORDS: List[str] = QUEER_ORIENTATION_WORDS + BISEXUAL_WORDS + PANSEXUAL_WORDS + +ALL_SEXUAL_ORIENTATION_WORDS: List[str] = ( + ASEXUAL_AROMANTIC_WORDS + + BISEXUAL_WORDS + + DEMISEXUAL_WORDS + + GAY_WORDS + + LESBIAN_WORDS + + PANSEXUAL_WORDS + + POLYAMOROUS_WORDS + + QUEER_ORIENTATION_WORDS + + HETEROSEXUAL_WORDS +) + + +UPPER_CLASS_WORDS: List[str] = [ + "wealthy", + "rich", + "affluent", + "privileged", + "moneyed", +] + +UPPER_CLASS_WORDS_STRING_SEARCH: List[str] = [ + "upper-class", + "financially well-off", + "high-net-worth", + "one-percenter", + "well-to-do", + "well-off", +] + +MIDDLE_CLASS_WORDS: List[str] = [] + +MIDDLE_CLASS_WORDS_STRING_SEARCH: List[str] = [ + "middle-class", +] + +WORKING_CLASS_WORDS: List[str] = [ + "impoverished", + "underprivileged", +] + +WORKING_CLASS_WORDS_STRING_SEARCH: List[str] = [ + "working-class", + "trailer trash", +] + +BELOW_POVERTY_LINE_WORDS: List[str] = [ + "poor", + "broke", +] + +BELOW_POVERTY_LINE_WORDS_STRING_SEARCH: List[str] = [ + "low-income", +] + +EDUCATIONAL_ATTAINMENT_WORDS: List[str] = [] + +EDUCATIONAL_ATTAINMENT_WORDS_STRING_SEARCH: List[str] = [ + "high-school-dropout", + "college-graduate", + "who dropped out of high school", + "with a high school diploma", + "with a college degree", + "with a bachelor's degree", + "with a master's degree", + "with a PhD", +] + +SOCIOECONOMIC_CLASS_WORDS_STRING_SEARCH: List[str] = ( + UPPER_CLASS_WORDS_STRING_SEARCH + + MIDDLE_CLASS_WORDS_STRING_SEARCH + + WORKING_CLASS_WORDS_STRING_SEARCH + + BELOW_POVERTY_LINE_WORDS_STRING_SEARCH + + EDUCATIONAL_ATTAINMENT_WORDS_STRING_SEARCH +) + +ALL_SOCIOECONOMIC_CLASS_WORDS: List[str] = ( + UPPER_CLASS_WORDS + + MIDDLE_CLASS_WORDS + + WORKING_CLASS_WORDS + + BELOW_POVERTY_LINE_WORDS + + EDUCATIONAL_ATTAINMENT_WORDS + + SOCIOECONOMIC_CLASS_WORDS_STRING_SEARCH +) diff --git a/langfair/generator/counterfactual.py b/langfair/generator/counterfactual.py index f737f01..b3847b6 100644 --- a/langfair/generator/counterfactual.py +++ b/langfair/generator/counterfactual.py @@ -10,6 +10,7 @@ import asyncio import itertools +import random import time import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -31,6 +32,18 @@ PERSON_WORDS, RACE_WORDS_NOT_REQUIRING_CONTEXT, RACE_WORDS_REQUIRING_CONTEXT, + ALL_AGE_WORDS, + AGE_WORDS_STRING_SEARCH, + ALL_HEALTH_CONDITION_WORDS, + HEALTH_CONDITION_WORDS_STRING_SEARCH, + ALL_NATIONALITY_WORDS, + NATIONALITY_WORDS_STRING_SEARCH, + ALL_APPEARANCE_WORDS, + APPEARANCE_WORDS_STRING_SEARCH, + ALL_RELIGION_WORDS, + ALL_SEXUAL_ORIENTATION_WORDS, + ALL_SOCIOECONOMIC_CLASS_WORDS, + SOCIOECONOMIC_CLASS_WORDS_STRING_SEARCH, ) from langfair.generator.generator import ResponseGenerator from langfair.utils.display import ( @@ -105,8 +118,22 @@ def __init__( self.attribute_to_word_lists = { "race": ALL_RACE_WORDS, "gender": ALL_GENDER_WORDS, + "age": ALL_AGE_WORDS, + "health-condition":ALL_HEALTH_CONDITION_WORDS, + "nationality":ALL_NATIONALITY_WORDS, + "physical-appearance":ALL_APPEARANCE_WORDS, + "religion": ALL_RELIGION_WORDS, + "sexual-orientation":ALL_SEXUAL_ORIENTATION_WORDS, + "socioeconomic-class":ALL_SOCIOECONOMIC_CLASS_WORDS } self.attribute_to_ref_dicts = {"gender": GENDER_TO_WORD_LISTS} + self.attribute_to_string_search_lists = { + "age": AGE_WORDS_STRING_SEARCH, + "health-condition": HEALTH_CONDITION_WORDS_STRING_SEARCH, + "nationality": NATIONALITY_WORDS_STRING_SEARCH, + "physical-appearance": APPEARANCE_WORDS_STRING_SEARCH, + "socioeconomic-class": SOCIOECONOMIC_CLASS_WORDS_STRING_SEARCH, + } self.gender_to_word_lists = GENDER_TO_WORD_LISTS self.cf_gender_mapping = GENDER_MAPPING self.gender_neutral_mapping = GENDER_NEUTRAL_MAPPING @@ -114,8 +141,15 @@ def __init__( self.strict_race_words = STRICT_RACE_WORDS self.detokenizer = sacremoses.MosesDetokenizer("en") self.group_mapping = { - "gender": ["male", "female"], + "gender": ["male", "female", "nonbinary", "queer"], "race": ["white", "black", "hispanic", "asian"], + "age": ["young", "middle-aged", "old"], + "health-condition": ["healthy", "disabled", "chronically_ill", "blind", "pregnant", ], + "nationality": ["american", "bolivian", "german", "indian", "albanian", "nigerian", "namibian", "nepali", "iranian", "samoan"], + "physical-appearance": ["fit", "attractive", "unattractive", "overweight", "tall", "small"], + "religion": ["atheist", "christian", "muslim", "jewish", "hindu", "buddhist"], + "sexual-orientation": ["homosexual", "queer", "heterosexual"], + "socioeconomic-class": ["upper-class", "middle-class", "working-class"], } try: @@ -267,10 +301,10 @@ def create_prompts( for race in self.group_mapping[attribute] } - else: + elif attribute == "gender" or custom_dict: if custom_dict: ref_dict = custom_dict - elif attribute == "gender": + else: ref_dict = self.attribute_to_ref_dicts[attribute] prompts_dict = {key + "_prompt": [] for key in ref_dict} @@ -282,6 +316,15 @@ def create_prompts( for key in counterfactual_prompts: prompts_dict[key + "_prompt"].append(counterfactual_prompts[key]) + else: + prompts_dict = { + group + "_prompt": [ + self._replace_attribute(text=text, attribute=attribute, target_group=group) + for text in prompts + ] + for group in self.group_mapping[attribute] + } + prompts_dict["original_prompt"] = prompts prompts_dict["attribute_words"] = [ attr_word for attr_word in attribute_words if len(attr_word) > 0 @@ -597,8 +640,13 @@ def _token_parser( tokens = word_tokenize(str(text).lower()) if attribute == "race": return self._get_race_subsequences(text) - elif attribute == "gender": - return list(set(tokens) & set(self.attribute_to_word_lists[attribute])) + elif attribute in self.attribute_to_word_lists: + token_matches = list(set(tokens) & set(self.attribute_to_word_lists[attribute])) + string_matches = [ + w for w in self.attribute_to_string_search_lists.get(attribute, []) + if w in text.lower() + ] + return list(set(token_matches + string_matches)) elif custom_list: return list(set(tokens) & set(custom_list)) @@ -608,20 +656,33 @@ def _sub_from_dict( """ Creates counterfactual variations based on a dictionary of reference lists. """ - ref_dict = {key: [t.lower() for t in val] for key, val in ref_dict.items()} + ref_dict = { + key: [ + [s.lower() for s in t] if isinstance(t, list) else t.lower() + for t in val + ] + for key, val in ref_dict.items() + } lower_tokens = word_tokenize(text.lower()) ref_values = { - val: idx for key in ref_dict for idx, val in enumerate(ref_dict[key]) + val: idx + for key in ref_dict + for idx, val in enumerate(ref_dict[key]) + if isinstance(val, str) } output_dict = {key: [None] * len(lower_tokens) for key in ref_dict} for key in ref_dict.keys(): for i, element in enumerate(lower_tokens): - output_dict[key][i] = ( - ref_dict[key][ref_values[element]] - if element in ref_values - else element - ) + if element in ref_values: + substitution = ref_dict[key][ref_values[element]] + output_dict[key][i] = ( + random.choice(substitution) + if isinstance(substitution, list) + else substitution + ) + else: + output_dict[key][i] = element output_dict[key] = self.detokenizer.detokenize(output_dict[key]) return output_dict @@ -655,6 +716,14 @@ def _get_race_subsequences(text: str) -> List[str]: seq = text.lower() return [subseq for subseq in STRICT_RACE_WORDS if subseq in seq] + def _replace_attribute(self, text: str, attribute: str, target_group: str) -> str: + """Replaces attribute words in text with target group label""" + seq = text.lower() + for word in self.attribute_to_word_lists[attribute]: + if word in seq: + seq = seq.replace(word, target_group) + return seq + @staticmethod def _replace_race(text: str, target_race: str) -> str: """Replaces text with a target word""" @@ -679,18 +748,24 @@ def _validate_attributes( custom_list: Optional[List[str]] = None, custom_dict: Optional[Dict[str, str]] = None, for_parsing: bool = True, + valid_attributes: Optional[List[str]] = None, ) -> None: + if valid_attributes is None: + valid_attributes = [ + "race", "gender", "age", "health-condition", "nationality", + "physical-appearance", "religion", "sexual-orientation", "socioeconomic-class" + ] if for_parsing: if custom_list and attribute: raise ValueError("Either custom_list or attribute must be None.") - if not (custom_list or attribute in ["race", "gender"]): + if not (custom_list or attribute in valid_attributes): raise ValueError( - "If custom_list is None, attribute must be 'race' or 'gender'." + f"If custom_list is None, attribute must be one of {valid_attributes}." ) else: if custom_dict and attribute: raise ValueError("Either custom_dict or attribute must be None.") - if not (custom_dict or attribute in ["race", "gender"]): + if not (custom_dict or attribute in valid_attributes): raise ValueError( - "If custom_dict is None, attribute must be 'race' or 'gender'." + f"If custom_dict is None, attribute must be one of {valid_attributes}." )