Empowering Website Conversations: Part 5

Introduction

Now that our QA and discriminator models are ready for deployment, it’s time to bring our chatbot to life. This entails building both a REST API and a user interface. In this section, I’ll walk you through the process of creating a streamlined chatbot REST API using Rust. This API will serve as the bridge between our application and the OpenAI API, allowing us to harness the power of our fine-tuned models.

Developing a custom REST API for interacting with the OpenAI API offers numerous benefits, including:

  1. Control: You have complete control over the API’s design, data formats, and storage. This means you can encapsulate calls to our discriminators and the QA model within a single user request, streamlining the interaction.
  2. Resource Protection: You can enforce usage and rate limits to safeguard our OpenAI resources, ensuring they are used efficiently and responsibly.
  3. Custom Error Handling and Logging: Implementing custom error handling and logging mechanisms allows us to provide meaningful feedback to users, making troubleshooting and maintenance more straightforward.

By crafting our REST API, we gain the flexibility and precision needed to tailor the integration to our unique use case, enhancing the overall functionality and user experience of our ChatBot application.

The previous section can be found here

Required routes

We need to implement a REST API with the following POST route that will first vet the input submitted by the user, generate our chatbot response and verify the output is on topic:

POST /chat
Body:
{
    prompt: string,
    max_tokens: i32,
    temperature: f32,
}

The request body includes the following parameters:

  • prompt: This parameter represents the question or input submitted by the user.
  • max_tokens: It specifies the maximum number of tokens that can be used in the response, controlling the length of the generated text.
  • temperature: This parameter determines the level of randomness to be incorporated into the response, influencing the creativity of the generated text.

The response you receive will either contain the generated text produced by our fine-tuned model or a pre-configured error message tailored to the situation.

Response
{
    data: {
        choice: string,
    }
}

OpenAI Libraries

When it comes to interfacing with the OpenAI API from your REST API, OpenAI offers official libraries for Python and Node.js. Additionally, there’s a wide array of community-supported libraries available in various programming languages. You can explore the full list of available libraries on the OpenAI website.

In this guide, we will be using Rust and a custom crate called ryst-openai, which is currently a ‘work in progress.’ This Rust crate is designed to provide support for OpenAI’s completions and chat functionalities.

Developers often choose Rust for its exceptional performance, safety features, and the ability to craft efficient and reliable code. Therefore, for the remainder of this blog, we assume a basic familiarity with Rust and that you have it installed on your system.”

Create the Crate

Start by creating a new bin crate. We are going to call our crate chatbot-api:

cargo new --bin chatbot-api

Add the follow dependencies to the Cargo.toml file:

[dependencies]
clap = { version = "4", features = ["derive", "env"] }
log = "0.4"
log4rs = "1"
rocket = { version = "0.5.0-rc.2", features = ["json"] }
ryst-error = { git = "https://github.com/embyrai/ryst.git", branch = "main" }
ryst-openai = { git = "https://github.com/embyrai/ryst.git", branch = "main" }
serde = "1"
toml = "0.7"

Most import crates:

  • ryst-openai - A work in progress Rust OpenAI Library. Currently supports completions and chat.
  • ryst-error - Reusable error crate
  • rocket - Rust web framework
  • clap - Command Line Argument Parser for Rust

Your crate’s resulting directory tree should resemble the following structure at the end of this blog:

Cargo.toml
src/
    resources/
              chat/
                   mod.rs
                   post.rs
              error.rs
              mod.rs
    routes/
           chat.rs
           mod.rs
    error.rs
    main.rs
    models.rs

Error Handling

We will be utilizing two distinct error types within our implementation:

  1. RestApiError: This error type is returned when any issues arise during the initialization or startup of the REST API.
  2. ErrorResponse: This is a generic error format intentionally designed to be convertible into an HTTP response. Its purpose is to facilitate standardized error handling in our API.

RestApiError is composed of reusable error definitions provided by the ryst-error. Please refer to the code for detailed information on when to employ each error type.

Add the following to src/error.rs:

use std::error::Error;

use ryst_error::{InternalError, InvalidArgumentError, InvalidStateError};

#[derive(Debug)]
pub enum RestApiError {
    Internal(InternalError),
    InvalidArgument(InvalidArgumentError),
    InvalidState(InvalidStateError),
}

impl Error for RestApiError {
    fn source(&self) -> Option<&(dyn Error + 'static)> {
        match self {
            RestApiError::Internal(e) => Some(e),
            RestApiError::InvalidArgument(e) => Some(e),
            RestApiError::InvalidState(e) => Some(e),
        }
    }
}

impl std::fmt::Display for RestApiError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            RestApiError::Internal(e) => e.fmt(f),
            RestApiError::InvalidArgument(e) => e.fmt(f),
            RestApiError::InvalidState(e) => e.fmt(f),
        }
    }
}

impl From<InternalError> for RestApiError {
    fn from(error: InternalError) -> Self {
        RestApiError::Internal(error)
    }
}

impl From<InvalidArgumentError> for RestApiError {
    fn from(error: InvalidArgumentError) -> Self {
        RestApiError::InvalidArgument(error)
    }
}

impl From<InvalidStateError> for RestApiError {
    fn from(error: InvalidStateError) -> Self {
        RestApiError::InvalidState(error)
    }
}

And add the ErrorResponse to src/resources/error.rs:

use std::error;
use std::fmt;

use rocket::serde::{Deserialize, Serialize};
use ryst_openai::OpenAIError;

/// Generic error designed with the expectation that it may be converted into an HTTP response
#[derive(Debug, Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct ErrorResponse {
    /// A corresponding HTTP status code for the error
    pub status_code: u16,

    /// The message that would be returned in an HTTP response
    pub message: String,

    /// Wrapped error that is not exposed in the HTTP response
    #[serde(skip_serializing)]
    #[serde(skip_deserializing)]
    pub source: Option<Box<dyn error::Error>>,
}

impl ErrorResponse {
    /// Create a new ErrorResponse
    ///
    /// # Arguments
    ///
    /// * `status_code` - Corresponding HTTP status code
    /// * `message` - External message to display to the user
    ///
    /// # Examples
    /// ```
    /// use crate::resources::error::ErrorResponse;
    ///
    /// let response = ErrorResponse::new(404, "The requested user was not found");
    ///
    /// assert_eq!(404, response.status_code;
    /// assert_eq!("The user was not found", response.message);
    /// ```
    pub fn new(status_code: u16, message: &str) -> Self {
        Self {
            status_code,
            message: message.to_string(),
            source: None,
        }
    }
}

impl error::Error for ErrorResponse {
    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
        self.source.as_ref().map(|s| s.as_ref())
    }
}

impl fmt::Display for ErrorResponse {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        if let Some(ref src) = self.source {
            src.fmt(f)
        } else {
            write!(
                f,
                "Status Code {}: Message {}",
                self.status_code, self.message
            )
        }
    }
}

impl From<OpenAIError> for ErrorResponse {
    fn from(error: OpenAIError) -> Self {
        match error {
            OpenAIError::Internal(_) => ErrorResponse::new(500, "Internal Server Error"),
            OpenAIError::InvalidArgument(e) => ErrorResponse::new(400, &e.to_string()),
            OpenAIError::InvalidState(e) => ErrorResponse::new(400, &e.to_string()),
        }
    }
}

We will use these errors throughout the rest of the code.

Configuration

The first step of creating the chatbot REST API is deciding how we want to configure it. I wanted to create a REST API that can be reusable for different fine-tuned models instead of implementing one that will only be useful for the Embyr website. This means we will need to be able to provide the list of models and prompts.

I implemented this by passing a model TOML file that contains the QA model name, system message, and the discriminator prompts and models names. It also requires an error message that can be returned to the request if a discriminator determines the prompt or answer is not related.

The TOML looks like:

# This file describes the models that should be used by the chatbot

# The chat model is used to answer the questions. The model should be trained to
# except prompts in the following format where '{query}' is replaced with the
# prompt from the user:
#
#             "Question:{query} Answer:"
#
[chat]
# name of the model (default)
#
model = "gpt-3.5-turbo"

# Content for the system message used to preface a user's question
#
system_message = "You are a factual chatbot."

# Optional
# If provided this model will be used to check that the prompt submitted by the
# user is related to the subject the chat model has been trained for.
#
[input_discriminator]

# name of the model
#
model = "babbage-002"

# The prompt the discriminator has been trained with. The discriminator should
# return " yes" if related and " no" if not.
#
prompt = "Is the following question related to my subject?\nQuestion: {query}\nRelated:"

# The error message that should be returned from the ChatBot if the prompt is
# not related.
#
error = "I'm sorry. I can only answer questions about my subject."

# Optional
# If provided this model will be used to check that the completion returned by
# the chat model is related to the subject the chat model has been trained for.
# Because the model was fine-tuned on a base model, it can return unrelated
# completions.
#
[output_discriminator]

# name of the model
#
model = "babbage-002"

# The prompt the discriminator has been trained with. The discriminator should
# return " yes" if related and " no" if not.
#
prompt = "Is the following answer related to Embyr and AI?\nAnswer: {query}\nRelated:"

# The error message that should be returned from the ChatBot if the completion
# is not related.
#
error = "I'm sorry. I am unable to answer your question right now."

To parse the TOML file, implemented the following struct in a src/model.rs file:

use std::fs;

use log::debug;
use ryst_error::InvalidArgumentError;
use serde::Deserialize;

use crate::error::RestApiError;

const DEFAULT_MODEL: &str = "gpt-3.5-turbo";
const DEFAULT_SYSTEM_MESSAGE: &str = "You are a factual chatbot.";

#[derive(Deserialize, Clone, Debug)]
pub struct ChatModel {
    pub model: String,
    pub system_message: String,
}

#[derive(Deserialize, Clone, Debug)]
pub struct Discriminator {
    pub model: String,
    pub prompt: String,
    pub error: String,
}

#[derive(Deserialize, Clone, Debug)]
pub struct Models {
    pub chat: ChatModel,
    pub input_discriminator: Option<Discriminator>,
    pub output_discriminator: Option<Discriminator>,
}

impl Models {
    pub fn from_toml(path: &str) -> Result<Self, RestApiError> {
        debug!(
            "Loading config toml file: {:?}",
            fs::canonicalize(path).map_err(|err| RestApiError::InvalidArgument(
                InvalidArgumentError::new("model_path", err.to_string())
            ))?
        );

        let toml_string = fs::read_to_string(path).map_err(|err| {
            RestApiError::InvalidArgument(InvalidArgumentError::new("model", err.to_string()))
        })?;
        toml::from_str::<Models>(&toml_string).map_err(|error| {
            RestApiError::InvalidArgument(InvalidArgumentError::new("model", error.to_string()))
        })
    }
}

impl Default for Models {
    fn default() -> Self {
        Self {
            chat: ChatModel {
                model: DEFAULT_MODEL.to_string(),
                system_message: DEFAULT_SYSTEM_MESSAGE.to_string(),
            },
            input_discriminator: None,
            output_discriminator: None,
        }
    }
}

Using clap, define the command line arguments for the REST API in src/main.rs:

use clap::{Parser, ValueEnum};

/// ChatBot REST API
#[derive(Parser)]
#[command(version = option_env!("BUILD_VERSION").unwrap_or(env!("CARGO_PKG_VERSION")))]
pub struct RestApiArgs {
    /// REST API bind endpoint
    #[arg(
        long = "bind",
        env = "CHATBOT_REST_API_BIND",
        default_value = "127.0.0.1:8000",
        value_name = "REST_API_ENDPOINT"
    )]
    rest_api_bind: String,
    /// Allow origins for CORS
    #[arg(
        long = "allow-list",
        env = "CHATBOT_CORS_ALLOW_LIST",
        value_name = "ORIGIN",
        default_value = "*"
    )]
    allow_list: String,
    /// The path to the TOML file that defines the models that should be used
    #[arg(long = "models-path", env = "CHATBOT_MODELS_PATH", value_name = "PATH")]
    models_path: Option<String>,
}

This will result in the following command line options:

ChatBot REST API

Usage: chatbot-api [OPTIONS]

Options:
      --bind <REST_API_ENDPOINT>   REST API bind endpoint [env: CHATBOT_REST_API_BIND=] [default: 127.0.0.1:8000]
      --allow-list <ORIGIN>        Allow origins for CORS [env: CHATBOT_CORS_ALLOW_LIST=] [default: *]
      --models-path <PATH>         The path to the TOML file that defines the models that should be used [env: CHATBOT_MODELS_PATH=]

The path to the model TOML file should be passed to the --models-path. Update your own models.toml file to include the names of the models you built in the previous sections. If no model TOML file is provided, the ChatBot will default to using gpt-3.5-turbo with no discriminator checks.

Handler

With our model configuration in place, the next step is to develop a handler that can be used within our POST route.

The following handler will check if there is an input discriminator and check if the prompt passed in is related, returning the associated error message if not. If it is, the prompt will be submitted against the fine-tuned QA model. The generated response will then be checked against the output discriminator. If related, the generated response is returned, otherwise the output discriminator error is returned.

The following code uses ryst-openai::CompletionRequest support. Add it to the src/resources/chat/post.rs file:

use rocket::serde::{Deserialize, Serialize};
use ryst_openai::CompletionRequest;

use crate::models::Models;
use crate::resources::error::ErrorResponse;

const DEFAULT_MAX_TOKENS: i32 = 500;
const DISCRIMINATOR_NEGATIVE_RESPONSE: &str = " no";

#[derive(Debug, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct ChatPrompt {
    pub prompt: String,
    pub max_tokens: Option<i32>,
    pub temperature: Option<f32>,
}

#[derive(Debug, Serialize)]
#[serde(crate = "rocket::serde")]
pub struct ChatResponse {
    pub data: ChatCompletion,
}

#[derive(Debug, Serialize)]
#[serde(crate = "rocket::serde")]
pub struct ChatCompletion {
    pub choice: String,
}

pub async fn handler(
    models: &Models,
    prompt: ChatPrompt,
) -> Result<ChatResponse, ErrorResponse> {
    // check if prompt is related
    if let Some(input_discriminator) = &models.input_discriminator {
        let response = CompletionRequest::new(
            &input_discriminator.model,
            &input_discriminator
                .prompt
                .replace("{query}", &prompt.prompt),
        )
        .with_user(&prompt.user)
        .with_temperature(0.0)
        .with_max_tokens(1)
        .with_logprobs(2)
        .submit()
        .await?;

        if response.choices[0].text == DISCRIMINATOR_NEGATIVE_RESPONSE {
            return Ok(ChatResponse {
                data: ChatCompletion {
                    choice: input_discriminator.error.to_string(),
                },
            });
        }
    }

    let mut chat_request = ChatCompletionRequest::new(
        &models.chat.model,
        &[
            Message::new("system", &models.chat.system_message),
            Message::new("user", &prompt.prompt),
        ],
    )
    .with_user(&prompt.user)
    .with_stop("**STOP**")
    .with_max_tokens(DEFAULT_MAX_TOKENS);

    if let Some(max_tokens) = prompt.max_tokens {
        chat_request = chat_request.with_max_tokens(max_tokens)
    }

    if let Some(temperature) = prompt.temperature {
        chat_request = chat_request.with_temperature(temperature)
    }

    let chat_response = chat_request.submit().await?;

    let choice = &chat_response
        .choices
        .get(0)
        .ok_or_else(|| {
            error!("No choices were returned from OpenAI");
            ErrorResponse::new(500, "Internal Server Error")
        })?
        .message
        .content;

    // check if answer is related
    if let Some(output_discriminator) = &models.output_discriminator {
        let response = CompletionRequest::new(
            &output_discriminator.model,
            &output_discriminator.prompt.replace("{query}", choice),
        )
        .with_user(&prompt.user)
        .with_temperature(0.0)
        .with_max_tokens(1)
        .with_logprobs(2)
        .submit()
        .await?;

        if response.choices[0].text == DISCRIMINATOR_NEGATIVE_RESPONSE {
            return Ok(ChatResponse {
                data: ChatCompletion {
                    choice: output_discriminator.error.to_string(),
                },
            });
        }
    }

    // return response
    Ok(ChatResponse {
        data: ChatCompletion {
            choice: choice.to_string(),
        },
    })
}

The /src/resources/chat/mod.rs will contain:

pub mod post;

and add the following to src/resources/mod.rs:

pub mod chat;
pub mod error;

POST route

The following step defines the POST /chat route, leveraging the capabilities of the Rocket framework. Utilizing Rocket’s route API, we create an asynchronous function. This function receives the user-submitted prompt and model information retrieved from Rocket’s state. It then forwards these parameters to the handler we previously established.

For more information about implementing routes in Rocket, see the Request guide.

Add the following to a src/routes/chat.rs file:

use rocket::http::Status;
use rocket::response::status;
use rocket::serde::json::Json;
use rocket::State;

use crate::models::Models;
use crate::resources::chat::post;

#[derive(Responder)]
pub enum ChatResponder {
    V1(Json<post::ChatResponse>),
    #[response(status = 400)]
    BadRequest(Json<ErrorResponse>),
    #[response(status = 500)]
    Internal(Json<ErrorResponse>),
}

impl From<ErrorResponse> for ChatResponder {
    fn from(err: ErrorResponse) -> Self {
        match err.status_code {
            400 => ChatResponder::BadRequest(Json(err)),
            _ => ChatResponder::Internal(Json(err)),
        }
    }
}

#[post("/chat", format = "json", data = "<chat_prompt>")]
pub async fn new_chat(
    chat_prompt: Json<post::ChatPrompt>,
    models: &State<Models>,
) -> ChatResponder {
    match post::handler(
        models,
        chat_prompt.into_inner(),
    )
    .await
    {
        Ok(new_chat_response) => ChatResponder::V1(Json(new_chat_response)),
        Err(err) => {
            // the log the error returned
            error!("{}", err);
            err.into()
        }
    }
}

The /src/routes/mod.rs will contain:

pub mod chat;

Putting it all together

Last but not least, we must construct a Rocket REST API that incorporates our route.

The create_config function will be responsible for capturing our command-line arguments and constructing a rocket::Config instance, which will govern the REST API’s endpoint and port settings. Before starting up the Rocket API, we need to parse the provided model file and pass it to Rocket’s storage mechanism. This will grant all routes access to the pre-configured Models struct.

And, to complete the grand finale, we mount our chat route and initiate the launch sequence!

Add the following to your existing main.rs file:

#[macro_use]
extern crate rocket;

use std::net::IpAddr;

use clap::{Parser, ValueEnum};
use rocket::Config;

use error::RestApiError;
use ryst_error::{InternalError, InvalidArgumentError} ;

use models::Models;
use routes::chat;

pub fn create_config(args: &RestApiArgs) -> Result<Config, RestApiError> {
    let bind_endpoint: Vec<&str> = args.rest_api_bind.split(':').collect();

    if bind_endpoint.len() != 2 {
        return Err(RestApiError::InvalidArgument(InvalidArgumentError::new(
            "bind",
            "provided endpoint is in the wrong format, should be address::port".to_string(),
        )));
    }

    Ok(Config {
        address: IpAddr::V4(bind_endpoint[0].parse().map_err(|_| {
            RestApiError::InvalidArgument(InvalidArgumentError::new(
                "address",
                "provided address is in the wrong format, should be 127.0.0.1".to_string(),
            ))
        })?),
        port: bind_endpoint[1].parse().map_err(|_| {
            RestApiError::InvalidArgument(InvalidArgumentError::new(
                "port",
                "provided port is in the wrong format, should be u16 integer".to_string(),
            ))
        })?,

        ..Config::debug_default()
    })
}

#[rocket::main]
async fn main() -> Result<(), RestApiError> {
    let args = RestApiArgs::parse();
    let config = create_config(&args)?;

    let models = if let Some(path) = &args.models_path {
        let models = Models::from_toml(path)?;
        debug!("Using configured model: {}", models.chat.model);
        models
    } else {
        let models = Models::default();
        debug!("Using default model: {}", models.chat.model);
        models
    };

    let _rocket = rocket::custom(config)
        .mount("/", routes![chat::new_chat])
           .manage(models)
        .launch()
        .await
        .map_err(|e| (RestApiError::Internal(InternalError::from_source(Box::new(e)))))?;

    Ok(())
}

The REST API can be run with the following command after setting your OPENAI_API_KEY environment variable:

$ cargo run -- --models-path models.toml
[2023-09-18 10:04:19.954] T[main] WARN [rocket::rocket] 🚀 Rocket has launched from http://127.0.0.1:8000

This REST API is intentionally designed to be simple, and I haven’t covered every aspect of its development.

However, I strongly recommend the implementation of comprehensive logging support that can be configured to log to both a file and standard output. Effective logging is essential for monitoring and troubleshooting.

Additionally, one crucial consideration is the incorporation of rate limiting mechanisms for token consumption. Since this REST API relies on your API KEY, you are responsible for the associated token costs. In our final solution, we have introduced methods to limit both the number of tokens consumed by individual users and the overall daily token usage by the API. When combined with the discriminators, these measures help ensure the responsible and secure usage of your ChatBot and tokens, guarding against misuse and excessive consumption.

Conclusion

This post has provided a rapid overview of creating a customizable chatbot REST API in Rust. In the upcoming installment, we will take the next step and dive into the implementation of a user interface that simulates a chat window. However, our dialogue partners will be our trained models instead of other individuals.

Part 1: What are Chatbots, and why would I want one?

Part 2: From Markdown to Training Data?

Part 3: Fine-tune a Chatbot QA model

Part 4: Safeguard your Chatbot with Discriminators

Part 6: Bring our Chatbot to Life: The User Interface

Resources

Rust

ryst-openai

Rocket

clap

OpenAI libraries.