diff --git a/main.py b/main.py new file mode 100644 index 0000000..638c065 --- /dev/null +++ b/main.py @@ -0,0 +1,130 @@ +from pathlib import Path +import requests +import llama_cpp_python as llm +from rich.console import Console +from rich.layout import Layout +from rich.panel import Panel +from rich.text import Text + +def download_model(url, model_path): + """ + Download the GGUF model from the given URL if it's not present locally. + + Parameters: + url (str): URL to download the model from. + model_path (Path): Local path to save the downloaded model. + """ + if not model_path.exists(): + try: + response = requests.get(url, allow_redirects=True) + response.raise_for_status() + model_path.write_bytes(response.content) + print(f"Model downloaded to {model_path}") + except requests.RequestException as e: + print(f"Error downloading the model: {e}") + +def load_model(model_path): + """ + Load a GGUF format model for the chatbot. + + Parameters: + model_path (Path): Path to the GGUF model file. + + Returns: + Model: The loaded GGUF model. + """ + return llm.load(str(model_path)) + +def generate_text(model, prompt, max_tokens=256, temperature=0.1, top_p=0.5, echo=False, stop=None): + """ + Generate a response from the LLM based on the given prompt. + + Parameters: + model: The loaded GGUF model. + prompt (str): The input prompt for the model. + max_tokens (int): Maximum number of tokens for the response. + temperature (float): Token sampling temperature. + top_p (float): Nucleus sampling parameter. + echo (bool): If True, the prompt is included in the response. + stop (list): Tokens at which the model should stop generating text. + + Returns: + str: The generated text response. + """ + if stop is None: + stop = ["#"] + output = model.generate( + prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + echo=echo, + stop=stop, + ) + return output["choices"][0]["text"].strip() + +def generate_prompt_from_template(user_input): + """ + Format the user input into a prompt suitable for the chatbot. + + Parameters: + user_input (str): The user's input message. + + Returns: + str: The formatted prompt for the chatbot. + """ + return f"You are a helpful chatbot.\n{user_input}" + +def create_chat_layout(): + """ + Create the layout for the chatbot interface using rich. + + Returns: + Layout: The layout object for the chat interface. + """ + layout = Layout() + + layout.split( + Layout(name="header", size=3), + Layout(ratio=1, name="main"), + Layout(name="footer", size=3) + ) + + layout["header"].update(Panel("[bold magenta]Llama.cpp Chatbot[/]", style="bold blue")) + layout["footer"].update(Text("Type your message and press [bold green]Enter[/]. Type 'exit' to end the chat.", justify="center")) + return layout + +def main(): + """ + The main function to run the chatbot. + """ + model_path = Path("zephyr-7b-alpha.Q3_K_M.gguf") + model_url = "https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/raw/main/zephyr-7b-alpha.Q3_K_M.gguf" + + if not model_path.exists(): + print("Model not found locally. Downloading...") + download_model(model_url, model_path) + + model = load_model(model_path) + + console = Console() + layout = create_chat_layout() + console.print(layout) + + chat_history = "" + + while True: + user_input = console.input("[bold green]You: [/]") + if user_input.lower() == "exit": + break + + prompt = generate_prompt_from_template(user_input) + bot_response = generate_text(model, prompt, max_tokens=356) + + chat_history += f"[bold green]You:[/] {user_input}\n[bold yellow]Bot:[/] {bot_response}\n" + chat_panel = Panel(chat_history, title="Chat History") + layout["main"].update(chat_panel) + console.print(layout) + +if __name__ == "__main__": + main()