131 lines
3.9 KiB
Python
131 lines
3.9 KiB
Python
from pathlib import Path
|
|
import requests
|
|
import llama_cpp_python as llm
|
|
from rich.console import Console
|
|
from rich.layout import Layout
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
|
|
def download_model(url, model_path):
|
|
"""
|
|
Download the GGUF model from the given URL if it's not present locally.
|
|
|
|
Parameters:
|
|
url (str): URL to download the model from.
|
|
model_path (Path): Local path to save the downloaded model.
|
|
"""
|
|
if not model_path.exists():
|
|
try:
|
|
response = requests.get(url, allow_redirects=True)
|
|
response.raise_for_status()
|
|
model_path.write_bytes(response.content)
|
|
print(f"Model downloaded to {model_path}")
|
|
except requests.RequestException as e:
|
|
print(f"Error downloading the model: {e}")
|
|
|
|
def load_model(model_path):
|
|
"""
|
|
Load a GGUF format model for the chatbot.
|
|
|
|
Parameters:
|
|
model_path (Path): Path to the GGUF model file.
|
|
|
|
Returns:
|
|
Model: The loaded GGUF model.
|
|
"""
|
|
return llm.load(str(model_path))
|
|
|
|
def generate_text(model, prompt, max_tokens=256, temperature=0.1, top_p=0.5, echo=False, stop=None):
|
|
"""
|
|
Generate a response from the LLM based on the given prompt.
|
|
|
|
Parameters:
|
|
model: The loaded GGUF model.
|
|
prompt (str): The input prompt for the model.
|
|
max_tokens (int): Maximum number of tokens for the response.
|
|
temperature (float): Token sampling temperature.
|
|
top_p (float): Nucleus sampling parameter.
|
|
echo (bool): If True, the prompt is included in the response.
|
|
stop (list): Tokens at which the model should stop generating text.
|
|
|
|
Returns:
|
|
str: The generated text response.
|
|
"""
|
|
if stop is None:
|
|
stop = ["#"]
|
|
output = model.generate(
|
|
prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
echo=echo,
|
|
stop=stop,
|
|
)
|
|
return output["choices"][0]["text"].strip()
|
|
|
|
def generate_prompt_from_template(user_input):
|
|
"""
|
|
Format the user input into a prompt suitable for the chatbot.
|
|
|
|
Parameters:
|
|
user_input (str): The user's input message.
|
|
|
|
Returns:
|
|
str: The formatted prompt for the chatbot.
|
|
"""
|
|
return f"You are a helpful chatbot.\n{user_input}"
|
|
|
|
def create_chat_layout():
|
|
"""
|
|
Create the layout for the chatbot interface using rich.
|
|
|
|
Returns:
|
|
Layout: The layout object for the chat interface.
|
|
"""
|
|
layout = Layout()
|
|
|
|
layout.split(
|
|
Layout(name="header", size=3),
|
|
Layout(ratio=1, name="main"),
|
|
Layout(name="footer", size=3)
|
|
)
|
|
|
|
layout["header"].update(Panel("[bold magenta]Llama.cpp Chatbot[/]", style="bold blue"))
|
|
layout["footer"].update(Text("Type your message and press [bold green]Enter[/]. Type 'exit' to end the chat.", justify="center"))
|
|
return layout
|
|
|
|
def main():
|
|
"""
|
|
The main function to run the chatbot.
|
|
"""
|
|
model_path = Path("zephyr-7b-alpha.Q3_K_M.gguf")
|
|
model_url = "https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/raw/main/zephyr-7b-alpha.Q3_K_M.gguf"
|
|
|
|
if not model_path.exists():
|
|
print("Model not found locally. Downloading...")
|
|
download_model(model_url, model_path)
|
|
|
|
model = load_model(model_path)
|
|
|
|
console = Console()
|
|
layout = create_chat_layout()
|
|
console.print(layout)
|
|
|
|
chat_history = ""
|
|
|
|
while True:
|
|
user_input = console.input("[bold green]You: [/]")
|
|
if user_input.lower() == "exit":
|
|
break
|
|
|
|
prompt = generate_prompt_from_template(user_input)
|
|
bot_response = generate_text(model, prompt, max_tokens=356)
|
|
|
|
chat_history += f"[bold green]You:[/] {user_input}\n[bold yellow]Bot:[/] {bot_response}\n"
|
|
chat_panel = Panel(chat_history, title="Chat History")
|
|
layout["main"].update(chat_panel)
|
|
console.print(layout)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|