Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added WIKI takeover feature #57

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
2 changes: 1 addition & 1 deletion agents/portforwarding.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ kubectl port-forward corenlp-7fd4974bb-8mq5g 4080:5001 -n chirpy
kubectl port-forward dialogact-849b4b67d8-ngzd5 4081:5001 -n chirpy &
kubectl port-forward g2p-7644ff75bd-cjj57 4082:5001 -n chirpy &
kubectl port-forward gpt2ed-68f849f64b-wr8zw 4083:5001 -n chirpy &
kubectl port-forward questionclassifier-668c4fd6c6-fd586 4084:5001 -n chirpy &
kubectl port-forward questionclassifier-668c4fd6c6-7nl2k 4084:5001 -n chirpy &
kubectl port-forward convpara-dbdc8dcfb-csktj 4085:5001 -n chirpy &
kubectl port-forward entitylinker-59b9678b8-nmwx9 4086:5001 -n chirpy &
kubectl port-forward blenderbot-695c7b5896-gkz2s 4087:5001 -n chirpy &
Expand Down
28 changes: 24 additions & 4 deletions chirpy/core/dialog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,16 @@ def update_rg_states(self, results: RankedResults, selected_rg: str):
# Get the args needed for the update_state_if_not_chosen fn. That's (state, conditional_state) for all RGs except selected_rg
other_rgs = [rg for rg in results.keys() if rg != selected_rg and not is_killed(results[rg])]
logger.info(f"now, current states are {rg_states}")
args_list = [[rg_states[rg], results[rg].conditional_state] for rg in other_rgs]

def rg_was_taken_over(rg):
if self.state_manager.last_state:
logger.debug(f"Rg that is selected is {selected_rg}. Currently evaluated rg is {rg}. "
f"rg == self.state_manager.last_state.active_rg is {rg == self.state_manager.last_state.active_rg}")
return rg_states[selected_rg].rg_that_was_taken_over and rg == self.state_manager.last_state.active_rg
else:
return None

args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs]

# Run update_state_if_not_chosen for other RGs
logger.info(f'Starting to run update_state_if_not_chosen for {other_rgs}...')
Expand Down Expand Up @@ -331,7 +340,6 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe

# Get the states for the RGs we'll run, which we'll use as input to the get_response/get_prompt fn
logger.debug('Copying RG states to use as input...')
input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts

# import pdb; pdb.set_trace()

Expand All @@ -343,10 +351,22 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe
priority_modules = [last_state_active_rg]
else:
priority_modules = []

rg_was_taken_over = None
if self.state_manager.last_state_response:
rg_was_taken_over = self.state_manager.last_state_response.state.rg_that_was_taken_over

def rg_to_resume(rg):
logger.debug(f"rg that was taken over is {rg_was_taken_over}. Currently evaluated rg is {rg}. "
f"rg == rg_was_taken_over is {rg == rg_was_taken_over}.")
return rg == rg_was_taken_over

function_name = 'get_prompt_wrapper' if phase == 'prompt' else 'get_response'
args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list])
results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list,
function_name=f'get_{phase}',
function_name=function_name,
timeout=timeout,
args_list=[[state] for state in input_rg_states],
args_list=args_list, # [[state] for state in input_rg_states],
priority_modules=priority_modules)

# Log the initial results
Expand Down
2 changes: 1 addition & 1 deletion chirpy/core/entity_linker/wiki_data_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ANCHORTEXT_QUERY_TIMEOUT = 3.0 # seconds
ENTITYNAME_QUERY_TIMEOUT = 1.0 # seconds

ARTICLES_INDEX_NAME = 'enwiki-20220107-articles'
ARTICLES_INDEX_NAME = 'enwiki-20200920-articles'

# These are the fields we DO want to fetch from ES
FIELDS_FILTER = ['doc_title', 'doc_id', 'categories', 'pageview', 'linkable_span_info', 'wikidata_categories_all', 'redirects', 'plural']
Expand Down
39 changes: 36 additions & 3 deletions chirpy/core/entity_tracker/entity_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class EntityTrackerState(object):

def __init__(self):
self.cur_entity = None # the current entity under discussion (can be None)
self.talked_unfinished = [] # entities that we have not finished talking about, but the rg is taken over
self.able_to_takeover_entities = [] # entities that are found in the response in that turn and can be used for wiki rg to takeover
self.talked_rejected = [] # entities we talked about in the past, and stopped talking about because the user indicated they didn't want to talk about it any more
self.talked_finished = [] # entities we talked about in the past, that aren't in talked_rejected
self.talked_transitionable = []
Expand Down Expand Up @@ -97,7 +99,7 @@ def finish_entity(self, entity: Optional[WikiEntity], transition_is_possible=Tru
logger.error(f"This is an error. This should be a WikiEntity object but {entity} is of type {type(entity)}")
entity = None

if entity is not None and entity not in self.talked_finished:
if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished:
logger.info(f'Putting entity {entity} on the talked_finished list')
self.talked_finished.append(entity)

Expand Down Expand Up @@ -277,16 +279,23 @@ def condition_fn(entity_linker_result, linked_span, entity) -> bool:
if nav_intent_output.neg_intent or nav_intent_output.pos_intent or last_answer_type in [AnswerType.QUESTION_SELFHANDLING, AnswerType.QUESTION_HANDOFF]:
self.cur_entity = self.entity_initiated_on_turn

logger.info(f'Resetting able_to_takeover_entities to empty list')
self.able_to_takeover_entities = []

for linked_span in current_state.entity_linker.high_prec:
if not self.talked(linked_span.top_ent):
logger.info(f'Adding {linked_span.top_ent} to user_mentioned_untalked')
self.user_mentioned_untalked.append(linked_span.top_ent)
logger.info(f'Adding {linked_span.top_ent} to able_to_takeover_entities')
self.able_to_takeover_entities.append(linked_span.top_ent)

logger.primary_info(f'The EntityTrackerState is now: {self}')
# logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}')

# Update the entity tracker history
self.history[-1]['user'] = self.cur_entity


def record_untalked_high_prec_entities(self, current_state):
"""
Take any entities in the entity linker's high precision set for this turn, and if they haven't been discussed,
Expand All @@ -313,6 +322,7 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up
result: ResponseGeneratorResult, PromptResult, or UpdateEntity
rg: the name of the RG that provided the new entity
"""

if isinstance(result, UpdateEntity):
new_entity = result.cur_entity
phase = 'get_entity'
Expand All @@ -325,6 +335,14 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up

transition_is_possible = not getattr(result, 'no_transition', False)

if self.able_to_takeover_entities and result.state.takeover_entity:
self.talked_unfinished.append(self.cur_entity)
new_entity = self.able_to_takeover_entities.pop()
logger.primary_info(f'Removing {new_entity} from {self.able_to_takeover_entities}')
self.able_to_takeover_entities = [e for e in self.able_to_takeover_entities if e != new_entity]
logger.info(f'After takeover, self.talk_unfinished is {self.talked_unfinished}, self.able_to_takeover_entities is {self.able_to_takeover_entities}'
f' and self.talked_unfinished is {self.talked_finished}.')

if new_entity == self.cur_entity:
logger.primary_info(f'new_entity={new_entity} from {rg} RG {phase} is the same as cur_entity, so keeping EntityTrackerState the same')
else:
Expand All @@ -340,11 +358,18 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up
self.cur_entity = new_entity
# Remove new_entity from user_mentioned_untalked
if new_entity in self.user_mentioned_untalked:
logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked}')
logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked} after conversation is resumed.')
self.user_mentioned_untalked = [e for e in self.user_mentioned_untalked if e != new_entity]

logger.primary_info(f'Set cur_entity to new_entity={new_entity} from {rg} RG {phase}')
logger.primary_info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}')

if new_entity in self.talked_unfinished:
archived_entity = new_entity
logger.info(
f"Removing archived_entity [{archived_entity}] from talked_unfinished [{self.talked_unfinished}]")
self.talked_unfinished.remove(archived_entity)

logger.info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}')

# If we're updating after receiving UpdateEntity from an RG, put any undiscussed high precision entities that
# the user mentioned this turn in user_mentioned_untalked
Expand All @@ -360,6 +385,8 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up
def __repr__(self, show_history=False):
output = f"<EntityTrackerState: "
output += f"cur_entity={self.cur_entity.name if self.cur_entity else self.cur_entity}"
output += f", talked_unfinished={[ent.name for ent in self.talked_unfinished]}"
output += f", able_to_takeover_entities={[ent.name for ent in self.able_to_takeover_entities]}"
output += f", talked_finished={[ent.name for ent in self.talked_finished]}"
output += f", talked_rejected={[ent.name for ent in self.talked_rejected]}"
output += f", talked_transitionable={[ent.name for ent in self.talked_transitionable]}"
Expand All @@ -380,6 +407,8 @@ def keep_entity(ent: Optional[WikiEntity]) -> bool:
if ent is None:
return True
return ent in entities

self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)]
self.talked_finished = [ent for ent in self.talked_finished if keep_entity(ent)]
self.talked_rejected = [ent for ent in self.talked_rejected if keep_entity(ent)]
self.user_mentioned_untalked = [ent for ent in self.user_mentioned_untalked if keep_entity(ent)]
Expand All @@ -393,6 +422,8 @@ def reduce_size(self, max_size: int):
# Make a set (no duplicates) of all the WikiEntities stored in this EntityTrackerState
entity_set = set()
entity_set.add(self.cur_entity)
entity_set.update(self.talked_unfinished)
entity_set.update(self.able_to_takeover_entities)
entity_set.update(self.talked_finished)
entity_set.update(self.talked_rejected)
entity_set.update(self.user_mentioned_untalked)
Expand All @@ -408,6 +439,8 @@ def replace_ent(ent: Optional[WikiEntity]):
return None
return entname2ent[ent.name]
self.cur_entity = replace_ent(self.cur_entity)
self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished]
self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities]
self.talked_finished = [replace_ent(ent) for ent in self.talked_finished]
self.talked_rejected = [replace_ent(ent) for ent in self.talked_rejected]
self.user_mentioned_untalked = [replace_ent(ent) for ent in self.user_mentioned_untalked]
Expand Down
85 changes: 65 additions & 20 deletions chirpy/core/logging_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,52 @@
from typing import Optional
from colorama import Fore, Back

from pathlib import Path

def get_active_branch_name():
git_dir = Path(".") / ".git"
if git_dir.is_dir():
head_dir = git_dir / "HEAD"
with head_dir.open("r") as f: content = f.read().splitlines()

for line in content:
if line[0:4] == "ref:":
return line.partition("refs/heads/")[2]
else: # for integ. testing, we don't copy .git/ to the instance
return []

LINEBREAK = '<linebreak>'

# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-sensitive)
# The path_strings are strings we'll search for (case-sensitive) in the path of the file that does the log message
# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-insensitive)
# The path_strings are strings we'll search for (case-insensitive) in the path of the file that does the log message
# You can comment out parts of this dict and add your own components to make it easier to only see what you're working on
# See https://rich.readthedocs.io/en/stable/appendix/colors.html for list of rich colors
# See https://github.com/willmcgugan/rich/blob/master/rich/_emoji_codes.py for emoji codes
COLOR_SETTINGS = {
'WIKI': {'color': Fore.MAGENTA, 'path_strings': ['wiki']},
'MOVIES': {'color': Fore.GREEN, 'path_strings': ['movies']},
# 'NEWS': {'color': Fore.CYAN, 'path_strings': ['news']},
'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'path_strings': ['acknowledgment']},
# 'LAUNCH': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['launch']},
'CATEGORIES': {'color': Fore.YELLOW, 'path_strings': ['categories']},
'NEURAL_CHAT': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['neural_chat']},
'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'path_strings': ['entity_linker']},
'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'path_strings': ['entity_tracker']},
'experiments': {'color': Fore.LIGHTGREEN_EX, 'path_strings': ['experiments']},
'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['navigational_intent']}
'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'rich_color': '#0AAB42',
'emoji': ':white_heavy_check_mark:', 'path_strings': ['acknowledgment']},
'ALEXA_COMMANDS': {'emoji': ':speaking_head_in_silhouette:', 'path_strings': ['alexa_commands']},
'ALIENS': {'rich_color': '#1EA8B3', 'emoji': ':alien:'},
'CATEGORIES': {'rich_color': '#15EBCE', 'emoji': ':newspaper:', 'path_strings': ['categories']},
'CORONAVIRUS': {'rich_color': '#F70C6E', 'emoji': ':face_with_medical_mask:'},
'FOOD': {'rich_color': '#97F20F', 'emoji': ':sushi:'},
'LAUNCH': {'emoji': ':checkered_flag:', 'path_strings': ['launch']},
'MOVIES': {'rich_color': '#F0D718', 'emoji': ':movie_camera:', 'path_strings': ['movies']},
'MUSIC': {'rich_color': '#0586FF', 'emoji': ':musical_notes:'},
'NEURAL_CHAT': {'rich_color': '#0EE827', 'emoji': ':brain:', 'path_strings': ['neural_chat']},
'NEWS': {'rich_color': '#1C64FF', 'emoji': ':newspaper:'},
'OFFENSIVE_USER': {'rich_color': '#EB5215', 'emoji': ':prohibited:'},
'ONE_TURN_HACK': {'rich_color': '#88B0B3', 'emoji': ':hammer:'},
'OPINION': {'rich_color': '#D011ED', 'emoji': ':thinking_face:'},
'PERSONAL_ISSUES': {'rich_color': '#BC3BEB', 'emoji': ':slightly_frowning_face:'},
'SPORTS': {'rich_color': '#EB8715', 'emoji': ':football:'},
'WIKI': {'rich_color': '#42C2F5', 'emoji': ':books:'},
'TRANSITION': {'rich_color': '#5FD700', 'emoji': ':soon_arrow:'},
'REOPEN': {'rich_color': '##5F00FF', 'emoji': ':door:'},
'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'rich_color': '#0BC3E3', 'path_strings': ['entity_linker']},
'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'rich_color': '#DB960D', 'path_strings': ['entity_tracker']},
'experiments': {'color': Fore.LIGHTGREEN_EX, 'rich_color': '#CADB0D', 'path_strings': ['experiments']},
'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'rich_color': '#DB0D93', 'path_strings': ['navigational_intent']}
}

LOG_FORMAT = '[%(levelname)s] [%(asctime)s] [fn_vers: {function_version}] [session_id: {session_id}] [%(pathname)s:%(lineno)d]\n%(message)s\n'
Expand All @@ -33,16 +62,24 @@ def colored(str, fore=None, back=None, include_reset=True):
new_str = '{}{}{}'.format(back, new_str, Back.RESET if include_reset else '')
return new_str

def get_rich_color_for_rg(rg_name):
for component_name, settings in COLOR_SETTINGS.items():
if component_name.lower() == rg_name.lower() and settings.get('rich_color'):
color = settings['rich_color']
return f"[{color}]{rg_name}[/{color}]"
return rg_name

def get_line_color(line):
def get_line_color(line, branch_name):
"""
Given a line of logging (which is one line of a multiline log message), searches for component names at the
beginning of the line. If one is found, returns its color.
"""
first_part_line = line.strip().split()[0]
for component_name, settings in COLOR_SETTINGS.items():
if component_name in first_part_line:
return settings['color']
return settings.get('color')
if any(b.lower() in first_part_line.lower() for b in branch_name):
return Fore.BLUE
return None


Expand All @@ -62,7 +99,7 @@ def get_line_key(idx: int):

class ChirpyFormatter(logging.Formatter):
"""
A custom formatter that formats linebreaks and color according to logger_settings, and the context of each message.
A color formatter that formats linebreaks and color according to logger_settings, and the context of each message.

Based on this: https://stackoverflow.com/a/14859558
"""
Expand All @@ -72,6 +109,10 @@ def __init__(self, allow_multiline: bool, use_color: bool, session_id: Optional[
self.use_color = use_color
self.session_id = session_id
self.function_version = function_version
if self.use_color:
branch_name = get_active_branch_name()
branch_name = ''.join([x if x.isalpha() else ' ' for x in branch_name])
self.branch_name = branch_name.split()
self.update_format()

def update_format(self):
Expand Down Expand Up @@ -137,15 +178,19 @@ def format_color(self, record):
lines = record.msg.split('\n')
for idx, line in enumerate(lines):
setattr(record, get_line_key(idx), line) # e.g. record['line_5'] -> the text of the 5th line of logging
line_colors = [get_line_color(line) for line in lines] # get the color for each line
line_colors = [get_line_color(line, self.branch_name) for line in lines] # get the color for each line
self._style._fmt = self.fmt.replace('%(message)s', linecolored_msg_fmt(line_colors)) # this format string has keys for line_1, line_2, etc, along with line-specific colors

# If the filepath of the calling function contains a path string for a colored component, return its color
else:
for component, settings in COLOR_SETTINGS.items():
for path_string in settings['path_strings']:
if path_string in record.pathname:
self._style._fmt = colored(self.fmt, fore=settings['color'])
if settings.get('path_strings'):
for path_string in settings['path_strings']:
if path_string in record.pathname:
self._style._fmt = colored(self.fmt, fore=settings['color'])
continue
if any(b in record.pathname for b in self.branch_name):
self._style._fmt = colored(self.fmt, fore=Fore.BLUE)

# Use the formatter class to do the formatting (with a possibly modified format)
result = logging.Formatter.format(self, record)
Expand Down
Loading