diff --git a/langfair/auto/auto.py b/langfair/auto/auto.py index ff78688..75b2d34 100644 --- a/langfair/auto/auto.py +++ b/langfair/auto/auto.py @@ -39,6 +39,7 @@ Protected_Attributes = { "race": ["white", "black", "asian", "hispanic"], "gender": ["male", "female"], + "sexual_orientation": ["heterosexual", "gay", "lesbian", "bisexual"], } @@ -179,7 +180,7 @@ async def evaluate( print("------------------------------------------") # 1. Check for Fairness Through Unawareness FTU # Parse prompts for protected attribute words - protected_words = {"race": 0, "gender": 0} + protected_words = {attr: 0 for attr in Protected_Attributes} total_protected_words = 0 for attribute in protected_words.keys(): diff --git a/langfair/constants/word_lists.py b/langfair/constants/word_lists.py index bf69bf4..850ed3b 100644 --- a/langfair/constants/word_lists.py +++ b/langfair/constants/word_lists.py @@ -22,6 +22,7 @@ This file includes word lists for the following 4 categories: - Race (Asian, Hispanic, and White): Most common names for each race (Garg et al. 2018) - Gender (Female, Male): Gender specific words - such as woman, man etc. - for each gender (Bolukbasi et al. 2016) + - Sexual Orientation (Gay, Straight, Bisexual, etc.): Terms influenced from HRC Glossary of Terms (https://www.hrc.org/resources/glossary-of-terms) - Adjectives: List of adjectives (Garg et al. 2018) - Professions: List of professions (Bolukbasi et al. 2016) @@ -173,6 +174,32 @@ ] +################################################################################ +# Define sexual orientation word lists +################################################################################ +# For sexual orientation, string search is done (similar to race) +SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT: List[str] = [ + "homosexual", + "heterosexual", + "bisexual", + "lesbian", + "queer", + "lgbtq", + "lgbt", + "homosexuals", + "heterosexuals", + "bisexuals", + "lesbians", + "nonheterosexual", + "non-heterosexual", +] + +SEXUAL_ORIENTATION_WORDS_REQUIRING_CONTEXT: List[str] = [ + "gay", + "straight", +] + + # Stereotype groups STEREOTYPE_GROUP_NUMBER_MAPPING = { "Homosexuals": "The homosexual", diff --git a/langfair/generator/counterfactual.py b/langfair/generator/counterfactual.py index f737f01..e7396d3 100644 --- a/langfair/generator/counterfactual.py +++ b/langfair/generator/counterfactual.py @@ -31,6 +31,8 @@ PERSON_WORDS, RACE_WORDS_NOT_REQUIRING_CONTEXT, RACE_WORDS_REQUIRING_CONTEXT, + SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT, + SEXUAL_ORIENTATION_WORDS_REQUIRING_CONTEXT, ) from langfair.generator.generator import ResponseGenerator from langfair.utils.display import ( @@ -60,6 +62,19 @@ ) # Extend to include words that indicate race whether or not a person word follows STRICT_RACE_WORDS = list(set(STRICT_RACE_WORDS)) ALL_RACE_WORDS = RACE_WORDS_REQUIRING_CONTEXT + RACE_WORDS_NOT_REQUIRING_CONTEXT + +STRICT_SEXUAL_ORIENTATION_WORDS = [] +for sow in SEXUAL_ORIENTATION_WORDS_REQUIRING_CONTEXT: + for pw in PERSON_WORDS: + STRICT_SEXUAL_ORIENTATION_WORDS.append(sow + " " + pw) + +# Extend to include words that indicate sexual orientation whether or not a person word follows +STRICT_SEXUAL_ORIENTATION_WORDS.extend(SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT) +STRICT_SEXUAL_ORIENTATION_WORDS = list(set(STRICT_SEXUAL_ORIENTATION_WORDS)) +ALL_SEXUAL_ORIENTATION_WORDS = ( + SEXUAL_ORIENTATION_WORDS_REQUIRING_CONTEXT + + SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT +) warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -105,6 +120,7 @@ def __init__( self.attribute_to_word_lists = { "race": ALL_RACE_WORDS, "gender": ALL_GENDER_WORDS, + "sexual_orientation": ALL_SEXUAL_ORIENTATION_WORDS, } self.attribute_to_ref_dicts = {"gender": GENDER_TO_WORD_LISTS} self.gender_to_word_lists = GENDER_TO_WORD_LISTS @@ -112,10 +128,12 @@ def __init__( self.gender_neutral_mapping = GENDER_NEUTRAL_MAPPING self.all_race_words = ALL_RACE_WORDS self.strict_race_words = STRICT_RACE_WORDS + self.strict_sexual_orientation_words = STRICT_SEXUAL_ORIENTATION_WORDS self.detokenizer = sacremoses.MosesDetokenizer("en") self.group_mapping = { "gender": ["male", "female"], "race": ["white", "black", "hispanic", "asian"], + "sexual_orientation": ["heterosexual", "gay", "lesbian", "bisexual"], } try: @@ -145,7 +163,7 @@ async def estimate_token_cost( tiktoken_model_name: str The name of the OpenAI model to use for token counting. - attribute: str, either 'gender' or 'race' + attribute: str, either 'gender', 'race', or 'sexual_orientation' Specifies attribute to be used for counterfactual generation example_responses : list of strings, default=None @@ -196,9 +214,9 @@ def parse_texts( texts : list of strings A list of texts to be parsed for protected attribute words - attribute : {'race','gender'}, default=None - Specifies what to parse for among race words and gender words. Must be specified - if custom_list is None + attribute : {'race','gender','sexual_orientation'}, default=None + Specifies what to parse for among race words, gender words, and sexual orientation + words. Must be specified if custom_list is None custom_list : List[str], default=None Custom list of tokens to use for parsing prompts. Must be provided if attribute is None. @@ -233,9 +251,9 @@ def create_prompts( prompts : List[str] A list of prompts on which counterfactual substitution and response generation will be done - attribute : {'gender', 'race'}, default=None - Specifies whether to use race or gender for counterfactual substitution. Must be provided if - custom_dict is None. + attribute : {'gender', 'race', 'sexual_orientation'}, default=None + Specifies whether to use race, gender, or sexual orientation for counterfactual + substitution. Must be provided if custom_dict is None. custom_dict : Dict[str, List[str]], default=None A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys @@ -267,6 +285,14 @@ def create_prompts( for race in self.group_mapping[attribute] } + elif attribute == "sexual_orientation": + prompts_dict = { + orientation + "_prompt": self._counterfactual_sub_sexual_orientation( + texts=prompts, target_orientation=orientation + ) + for orientation in self.group_mapping[attribute] + } + else: if custom_dict: ref_dict = custom_dict @@ -292,30 +318,36 @@ def neutralize_tokens( self, texts: List[str], attribute: str = "gender" ) -> List[str]: """ - Neutralize gender and race words contained in a list of texts. Replaces gender words with a - gender-neutral equivalent and race words with "[MASK]". + Neutralize gender, race, and sexual orientation words contained in a list of texts. + Replaces gender words with a gender-neutral equivalent and race or sexual orientation + words with "[MASK]". Parameters ---------- texts : List[str] - A list of texts on which gender or race neutralization will occur + A list of texts on which gender, race, or sexual orientation neutralization will occur - attribute : {'gender', 'race'}, default='gender' - Specifies whether to use race or gender for neutralization + attribute : {'gender', 'race', 'sexual_orientation'}, default='gender' + Specifies whether to use race, gender, or sexual orientation for neutralization Returns ------- list - List of texts neutralized for race or gender + List of texts neutralized for race, gender, or sexual orientation """ assert attribute in [ "gender", "race", - ], "Only gender and race attributes are supported." + "sexual_orientation", + ], "Only gender, race, and sexual_orientation attributes are supported." if attribute == "gender": return [self._neutralize_gender(text) for text in texts] elif attribute == "race": return self._counterfactual_sub_race(texts=texts, target_race="[MASK]") + elif attribute == "sexual_orientation": + return self._counterfactual_sub_sexual_orientation( + texts=texts, target_orientation="[MASK]" + ) async def generate_responses( self, @@ -335,9 +367,9 @@ async def generate_responses( prompts : list of strings A list of prompts on which counterfactual substitution and response generation will be done - attribute : {'gender', 'race'}, default=None - Specifies whether to use race or gender for counterfactual substitution. Must be provided if - custom_dict is None. + attribute : {'gender', 'race', 'sexual_orientation'}, default=None + Specifies whether to use race, gender, or sexual orientation for counterfactual + substitution. Must be provided if custom_dict is None. custom_dict : Dict[str, List[str]], default=None A dictionary containing corresponding lists of tokens for counterfactual substitution. Keys @@ -457,9 +489,9 @@ def check_ftu( prompts : list of strings A list of prompts to be parsed for protected attribute words - attribute : {'race','gender'}, default=None - Specifies what to parse for among race words and gender words. Must be specified - if custom_list is None + attribute : {'race','gender','sexual_orientation'}, default=None + Specifies what to parse for among race words, gender words, and sexual orientation + words. Must be specified if custom_list is None custom_list : List[str], default=None Custom list of tokens to use for parsing prompts. Must be provided if attribute is None. @@ -599,6 +631,8 @@ def _token_parser( return self._get_race_subsequences(text) elif attribute == "gender": return list(set(tokens) & set(self.attribute_to_word_lists[attribute])) + elif attribute == "sexual_orientation": + return self._get_sexual_orientation_subsequences(text) elif custom_list: return list(set(tokens) & set(custom_list)) @@ -673,6 +707,52 @@ def _replace_race(text: str, target_race: str) -> str: seq = seq.replace(subseq, race_replacement_mapping[subseq]) return seq + @staticmethod + def _get_sexual_orientation_subsequences(text: str) -> List[str]: + """Used to check for sexual orientation string sequences""" + seq = text.lower() + return [subseq for subseq in STRICT_SEXUAL_ORIENTATION_WORDS if subseq in seq] + + def _counterfactual_sub_sexual_orientation( + self, + texts: List[str], + target_orientation: str, + ) -> List[str]: + """Implements counterfactual substitution for sexual orientation""" + new_texts = [] + for text in texts: + new_text = self._replace_sexual_orientation(text, target_orientation) + new_texts.append(new_text) + return new_texts + + @staticmethod + def _replace_sexual_orientation(text: str, target_orientation: str) -> str: + """Replaces sexual orientation words with a target orientation word""" + seq = text.lower() + orientation_replacement_mapping = {} + # Build a set of singular NOT_REQUIRING_CONTEXT words for plural detection + singular_words = set(SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT) + for sow in SEXUAL_ORIENTATION_WORDS_REQUIRING_CONTEXT: + for pw in PERSON_WORDS: + key = sow + " " + pw + orientation_replacement_mapping[key] = target_orientation + " " + pw + for sow in SEXUAL_ORIENTATION_WORDS_NOT_REQUIRING_CONTEXT: + # Preserve plural form: if source word is a plural of another listed word, use plural target + if ( + sow.endswith("s") + and sow[:-1] in singular_words + and target_orientation != "[MASK]" + ): + orientation_replacement_mapping[sow] = target_orientation + "s" + else: + orientation_replacement_mapping[sow] = target_orientation + + # Replace longest matches first to avoid partial replacements + for subseq in sorted(STRICT_SEXUAL_ORIENTATION_WORDS, key=len, reverse=True): + if subseq in seq: + seq = seq.replace(subseq, orientation_replacement_mapping[subseq]) + return seq + @staticmethod def _validate_attributes( attribute: Optional[str] = None, @@ -683,14 +763,18 @@ def _validate_attributes( if for_parsing: if custom_list and attribute: raise ValueError("Either custom_list or attribute must be None.") - if not (custom_list or attribute in ["race", "gender"]): + if not ( + custom_list or attribute in ["race", "gender", "sexual_orientation"] + ): raise ValueError( - "If custom_list is None, attribute must be 'race' or 'gender'." + "If custom_list is None, attribute must be 'race', 'gender', or 'sexual_orientation'." ) else: if custom_dict and attribute: raise ValueError("Either custom_dict or attribute must be None.") - if not (custom_dict or attribute in ["race", "gender"]): + if not ( + custom_dict or attribute in ["race", "gender", "sexual_orientation"] + ): raise ValueError( - "If custom_dict is None, attribute must be 'race' or 'gender'." + "If custom_dict is None, attribute must be 'race', 'gender', or 'sexual_orientation'." ) diff --git a/tests/test_counterfactualgenerator.py b/tests/test_counterfactualgenerator.py index 935fb5a..2eb38c2 100644 --- a/tests/test_counterfactualgenerator.py +++ b/tests/test_counterfactualgenerator.py @@ -109,3 +109,135 @@ async def mock_async_api_call(prompt, *args, **kwargs): if "response" in key ] ) + + +@pytest.mark.asyncio +async def test_counterfactual_sexual_orientation(monkeypatch): + MOCKED_PROMPTS = [ + "prompt 1: male person", + "prompt 2: female person", + "prompt 3: homosexual person", + "prompt 4: lesbian couple", + ] + MOCKED_SEXUAL_ORIENTATION_PROMPTS = { + "heterosexual_prompt": [ + "prompt 3: heterosexual person", + "prompt 4: heterosexual couple", + ], + "gay_prompt": ["prompt 3: gay person", "prompt 4: gay couple"], + "lesbian_prompt": [ + "prompt 3: lesbian person", + "prompt 4: lesbian couple", + ], + "bisexual_prompt": [ + "prompt 3: bisexual person", + "prompt 4: bisexual couple", + ], + "attribute_words": [["homosexual"], ["lesbian"]], + "original_prompt": [ + "prompt 3: homosexual person", + "prompt 4: lesbian couple", + ], + } + MOCKED_RESPONSES = [ + "Gender response", + "Sexual orientation response", + ] + + async def mock_async_api_call(prompt, *args, **kwargs): + if "1" in prompt or "2" in prompt: + return [MOCKED_RESPONSES[0]] + elif "3" in prompt or "4" in prompt: + return [MOCKED_RESPONSES[-1]] + + mock_object = AzureChatOpenAI( + deployment_name="YOUR-DEPLOYMENT", + temperature=0, + api_key="SECRET_API_KEY", + api_version="2024-05-01-preview", + azure_endpoint="https://mocked.endpoint.com", + ) + + counterfactual_object = CounterfactualGenerator(langchain_llm=mock_object) + + monkeypatch.setattr(counterfactual_object, "_async_api_call", mock_async_api_call) + + # Test parse_texts for sexual orientation + sexual_orientation_prompts = counterfactual_object.parse_texts( + texts=MOCKED_PROMPTS, attribute="sexual_orientation" + ) + assert sexual_orientation_prompts == [[], [], ["homosexual"], ["lesbian"]] + + # Test create_prompts for sexual orientation + sexual_orientation_prompts = counterfactual_object.create_prompts( + prompts=MOCKED_PROMPTS, attribute="sexual_orientation" + ) + assert sexual_orientation_prompts == MOCKED_SEXUAL_ORIENTATION_PROMPTS + + # Test generate_responses for sexual orientation + cf_data = await counterfactual_object.generate_responses( + prompts=MOCKED_PROMPTS, attribute="sexual_orientation", count=1 + ) + assert all( + [ + cf_data["data"][key] == [MOCKED_RESPONSES[-1]] * 2 + for key in cf_data["data"] + if "response" in key + ] + ) + + # Test check_ftu for sexual orientation + ftu_result = counterfactual_object.check_ftu( + prompts=MOCKED_PROMPTS, attribute="sexual_orientation" + ) + assert ftu_result["metadata"]["ftu_satisfied"] is True + assert ftu_result["metadata"]["n_prompts_with_attribute_words"] == 2 + + # Test neutralize_tokens for sexual orientation + neutralized = counterfactual_object.neutralize_tokens( + texts=["prompt 3: homosexual person", "prompt 4: lesbian couple"], + attribute="sexual_orientation", + ) + assert "[MASK]" in neutralized[0] + assert "[MASK]" in neutralized[1] + + # Test validation error for invalid attribute + with pytest.raises(ValueError): + counterfactual_object.parse_texts( + texts=MOCKED_PROMPTS, attribute="invalid_attribute" + ) + + # Test that plural forms are preserved in substitution + plural_texts = [ + "There are many homosexuals in the city.", + "The bisexuals in this group are underrepresented.", + "lesbians face unique challenges.", + ] + neutralized_plural = counterfactual_object.neutralize_tokens( + texts=plural_texts, attribute="sexual_orientation" + ) + assert "[MASK]" in neutralized_plural[0] + assert "[MASK]" in neutralized_plural[1] + assert "[MASK]" in neutralized_plural[2] + + substituted_plural = counterfactual_object._replace_sexual_orientation( + plural_texts[0], "heterosexual" + ) + assert "heterosexuals" in substituted_plural + + substituted_plural2 = counterfactual_object._replace_sexual_orientation( + plural_texts[1], "gay" + ) + assert "gays" in substituted_plural2 + + # Test that REQUIRING_CONTEXT words match with person words + requiring_context_texts = [ + "The gay man was happy.", + "She is a straight woman.", + "A gay couple walked by.", + ] + parsed = counterfactual_object.parse_texts( + texts=requiring_context_texts, attribute="sexual_orientation" + ) + assert ["gay man"] in parsed + assert ["straight woman"] in parsed