-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🚨 Major Refactoring to Examples and Names
- Loading branch information
1 parent
0842e3f
commit 67cc7ab
Showing
44 changed files
with
1,066 additions
and
4,318 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,2 @@ | ||
"""Entry point for the botiverse package.""" | ||
from botiverse.gui.gui import chat_gui | ||
#from botiverse.TODS.DNN_DST.DNN_DST import DNNDST | ||
#from botiverse.TODS.DNN_TODS import DNNTODS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
79 changes: 0 additions & 79 deletions
79
botiverse/bots/Vocalizer/Vocalizer.py → botiverse/bots/VoiceBot/SpeechClassifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import numpy as np | ||
import json | ||
from gtts import gTTS | ||
import tempfile | ||
import os | ||
from botiverse.models import TTS | ||
from playsound import playsound | ||
|
||
from botiverse.models import LSTMClassifier | ||
from botiverse.preprocessors import Vocalize, Wav2Vec, Wav2Text, BertEmbedder, Frequency, BertSentenceEmbedder | ||
from botiverse.bots.VoiceBot.utils import voice_input | ||
|
||
|
||
class VoiceBot(): | ||
'''An interface for the vocalizer chatbot which simulates a call with a customer service bot.''' | ||
def __init__(self, call_json_path, repr='BERT-Sentence'): | ||
''' | ||
Load the call data from a json file. | ||
:param call_json_path: The path to the json file containing the call state machine. | ||
''' | ||
with open(call_json_path, 'r') as file: | ||
call_json = file.read() | ||
self.call_data = json.loads(call_json) | ||
self.current_node = 'A' | ||
self.wav2text = Wav2Text() | ||
if repr == 'BERT': | ||
self.bert_embeddings = BertEmbedder() | ||
elif repr == 'BERT-Sentence': | ||
self.bert_sentence_embeddings = BertSentenceEmbedder() | ||
else: | ||
raise Exception(f"Invalid representation {repr}. Expected BERT or BERT-Sentence.") | ||
|
||
def generate_speech(self, text, offline=False): | ||
'''Use Google's TTS or offline FastSpeech 1.0 to play speech from the given text. | ||
:param text: The text to be converted into speech. | ||
:param offline: Whether to use offline FastSpeech 1.0 to generate speech. | ||
''' | ||
if offline: | ||
tts = TTS() | ||
tts.speak(text) | ||
else: | ||
tts = gTTS(text=text, lang='en', tld="us", slow=False) | ||
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio: | ||
temp_filename = temp_audio.name | ||
tts.save(temp_filename) | ||
playsound(temp_filename) | ||
|
||
def simulate_call(self): | ||
''' | ||
Simulate a call with a customer service bot as driven by the call state machine. | ||
''' | ||
while True: | ||
if self.current_node == 'Z': | ||
# the final state has a different structure, bot only speaks and then the call ends | ||
bot_message = self.call_data[self.current_node]['Bot'] | ||
self.generate_speech(bot_message) | ||
break | ||
|
||
# 1 - get the current node's data and from that get the message the bot should speak | ||
node_data = self.call_data[self.current_node] | ||
bot_message = node_data['Bot'] | ||
self.generate_speech(bot_message) | ||
|
||
# 2 - get the intent options that the bot expects from the user and classify the user's response | ||
options = node_data['Options'] | ||
intents = [option['Intent'] for option in options] | ||
max_dur = node_data['max_duration'] | ||
human_resp = voice_input(record_time=int(max_dur)) | ||
human_resp = self.wav2text.transcribe(human_resp) | ||
selected_ind, score = self.bert_sentence_embeddings.closest_sentence(human_resp, intents, retun_ind=True) | ||
print(f"you said: {human_resp} and the bot decided that you meant {intents[selected_ind]} with a score of {score}") | ||
|
||
# 3 - speak according to the chosen option | ||
speak_message = options[selected_ind]['Speak'] | ||
self.generate_speech(speak_message) | ||
|
||
# 4 - go to the next state | ||
self.current_node = options[selected_ind]['Next'] | ||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from botiverse.bots.WhizBot.WhizBot_GRU import WhizBot_GRU | ||
from botiverse.bots.WhizBot.WhizBot_BERT import WhizBot_BERT | ||
|
||
class WhizBot: | ||
''' | ||
A class that provides an interface for the WhizBot-BERT and WhizBot-GRU models. | ||
''' | ||
def __init__(self, repr='BERT'): | ||
""" | ||
Initializes WhizBot and sets its representation type. | ||
:param repr: The representation type of the WhizBot model. Either "BERT" or "GRU". | ||
:type repr: str | ||
""" | ||
if repr == 'BERT': | ||
self.bot = WhizBot_BERT() | ||
elif repr == 'GRU': | ||
self.bot = WhizBot_GRU() | ||
else: | ||
raise ValueError('Invalid representation type for WhizBot. Please choose either "BERT" or "GRU".') | ||
|
||
|
||
|
||
def read_data(self, file_path): | ||
""" | ||
Reads and pre-processes the data, sets up the model based on the data and prepares the train-validation split. | ||
:param file_path: The path to the file that contains the dataset. | ||
:type file_path: str | ||
:returns: None | ||
""" | ||
self.bot.read_data(file_path) | ||
|
||
def train(self, epochs=10, batch_size=32): | ||
""" | ||
Trains the model using the training dataset. | ||
:param epochs: The number of training epochs. | ||
:type epochs: int | ||
:param batch_size: The number of training examples utilized used to make one paramenters updat. | ||
:type batch_size: int | ||
:returns: None | ||
""" | ||
self.bot.train(epochs, batch_size) | ||
|
||
def validation(self, batch_size=32): | ||
""" | ||
Tests the model performance using the validation dataset and calculates the accuracy. | ||
:param batch_size: The number of training examples utilized used to make one paramenters updat. | ||
:type batch_size: int | ||
:returns: None | ||
""" | ||
self.bot.validation(batch_size) | ||
|
||
def infer(self, string): | ||
""" | ||
Performs inference using the model. | ||
:param string: The input string to perform inference on. | ||
:type string: str | ||
:returns: A random response from the response list of the predicted label. | ||
""" | ||
return self.bot.infer(string) | ||
|
||
def save(self, path): | ||
""" | ||
Saves the model parameters to the given path. | ||
:param path: The path where the model parameters will be saved. | ||
:type path: str | ||
:returns: None | ||
""" | ||
self.bot.save(path) | ||
|
||
def load(self, path): | ||
""" | ||
Loads the model parameters from the given path. | ||
:param path: The path from where the model parameters will be loaded. | ||
:type path: str | ||
:returns: None | ||
""" | ||
self.bot.load(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
from botiverse.bots.basic_chatbot.basic_chatbot import basic_chatbot | ||
from botiverse.bots.BasicBot.BasicBot import BasicBot | ||
from botiverse.bots.WhizBot.WhizBot import WhizBot | ||
from botiverse.bots.basic_TODS.basic_TODS import BasicTODS | ||
from botiverse.bots.deep_TODS.deep_TODS import DeepTODS | ||
from botiverse.bots.Vocalizer.Vocalizer import SpeechClassifier, Vocalizer | ||
from botiverse.bots.VoiceBot.SpeechClassifier import SpeechClassifier | ||
from botiverse.bots.VoiceBot.VoiceBot import VoiceBot |
Empty file.
Oops, something went wrong.