it works!

This commit is contained in:
Isaak Buslovich 2023-11-21 00:40:09 +01:00
parent f510b847f0
commit 27b9d53516
Signed by: Isaak
GPG Key ID: EEC31D6437FBCC63

136
main.py
View File

@ -1,130 +1,40 @@
from pathlib import Path
import requests
import llama_cpp_python as llm
"""A simple LLM chatbot"""
import argparse
from llama_cpp import Llama
from rich.console import Console
from rich.layout import Layout
from rich.panel import Panel
from rich.text import Text
def download_model(url, model_path):
"""
Download the GGUF model from the given URL if it's not present locally.
Parameters:
url (str): URL to download the model from.
model_path (Path): Local path to save the downloaded model.
"""
if not model_path.exists():
try:
response = requests.get(url, allow_redirects=True)
response.raise_for_status()
model_path.write_bytes(response.content)
print(f"Model downloaded to {model_path}")
except requests.RequestException as e:
print(f"Error downloading the model: {e}")
def create_arg_parser():
"""Create and return the argument parser."""
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="./zephyr-7b-alpha.Q3_K_M.gguf")
return parser.parse_args()
def load_model(model_path):
"""
Load a GGUF format model for the chatbot.
Parameters:
model_path (Path): Path to the GGUF model file.
def get_llama_response(llm, prompt):
"""Get response from Llama model."""
try:
output = llm(prompt, max_tokens=60, stop=["Q:", "\n"], echo=False)
return output.get('choices', [{}])[0].get('text', "No response generated.")
except Exception as e:
return f"Error generating response: {e}"
Returns:
Model: The loaded GGUF model.
"""
return llm.load(str(model_path))
def generate_text(model, prompt, max_tokens=256, temperature=0.1, top_p=0.5, echo=False, stop=None):
"""
Generate a response from the LLM based on the given prompt.
Parameters:
model: The loaded GGUF model.
prompt (str): The input prompt for the model.
max_tokens (int): Maximum number of tokens for the response.
temperature (float): Token sampling temperature.
top_p (float): Nucleus sampling parameter.
echo (bool): If True, the prompt is included in the response.
stop (list): Tokens at which the model should stop generating text.
Returns:
str: The generated text response.
"""
if stop is None:
stop = ["#"]
output = model.generate(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
echo=echo,
stop=stop,
)
return output["choices"][0]["text"].strip()
def generate_prompt_from_template(user_input):
"""
Format the user input into a prompt suitable for the chatbot.
Parameters:
user_input (str): The user's input message.
Returns:
str: The formatted prompt for the chatbot.
"""
return f"You are a helpful chatbot.\n{user_input}"
def create_chat_layout():
"""
Create the layout for the chatbot interface using rich.
Returns:
Layout: The layout object for the chat interface.
"""
layout = Layout()
layout.split(
Layout(name="header", size=3),
Layout(ratio=1, name="main"),
Layout(name="footer", size=3)
)
layout["header"].update(Panel("[bold magenta]Llama.cpp Chatbot[/]", style="bold blue"))
layout["footer"].update(Text("Type your message and press [bold green]Enter[/]. Type 'exit' to end the chat.", justify="center"))
return layout
def main():
"""
The main function to run the chatbot.
"""
model_path = Path("zephyr-7b-alpha.Q3_K_M.gguf")
model_url = "https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/raw/main/zephyr-7b-alpha.Q3_K_M.gguf"
if not model_path.exists():
print("Model not found locally. Downloading...")
download_model(model_url, model_path)
model = load_model(model_path)
"""Main function to run the chatbot."""
args = create_arg_parser()
llm = Llama(model_path=args.model, verbose=False)
console = Console()
layout = create_chat_layout()
console.print(layout)
chat_history = ""
while True:
user_input = console.input("[bold green]You: [/]")
if user_input.lower() == "exit":
user_input = console.input("[bold cyan]Your question: [/bold cyan]")
if user_input.lower() in ['exit', 'quit']:
break
prompt = generate_prompt_from_template(user_input)
bot_response = generate_text(model, prompt, max_tokens=356)
prompt = f"Question: {user_input} Answer: "
response_text = get_llama_response(llm, prompt)
console.print(f"[blue]Answer: {response_text}[/blue]")
chat_history += f"[bold green]You:[/] {user_input}\n[bold yellow]Bot:[/] {bot_response}\n"
chat_panel = Panel(chat_history, title="Chat History")
layout["main"].update(chat_panel)
console.print(layout)
if __name__ == "__main__":
main()