Luís Roque, Rafael Guedes, Author at Towards Data Science https://towardsdatascience.com/author/luisroque/ The world’s leading publication for data science, AI, and ML professionals. Thu, 20 Feb 2025 00:24:52 +0000 en-US hourly 1 https://wordpress.org/?v=6.7.1 https://towardsdatascience.com/wp-content/uploads/2025/02/cropped-Favicon-32x32.png Luís Roque, Rafael Guedes, Author at Towards Data Science https://towardsdatascience.com/author/luisroque/ 32 32 Multimodal Search Engine Agents Powered by BLIP-2 and Gemini https://towardsdatascience.com/multimodal-search-engine-agents-powered-by-blip-2-and-gemini/ Wed, 19 Feb 2025 22:01:52 +0000 https://towardsdatascience.com/?p=598147 This post was co-authored with Rafael Guedes. Introduction Traditional models can only process a single type of data, such as text, images, or tabular data. Multimodality is a trending concept in the AI research community, referring to a model’s ability to learn from multiple types of data simultaneously. This new technology (not really new, but […]

The post Multimodal Search Engine Agents Powered by BLIP-2 and Gemini appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

Traditional models can only process a single type of data, such as text, images, or tabular data. Multimodality is a trending concept in the AI research community, referring to a model’s ability to learn from multiple types of data simultaneously. This new technology (not really new, but significantly improved in the last few months) has numerous potential applications that will transform the user experience of many products.

One good example would be the new way search engines will work in the future, where users can input queries using a combination of modalities, such as text, images, audio, etc. Another example could be improving AI-powered customer support systems for voice and text inputs. In e-commerce, they are enhancing product discovery by allowing users to search using images and text. We will use the latter as our case study in this article.

The frontier AI research labs are shipping several models that support multiple modalities every month. CLIP and DALL-E by OpenAI and BLIP-2 by Salesforce combine image and text. ImageBind by Meta expanded the multiple modality concept to six modalities (text, audio, depth, thermal, image, and inertial measurement units).

In this article, we will explore BLIP-2 by explaining its architecture, the way its loss function works, and its training process. We also present a practical use case that combines BLIP-2 and Gemini to create a multimodal fashion search agent that can assist customers in finding the best outfit based on either text or text and image prompts.

Figure 1: Multimodal Search Agent (image by author with Gemini)

As always, the code is available on our GitHub.

BLIP-2: a multimodal model

BLIP-2 (Bootstrapped Language-Image Pre-Training) [1] is a vision-language model designed to solve tasks such as visual question answering or multimodal reasoning based on inputs of both modalities: image and text. As we will see below, this model was developed to address two main challenges in the vision-language domain:

  1. Reduce computational cost using frozen pre-trained visual encoders and LLMs, drastically reducing the training resources needed compared to a joint training of vision and language networks.
  2. Improving visual-language alignment by introducing Q-Former. Q-Former brings the visual and textual embeddings closer, leading to improved reasoning task performance and the ability to perform multimodal retrieval.

Architecture

The architecture of BLIP-2 follows a modular design that integrates three modules:

  1. Visual Encoder is a frozen visual model, such as ViT, that extracts visual embeddings from the input images (which are then used in downstream tasks).
  2. Querying Transformer (Q-Former) is the key to this architecture. It consists of a trainable lightweight transformer that acts as an intermediate layer between the visual and language models. It is responsible for generating contextualized queries from the visual embeddings so that they can be processed effectively by the language model.
  3. LLM is a frozen pre-trained LLM that processes refined visual embeddings to generate textual descriptions or answers.
Figure 2: BLIP-2 architecture (image by author)

Loss Functions

BLIP-2 has three loss functions to train the Q-Former module:

  • Image-text contrastive loss [2] enforces the alignment between visual and text embeddings by maximizing the similarity of paired image-text representations while pushing apart dissimilar pairs.
  • Image-text matching loss [3] is a binary classification loss that aims to make the model learn fine-grained alignments by predicting whether a text description matches the image (positive, i.e., target=1) or not (negative, i.e., target=0).
  • Image-grounded text generation loss [4] is a cross-entropy loss used in LLMs to predict the probability of the next token in the sequence. The Q-Former architecture does not allow interactions between the image embeddings and the text tokens; therefore, the text must be generated based solely on the visual information, forcing the model to extract relevant visual features.

For both image-text contrastive loss and image-text matching loss, the authors used in-batch negative sampling, which means that if we have a batch size of 512, each image-text pair has one positive sample and 511 negative samples. This approach increases efficiency since negative samples are taken from the batch, and there is no need to search the entire dataset. It also provides a more diverse set of comparisons, leading to a better gradient estimation and faster convergence.

Figure 3: Training losses explained (image by author)

Training Process

The training of BLIP-2 consists of two stages:

Stage 1 – Bootstrapping visual-language representation:

  1. The model receives images as input that are converted to an embedding using the frozen visual encoder.
  2. Together with these images, the model receives their text descriptions, which are also converted into embedding.
  3. The Q-Former is trained using image-text contrastive loss, ensuring that the visual embeddings align closely with their corresponding textual embeddings and get further away from the non-matching text descriptions. At the same time, the image-text matching loss helps the model develop fine-grained representations by learning to classify whether a given text correctly describes the image or not.
Figure 4: Stage 1 training process (image by author)

Stage 2 – Bootstrapping vision-to-language generation:

  1. The pre-trained language model is integrated into the architecture to generate text based on the previously learned representations.
  2. The focus shifts from alignment to text generation by using the image-grounded text generation loss which improves the model capabilities of reasoning and text generation.
Figure 5: Stage 2 training process (image by the author)

Creating a Multimodal Fashion Search Agent using BLIP-2 and Gemini

In this section, we will leverage the multimodal capabilities of BLIP-2 to build a fashion assistant search agent that can receive input text and/or images and return recommendations. For the conversation capabilities of the agent, we will use Gemini 1.5 Pro hosted in Vertex AI, and for the interface, we will build a Streamlit app.

The fashion dataset used in this use case is licensed under the MIT license and can be accessed through the following link: Fashion Product Images Dataset. It consists of more than 44k images of fashion products.

The first step to make this possible is to set up a Vector DB. This enables the agent to perform a vectorized search based on the image embeddings of the items available in the store and the text or image embeddings from the input. We use docker and docker-compose to help us set up the environment:

  • Docker-Compose with Postgres (the database) and the PGVector extension that allows vectorized search.
services:
  postgres:
    container_name: container-pg
    image: ankane/pgvector
    hostname: localhost
    ports:
      - "5432:5432"
    env_file:
      - ./env/postgres.env
    volumes:
      - postgres-data:/var/lib/postgresql/data
    restart: unless-stopped

  pgadmin:
    container_name: container-pgadmin
    image: dpage/pgadmin4
    depends_on:
      - postgres
    ports:
      - "5050:80"
    env_file:
      - ./env/pgadmin.env
    restart: unless-stopped

volumes:
  postgres-data:
  • Postgres env file with the variables to log into the database.
POSTGRES_DB=postgres
POSTGRES_USER=admin
POSTGRES_PASSWORD=root
  • Pgadmin env file with the variables to log into the UI for manual querying the database (optional).
PGADMIN_DEFAULT_EMAIL=admin@admin.com 
PGADMIN_DEFAULT_PASSWORD=root
  • Connection env file with all the components to use to connect to PGVector using Langchain.
DRIVER=psycopg
HOST=localhost
PORT=5432
DATABASE=postgres
USERNAME=admin
PASSWORD=root

Once the Vector DB is set up and running (docker-compose up -d), it is time to create the agents and tools to perform a multimodal search. We build two agents to solve this use case: one to understand what the user is requesting and another one to provide the recommendation:

  • The classifier is responsible for receiving the input message from the customer and extracting which category of clothes the user is looking for, for example, t-shirts, pants, shoes, jerseys, or shirts. It will also return the number of items the customer wants so that we can retrieve the exact number from the Vector DB.
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_google_vertexai import ChatVertexAI
from pydantic import BaseModel, Field

class ClassifierOutput(BaseModel):
    """
    Data structure for the model's output.
    """

    category: list = Field(
        description="A list of clothes category to search for ('t-shirt', 'pants', 'shoes', 'jersey', 'shirt')."
    )
    number_of_items: int = Field(description="The number of items we should retrieve.")

class Classifier:
    """
    Classifier class for classification of input text.
    """

    def __init__(self, model: ChatVertexAI) -> None:
        """
        Initialize the Chain class by creating the chain.
        Args:
            model (ChatVertexAI): The LLM model.
        """
        super().__init__()

        parser = PydanticOutputParser(pydantic_object=ClassifierOutput)

        text_prompt = """
        You are a fashion assistant expert on understanding what a customer needs and on extracting the category or categories of clothes a customer wants from the given text.
        Text:
        {text}

        Instructions:
        1. Read carefully the text.
        2. Extract the category or categories of clothes the customer is looking for, it can be:
            - t-shirt if the custimer is looking for a t-shirt.
            - pants if the customer is looking for pants.
            - jacket if the customer is looking for a jacket.
            - shoes if the customer is looking for shoes.
            - jersey if the customer is looking for a jersey.
            - shirt if the customer is looking for a shirt.
        3. If the customer is looking for multiple items of the same category, return the number of items we should retrieve. If not specfied but the user asked for more than 1, return 2.
        4. If the customer is looking for multiple category, the number of items should be 1.
        5. Return a valid JSON with the categories found, the key must be 'category' and the value must be a list with the categories found and 'number_of_items' with the number of items we should retrieve.

        Provide the output as a valid JSON object without any additional formatting, such as backticks or extra text. Ensure the JSON is correctly structured according to the schema provided below.
        {format_instructions}

        Answer:
        """

        prompt = PromptTemplate.from_template(
            text_prompt, partial_variables={"format_instructions": parser.get_format_instructions()}
        )
        self.chain = prompt | model | parser

    def classify(self, text: str) -> ClassifierOutput:
        """
        Get the category from the model based on the text context.
        Args:
            text (str): user message.
        Returns:
            ClassifierOutput: The model's answer.
        """
        try:
            return self.chain.invoke({"text": text})
        except Exception as e:
            raise RuntimeError(f"Error invoking the chain: {e}")
  • The assistant is responsible for answering with a personalized recommendation retrieved from the Vector DB. In this case, we are also leveraging the multimodal capabilities of Gemini to analyze the images retrieved and produce a better answer.
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_google_vertexai import ChatVertexAI
from pydantic import BaseModel, Field

class AssistantOutput(BaseModel):
    """
    Data structure for the model's output.
    """

    answer: str = Field(description="A string with the fashion advice for the customer.")

class Assistant:
    """
    Assitant class for providing fashion advice.
    """

    def __init__(self, model: ChatVertexAI) -> None:
        """
        Initialize the Chain class by creating the chain.
        Args:
            model (ChatVertexAI): The LLM model.
        """
        super().__init__()

        parser = PydanticOutputParser(pydantic_object=AssistantOutput)

        text_prompt = """
        You work for a fashion store and you are a fashion assistant expert on understanding what a customer needs.
        Based on the items that are available in the store and the customer message below, provide a fashion advice for the customer.
        Number of items: {number_of_items}
        
        Images of items:
        {items}

        Customer message:
        {customer_message}

        Instructions:
        1. Check carefully the images provided.
        2. Read carefully the customer needs.
        3. Provide a fashion advice for the customer based on the items and customer message.
        4. Return a valid JSON with the advice, the key must be 'answer' and the value must be a string with your advice.

        Provide the output as a valid JSON object without any additional formatting, such as backticks or extra text. Ensure the JSON is correctly structured according to the schema provided below.
        {format_instructions}

        Answer:
        """

        prompt = PromptTemplate.from_template(
            text_prompt, partial_variables={"format_instructions": parser.get_format_instructions()}
        )
        self.chain = prompt | model | parser

    def get_advice(self, text: str, items: list, number_of_items: int) -> AssistantOutput:
        """
        Get advice from the model based on the text and items context.
        Args:
            text (str): user message.
            items (list): items found for the customer.
            number_of_items (int): number of items to be retrieved.
        Returns:
            AssistantOutput: The model's answer.
        """
        try:
            return self.chain.invoke({"customer_message": text, "items": items, "number_of_items": number_of_items})
        except Exception as e:
            raise RuntimeError(f"Error invoking the chain: {e}")

In terms of tools, we define one based on BLIP-2. It consists of a function that receives a text or image as input and returns normalized embeddings. Depending on the input, the embeddings are produced using the text embedding model or the image embedding model of BLIP-2.

from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from PIL.JpegImagePlugin import JpegImageFile
from transformers import AutoProcessor, Blip2TextModelWithProjection, Blip2VisionModelWithProjection

PROCESSOR = AutoProcessor.from_pretrained("Salesforce/blip2-itm-vit-g")
TEXT_MODEL = Blip2TextModelWithProjection.from_pretrained("Salesforce/blip2-itm-vit-g", torch_dtype=torch.float32).to(
    "cpu"
)
IMAGE_MODEL = Blip2VisionModelWithProjection.from_pretrained(
    "Salesforce/blip2-itm-vit-g", torch_dtype=torch.float32
).to("cpu")

def generate_embeddings(text: Optional[str] = None, image: Optional[JpegImageFile] = None) -> np.ndarray:
    """
    Generate embeddings from text or image using the Blip2 model.
    Args:
        text (Optional[str]): customer input text
        image (Optional[Image]): customer input image
    Returns:
        np.ndarray: embedding vector
    """
    if text:
        inputs = PROCESSOR(text=text, return_tensors="pt").to("cpu")
        outputs = TEXT_MODEL(**inputs)
        embedding = F.normalize(outputs.text_embeds, p=2, dim=1)[:, 0, :].detach().numpy().flatten()
    else:
        inputs = PROCESSOR(images=image, return_tensors="pt").to("cpu", torch.float16)
        outputs = IMAGE_MODEL(**inputs)
        embedding = F.normalize(outputs.image_embeds, p=2, dim=1).mean(dim=1).detach().numpy().flatten()

    return embedding

Note that we create the connection to PGVector with a different embedding model because it is mandatory, although it will not be used since we will store the embeddings produced by BLIP-2 directly.

In the loop below, we iterate over all categories of clothes, load the images, and create and append the embeddings to be stored in the vector db into a list. Also, we store the path to the image as text so that we can render it in our Streamlit app. Finally, we store the category to filter the results based on the category predicted by the classifier agent.

import glob
import os

from dotenv import load_dotenv
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_postgres.vectorstores import PGVector
from PIL import Image

from blip2 import generate_embeddings

load_dotenv("env/connection.env")

CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver=os.getenv("DRIVER"),
    host=os.getenv("HOST"),
    port=os.getenv("PORT"),
    database=os.getenv("DATABASE"),
    user=os.getenv("USERNAME"),
    password=os.getenv("PASSWORD"),
)

vector_db = PGVector(
    embeddings=HuggingFaceEmbeddings(model_name="nomic-ai/modernbert-embed-base"),  # does not matter for our use case
    collection_name="fashion",
    connection=CONNECTION_STRING,
    use_jsonb=True,
)

if __name__ == "__main__":

    # generate image embeddings
    # save path to image in text
    # save category in metadata
    texts = []
    embeddings = []
    metadatas = []

    for category in glob.glob("images/*"):
        cat = category.split("/")[-1]
        for img in glob.glob(f"{category}/*"):
            texts.append(img)
            embeddings.append(generate_embeddings(image=Image.open(img)).tolist())
            metadatas.append({"category": cat})

    vector_db.add_embeddings(texts, embeddings, metadatas)

We can now build our Streamlit app to chat with our assistant and ask for recommendations. The chat starts with the agent asking how it can help and providing a box for the customer to write a message and/or to upload a file.

Once the customer replies, the workflow is the following:

  • The classifier agent identifies which categories of clothes the customer is looking for and how many units they want.
  • If the customer uploads a file, this file is going to be converted into an embedding, and we will look for similar items in the vector db, conditioned by the category of clothes the customer wants and the number of units.
  • The items retrieved and the customer’s input message are then sent to the assistant agent to produce the recommendation message that is rendered together with the images retrieved.
  • If the customer did not upload a file, the process is the same, but instead of generating image embeddings for retrieval, we create text embeddings.
import os

import streamlit as st
from dotenv import load_dotenv
from langchain_google_vertexai import ChatVertexAI
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_postgres.vectorstores import PGVector
from PIL import Image

import utils
from assistant import Assistant
from blip2 import generate_embeddings
from classifier import Classifier

load_dotenv("env/connection.env")
load_dotenv("env/llm.env")

CONNECTION_STRING = PGVector.connection_string_from_db_params(
    driver=os.getenv("DRIVER"),
    host=os.getenv("HOST"),
    port=os.getenv("PORT"),
    database=os.getenv("DATABASE"),
    user=os.getenv("USERNAME"),
    password=os.getenv("PASSWORD"),
)

vector_db = PGVector(
    embeddings=HuggingFaceEmbeddings(model_name="nomic-ai/modernbert-embed-base"),  # does not matter for our use case
    collection_name="fashion",
    connection=CONNECTION_STRING,
    use_jsonb=True,
)

model = ChatVertexAI(model_name=os.getenv("MODEL_NAME"), project=os.getenv("PROJECT_ID"), temperarture=0.0)
classifier = Classifier(model)
assistant = Assistant(model)

st.title("Welcome to ZAAI's Fashion Assistant")

user_input = st.text_input("Hi, I'm ZAAI's Fashion Assistant. How can I help you today?")

uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if st.button("Submit"):

    # understand what the user is asking for
    classification = classifier.classify(user_input)

    if uploaded_file:

        image = Image.open(uploaded_file)
        image.save("input_image.jpg")
        embedding = generate_embeddings(image=image)

    else:

        # create text embeddings in case the user does not upload an image
        embedding = generate_embeddings(text=user_input)

    # create a list of items to be retrieved and the path
    retrieved_items = []
    retrieved_items_path = []
    for item in classification.category:
        clothes = vector_db.similarity_search_by_vector(
            embedding, k=classification.number_of_items, filter={"category": {"$in": [item]}}
        )
        for clothe in clothes:
            retrieved_items.append({"bytesBase64Encoded": utils.encode_image_to_base64(clothe.page_content)})
            retrieved_items_path.append(clothe.page_content)

    # get assistant's recommendation
    assistant_output = assistant.get_advice(user_input, retrieved_items, len(retrieved_items))
    st.write(assistant_output.answer)

    cols = st.columns(len(retrieved_items)+1)
    for col, retrieved_item in zip(cols, ["input_image.jpg"]+retrieved_items_path):
        col.image(retrieved_item)

    user_input = st.text_input("")

else:
    st.warning("Please provide text.")

Both examples can be seen below:

Figure 6 shows an example where the customer uploaded an image of a red t-shirt and asked the agent to complete the outfit.

Figure 6: Example of text and image input (image by author)

Figure 7 shows a more straightforward example where the customer asked the agent to show them black t-shirts.

Figure 7: Example of text input (image by author)

Conclusion

Multimodal AI is no longer just a research topic. It is being used in the industry to reshape the way customers interact with company catalogs. In this article, we explored how multimodal models like BLIP-2 and Gemini can be combined to address real-world problems and provide a more personalized experience to customers in a scalable way.

We explored the architecture of BLIP-2 in depth, demonstrating how it bridges the gap between text and image modalities. To extend its capabilities, we developed a system of agents, each specializing in different tasks. This system integrates an LLM (Gemini) and a vector database, enabling retrieval of the product catalog using text and image embeddings. We also leveraged Gemini’s multimodal reasoning to improve the sales assistant agent’s responses to be more human-like.

With tools like BLIP-2, Gemini, and PG Vector, the future of multimodal search and retrieval is already happening, and the search engines of the future will look very different from the ones we use today.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Junnan Li, Dongxu Li, Silvio Savarese, Steven Hoi. 2023. BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models. arXiv:2301.12597

[2] Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, Dilip Krishnan. 2020. Supervised Contrastive Learning. arXiv:2004.11362

[3] Junnan Li, Ramprasaath R. Selvaraju, Akhilesh Deepak Gotmare, Shafiq Joty, Caiming Xiong, Steven Hoi. 2021. Align before Fuse: Vision and Language Representation Learning with Momentum Distillation. arXiv:2107.07651

[4] Li Dong, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng Gao, Ming Zhou, Hsiao-Wuen Hon. 2019. Unified Language Model Pre-training for Natural Language Understanding and Generation. arXiv:1905.03197

The post Multimodal Search Engine Agents Powered by BLIP-2 and Gemini appeared first on Towards Data Science.

]]>
Building Visual Agents that can Navigate the Web Autonomously https://towardsdatascience.com/building-visual-agents-that-can-navigate-the-web-autonomously-1184efbfe895/ Sat, 11 Jan 2025 13:02:04 +0000 https://towardsdatascience.com/building-visual-agents-that-can-navigate-the-web-autonomously-1184efbfe895/ A step-by-step guide to creating visual agents that can navigate the web autonomously

The post Building Visual Agents that can Navigate the Web Autonomously appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

In the age of exponential growth in artificial intelligence, the topic of the moment is the rise of Agentic Ai. These AI systems leverage large language models (LLMs) to make decisions, plan, and collaborate with other agents or humans.

When we wrap an LLM with a role, a set of tools, and a specific goal, we create what we call an agent. By focusing on a well-defined objective and having access to relevant APIs or external tools (like search engines, databases, or even browser interfaces – more about this later), agents can autonomously explore paths to achieve their targets. Thus, agentic AI opens up a new paradigm where multiple agents can tackle complex, multi-step workflows.

John Carmack and Andrej Karpathy recently discussed a topic on X (formerly Twitter) that inspired this article. Carmack mentioned that AI-powered assistants can push applications to expose features through text-based interfaces. In this world, LLMs talk to a command-line interface wrapped under the graphical user interface (a.k.a. GUI), sidestepping some of the complexity of pure vision-based navigation (that exists because we humans need it). Karpathy raises the valid point that advanced AI systems can become better at automating GUIs before developers can provide comprehensive textual interfaces to every application. We agree with Karpathy on this one.

Building on these ideas, this aHi rticle explores how to implement and empower AI agents with visual navigation capabilities. We will walk through how to build agents that can navigate the web autonomously by solely relying on their vision skills (no APIs or scraping). We can browse websites, move through pages to achieve a pre-defined goal, and retrieve the necessary information without human intervention.

Figure 2: Visual Agentic AI (image by author with DALL-E)
Figure 2: Visual Agentic AI (image by author with DALL-E)

As always, the code is available on our GitHub.

Multimodal LLMs: how do they work?

Multimodal LLMs (MLLM) were developed to address the limitations of LLMs. The latter performs well in zero-shot reasoning on most NLP tasks but fall short when dealing with vision elements. On the other hand, MLLMs complement Large Vision Models (LVM), which can process visual elements but lack the advanced reasoning capabilities of LLMs. By combining both, MLLMs integrate LLM reasoning with LVM visual processing, enabling the analysis of different inputs such as text and images [1]. Figure 2 shows the current state-of-the-art MLLMs and their evolution over time.

Figure 3: Existing MLLMs landscape (source)
Figure 3: Existing MLLMs landscape (source)

Architecture

A typical MLLM architecture consists of three elements:

1. The pre-trained modality encoder is responsible for understanding the relationship between text and any other modality, such as audio or image. It aligns their respective representations in a shared latent space.

In this case, our model receives images and text as input. Thus, it has two encoders, one targeting images and the other text. An image encoder is typically a convolution neural network (CNN) or a vision transformer (ViT) that converts the image into a high-dimensional vector representation, i.e., an embedding. A text encoder is usually a transformer-based language model that equally converts text into an embedding representation.

Afterward, the model aligns the outputs of both encoders in the shared latent space so that the embeddings of similar images and text descriptions are closer in that space.

This alignment is crucial for the model to understand which images match the text descriptions, and it is achieved by training the model using a contrastive loss that:

  • Computes the similarity (dot product) between every image-text pair in the batch.
  • Applies a softmax function to create a probability distribution over the pairs.
  • Optimizes the model using, for example, cross-entropy loss. It maximizes the similarity between correct image-text pairs and minimizes the similarity between unrelated image-text pairs.
Figure 4: Training a multimodal model to align text and image embeddings (image by author)
Figure 4: Training a multimodal model to align text and image embeddings (image by author)

2. Modality interface consists of a learnable connector. It is responsible for bridging the gap between modalities. This learnable connector is faster and cheaper to train than training an MLLM in an end-to-end manner, and its objective consists of aligning the output of the visual/audio encoder and the input text. It can be implemented in two ways:

  • Token-level fusion is where the output of the image/audio encoders is transformed into tokens (through query-based learning or by simply using a linear MLP) and concatenated with the text tokens.
  • Feature-level fusion adds extra modules to capture deeper interactions between text and visual/audio features through cross-attention layers.
Figure 5: Token and Feature level fusion (image by author)
Figure 5: Token and Feature level fusion (image by author)

3. Pre-trained LLMs are responsible for receiving the aligned representation of the different input modalities as input to reason and generate a text answer. One can also add an optional generator to create more modalities besides text.

We can use any LLMs in this layer, like GPT-4o, LLaMA, Mixtral, Gemini, Qwen, etc. The choice of the LLM depends on the specific use case, as these models come in varying sizes (usually larger models mean better performance). Some models are multilingual, while others focus on a single language, most commonly English. And certain models, such as Mixtral, achieve faster inference times by using a Mixture of Experts (MoE). This technique scales up model expressiveness without increasing so much the total number of parameters.

Figure 6: MLLM architecture (source)
Figure 6: MLLM architecture (source)

Google GenAI SDK

Google recently launched its GenAI SDK which can be easily installed by running the following command pip install google-generativeai. This new package is the easiest way for developers to interact with Gemini, Google’s multimodal model developed by Google DeepMind [2].

The team at Google developed several notebooks with different use cases that can be used with this new package [3]. Some of the most interesting use cases are:

  • Object Detection can be easily performed with gemini-1.5-flash-002 by simply sending the image and the object you want to extract as input. In the example below, the user targeted the umbrellas in the image.
Figure 6: Extracting bounding boxes with Gemini (source)
Figure 6: Extracting bounding boxes with Gemini (source)

Video Analysis is possible by leveraging Gemini multimodal capabilities to analyze a video. One can use it to identify elements of interest in the different frames or to summarize its content.

  • Entity extraction from a large corpus can be a hard and time-consuming task for a human, but not for Gemini. In this example, the team at Google shows how to extract streets and means of transportation from a long text by just prompting the following:
directions_prompt = f"""
  From the given text, extract the following entities and return a list of them.
  Entities to extract: street name, form of transport.
  Text: {directions}
  Street = []
  Transport = [] """

{ "Street": [ "Via dei Fori Imperiali", "Via del Corso", "Via della Lungara" ], "Transport": [ "train", "metro", "bus", "Terravision shuttle", "taxi", "private transfer service" ] }

  • Interpreting Charts, Graphs, and Slides. As we saw earlier, being multimodal means that Gemini can process both text and image inputs and reason across both input types. In this example, Gemini is asked to explain the different pub/sub patterns using the image below:

The image shows three different Pub/Sub patterns:

Many-to-One (Fan-In): In this pattern, multiple publishers send messages to a single topic. A single subscriber listens to this topic and consumes all the messages.

Many-to-Many: This pattern allows multiple publishers to send messages to a single topic and multiple subscribers can consume messages from that topic. It is useful for load balancing and ensuring that messages are distributed across multiple consumers.

One-to-Many (Fan-Out): This pattern has a single publisher sending messages to a topic, and multiple subscribers are listening to that topic. This pattern is used to distribute messages to a large number of consumers.

Figure 7: Slide with Cloud Pub/Sub explanation (source)
Figure 7: Slide with Cloud Pub/Sub explanation (source)

Navigating web pages using Google GenAI SDK

In this section, we work through our use case using Google GenAI SDK. We create an agent capable of navigating web pages and extracting content from them. The goal is to have the agent navigate to ZAAI‘s web page, search for its blog, and get some metadata from the last article published. This agent leverages Gemini’s multimodal capabilities to extract the necessary information based only on screenshots from the website (image) and the instructions we provide through text.

These agents will open the door to many use cases that were impossible before. Even a few months ago, agents only relied on APIs to get information from third-party systems. In this case, there is no need to spend time building an API since the agent can navigate and extract information as humans do. Unlike traditional scraping, our agent will adapt to changes in the UI/UX of the websites and won’t require custom code to find specific elements on the page.

We start by importing the libraries, defining global variables, and loading the Gemini API Key from an env file (the API key can be obtained here):

Python">import subprocess
import time
import pyautogui
import base64
import google.generativeai as genai
import json
import re
import os
from dotenv import load_dotenv
from PIL import Image, ImageDraw
CHROME_PATH="/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
SCREENSHOT_PATH = "assets/zaai_homepage.png"
SCREENSHOT_BBOXED_PATH = "assets/zaai_homepage_bboxed.png"
SCREENSHOT_BLOG_PATH = "assets/zaai_lab.png"

ZAAI_URL = "https://zaai.ai"

load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
    raise ValueError("API key not found. Please set GEMINI_API_KEY in your .env file.")
genai.configure(api_key=api_key)

We create three utility functions to handle the input and output for Gemini. The first function reads an image file and converts it into a Base64-encoded string (necessary for Gemini’s API). The second function extracts JSON data from the LLM’s response and performs the necessary pre-processing to ensure the data is usable. The third function transforms bounding box coordinates, which Google normalizes to a range of 0 to 1000, into actual pixel values.

def encode_image_to_base64(image_path: str) -> str:
    """ Read an image file and return its Base64-encoded string. """
    with open(image_path, "rb") as img_file:
        image_data = img_file.read()
    return base64.b64encode(image_data).decode("utf-8")

def extract_json_from_response(response) -> dict:
    """
    Extract JSON content from the LLM response, removing any code fence markers.
    """
    if not hasattr(response, "candidates") or not response.candidates:
        raise ValueError("Response does not contain valid candidates.")
    raw_text = response.candidates[0].content.parts[0].text
    json_str = re.sub(r"^```json|```$", "", raw_text.strip(), flags=re.MULTILINE)

    try:
        parsed_data = json.loads(json_str)
        return parsed_data
    except json.JSONDecodeError as e:
        raise ValueError(f"Failed to parse JSON: {e}nRaw LLM Response:n{raw_text}")

def update_coordinates_to_pixels(detection_info: dict, width: int, height: int) -> None:
    """
    Convert normalized bounding box coordinates ([0..1000]) to actual pixel values.
    """
    for key, value in detection_info.items():
        coords = value["coordinates"]
        xmin, ymin, xmax, ymax = coords
        value["coordinates"] = [
            (xmin / 1000.0) * width,
            (ymin / 1000.0) * height,
            (xmax / 1000.0) * width,
            (ymax / 1000.0) * height
        ]

Then, we developed functions to make tools available for the agent. First, we define a tool that overlays bounding boxes on an image and saves it. The second one captures a screenshot of the screen. A third one launches the Chrome browser and navigates to a specified URL. Finally, the last one identifies and clicks on bounding boxes, enabling interaction with specific UI elements.

def draw_bounding_boxes(image_path: str, detection_info: dict, output_path: str, color: str = "red") -> None:
    """
    Draw bounding boxes on the image and save to output_path.
    Expects detection_info to have pixel coordinates already.
    """
    try:
        with Image.open(image_path) as img:
            draw = ImageDraw.Draw(img)

        for label, details in detection_info.items():
                coords = details["coordinates"]
                description = details.get("description", "")
                draw.rectangle(coords, outline=color, width=2)
                draw.text((coords[0], coords[1] - 10), label, fill=color)

            img.save(output_path)
            print(f"Image saved with bounding boxes at: {output_path}")
    except Exception as e:
        print(f"Failed to create image with bounding boxes: {e}")

def take_screenshot(output_path: str):
    """Take a screenshot of the main screen (or the active window)."""
    time.sleep(2)  # wait a bit for the page to load
    screenshot = pyautogui.screenshot()
    screenshot.save(output_path)
    print(f"Screenshot saved to {output_path}.")

def open_chrome(url: str):
    """Open Chrome to a specific URL using a subprocess."""
    print(f"Opening Chrome at {url} ...")
    subprocess.Popen([CHROME_PATH, url])
    time.sleep(5)

def find_and_click_lab_element(bounding_box_data: dict):
    """
    Click the bounding box that should lead to the 'Lab' (or blog page).
    For simplicity, let's assume we pick the bounding box whose label
    or description references "Lab" or "Blog"
    """
    target_label = None
    for label, info in bounding_box_data.items():
        lower_desc = info["description"].lower()
        lower_label = label.lower()
        if "lab" in lower_desc or "lab" in lower_label:
            target_label = label
            break
        if "blog" in lower_desc or "blog" in lower_label:
            target_label = label
            break

    if not target_label:
        print("Could not find a bounding box that references Lab/Blog in the description.")
        return
    coords = bounding_box_data[target_label]["coordinates"]

    # coords is [xmin, ymin, xmax, ymax]
    # let's pick the down left corner of the bb
    x_center = coords[0] / 2 # bc of retina res
    y_center = coords[1] / 2 # bc of retina res
    print(f"Clicking element: {target_label}")

    pyautogui.moveTo(x_center, y_center, duration=0.5)
    pyautogui.click()

Finally, we implement capabilities for our agent to extract specific information from what it is "seeing."

def identify_elements_with_descriptions(image_path: str) -> list:
    """
    Step 1: Ask the model to identify clickable elements and include descriptions.
    Return a list of objects, each containing 'label' and 'description'.
    """
    model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
    encoded_image = encode_image_to_base64(image_path)

    prompt = """
    You are given a screenshot of a website homepage.
    Identify all the relevant clickable elements (text, buttons, icons, tabs, images, etc.)
    on the website page only (discard browser elements if they appear in the image) and provide:
      - A semantically rich name as the label (e.g., "Lab Link" or "Blog Tab")
      - A short description of its purpose on the page and any relevant visual details

    Output JSON in this format:
    {
      "elements": [
        {
          "label": "some descriptive label",
          "description": "short description with visual nuances"
        },
        ...
      ]
    }
    """

    response = model.generate_content([
        {"mime_type": "image/png", "data": encoded_image},
        prompt
    ])

    parsed_data = extract_json_from_response(response)
    if "elements" not in parsed_data:
        raise ValueError("No 'elements' field found in the JSON response.")

    return parsed_data["elements"]

def propose_bounding_boxes(image_path: str, identified_elements: list) -> dict:
    """
    Step 2: Provide the list of elements from Step 1 (labels + descriptions).
    Ask the model to propose bounding boxes in [xmin, ymin, xmax, ymax] with 0..1000 scale.
    Return a dict where keys are labels, and values have 'coordinates' + 'description'.
    We also copy the 'description' from the elements so that we keep it in the final output.
    """
    model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
    encoded_image = encode_image_to_base64(image_path)

    elements_json_str = json.dumps(identified_elements, indent=2)

    prompt = f"""
    The following clickable elements were identified (labels + descriptions):
    {elements_json_str}

    Propose a bounding box (in [xmin, ymin, xmax, ymax], 0..1000 scale) for each element
    so we can locate them on the screenshot.

    Output JSON in the format:
    {{
      "<element_label>": {{
        "coordinates": [xmin, ymin, xmax, ymax],
        "description": "<the same description from above>"
      }},
      ...
    }}
    """

    response = model.generate_content([
        {"mime_type": "image/png", "data": encoded_image},
        prompt
    ])

    parsed_data = extract_json_from_response(response)
    return parsed_data

def retrieve_latest_blog_info(image_path: str) -> (str, str):
    """
    Getting the latest post title and date
    """
    model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest")
    encoded_image = encode_image_to_base64(image_path)

    prompt = """
    You are given a screenshot of a website blog page.
    Identify the latest article and get its title and date.

    Output JSON in this format:
    {
      "title": "title of the article",
      "pub_date": "date of publishing of the article"
    },
    """

    response = model.generate_content([
        {"mime_type": "image/png", "data": encoded_image},
        prompt
    ])

    parsed_data = extract_json_from_response(response)
    if "title" not in parsed_data:
        raise ValueError("No 'title' field found in the JSON response.")

    return parsed_data['title'], parsed_data['pub_date']

The visual workflow of the process can be seen in Figure 8.

Figure 8: Steps performed by the Visual Agent to extract the metadata from the latest blog article (image by author).
Figure 8: Steps performed by the Visual Agent to extract the metadata from the latest blog article (image by author).

Conclusion

Agentic AI opens up many new possibilities. As we make these agents more autonomous, the ability to act on goals and adapt to new information makes old and complex problems suddenly easy to solve. They are also gaining more and more human-like capabilities. For instance, the ability to interpret an image and extract information was a characteristic that only humans could do until recently. A few months ago, if we wanted a machine to be able to extract this level of detailed information from an image, we would 1) train a model for a specific goal or 2) program the machine to perform a very specific task. None of the approaches were scalable since if something changed, we would most likely need to retrain or reprogram our approach.

Nowadays, multimodal models can interpret and analyze visual information by simply following text instructions and reasoning over the inputs. This means that, without any modifications, the model can still function even if the context or the visual appearance of the desired information changes.

Agentic AI is still taking the first steps, and we look forward to seeing what comes next!

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Shukang Yin, Chaoyou Fu, Sirui Zhao, Ke Li, Xing Sun, Tong Xu, Enhong Chen. (2024). A Survey on Multimodal Large Language Models. arXiv:2306.13549.

[2] https://github.com/google-gemini/generative-ai-python/blob/main/README.md

[3] https://github.com/google-gemini/cookbook/tree/main/examples

The post Building Visual Agents that can Navigate the Web Autonomously appeared first on Towards Data Science.

]]>
Agentic AI: Building Autonomous Systems from Scratch https://towardsdatascience.com/agentic-ai-building-autonomous-systems-from-scratch-8f80b07229ea/ Fri, 13 Dec 2024 15:02:31 +0000 https://towardsdatascience.com/agentic-ai-building-autonomous-systems-from-scratch-8f80b07229ea/ A Step-by-Step Guide to Creating Multi-Agent Frameworks in the Age of Generative AI

The post Agentic AI: Building Autonomous Systems from Scratch appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

The rise of generative AI is the new platform shift of the digital era. It solves problems ranging from automation in large enterprises to various types of R&D and creativity. The global market is projected to surpass $65 billion in 2024, and 86% of IT leaders anticipate large organizational changes [1]. So far, the biggest returns are from chatbots (the more generic and abundant use case), code copilots, and enterprise search.

Investment continues to flow into AI, with $13.8 billion invested in 2024 (a sixfold increase from 2023) [1]. Besides, businesses are embedding AI into their core strategies and systems. Technologies like retrieval-augmented generation (RAG), fine-tuning, and specialized models for vertical applications (e.g., healthcare, legal) are becoming mainstream.

Large Language Models (LLMs) have brought attention to AI (in several ways) and opened the door to new ways of solving old problems. This new way is through Agentic Ai – a framework where autonomous agents work collaboratively to execute complex, multi-step workflows.

Our demo shows how you can work and develop a multi-agent system. It integrates three specialized agents:

  • A web researcher agent that ingests and analyzes internet data.
  • A transcriptor and summarizer agent that retrieves and condenses video or text data into actionable summaries.
  • A blog writer agent that synthesizes this information into a coherent structure.

These agents operate within a structured workflow. They leverage foundational LLMs and existing tools in everyday enterprise stacks. We show how organizations can streamline tasks, reduce human effort, and enhance output quality – all while maintaining adaptability to complex scenarios.

Figure 1: Multi-Agent System (image by author with DALL-E)
Figure 1: Multi-Agent System (image by author with DALL-E)

As always, the code is available on our GitHub.

AI Agents: What are they?

AI agents, usually powered by an LLM, are systems designed to act autonomously to achieve a specific goal. They receive an input prompt and have access to a toolset needed to complete certain tasks.

The input prompt can take several forms. It can be a simple text prompt given by a human with instructions to follow, such as _"Write a blog post about AI Agents." I_n a Multi-Agent System (MAS), it can be the output of the previous agent, which can also be text or more structured data such as JSON.

The tools an agent can access to perform a task are crucial for its success (similar to humans). For example, if a chef does not have an oven working, they cannot cook a delicious roast. In the case of AI Agents, these tools are usually APIs that allow them to connect to other systems to perform a task. For example, a connection to a search engine to look up information or a database to run a query against.

When building these kinds of Agents, there are two main classes to be defined [2]:

  1. The Agent (which has four main components):
  • The LLM utilized by the agent can be a closed source like GPT-4 or Claude Sonnet or an open source like LLaMA 3.3 or Mixtral 8×22. The LLM receives parameters we should also set accordingly, such as temperature or the maximum number of tokens produced. The choice of the LLM will depend on the task the agent has to perform. For example, while GPT-4 has good reasoning capabilities, Claude Sonnet performs better at coding, and GPT-4o-mini is the fastest. On the other hand, one might opt for an open-source model to avoid sharing information with third-party companies if dealing with critical information.
  • The Role of the agent defines its responsibilities, providing purpose and guiding the agent through the tasks and behaviors that are expected of it. For example, the role of the agent can be processing and analyzing information, retrieving data from a database, or coordinating interactions between other agents.
  • The Backstory defines the agent’s current knowledge of its environment, responsibilities, and interactions with other agents or tools. It also defines the agent’s current intent, i.e., what the agent is planning to do based on its knowledge of the environment and its goal.
  • The Goal is what the agent is expected to achieve, and it usually translates into the agent’s output. For example, if the agent is responsible for retrieving data from a database and answering a user’s question, its goal is to get an answer, and the output is the answer.
  1. The Task (which has three main components):
  • The Description provides a detailed explanation of what needs to be done by clearly defining the nature of the task and the outcome. It also provides specific instructions and constraints that the agent might face. For example, if the task is to retrieve data from a database, the description must specify the parameters for retrieval and any formatting requirements.
  • The Output describes how the task result should be presented by setting clear expectations for the output. It can indicate that it should be in text, JSON, list, HTML, SQL, or the response from an API.
  • Finally, the Agent is responsible for executing the task.

While one AI Agent can effectively perform a specific task, we can only extract its full potential when leveraging a group of agents. By interacting and collaborating with each other, they offer scalability and specialization to solve complex problems. The next section will address these MAS.

Figure 2: Agent and Task definition (image by author)
Figure 2: Agent and Task definition (image by author)

Multi-Agent Collaborative System (MAS)

MAS is defined by a group of agents, also referred to as a Crew. Each possesses unique skills and specialized capabilities. These agents collaborate to solve simple tasks in order to achieve a bigger and more complex common goal [3].

Within the Crew, each agent is an individual LLM with distinct characteristics, roles, and specific tools. Similar to humans, these agents communicate with one another by sending the output of their tasks to subsequent agents to build upon.

The structure of a Crew can be categorized into three main types based on the interaction between agents:

  • Sequential: the agents work in a chain, where one agent’s output is the next’s input. By solving smaller tasks, they can solve the bigger and more complex objective for which the Crew was designed.
  • Hierarchical: it usually consists of a manager and multiple subordinates, and the role of the leader is to delegate, plan, and manage the completion of tasks. The subordinates execute the leader’s instructions. In this scenario, we can have agents performing tasks simultaneously since not every agent has a sequential dependency.
  • Hybrid: this structure has Sequential and Hierarchical environments within the same Crew. It typically happens when some agents, with complex tasks at hand, break them down into smaller ones and build a sub-crew with new agents. They become the leader of that sub-crew and, at the same time, a subordinate of the original Crew.
Figure 3: Multi-Agent Systems/Crew possible structures (image by author)
Figure 3: Multi-Agent Systems/Crew possible structures (image by author)

CrewAI: Creating a MAS to write a blog post

In this section, we will create a Multi-Agent System (MAS) to write a blog post about AI agents (we know it might sound a bit confusing to have AI agents writing about AI agents, but bear with us) using one of the most popular packages in this space, CrewAI. Figure 4 illustrates the complete architecture of our approach:

Figure 4: Agents architecture (image by author)
Figure 4: Agents architecture (image by author)

Our crew comprises three agents with different tasks that work collaboratively to generate a blog post in HTML format. These agents are:

  • Web Researcher Agent responsible for connecting to a search engine called SearXNG and retrieving useful and the most up-to-date YouTube URLs about AI Agents. The agent and its task definition can be seen below:
researcher:
 role: >
  {topic} Senior Data Researcher
 goal: >
  Uncover cutting-edge developments in {topic}
 backstory: >
  You're a seasoned researcher with a knack for uncovering the latest
  developments in {topic}. Known for your ability to find the most relevant
  information and present it in a clear and concise manner.

research_task:
  description: >
    Conduct a thorough research about {topic}
    Make sure you find any interesting and relevant youtube links given
    the current year is 2024.
  expected_output: >
    A list with youtube URLs that cover {topic} and the respective description
    with the most relevant information about {topic}. Ignore any links that don't
    start with "https://www.youtube.com".
  agent: researcher
  • Transcriptor & Summarizer Agent connects to the YouTube API to retrieve transcriptions from the URLs provided by the Web Search Agent. It summarizes the transcription, extracts main insights and references, and makes recommendations based on the video’s content. As shown below, the agent and its task is defined as:
summarizer:
  role: >
    {topic} Summarizer
  goal: >
    Summarize and extract knowledge and other insightful and interesting information from {topic}
  backstory: >
    You're an expert on analyzing information and extracting the most important
    and insightful information in a concise manner.
    You're known for your ability to summarize and retrive facts, references, quotes
    and recommend the most useful surprising information about the {topic}.

summarize_task:
  description: >
    Analyse the information about the {topic} thoroughly to extract the most valuable insights, facts, and recommendations.
    Adhere strictly to the provided schema when extracting information from the input content.
    Ensure that the output matches the field descriptions, types and constraints exactly. Ignore any links that don't
    start with "https://www.youtube.com".
  expected_output: >
    A json with a summary of the {topic} and the most valuable insights, facts, and recommendations. Also add the
    youtube links as references.
  agent: summarizer
  • Blog Writer Agent is the third and last agent of our crew. It makes use of the summary generated by the Transcriptor & Summarizer to create a blog post about the topic in an HTML format. This HTML must have a professional appearance, including a navigation bar to help the reader surface through the article. The agent and its task is defined as:
blog_writer:
  role: >
    {topic} Blog Writer
  goal: >
    Create detailed blog posts based on {topic} research findings
  backstory: >
    You're a meticulous writer with a keen eye for detail.
    You're known for your ability to turn complex topics into clear and concise blog posts,
    making it easy for others to understand and act on the information you provide.

write_task:
  description: >
    Review the context you got and expand each topic into a full section for a blog post.
    Make sure the blog post is detailed and contains any and all relevant information.
    The blog post must contain an introduction, a body, a code example and a conclusion section.
  expected_output: >
    A fully-fledged blog post with the main topics, each presented as a complete section of information.
    Format it as HTML without using '```'. Make it look like a professional tech blog website,
    including a navbar, menu, and styling. Incorporate YouTube links as clickable references within
    the text and at the end.
  agent: blog_writer

The definitions above must be defined in two different YAML files, one for agents (agents.yaml) and another for tasks (taks.yaml). Once this process has been done, it is time to create the tools agents use to perform their tasks.

In our case, only the Blog Writer does not need any tool, the Researcher needs the search engine, while the Transcriptor and Summarizer needs the connection to the YouTube API.

Search Engine (SearXNG)

  • It is good practice to define the input schema when defining the tools. As shown in the code snippet below, our search engine tool expects to receive the search query and the number of results to retrieve.
  • Then, we must define the tool itself. The __init__ function sets the search engine to use, while the _run function specifies how the agent will use the tool. It basically searches YouTube videos about the topic the user requested, in this case, AI Agents.
from crewai.tools import BaseTool
from typing import Type, Optional, List, Dict
from pydantic import BaseModel, Field, PrivateAttr
from langchain_community.utilities import SearxSearchWrapper

class SearxSearchToolInput(BaseModel):
    """Input schema for SearxSearchTool."""

    query: str = Field(..., description="The search query.")
    num_results: int = Field(10, description="The number of results to retrieve.")

class SearxSearchTool(BaseTool):
    name: str = "searx_search_tool"
    description: str = (
        "A tool to perform searches using the Searx metasearch engine. "
        "Specify a query and optionally limit by engines, categories, or number of results."
    )
    args_schema: Type[BaseModel] = SearxSearchToolInput
    _searx_wrapper: SearxSearchWrapper = PrivateAttr()

    def __init__(self, searx_host: str, unsecure: bool = False):
        """Initialize the SearxSearchTool with SearxSearchWrapper."""
        super().__init__()
        self._searx_wrapper = SearxSearchWrapper(
            searx_host=searx_host, unsecure=unsecure
        )

    def _run(
        self,
        query: str,
        num_results: int = 10,
    ) -> List[Dict]:
        """Perform a search using the Searx API."""
        try:
            results = self._searx_wrapper.results(
                query=query + " :youtube",
                num_results=num_results,
            )
            return results
        except Exception as e:
            return [{"Error": str(e)}]

YouTube API

  • This tool follows the same principle by first defining the input schema, which consists of the YouTube URL and the language we want the transcription to be in.
  • We also define an output schema where not only the transcription is returned but also the duration of the video.
  • Finally, the tool itself consists of extracting the video ID from the URL and connecting to the Youtube API to retrieve the transcription, as seen in the _run function.
Python">from typing import Type, Optional
from pydantic import Field, BaseModel

from youtube_transcript_api import (
    NoTranscriptFound,
    TranscriptsDisabled,
    YouTubeTranscriptApi,
)

from crewai.tools import BaseTool

class YouTubeTranscriptToolInputSchema(BaseModel):
    """
    Tool for fetching the transcript of a YouTube video using the YouTube Transcript API.
    Returns the transcript with text, start time, and duration.
    """

    video_url: str = Field(
        ..., description="URL of the YouTube video to fetch the transcript for."
    )
    language: Optional[str] = Field(
        None, description="Language code for the transcript (e.g., 'en' for English)."
    )

class YouTubeTranscriptToolOutputSchema(BaseModel):
    """
    Output schema for the YouTubeTranscriptTool. Contains the transcript text, duration, comments, and metadata.
    """

    transcript: str = Field(..., description="Transcript of the YouTube video.")
    duration: float = Field(
        ..., description="Duration of the YouTube video in seconds."
    )

class YouTubeTranscriptTool(BaseTool):
    """
    Tool for fetching the transcript of a YouTube video using the YouTube Transcript API.

    Attributes:
        input_schema (YouTubeTranscriptToolInputSchema): The schema for the input data.
        output_schema (YouTubeTranscriptToolOutputSchema): The schema for the output data.
    """

    name: str = "youtube_transcript_tool"
    description: str = (
        "A tool to perform youtube transcript extraction. "
        "Specify the url of the youtube video and optionally the language code."
    )
    args_schema: Type[BaseModel] = YouTubeTranscriptToolInputSchema

    def __init__(self):
        """
        Initializes the YouTubeTranscriptTool.
        """
        super().__init__()

    def _run(
        self, video_url: str, language: Optional[str] = None
    ) -> YouTubeTranscriptToolOutputSchema:
        """
        Runs the YouTubeTranscriptTool with the given parameters.

        Args:
            video_url (list[str]): The list of YouTube video URLs to fetch the transcript for.
            language (Optional[str]): The language code for the transcript (e.g., 'en' for English).

        Returns:
            YouTubeTranscriptToolOutputSchema: The output of the tool, adhering to the output schema.

        Raises:
            Exception: If fetching the transcript fails.
        """

        video_id = self.extract_video_id(video_url)
        try:
            if language:
                transcripts = YouTubeTranscriptApi.get_transcript(
                    video_id, languages=[language]
                )
            else:
                transcripts = YouTubeTranscriptApi.get_transcript(video_id)
        except (NoTranscriptFound, TranscriptsDisabled) as e:
            raise Exception(
                f"Failed to fetch transcript for video '{video_id}': {str(e)}"
            )

        transcript_text = " ".join([transcript["text"] for transcript in transcripts])
        total_duration = sum([transcript["duration"] for transcript in transcripts])

        return YouTubeTranscriptToolOutputSchema(
            transcript=transcript_text,
            duration=total_duration,
        )

    @staticmethod
    def extract_video_id(url: str) -> str:
        """
        Extracts the video ID from a YouTube URL.

        Args:
            url (str): The YouTube video URL.

        Returns:
            str: The extracted video ID.
        """
        return url.split("v=")[-1].split("&amp;")[0]

With agents, tasks, and tools defined, we can now build our Crew and set their dependencies.

As shown below, we declare the agents by using the @agent decorator, together with the respective tools and their components (role, goal and backstory). For tasks, we use the @task decorator and define their components (description, output and agent). Finally, we define how they should collaborate, i.e., they work in a sequential manner Process.sequential.

import os
from crewai import Agent, Crew, Process, Task
from crewai.project import CrewBase, agent, crew, task
from crew_zaai.src.crew_zaai.tools.searx import SearxSearchTool
from crew_zaai.src.crew_zaai.tools.youtube import YouTubeTranscriptTool

@CrewBase
class CrewZaai:
    """CrewZaai crew"""

    agents_config = "config/agents.yaml"
    tasks_config = "config/tasks.yaml"

    @agent
    def researcher(self) -> Agent:
        search_tool = SearxSearchTool(
            searx_host=os.getenv("SEARXNG_BASE_URL"), unsecure=False
        )

        return Agent(
            config=self.agents_config["researcher"], tools=[search_tool], verbose=True
        )

    @agent
    def summarizer(self) -> Agent:
        youtube_tool = YouTubeTranscriptTool()

        return Agent(
            config=self.agents_config["summarizer"], tools=[youtube_tool], verbose=True
        )

    @agent
    def blog_writer(self) -> Agent:
        return Agent(config=self.agents_config["blog_writer"], verbose=True)

    @task
    def research_task(self) -> Task:
        return Task(
            config=self.tasks_config["research_task"],
        )

    @task
    def summarizer_task(self) -> Task:
        return Task(
            config=self.tasks_config["summarize_task"],
        )

    @task
    def write_task(self) -> Task:
        return Task(
            config=self.tasks_config["write_task"], output_file="assets/report.html"
        )

    @crew
    def crew(self) -> Crew:
        """Creates the CrewZaai crew"""
        return Crew(
            agents=self.agents,
            tasks=self.tasks,
            process=Process.sequential,
            verbose=True,
        )@CrewBase
class CrewZaai:
    """CrewZaai crew"""

The last step is to make our agents work together to write the blog post by kicking off our Crew.

import sys
import warnings

from crew_zaai.src.crew_zaai.crew import CrewZaai
from dotenv import load_dotenv

warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
load_dotenv()

inputs = {"topic": "AI Agents"}
CrewZaai().crew().kickoff(inputs=inputs)

By default, the LLM used is GPT-4o-mini from OpenAI; therefore, the OpenAI API key must be set as an environment variable. We also need to set the API key for YouTube (check this link to create your key) and the URL for the search engine. Our .env file must contain the following variables:

YOUTUBE_API_KEY=<YOUR KEY>
OPENAI_API_KEY=<YOUR KEY>
SEARXNG_BASE_URL=https://search.zaai.ai

After running the script, the output of our crew is stored in a folder called assets/ and a screenshot of it can be seen below:

Figure 5: Blog post about AI Agents (image by author)
Figure 5: Blog post about AI Agents (image by author)

Conclusion

In this article, we explored the topic of the moment: Agentic AI.

Agentic AI is a shift in the industry towards a more autonomous and collaborative workflow without human intervention. It promises to free humans from boring and simple tasks and allow them to focus on more value-added ones.

We presented a simple use case where a MAS produced a blog post without human intervention. The agents searched the web for content, got transcriptions from recent YouTube videos, structured ideas, and produced an HTML page with the output. We can now use these agents for new and more complex use cases. For example, we could help customers search for products on a website using text and images, leveraging multimodal capabilities. We could also create personalized curriculums for a course (education) or review documents and assess their compliance (legal).

With the right investments and strategies, AI will continue to set new standards for productivity and creativity. The possibilities are endless, and the time to explore them is now.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Menlo Ventures. (2024). The State of Generative AI in the Enterprise. Retrieved from https://menlovc.com/2024-the-state-of-generative-ai-in-the-enterprise/

[2] Talebirad, Y., & Nadiri, A. (2023). Multi-Agent Collaboration: Harnessing the Power of Intelligent LLM Agents. arXiv:2306.03314.

[3] Han, S., Zhang, Q., Yao, Y., Jin, W., Xu, Z., & He, C. (2024). LLM Multi-Agent Systems: Challenges and Open Problems. arXiv:2402.03578.

The post Agentic AI: Building Autonomous Systems from Scratch appeared first on Towards Data Science.

]]>
TimesFM: The Boom of Foundation Models in Time Series Forecasting https://towardsdatascience.com/timesfm-the-boom-of-foundation-models-in-time-series-forecasting-29701e0b20b5/ Fri, 20 Sep 2024 19:18:03 +0000 https://towardsdatascience.com/timesfm-the-boom-of-foundation-models-in-time-series-forecasting-29701e0b20b5/ Explore How Google's Latest AI Model Delivers Zero-Shot Forecasting Accuracy Using Over 307 Billion Data Points

The post TimesFM: The Boom of Foundation Models in Time Series Forecasting appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

Forecasting is one of the most important use cases across all industries. One example is the retail industry. Several planning activities require predicting capabilities, and these contribute to optimizing margin, e.g., financial, production, or workforce planning. This can impact stock management, for instance, waste and leftovers or stockouts, customer service levels, and overall decision-making.

Developing an accurate forecasting model to support the above-mentioned processes requires a deep understanding of state-of-the-art (SOTA) forecasting methodologies. At the same time, it requires specific business domain knowledge to which they are applied. These two factors have been motivating the increasing interest in pre-trained models – they reduce the need for highly custom setups. Adding that motivation to the success of large pre-trained models in the Natural Language Processing (NLP) community, a.k.a. Large Language Models (LLMs), we have a research path with many contributors.

Theoretically, we know several similarities between language and time series tasks, such as the fact that the data is sequential. On the other hand, one key difference is that time series data is continuous, whereas NLP models typically work with discrete tokens (words, phrases). In NLP, tokenization is straightforward – language can be broken into distinct units like words or subwords. In contrast, with time series data, there are no natural "breakpoints" in continuous sequences. As a result, converting continuous time-series data into meaningful, discrete tokens that a model can process while preserving temporal patterns and relationships is particularly challenging.

Despite these challenges, large tech companies and research labs have been making significant efforts to develop foundation models tailored for time series forecasting. We call foundation models any model trained on vast amounts of data – often millions, billions, or even trillions of data points – enabling them to generalize across a wide range of tasks and domains. A key feature of these models is their ability to perform zero-shot inference. This means we can generate accurate forecasts for new datasets without retraining our model, reducing time and effort when applying them to different use cases.

In this article, we provide an in-depth explanation of TimesFM, Google’s new foundation model for time series forecasting. We explore its architecture and the main components that enable the model to perform zero-shot inference. We also discuss the differences between the 4 foundation models we have researched so far: TimesFM, Chronos, MOIRAI, and TimeGPT.

Following this theoretical overview, we apply TimesFM to a specific use case and dataset. We cover the practical implementation details and provide a thorough analysis of the performance of the model. Finally, we compare TimesFM’s performance with TiDE, Chronos, and MOIRAI on a public dataset.

Figure 1: TimesFM (image by author with DALL-E)
Figure 1: TimesFM (image by author with DALL-E)

As always, the code is available on our GitHub.

TimesFM: Training Data

Like other foundation models for time series, such as Chronos or Moirai, TimesFM [1] was designed to perform zero-shot forecasts as accurately as SOTA-supervised forecasting models.

A foundation model must be trained on a large volume of temporal data to perform accurately across various use cases. The final dataset must be sufficiently heterogeneous to represent the wide variety of domains the model needs to predict accurately, considering specific trends, multiple seasonalities, and different time granularities.

To achieve this, the research team at Google created a dataset based on 4 main sources:

  • Google Trends data consists of search interest over time at hourly, daily, weekly, and monthly granularities. The data ranges from 2018 to 2019 for hourly granularity and from 2007 to 2021 for the other granularities, comprising a total of 0.5 billion time points.
  • Wiki Pageviews captures hourly views of Wikipedia pages. The data, ranging from 2021 to 2023, was aggregated to cover other granularities (daily, weekly, and monthly), resulting in a dataset of 300 billion time points.
  • Synthetic Data was generated to increase the diversity of seasonal patterns using a mixture of sine and cosine functions of different frequencies. This data also captures trends, such as linear and exponential patterns, with change points and step functions. The authors generated 3 million time series, each with 2,048 time points, adding 6 billion time points to the previous datasets.
  • Other real-world datasets were incorporated, including publicly available datasets such as the M4 dataset in five different granularities (yearly, quarterly, monthly, daily, and hourly) with 23 million time points, the LibCity dataset with 15-minute granularity and 34 million time points, Favorita Sales with daily granularity and 139 million time points, Weather with 10-minute granularity and 2 million time points, Traffic with hourly granularity and 15 million time points, and Electricity with hourly granularity and 8 million time points.

TimesFM was trained on nearly 307 billion time points and 205.3 million time series.

Model Architecture

TimesFM was designed to predict a time series’s future H time steps given a context set of C time points. Note that this model does not handle external information such as static covariates (e.g., product brand or category) or dynamic covariates (e.g., discounts or prices).

Input Layers

The input layers are responsible for converting the time series into input tokens for the transformer layers. This preprocessing is inspired by PatchTST [2], where tokenization involves breaking down the time-series data into non-overlapping patches of a predefined size (the number of patches is determined by context length/input patch length). By transforming the original series into patches, the data is converted from a continuous space to a discrete space (tokens), which offers several advantages:

  • Enabling the attention mechanism to extract local semantic meaning by examining groups of time series data instead of focusing on individual time steps.
  • Reducing the number of tokens fed to the encoder decreases the required memory and allows the model to process longer input sequences.
  • Providing the model with longer sequences gives it more information to analyze and more meaningful temporal relationships to extract, potentially resulting in more accurate forecasts.

After the data is split into patches, these tokens are processed by a residual block comprised of a Multi-Layer Perceptron (MLP) with one hidden layer and a skip connection to create several 1,280-dimensional vectors (one per token). Finally, positional encoding is added to these d-dimensional vectors to generate the final representation that will feed into the transformer layers.

For example, as shown in Figure 2, the input layer breaks the time series into 5 tokens, each of length 32. Before feeding them into the transformer layer, these tokens are processed by an MLP block to create five 1,280-dimensional vectors. Afterward, absolute positional encoding, based on sine and cosine functions from the original Transformer architecture, is added to the output of the MLP.

Figure 2: Preprocessing of input data involves converting the time points into tokens, which are then processed by an MLP block. Along with the skip connection, this creates the final d-dimensional vector, to which positional encoding is subsequently added (image by author).
Figure 2: Preprocessing of input data involves converting the time points into tokens, which are then processed by an MLP block. Along with the skip connection, this creates the final d-dimensional vector, to which positional encoding is subsequently added (image by author).

Stacked Transformer

The transformer layers are stacked sequentially, each consisting of multi-head self-attention layers, followed by a feed-forward neural network. The authors used causal attention, and Figure 3 shows how it works. Causal attention ensures that each output token (prediction) can only attend to the input tokens (data points) that occur before it in the sequence. Using this approach we are preventing leaking unobserved data (at that point in time) into the trainig process.

During training, the model forecasts the next H time points based on different input window sizes. For example, the model is simultaneously trained on the first 32 time points to predict the next H time points, the first 64 time points to predict the next H time points, and so on. This progressive approach allows the model to learn from various window sizes during training.

Figure 3: Stacked transformer layer training process where H=64 (image by author)
Figure 3: Stacked transformer layer training process where H=64 (image by author)

Output Layers

The output layer is responsible for mapping the output tokens into predictions of the part of the time series that follows the last input patch. As seen in the previous section, the model can predict an output patch with a different input size because it was trained based on a setup where the output and input patch sizes can be different. The authors found that this characteristic allows the model to predict a horizon of any length faster and more accurately.

Figure 4: TimesFM from tokens to forecast values (image by author)
Figure 4: TimesFM from tokens to forecast values (image by author)

Loss Function

The loss function used was Mean Squared Error since the authors focused on point forecasting. Nevertheless, they mentioned that it can be easily adaptable to probabilistic forecasting, and the code available to try this model already contains this feature.

TimeGPT vs Chronos vs MOIRAI vs TimesFM: The Comparison

This section presents the similarities and dissimilarities between the foundation models we have studied and published in previous articles.

These models are designed for forecasting and time-series analysis, and the table below allows us to evaluate their features based on several characteristics. One important factor in our evaluation is whether the models are open-source: Chronos, MOIRAI, and TimesFM are open for public access and modification, while TimeGPT is proprietary. In terms of the size of the training data, TimesFM stands out with a significantly larger dataset. Still, we must remember that more training data does not necessarily correlate with better model performance. Actually, we prefer smaller datasets since models trained on them can be updated more frequently.

As they all are foundation models, they all support zero-shot learning. One additional difference between the models is that Chronos and TimesFM are limited to univariate analysis. All models are equipped with probabilistic forecasting capabilities, allowing them to predict outcomes with uncertainty estimates. The number of model parameters, an indicator of complexity, is within similar ranges across the models. Interestingly, these models use far fewer parameters than their counterparts for language tasks (LLMs typically come in sizes of 7b, 70b, and even 405b parameters).

Table 1: This table compares four forecasting and time-series models - TimeGPT, Chronos, MOIRAI, and TimesFM - highlighting key differences in openness, data size, learning capabilities, multivariate support, probabilistic forecasting, and model complexity (image by author).
Table 1: This table compares four forecasting and time-series models – TimeGPT, Chronos, MOIRAI, and TimesFM – highlighting key differences in openness, data size, learning capabilities, multivariate support, probabilistic forecasting, and model complexity (image by author).

TimesFM vs. Chronos vs. MOIRAI: a comparison in a public dataset

In this section, we will use TimesFM to forecast tourism visitors to Australia using a real-world dataset that is publicly available under the CC-BY-4.0 license. Subsequently, we will compare the forecasting performance of TimesFM against Chronos (its large version), Moirai, and TiDE. To access the code that generates the forecasts with Chronos and TiDE, please refer to this article, and for Moirai, please refer to this one.

Although TimesFM cannot use external information, other models like TiDE and Moirain can. Therefore, we enhanced the dataset with economic covariates (e.g., CPI, Inflation Rate, GDP) extracted from Trading Economics. It uses economic indicators based on official sources. We also perform some preprocessing to further increase the usability of the dataset. The final structure of the dataset is the following:

  • Unique ID: A combination of encoded names for States, Zones, Regions within Australia, and the purpose of the visit (e.g., business, holiday, visiting, other).
  • Time: Represents the time dimension of the dataset.
  • Target: The target variable to predict, in this case, the number of visits.
  • Dynamic Covariates: Economic indicators such as CPI, Inflation Rate, and GDP that vary over time.
  • Static Covariates (Static_1 to Static_4): Extracted from the unique ID, these provide additional information for analysis, including geographic and purpose-of-visit details.

We stored the new version of the dataset here so that our experiments can be easily reproduced.

We start by importing the libraries and setting global variables. We set the date column, target column, dynamic covariates, the frequency of our series and the forecast horizon.

Python">import pandas as pd
import numpy as np
import utils
import timesfm
from datasets import load_dataset

import warnings
warnings.filterwarnings("ignore")
TIME_COL = "Date"
TARGET = "visits"
FORECAST_HORIZON = 8 # months
FREQ = "MS"

After that, we load our dataset, which already has the exogenous features mentioned in the dataset description.

# load data
df = pd.DataFrame(load_dataset("zaai-ai/time_series_datasets", data_files={'train': 'data.csv'})['train']).drop(columns=['Unnamed: 0'])
df[TIME_COL] = pd.to_datetime(df[TIME_COL])

print(f"Distinct number of time series: {len(df['unique_id'].unique())}")
df.head()

Distinct number of time series: 304

Once the dataset is loaded, we can split the data between train and test (we decided to use the last 8 months of data for our test set).

# 8 months to test
train = df[df[TIME_COL] <= (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]
test = df[df[TIME_COL] > (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]

print(f"Months for training: {len(train[TIME_COL].unique())} from {min(train[TIME_COL]).date()} to {max(train[TIME_COL]).date()}")
print(f"Months for testing: {len(test[TIME_COL].unique())} from {min(test[TIME_COL]).date()} to {max(test[TIME_COL]).date()}")

Months for training: 220 from 1998–01–01 to 2016–04–01 Months for testing: 8 from 2016–05–01 to 2016–12–01

With the dataset split, we can forecast using TimesFM. For that, we need to load the model from Hugging Face and set the following parameters:

  • _horizonlen – which is the forecast horizon we defined earlier.
  • _contextlen – how many items in the sequence the model can attend to (any positive integer up to 512).
  • The remaining hyper parameters must be fixed because the 200M model expects a input patch of 32, a output patch of 128, 20 layers and a model dimension of 1200.
tfm = timesfm.TimesFm(
    context_len=512,
    horizon_len=FORECAST_HORIZON,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280,
    backend="cpu",
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

forecast_df = tfm.forecast_on_df(
    inputs=train.loc[:,[TIME_COL, 'unique_id', TARGET]].rename(columns={TIME_COL:'ds'}),
    freq=FREQ,
    value_name=TARGET,
    num_jobs=-1,
)

Once the forecast has finished, we can plot the ground truth values and the predictions, and we will check the top 3 series with more visits.

Figure 5: TimesFM forecast vs. actual values (image by author)
Figure 5: TimesFM forecast vs. actual values (image by author)

Figure 5 shows that TimesFM struggled to forecast the first two series, especially the huge drop in the first series. The third series looks much better, and we have a nearly perfect overlap between the actuals and the forecast. It is also worth mentioning that the actuals are always within the prediction interval.

Having obtained the forecast from TimesFM, we can now load the forecast generated by TiDE, Chronos and Moirai and compute forecasting performance metrics for comparison. For better interpretability, we have used the Mean Absolute Percentage Error (MAPE) as our comparison metric.

Figure 6: MAPE comparison between TimesFM, MOIRAI, Chronos and TiDE (image by author)
Figure 6: MAPE comparison between TimesFM, MOIRAI, Chronos and TiDE (image by author)

As shown in Figure 6, TimesFM had the lowest MAPE in 4 out of 8 months, demonstrating a performance similar to Chronos, the most accurate foundation model we have researched so far.

We also tested all models on 4 large private datasets, and the results were consistent with what we are reporting in this article. Testing with private datasets is very important for foundational models since there is a risk that the specific public dataset used in the evaluation could have been part of the original training data of the models. If that is the case, we would introduce bias in the evaluation process.

Conclusion

In this article, we explored TimesFM, a foundation model for time-series forecasting developed by Google. Our analysis shows that TimesFM outperforms TiDE and performs similarly to Chronos on this public dataset. In the case of TiDE, the model had access to external information and was specifically trained on this dataset. Still, TimesFM outperformed it in 6 out of 8 months. It’s worth mentioning that we did not perform any hyperparameter tuning on TiDE, which could lead to improvements in its performance. Regarding Chronos, since both models had access to the same information, we expected a similar performance, which was the case. Still, Chronos had a slightly lower MAPE than TimesFM in 4 out of 8 months.

The field of foundation models in time series forecasting seems to be converging, similar to what we have observed with large models in language tasks. As datasets grow larger and models become more complex, the marginal improvements in performance start to diminish. There are two key aspects we are particularly interested in: 1) how to effectively fine-tune these models in the context of time series forecasting, as it has proven successful for language tasks, and 2) the need for a better evaluation framework for these types of models.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Abhimanyu Das, Weihao Kong, Rajat Sen, Yichen Zhou. A decoder-only foundation model for time-series forecasting. arXiv:2310.10688, 2023.

[2] Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. A Time Series is Worth 64 Words: Long-term Forecasting with Transformers. arXiv:2211.14730, 2022.

The post TimesFM: The Boom of Foundation Models in Time Series Forecasting appeared first on Towards Data Science.

]]>
The Evolution of Llama: From Llama 1 to Llama 3.1 https://towardsdatascience.com/the-evolution-of-llama-from-llama-1-to-llama-3-1-13c4ebe96258/ Fri, 06 Sep 2024 04:44:26 +0000 https://towardsdatascience.com/the-evolution-of-llama-from-llama-1-to-llama-3-1-13c4ebe96258/ A Comprehensive Guide to the Advancements and Innovations in the Family of Llama Models from Meta AI

The post The Evolution of Llama: From Llama 1 to Llama 3.1 appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

Meta has released three major versions of its large language model (LLM), Llama, along with a minor (if we can call it that) update (version 3.1). The initial release of Llama in early 2023 marked a significant step forward for the open-source community in natural language processing (NLP). Meta has consistently contributed to this community by sharing its latest LLM versions.

To ensure correctness, we should distinguish between open and open-source LLMs. Open-source software traditionally makes its source code available under specific public use and modification licenses. In the context of LLMs, open LLMs typically disclose model weights and initial code. At the same time, open-source LLMs would also share the entire training process, including training data, with a permissive license. Most models today, including Meta’s Llama, fall under the open LLMs category since they do not release the datasets used for training.

Llama has undergone three key architectural iterations. Version 1 introduced several enhancements to the original Transformer architecture. Version 2 implemented Grouped-Query Attention (GQA) in larger models. Version 3 extended GQA to smaller models, introduced a more efficient tokenizer and expanded the vocabulary size. Version 3.1 did not change the core architecture. The bigger changes were the cleaning process of the training data, the longer context length, and the additional supported languages.

This article explores Llama’s architectural evolution, highlighting key advancements and their implications for the future of LLMs. It concludes with a practical experiment comparing Llama 2 and Llama 3, evaluating their performance on specific tasks using metrics like inference speed, answer length, and the Relative Answer Quality (RAQ) framework [1]. RAQ provides an objective ranking framework to test LLMs’ answers based on their accuracy relative to the ground truth, making it particularly useful for evaluating specific use cases.

Figure 1: Llama family (image by author with DALL-E)
Figure 1: Llama family (image by author with DALL-E)

As always, the code is available on our GitHub.

Llama: A Family of Open LLMs

Llama 1: The First Model

The first model of this family, Llama 1 [2], was built based on the encoder-decoder transformer architecture developed by Vaswani et al. in 2017 [3]. It was (and still is) one of the major breakthroughs in the area of NLP and the backbone architecture of all LLM models.

Llama 1 used it in its core architecture, combining it with several improvements, such as:

Pre-normalization

Inspired by the improvement of training stability implemented in the architecture of GPT3 [4], Llama 1 also normalizes the input of each transformer sub-layer rather than only the output, as shown in Figure 2.

Figure 2: Differences between original and Llama 1 architectures where each input in the sub-layer transformer is normalized (image by author)
Figure 2: Differences between original and Llama 1 architectures where each input in the sub-layer transformer is normalized (image by author)

Additionally, they replace the traditional LayerNorm function with RMSNorm [5], which is more computationally efficient while preserving training stability and increasing model convergence.

RMSNorm achieves better efficiency because its authors demonstrated that the benefits of LayerNorm arise from rescaling invariance rather than recentering invariance. This insight allowed them to remove the mean calculation from the normalization process, making it simpler, just as effective, and significantly more efficient.

Figure 3: Equation differences between LayerNorm and RMSNorm (image by author)
Figure 3: Equation differences between LayerNorm and RMSNorm (image by author)

SwiGLU activation function

Regarding the activation function, the authors decided to replace the well-known ReLU with the SwiGLU [6] function, which has been shown to improve model performance. The main difference between both functions is that:

  • ReLU transforms all negative values to 0 and returns the same value if it is positive.
  • SwiGLU has a trainable parameter β that controls the degree of interpolation. As the β increases, the behavior becomes more similar to ReLU, as shown in Figure 4.
Figure 4: Behaviour differences between ReLU and SwiGLU, where when β=100, ReLU and SwiGLU overlap.
Figure 4: Behaviour differences between ReLU and SwiGLU, where when β=100, ReLU and SwiGLU overlap.

Rotary Positional Embeddings

Positional Embeddings are crucial for LLMs because the Transformer architecture is order invariant. This means it would represent two sentences in the same way, even if they use the same words in a different order and with different meanings. For example, the following sentences would have the same meaning for a Transformer if positional embeddings were not applied:

Sentence 1: Llama 2 is better than Llama 1 Sentence 2: Llama 1 is better than Llama 2

The original paper [3] implemented Absolute Positional Embeddings represented through two sinusoidal functions (sine and cosine). Each position in the sequence has a unique positional embedding that is summed up in the word embedding, ensuring that two sentences with the same words do not mean the same thing.

Let’s consider that the words in a sentence are encoded with a 1d-vector rather than a multiple-dimensional vector for explainability purposes. As shown in Figure 5, the words "1" and "2" are represented with the same value in the word embedding for both sentences. Still, after adding the positional encoding, they are represented with different values (0.88 → 1.04 and 0.26 → 0.1), respectively.

Figure 5: Absolute Positional Embedding (image by author)
Figure 5: Absolute Positional Embedding (image by author)

Although it already solved the problem of Transformers being order invariant, it still creates positional embeddings independent of each other. The consequence is that the proximity of two positions is not modeled. This means that from the model point of view, there are no differences in the correlation between positions 1 and 2 and positions 1 and 500. We know this is not the case because the similarity between the word in positions 1 and 2, in theory, must be higher than the similarity between the word in positions 1 and 500.

Rotary positional embeddings [7] (RoPE) can tackle this problem and model the relative position of the words by representing each position in the sequence through a rotation of the word embedding. Let’s use the same example as before: ‘Llama 2 is better than Llama 1’, and let’s consider that the word embedding now has 2 dimensions. The word _better w_ill be represented by a 2D rotated vector of the original 2D vector based on its position m (4) and a constant θ, as shown in Figure 6.

Figure 6: Rotary positional embedding where the original vector is transformed into a new vector based on its position (m=4) and the constant θ (image by author)
Figure 6: Rotary positional embedding where the original vector is transformed into a new vector based on its position (m=4) and the constant θ (image by author)

This approach allows us to preserve the relative distance between words because even if we add more words to the original sentence, the similarity between both vectors remains the same. Imagine this example where we add two words to the sentence, ‘The LLM Llama 2 is better than Llama 1’, the position better and than have different positions (4 & 5 → 6 & 7), but since the amount of rotation is the same, the similarity between both vectors also remains the same (the dot product between vectors on the left image is the same as in the right image).

Figure 7: Ability of rotational embeddings to preserve relative distance between tokens (image by author)
Figure 7: Ability of rotational embeddings to preserve relative distance between tokens (image by author)

Llama 2: The Evolved Form of Llama 1

Llama 2 [8] kept all the architecture changes made to the original Transformer architecture on Llama 1. Additionally, it increased the context length from 2048 to 4096 and replaced Multi-Head Attention (MHA) [9] with Grouped-Query Attention (GQA) [10] for the larger models (34B and 70B).

MHA is a bottleneck for Transformers due to the high demand for memory resources to load all attention queries, keys, and values heads. There are two different approaches to overcome this problem:

  1. Multi-Query Attention [9] (MQA) significantly decreases the memory needed by just using a single key and value but multiple query heads in the attention layer. However, this solution can lead to quality degradation and training instability, making other open LLMs, such as T5, not choose this approach.
  2. GQA is placed between MHA and MQA by dividing query values into G groups (GQA-G) that share a single key and value head. A GQA-1 means that all queries are aggregated in one group and, therefore, the same as MQA, while a GQA-H (H = number of Heads) is the equivalent of MHA, where each query is treated as a group. This approach reduces the number of keys and values heads into a single key and value per query group. It reduces the size of the key value cached and, hence, the amount of data needed to be loaded. This more moderate reduction than MQA accelerates the inference speed and reduces the memory requirements during decoding with a quality closer to MHA and nearly the same speed as MQA.
Figure 8: Overview of the different MHA, GQA, and MQA approaches (image by author)
Figure 8: Overview of the different MHA, GQA, and MQA approaches (image by author)

Llama 3: Size and Tokenization

Llama 3 [11] increased the context length from 4096 to 8192 and extended the GQA to the smaller model (8B). Besides that, the authors replaced the tokenizer Sentence Piece [12] with the TikToken [13] used in OpenAI models. It significantly improved model performance since it has a vocabulary size of 128k tokens instead of 32k.

The main difference between both tokenizers is that TikToken ignores byte pair encoding (BPE) [14] merging rules when an input token is part of the vocabulary. Imagine that generating is part of the vocabulary. It will be returned as a token instead of two tokens coming from the split of the word into the smallest units, generating and ing .

Llama 3.1: The Latest (and Biggest) Release

Released in July 2024, Llama 3.1 introduces a significant leap in context length (128K tokens) and support for eight additional languages. One of the key pieces of the release was the larger model Llama 3.1 405B. Until then, open LLMs were generally released in sizes below 100B.

Lastly, a summary of the evolution can be seen in the table below:

Table 1: Comparing Llama evolution regarding context length, vocabulary size, training data size, and languages they support.
Table 1: Comparing Llama evolution regarding context length, vocabulary size, training data size, and languages they support.

Llama 2 vs Llama 3: Model Comparison

In this section, we will apply Llama2 and Llama 3 to a question-answering dataset under the License CC BY-SA 4.0 called SQuAD (it can be found here). This reading comprehension dataset consists of questions about a set of Wikipedia articles. Based on context, the model should be able to retrieve the correct answer to a question. The three more important fields for our use case are:

  • question – the question a model should answer.
  • context – background information from which the model needs to extract the answer.
  • answers – the text answer to the question.

The evaluation process will consist of three quantitative metrics: one to assess the inference speed, the second to determine the answer length, and the third to evaluate accuracy. For the latter, we use RAQ [1]. RAQ ranks the answers of Llama 2 and Llama 3 using an independent LLM based on how close they are to the ground truth answer.

We start by downloading both models in a .gguf format to be able to run them in CPU, and we place them under the folder model/.

We used the instruct version of each model with a 4-bit quantization:

After that, we import all the libraries and our own generator that receives the model we want to use as an argument.

import os
import seaborn as sns
import matplotlib.pyplot as plt
import scikit_posthocs as sp
import pandas as pd
import utils

from dotenv import load_dotenv
from generator.generator import Generator
from datasets import load_dataset
Llama2 = Generator(model='Llama2')
Llama3 = Generator(model='Llama3')
load_dotenv('env/var.env')

This class is responsible for importing the model parameters defined in a config.yaml file with the following characteristics: context_length of 1024, temperature of 0.7, and max_tokens of 2000.

generator:
  Llama2:
    llm_path: "model/Llama-2-7b-32k-instruct.Q4_K_M.gguf"
  Llama3:
    llm_path: "model/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf"
  context_length: 1024
  temperature: 0.7
  max_tokens: 2000

Besides that, it also creates the Prompt Template, which uses LangChain. It formats the query and the context based on the template before passing it to the LLM to get a response.

from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import LlamaCpp

from base.config import Config
class Generator(Config):
    """Generator, aka LLM, to provide an answer based on some question and context"""
    def __init__(self, model) -> None:
        super().__init__()
    # template
        self.template = """
            Use the following pieces of context to answer the question at the end.
            {context}
            Question: {question}
            Answer:
        """
   # load llm from local file
        self.llm = LlamaCpp(
            model_path=f"{self.parent_path}/{self.config['generator'][model]['llm_path']}",
            n_ctx=self.config["generator"]["context_length"],
            temperature=self.config["generator"]["temperature"],
        )
        # create prompt template
        self.prompt = PromptTemplate(
            template=self.template, input_variables=["context", "question"]
        )
    def get_answer(self, context: str, question: str) -> str:
        """
        Get the answer from llm based on context and user's question
        Args:
            context: most similar document retrieved
            question: user's question
        Returns:
            llm answer
        """
        query_llm = LLMChain(
            llm=self.llm,
            prompt=self.prompt,
            llm_kwargs={"max_tokens": self.config["generator"]["max_tokens"]},
        )
        return query_llm.run({"context": context, "question": question})

With the LLMs loaded, we fetch the SQuAD dataset from HuggingFace and shuffle it to ensure enough variety regarding the question theme.

squad = load_dataset("squad", split="train")
squad = squad.shuffle()

Now, we can loop over 30 questions and contexts and record the above metrics.

for i in range(30):
    context = squad[i]['context']
    query = squad[i]['question']
    answer = squad[i]['answers']['text'][0]

    # Llama 2
    answer_Llama2, words_per_second, words = utils.get_llm_response(Llama2, context, query)
    Llama2_metrics["words_per_second"].append(words_per_second)
    Llama2_metrics["words"].append(words)

    # Llama 3
    answer_Llama3, words_per_second, words = utils.get_llm_response(Llama3, context, query)
    Llama3_metrics["words_per_second"].append(words_per_second)
    Llama3_metrics["words"].append(words)

    # RAQ
    llm_answers_dict = {'Llama2': answer_Llama2, 'Llama3': answer_Llama3}
    rank = utils.get_gpt_rank(answer, llm_answers_dict, os.getenv("OPENAI_API_KEY"))
    Llama2_metrics["rank"].append(rank.index('1')+1)
    Llama3_metrics["rank"].append(rank.index('2')+1)

The function get_llm_response receives the loaded LLM, the context, and the question and returns the LLM answer and quantitative metrics.

def get_llm_response(model: Generator, context: str, query: str) -> Tuple[str, int, int]:
    """
    Generates an answer from a given LLM based on context and query
    returns the answer and the number of words per second and the total number of words
    Args:
        model: LLM
        context: context data
        query: question
    Returns:
        answer, words_per_second, words
    """
    init_time = time.time()
    answer_llm = model.get_answer(context, query)
    total_time = time.time()-init_time
    words_per_second = len(re.sub("[^a-zA-Z']+", ' ', answer_llm).split())/total_time
    words = len(re.sub("[^a-zA-Z']+", ' ', answer_llm).split())
    return answer_llm, words_per_second, words

After completing the evaluation, we plotted the metrics and observed that Llama 3 is faster than Llama 2, generating approximately 1.1 words per second on average, compared to Llama 2’s 0.25 words per second. Regarding answer length, Llama 3 produces longer answers, with an average of 70 words, while Llama 2 7B generates responses averaging 15 words. Finally, according to the Relative Answer Quality (RAQ) framework, Llama 3 achieved the best average rank, approximately 1.25, while Llama 2 performed worse, with an average rank of around 1.8.

Figure 9: Model Comparison (image by author)
Figure 9: Model Comparison (image by author)

Table 2 presents the results of the Dunn post-hoc test, which compares the performance of different language models. Each cell indicates whether the difference in performance between the respective models is statistically significant at a 5% significance level. "Significant" denotes a statistically significant difference (p-value ≤ 0.05), while "Not Significant" indicates no statistically significant difference (p-value > 0.05). According to the Dunn test results, Llama 3’s performance is significantly different from that of Llama 2.

p_values = sp.posthoc_dunn([Llama2_metrics['rank'], Llama3_metrics['rank']], p_adjust='holm')
p_values > 0.05
Table 2: Significance of differences in ranks among the set of LLMs.
Table 2: Significance of differences in ranks among the set of LLMs.

Finally, in qualitative terms, we analyze the answers of both models to one of the questions. They both managed to correctly answer the question, ‘What percentage of improvement over energy code requirements will be the goal of all new construction and renovations?’ based on the following context:

Context: ‘Northwestern requires that all new buildings be LEED-certified. Silverman Hall on the Evanston campus was awarded Gold LEED Certification in 2010; Wieboldt Hall on the Chicago campus was awarded Gold LEED Certification in 2007, and the Ford Motor Company Engineering Design Center on the Evanston campus was awarded Silver LEED Certification in 2006. New construction and renovation projects will be designed to provide at least a 20% improvement over energy code requirements where technically feasible. The university also released at the beginning of the 2008–09 academic year the Evanston Campus Framework Plan, which outlines plans for future development of the Evanston Campus. The plan not only emphasizes the sustainable construction of buildings, but also discusses improving transportation by optimizing pedestrian and bicycle access. Northwestern has had a comprehensive recycling program in place since 1990. Annually more than 1,500 tons are recycled at Northwestern, which represents 30% of the waste produced on campus. Additionally, all landscape waste at the university is composted.’

Nevertheless, as shown below, Llama 2 contradicts itself by starting to say that the answer is not in the context and ending with a passage from the context with the answer. Conversely, Llama 3 identifies the answer in the context and provides a concise and correct answer.

Llama 2 answer: ‘It is not mentioned in the given passage what percentage improvement over energy code requirements will be the goal of all new construction and renovations. However, it is mentioned that "New construction and renovation projects will be designed to provide at least a 20% improvement over energy code requirements where technically feasible.’

Llama 3 answer: ‘According to the provided context, all new construction and renovation projects will aim to provide at least a 20% improvement over energy code requirements where technically feasible.’

Conclusions

The evolution of the Llama models shows a steady trajectory of innovation aimed at improving efficiency, performance, and versatility in LLMs. Starting with Llama 1, which introduced foundational changes like input normalization with RMSNorm and smoother activation functions, each subsequent version has built upon these advancements.

Llama 2 refined this approach by optimizing inference efficiency with GQA, setting the stage for even greater improvements in Llama 3. Llama 3 expanded on these capabilities by extending GQA to smaller models, adopting a more efficient tokenizer with a much larger vocabulary, doubling context lengths, and significantly increasing training data.

The recent release of Llama 3.1 marks a new milestone. It further extends context lengths to 128K tokens, adds support for more languages, and introduces the largest open model so far – the 405B model.

These enhancements across the Llama versions have led to models with superior adaptability across diverse applications. So far, the Llama models have been downloaded over 300 million times, and the integration of Llama models into thousands of products that leverage private LLM capabilities is only starting. Ironically, Llama now leads the way in advancing open AI, taking up the position once held by OpenAI when it was, in fact, more open.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Luís Roque, Rafael Guedes. "Research to Production: Relative Answer Quality (RAQ) and NVIDIA NIM." Towards Data Science. Medium, 2024.

[2] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. "Llama: Open and Efficient Foundation Language Models." arXiv preprint arXiv:2302.13971, 2023.

[3] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. "Attention Is All You Need." arXiv preprint arXiv:1706.03762, 2017.

[4] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, Dario Amodei. "Language Models are Few-Shot Learners." arXiv preprint arXiv:2005.14165, 2020.

[5] Biao Zhang, Rico Sennrich. "Root Mean Square Layer Normalization." arXiv preprint arXiv:1910.07467, 2019.

[6] Noam Shazeer. "GLU Variants Improve Transformer." arXiv preprint arXiv:2002.05202, 2020.

[7] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv preprint arXiv:2104.09864, 2021.

[8] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, Dan Bikel, Lukas Blecher, Cristian Canton Ferrer, Moya Chen, Guillem Cucurull, David Esiobu, Jude Fernandes, Jeremy Fu, Wenyin Fu, Brian Fuller, Cynthia Gao, Vedanuj Goswami, Naman Goyal, Anthony Hartshorn, Saghar Hosseini, Rui Hou, Hakan Inan, Marcin Kardas, Viktor Kerkez, Madian Khabsa, Isabel Kloumann, Artem Korenev, Punit Singh Koura, Marie-Anne Lachaux, Thibaut Lavril, Jenya Lee, Diana Liskovich, Yinghai Lu, Yuning Mao, Xavier Martinet, Todor Mihaylov, Pushkar Mishra, Igor Molybog, Yixin Nie, Andrew Poulton, Jeremy Reizenstein, Rashi Rungta, Kalyan Saladi, Alan Schelten, Ruan Silva, Eric Michael Smith, Ranjan Subramanian, Xiaoqing Ellen Tan, Binh Tang, Ross Taylor, Adina Williams, Jian Xiang Kuan, Puxin Xu, Zheng Yan, Iliyan Zarov, Yuchen Zhang, Angela Fan, Melanie Kambadur, Sharan Narang, Aurelien Rodriguez, Robert Stojnic, Sergey Edunov, Thomas Scialom. "Llama 2: Open Foundation and Fine-Tuned Chat Models." arXiv preprint arXiv:2307.09288, 2023.

[9] Noam Shazeer. "Fast Transformer Decoding: One Write-Head is All You Need." arXiv preprint arXiv:1911.02150, 2019.

[10] Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." arXiv preprint arXiv:2305.13245, 2023.

[11] Meta AI. "Introducing Llama 3." Meta AI Blog, 2024.

[12] Taku Kudo, John Richardson. "SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing." arXiv preprint arXiv:1808.06226, 2018.

[13] OpenAI. "TikToken." GitHub.

[14] Rico Sennrich, Barry Haddow, Alexandra Birch. "Neural Machine Translation of Rare Words with Subword Units." arXiv preprint arXiv:1508.07909, 2015.

The post The Evolution of Llama: From Llama 1 to Llama 3.1 appeared first on Towards Data Science.

]]>
Gemma vs. Llama vs. Mistral: Exploring Smaller AI Models https://towardsdatascience.com/gemma-vs-llama-vs-mistral-exploring-smaller-ai-models-672a95f4b9b7/ Tue, 06 Aug 2024 12:36:11 +0000 https://towardsdatascience.com/gemma-vs-llama-vs-mistral-exploring-smaller-ai-models-672a95f4b9b7/ A Comparative Study of Small-Scale Language Models: Evaluating Gemma, Llama 3, and Mistral in Reading Comprehension Tasks

The post Gemma vs. Llama vs. Mistral: Exploring Smaller AI Models appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

Large Language Models (LLMs) have been evolving rapidly. Each month, new models are developed to surpass the current top scorers in the market. This healthy competition is beneficial for creating new approaches that increase quality and speed. Additionally, companies are focused on developing smaller models to make them accessible to individuals or organizations without powerful computing resources.

Just a few weeks ago, Apple introduced Apple Intelligence at their Worldwide Developers Conference. This is a set of multiple generative models fine-tuned to help users write and refine text, prioritize and summarize notifications, create images, and take in-app actions. The only foundational and proprietary model developed by Apple in that suite was introduced at the same conference. It is a small model designed to run on-device, where the hardware becomes a significant constraint. In Apple’s case, the model is closed-source. What we know is that it is a ~3 billion parameter model on par with the 7b versions of Gemma, Mistral, and Llama 3 (according to the results shared by Apple).

While Apple’s new model is exciting, we cannot test or reuse it. Hence, we are more interested in publicly available models since developers and companies can use them to build new products and services. It’s important to distinguish between open LLMs and open-source LLMs. Historically, open-source software refers to computer programs released under specific licenses, making the source code available for public use or modification. With LLMs, there is additional complexity, including the training data and model weights. Therefore, open LLMs typically disclose the model weights and initial code. An open-source LLM, on the other hand, would share every step of the training process, including the training data, along with a permissive license. It should allow others to use, build upon, and further distribute the model. Nevertheless, most of the models released these days fall under the category of open LLMs since, for example, they do not publish the datasets used for training purposes. This is the case for Gemma by Google, Mistral by Mistral AI, and Llama by Meta.

In this article, we analyze Gemma more closely to understand what differentiates these smaller models. Gemma is one of the most recently developed models released by Google. It comes in two versions, 2 billion and 7 billion parameters. Thus, it can be used on edge devices, and it aims to outperform state-of-the-art models like Mistral and Llama 3.

Additionally, we apply Gemma, Llama 3, and Mistral to a reading comprehension dataset called SQuAD. The LLMs are tasked with answering specific questions based on given contexts. We assess their performance using quantitative metrics such as inference speed and average answer length. We also use the Relative Answer Quality (RAQ) framework proposed by [1]. RAQ bridges the gap in evaluating LLMs for specific use cases by ranking answers based on their accuracy relative to the ground truth, providing a more nuanced and practical assessment of model performance.

Figure 1: Gemma vs. Llama vs. Mistral (image by author with DALL-E)
Figure 1: Gemma vs. Llama vs. Mistral (image by author with DALL-E)

As always, the code is available on our GitHub.

Gemma: the base text model of Gemini

Google released Gemma [2], an open LLM developed based on its powerful, closed-source model, Gemini [3].

Google released pre-trained and fine-tuned checkpoints to promote further research of the model in new use cases, making it available in two different sizes:

  • The 7B model is to be deployed and further developed on GPU or TPU.
  • The 2B model is designed to address computation constraints and allow its use on CPU or on-device applications.

Gemma promises to achieve state-of-the-art performance compared to other open models with roughly the same scale, like Llama 3 7B or Mistral 7B. This should happen across different domains, such as question answering, common sense reasoning, mathematics/science, and coding.

Gemma: what is new?

Gemma’s architecture is based on a decoder-only [4] Transformer [5] with a context length of 8192 tokens. Let’s explore the approach taken to make it smaller.

Multi-Query Attention

The 2B model utilizes Multi-Query Attention (MQA) to significantly reduce the memory resources required to load all query, key, and value heads, as opposed to the Multi-Head Attention (MHA) approach. MQA achieves this memory reduction by using a single key and value for multiple query heads in the attention layer, as illustrated in Figure 3.

While this approach allows Gemma 2B to be deployed on devices with smaller memory resources, it can lead to quality degradation and training instability. Therefore, the authors opted to use MHA in the 7B version, following the same method as Llama 3.

Figure 2: Differences between MHA and MQA (image by author)
Figure 2: Differences between MHA and MQA (image by author)

RoPE Embeddings

Transformers require Positional Embeddings because they are inherently order-invariant. This means that without positional information, a Transformer would represent sentences with the same words but different orders and meanings in the same way. For example:

Sentence 1: Gemma is better than Llama 3

Sentence 2: Llama 3 is better than Gemma

Positional information is typically represented using two sinusoidal functions (sine and cosine). Then, a unique positional embedding is created for each position in the sequence based on its position, the token embedding dimension, and the model dimension.

Therefore, adding positional information is crucial for enabling Transformers to process text properly. The original Transformer architecture used Absolute Positional Embeddings, where a vector representation of a position is added to the vector representation of a token.

Figure 3: Absolute Positional Embedding (image by author)
Figure 3: Absolute Positional Embedding (image by author)
Equations 1 and 2: Sine and Cosine functions that generate positional embeddings where pos is the position, i is the dimension of the positional encoding, and model is the total dimension of the vector.
Equations 1 and 2: Sine and Cosine functions that generate positional embeddings where pos is the position, i is the dimension of the positional encoding, and model is the total dimension of the vector.

The challenge with Absolute Positional Embeddings is that they do not explicitly encode the relative distances between tokens. While they capture positional information using sine and cosine functions, these embeddings are calculated independently for each position. This means that the model does not inherently understand the proximity or relational significance of different positions within a sequence. For instance, the embeddings for tokens at positions 1 and 2 may appear similar due to the nature of the sinusoidal functions, but the model doesn’t explicitly recognize that these positions are adjacent.

Because of this, the model might not differentiate the relationship between tokens at positions 1 and 2 from the relationship between tokens at positions 1 and 500. In natural language processing, words that are close together in a sentence often share more context or have a stronger semantic or syntactic relationship than words that are far apart. Absolute Positional Embeddings might not completely capture this nuance. It can lead to limitations in capturing long-range dependencies or the hierarchical structure of language.

Figure 4: Relative Position should be included in position encoding (image by author)
Figure 4: Relative Position should be included in position encoding (image by author)

Rotary Positional Embeddings (RoPE) [6] address this problem by modeling the relative positions of tokens through a rotation of the token embeddings in the sequence.

Let’s use the previous example, ‘Gemma is better than Llama,’ and consider each word as a token represented by a 2D vector. The word better will be represented by a 2D vector rotated from the original vector based on its position m and a constant angle θ, as shown in Figure 5.

Figure 5: Rotary positional embedding where the original vector is transformed into a new vector based on its position (m=3) and the constant θ (image by author)
Figure 5: Rotary positional embedding where the original vector is transformed into a new vector based on its position (m=3) and the constant θ (image by author)

This approach preserves the relative distance between tokens because the rotational transformation maintains the same similarity between vectors, regardless of their position in the sequence. For instance, if we add two words to the original sentence, making it ‘The LLM Gemma is better than Llama‘, the positions of better and than change from (3 & 4) to (5 & 6). However, since the rotation angle remains consistent, the similarity between these vectors (as measured by the dot product) stays the same, ensuring consistent relative positioning.

Figure 6: Ability of rotational embeddings to preserve relative distance between tokens (image by author)
Figure 6: Ability of rotational embeddings to preserve relative distance between tokens (image by author)

GeGLU Activation Function

The authors replaced the traditional ReLU activation function with a variant of a Gated Linear Unit (GLU) called GeGLU, as another study [7] has shown that it improves the quality of the output generated by the LLM.

There are two differences between the ReLU and GeGLU:

  1. Activation function – GeGLU uses a Gaussian Error Linear Unit (GELU) [8] function that differs from ReLU in the sense that it multiplies the neuron input x by a cumulative distribution function of the normal distribution. In this case, the probability of x being dropped is higher as x decreases.
Figure 7: Difference between GELU and ReLU (source)
Figure 7: Difference between GELU and ReLU (source)
  1. Sigmoid Activated – The simple ReLU or GELU activation function is applied between the hidden representation x and _ two linear transformations represented by two matrices (W1_ and W2). The Gating variant in GeGLU applies a gating mechanism (sigmoid) to one of the components, as shown in Equation 3.
Equation 3: Difference between GELU and Gated GELU
Equation 3: Difference between GELU and Gated GELU

Normalizer Location

The last modification to the original Transformer architecture is shown in Figure 8. The authors normalize both the input and output of each transformer sub-layer to improve training stability, contrary to the original paper, which only normalized the output.

Figure 8: Normalization added to the inputs of the transformer in Gemma architecture (image by author)
Figure 8: Normalization added to the inputs of the transformer in Gemma architecture (image by author)

They also replaced the traditional LayerNorm function with RMSNorm [8]. It is computationally more efficient while maintaining training stability improvements and helping the model converge.

RMSNorm achieves better efficiency because its authors demonstrated that the benefits of LayerNorm come from re-scaling invariance rather than re-centering invariance. Re-scaling invariance means that the output of the normalization process remains unchanged if a constant factor scales the input. In other words, multiplying all the inputs by a constant does not affect the normalized outputs. Re-centering invariance means that the output of the normalization process remains unchanged if a constant value is added to all the inputs. This implies that shifting all inputs by a constant amount does not affect the normalized outputs. This finding allows the removal of the overhead of computing the mean (you only need to compute the standard deviation), making RMSNorm simpler and more efficient.

Figure 9: Equation differences between LayerNorm and RMSNorm (image by author)
Figure 9: Equation differences between LayerNorm and RMSNorm (image by author)

Mistral AI vs. Meta vs. Google: a comparison between Gemma 7B vs. Llama 3 7B vs. Mistral 7B

In this section, we put 3 LLMs – Gemma 7B, Mistral 7B, and Llama 3 7B – to a test. We use a question-answering dataset under the License CC BY-SA 4.0 called SQuAD (it can be found here). This dataset is a reading comprehension dataset consisting of questions about a set of Wikipedia articles. Based on context, the models should be able to retrieve the correct answer to a question. The 3 more important fields for our use case are:

  • question – the question a model should answer.
  • context – background information from which the model needs to extract the answer.
  • answers – the text answer to the question.

The evaluation process will consist of two quantitative metrics:

  • words per second – assesses the inference speed.
  • words – assesses the length of the answer.

To assess the accuracy of the models in our use case, we use RAQ [1]. RAQ ranks the answers of all LLMs using an independent LLM based on how close they are to the ground truth answer.

We start by downloading the models in a .gguf format to be able to run them in CPU, and we place them under the folder model/.

We used the instruct version of each model with a 4-bit quantization:

After that, we import all the libraries and our generator that receives the model we want to use as an argument.

import os

import seaborn as sns
import matplotlib.pyplot as plt
import scikit_posthocs as sp
import pandas as pd
import utils

from dotenv import load_dotenv
from datasets import load_dataset
from generator.generator import Generator

llama = Generator(model='llama')
mistral = Generator(model='mistral')
gemma = Generator(model='gemma')
load_dotenv('env/var.env')

This class is responsible for importing the model parameters defined in a config.yaml file with the following characteristics: context_length of 1024, temperature of 0.7, and max_tokens of 2000.

generator:
  llama:
    llm_path: "model/Meta-llama-3-8B-Instruct-Q4_K_M.gguf"
  mistral:
    llm_path: "model/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
  gemma:
    llm_path: "model/gemma-7b-it-Q4_K_M.gguf"
  context_length: 1024
  temperature: 0.7
  max_tokens: 2000

It also creates the Prompt Template. This template helps format the query and the context before passing it to the LLM to get a response.

from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms import LlamaCpp

from base.config import Config
class Generator(Config):
    """Generator, aka LLM, to provide an answer based on some question and context"""
    def __init__(self, model) -> None:
        super().__init__()
    # template
        self.template = """
            Use the following pieces of context to answer the question at the end.
            {context}
            Question: {question}
            Answer:
        """
   # load llm from local file
        self.llm = LlamaCpp(
            model_path=f"{self.parent_path}/{self.config['generator'][model]['llm_path']}",
            n_ctx=self.config["generator"]["context_length"],
            temperature=self.config["generator"]["temperature"],
        )
        # create prompt template
        self.prompt = PromptTemplate(
            template=self.template, input_variables=["context", "question"]
        )
    def get_answer(self, context: str, question: str) -> str:
        """
        Get the answer from llm based on context and user's question
        Args:
            context: most similar document retrieved
            question: user's question
        Returns:
            llm answer
        """
        query_llm = LLMChain(
            llm=self.llm,
            prompt=self.prompt,
            llm_kwargs={"max_tokens": self.config["generator"]["max_tokens"]},
        )
        return query_llm.run({"context": context, "question": question})

With the LLMs loaded, we fetch the SQuAD dataset from HuggingFace and shuffle it to ensure enough variety in the question theme.

squad = load_dataset("squad", split="train")
squad = squad.shuffle()

Now, we can loop over 60 questions and contexts and record the metrics mentioned above.

for i in range(60):
    context = squad[i]['context']
    query = squad[i]['question']
    answer = squad[i]['answers']['text'][0]

    # Llama
    answer_llama, words_per_second, words = utils.get_llm_response(llama, context, query)
    llama_metrics["words_per_second"].append(words_per_second)
    llama_metrics["words"].append(words)
    # mistral
    answer_mistral, words_per_second, words = utils.get_llm_response(mistral, context, query)
    mistral_metrics["words_per_second"].append(words_per_second)
    mistral_metrics["words"].append(words)
    # gemma
    answer_gemma, words_per_second, words = utils.get_llm_response(gemma, context, query)
    gemma_metrics["words_per_second"].append(words_per_second)
    gemma_metrics["words"].append(words)

    # GPT-3.5 rank
    llm_answers_dict = {'llama': answer_llama, 'mistral': answer_mistral, 'gemma': answer_gemma}
    rank = utils.get_gpt_rank(answer, llm_answers_dict, os.getenv("OPENAI_API_KEY"))
    llama_metrics["rank"].append(rank.index('1')+1)
    mistral_metrics["rank"].append(rank.index('2')+1)
    gemma_metrics["rank"].append(rank.index('3')+1)

The function get_llm_response is responsible for receiving the loaded LLM, the context, and the question and return the LLM answer as well as the quantitative metrics.

def get_llm_response(model: Generator, context: str, query: str) -> Tuple[str, int, int]:
    """
    Generates an answer from a given LLM based on context and query
    returns the answer and the number of words per second and the total number of words
    Args:
        model: LLM
        context: context data
        query: question
    Returns:
        answer, words_per_second, words
    """
    init_time = time.time()
    answer_llm = model.get_answer(context, query)
    total_time = time.time()-init_time
    words_per_second = len(re.sub("[^a-zA-Z']+", ' ', answer_llm).split())/total_time
    words = len(re.sub("[^a-zA-Z']+", ' ', answer_llm).split())
    return answer_llm, words_per_second, words

We can see that Llama 3 is faster than Mistral and Gemma by producing on average ~0.7 words per second, while Mistral produces ~0.26 and Gemma ~0.4 words. In terms of answer length, Llama 3 also produces longer answers than Mistral and Gemma, with an average answer length of 148 words against 20 words for Mistral and 50 for Gemma. Finally, based on RAQ, Mistral had the best average rank of approximately 1.81, followed by Gemma with an average of 2.05, while Llama 3 performed worse with an average rank of approximately 2.1.

Figure 10: Metrics comparison between all LLMs (image by author)
Figure 10: Metrics comparison between all LLMs (image by author)

The RAQ framework also includes a statistical test to understand if the observed differences are significant. Table 1 displays the results of the Dunn post-hoc test, comparing the performance of different language models. Each cell indicates whether the difference in performance between the respective models is statistically significant at a 5% significance level. "Significant" denotes a statistically significant difference (p-value ≤ 0.05), while "Not Significant" indicates no statistically significant difference (p-value > 0.05). For the selected significance level, the Dunn test result shows that the difference in performance between models is not significant.

p_values = sp.posthoc_dunn([Llama_metrics['rank'], mistral_metrics['rank'], gemma_metrics['rank']], p_adjust='holm')
p_values > 0.05
Table 1: Significance of differences in ranks among the set of LLMs.
Table 1: Significance of differences in ranks among the set of LLMs.

It is always important to assess qualitatively some examples. Below, we have the answers from the 3 models to the question ‘Power House Day is celebrated on what day in New Haven?’ based on the following context:

Context: ‘For over a century, New Haven citizens had fought in the colonial militia alongside regular British forces, as in the French and Indian War. As the American Revolution approached, General David Wooster and other influential residents hoped that the conflict with the government in Britain could be resolved short of rebellion. On 23 April 1775, which is still celebrated in New Haven as Powder House Day, the Second Company, Governor’s Foot Guard, of New Haven entered the struggle against the governing British parliament. Under Captain Benedict Arnold, they broke into the powder house to arm themselves and began a three-day march to Cambridge, Massachusetts. Other New Haven militia members were on hand to escort George Washington from his overnight stay in New Haven on his way to Cambridge. Contemporary reports, from both sides, remark on the New Haven volunteers’ professional military bearing, including uniforms.’

All 3 models gave correct answers. While Llama 3 and Gemma provided more complete answers, Mistral was more succinct.

Llama 3 answer: ‘New Haven’s Powder House Day is celebrated on April 23rd.’

Gemma answer: ‘Sure! The text states on which day Powder House Day is celebrated on: Powder House Day is celebrated on 23 April in New Haven.’

Mistral answer: ’23 April’

Conclusion

On-device models present a great opportunity to enhance user experiences by making powerful LLMs accessible on devices with lower computational resources. Both Apple and Google are actively developing smaller, more efficient models to meet this need, enabling more people to benefit from advanced AI in their daily lives.

In this article, we explored Gemma, the open LLM developed by Google, which introduced four novel features to the traditional Transformer architecture: Multi-Query Attention in the 2B version, RoPE embeddings for positional encoding, GeGLU as the activation function, and input normalization.

We also compared Gemma’s performance against Llama 3 and Mistral on a reading comprehension dataset. We observed that Gemma produced more words per second and wrote longer answers than Mistral, but it did not surpass Llama 3 in these metrics. Using the RAQ framework, we assessed the accuracy of the three models. While the data showed better results from Mistral, followed by Gemma, the differences were not statistically significant. Therefore, we can say that the 3 models performed similarly when applied to our use case of reading comprehension.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Luís Roque, Rafael Guedes. Research to Production: Relative Answer Quality (RAQ) and NVIDIA NIM. https://medium.com/@luisroque/research-to-production-relative-answer-quality-raq-and-nvidia-nim-15ce0c45b3b6, 2024.

[2] Gemma Team, Google DeepMind. Gemma: Open Models Based on Gemini Research and Technology, 2023.

[3] Gemini Team. Gemini: A family of highly capable multimodal models, 2023.

[4] Noam Shazeer. Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150, 2019.

[5] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. Attention Is All You Need. arXiv:1706.03762, 2017.

[6] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864, 2021.

[7] Noam Shazeer. GLU Variants Improve Transformer. arXiv:2002.05202, 2020.

[8] Dan Hendrycks, Kevin Gimpel. Gaussian Error Linear Units (GELUs). arXiv:1606.08415, 2016.

[9] Biao Zhang, Rico Sennrich. Root Mean Square Layer Normalization. arXiv:1910.07467, 2019.

The post Gemma vs. Llama vs. Mistral: Exploring Smaller AI Models appeared first on Towards Data Science.

]]>
MMM: Bayesian Framework for Marketing Mix Modeling and ROAS https://towardsdatascience.com/mmm-bayesian-framework-for-marketing-mix-modeling-and-roas-ccade4005bd5/ Thu, 06 Jun 2024 05:21:33 +0000 https://towardsdatascience.com/mmm-bayesian-framework-for-marketing-mix-modeling-and-roas-ccade4005bd5/ Bayesian framework to model media channels performance, Return on Ad Spend (ROAS), and budget allocation using PyMC

The post MMM: Bayesian Framework for Marketing Mix Modeling and ROAS appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

Scalable internet businesses depend on Marketing to drive growth. Not only that, of course, but at a certain scale, very few companies can afford not to be extremely efficient in acquiring customers. Two hot topics that companies are investing heavily into bringing Artificial Intelligence (AI) capabilities into marketing are Media Mix Modeling (MMM) and Customer Lifetime Value (LTV) prediction. Both are focused on increasing the return on the investment organizations deploy on marketing. This article covers what MMM is and the best practices for applying it.

MMM is a technique that allows marketing teams to measure the impact of their investments and how they contribute to driving conversations. The complexity of this task has increased rapidly in the past years since the platforms available for advertising have skyrocketed. This phenomenon has spread potential customers over different media channels that we can separate into offline or online buckets. Traditional offline channels are unplugged from digital support and can range from the newspaper, radio, television ads, and coupons to a booth at a trade show. Online channels exploded, and companies use many of them together, such as email, social media, organic search, paid search, affiliate marketing, and influencer marketing.

One important caveat is that a good MMM requires an equally accurate data-driven attribution model, i.e., which channels contributed to acquiring a specific customer. Also, note that while attribution is performed at the user level, MMM is usually applied at the acquisition channel level. Data-driven attribution is out of the scope of this article.

In this article, our focus is twofold. First, we develop a Bayesian model designed to increase transparency on how each media channel performs. Secondly, we optimize the budget allocation to maximize our variable of interest, which in this case is revenue. Besides providing a detailed view of how a Bayesian approach works for MMM, we also give a walkthrough on implementing and applying it using a public dataset. We test the model accuracy and calculate each channel’s Return On Ad Spend (ROAS). Finally, we optimize a hypothetical budget across three channels to maximize revenue.

Figure 1: Marketing Mix Modelling - Optimising budget across different media channels (image by author with DALL-E)
Figure 1: Marketing Mix Modelling – Optimising budget across different media channels (image by author with DALL-E)

As always, the code is available on our GitHub.

Media Mix Modeling: What is it?

MMM empowers organizations across the globe by measuring the effectiveness of their advertising channels and providing transparency on how media spending impacts sales. These models play an important role in supporting the decision-making process of budget allocation across channels by optimizing a target variable of interest, such as sales, return on ad spend (ROAS), revenue, conversion, LTV, etc.

Over the past years, many studies have been performed, and several models have been proposed to try to model the influence that spending has on the variables of interest [1]. These models are based on weekly or monthly data aggregated geographically. We are interested in modeling the relationship between our dependent variables, one or many of the variables of interest defined above, and independent variables. Some independent variables are obvious, e.g., the ad spend across channels. Still, we can extend our approach to include further related effects from price, product distribution, inflation, weather, seasonality, and market competition.

The traditional approaches rely on regression methods to infer causation from correlation. Nevertheless, the response of sales to media spending is not linear – there is saturation, which means diminishing returns at high-level spending. Moreover, advertisement has a lag or carryover effect, meaning spending in previous weeks can impact sales from the following weeks.

Figure 2: Example of Ad Saturation Curve and Ad Lag effect (image by author)
Figure 2: Example of Ad Saturation Curve and Ad Lag effect (image by author)

Bayesian Methods for Media Mix Modeling

Bayesian methods can be defined to consider the saturation/shape and lag/carryover effects.

Before diving into the model details, let’s define a hypothetical dataset for a better understanding of what variables the model takes. Suppose we have weekly data at a country level where each row represents a Week (t), and each column represents either a Media Channel (m) or a Control Variable (c) such as seasonality or product price. The media spend of channel m at week t is defined as Xt,m, and the control variable for the same week is defined as Zt,c.

Table 1: Hypothetical weekly dataset with 3 media channels, one control variable, and the target variable (Sales)
Table 1: Hypothetical weekly dataset with 3 media channels, one control variable, and the target variable (Sales)

Lag or Carryover Effect

The carryover effect is modeled by a function called adstock [1]. This function creates a cumulative effect of the spending in a specific channel. It transforms its time series through a weighted average of the media spend from the current week and previous L-1 weeks. L is the maximum duration of the carryover effect for a particular media channel, and it plays an important role in estimating the weight Wm in the weighted average equation.

Equation 1: Adstock function that models the carryover effect
Equation 1: Adstock function that models the carryover effect

L can be set differently across media channels. It is a hyperparameter to be defined by an expert. If no prior information exists for a particular channel, the authors advise setting L to a large number, such as 13, to capture potentially heavily lagged effects.

The equation that defines the weight can have two different forms:

  1. Immediate/Geometric Adstock [2] when the advertisement effect peak happens at the same time as the ad exposure, i.e., we have a peak in sales in the same week we increased the spending of a media channel. In equation 2, αm is the retention rate of the ad effect.
Equation 2: Geometric decay function
Equation 2: Geometric decay function
  1. Delayed Adstock [1] when the advertisement effect peak takes longer to build up and does not immediately impact sales. In equation 3, θm is the delay of the peak effect.
Equation 3: Delayed Adstock function
Equation 3: Delayed Adstock function

Let’s pick up our hypothetical dataset and calculate the Immediate and Delayed Adstock for the Facebook channel. To start, we added 5 more weeks to the dataset. We consider a retention rate (αm) of 80% and a peak delay (θm) of 5 weeks. After that, we calculate the weight for the immediate effect and the weight for the delayed effect to get to the final value of Immediate and Delayed Adstock at week 8.

Table 2: Immediate and Delayed Adstock calculation for Facebook
Table 2: Immediate and Delayed Adstock calculation for Facebook

Figure 3 shows how much each week’s spending contributes to the sales volume at week 8.

Figure 3: Immediate and Delayed Adstock effect for our hypothetical Facebook Spend (image by author)
Figure 3: Immediate and Delayed Adstock effect for our hypothetical Facebook Spend (image by author)

Saturation or Shape Effect

The saturation or shape effect is modeled by transforming the media spends through a curvature function such as the logistic saturation function [3]. It is defined as follows:

Equation 3: Logistic saturation function
Equation 3: Logistic saturation function

where x represents the media spends, and λ controls the steepness of the saturation curve, i.e., determines how quickly the media spend effect saturates. We can then interpret a low λ value as a more gradual increase in the response function, which translates into media spending having a noticeable effect over a large range of values. Conversely, higher λ values will result in diminishing returns on spending. Figure 4 shows these different behaviors very clearly.

Figure 4: Logistic saturation functions based on different parameters (source)
Figure 4: Logistic saturation functions based on different parameters (source)

It is difficult to know which parameters we should use for the model since these are quite specific for how each channel behaves. Nonetheless, in a Bayesian approach, these parameters are estimated using prior distributions. Hence, the model selects the most likely value parameters for given data. Therefore, we must set a distribution rather than a single value.

Combining the Carryover and Shape Effect

As mentioned in the previous two sections, to model the carryover and shape effect, we need to apply the transformations to the media spending of each channel. It raises the question of which transformation should be applied first. The authors suggest to:

  • The shape effect follows the carryover if the media spending is heavily concentrated on certain periods.
  • The carryover follows the shape effect if the media spending is evenly distributed across multiple time periods.

Since organizations usually tend to concentrate their marketing activity, the most common approach is the carryover → shape ffect combination.

That said, the dependent variable sales y at week t can be modeled through a linear combination of media spending and control variables. We also use a regression coefficient β to model different effects for different media channels.

Equation 4: Modeling sales combining Carryover → Shape Effect and control variables
Equation 4: Modeling sales combining Carryover → Shape Effect and control variables

where 𝛂 is the intercept. The function f(xm,t) encodes the contribution of media on the target variable considering adstock (carryover) and saturation effects. γc is the effect of control variable Zt,c and et is white noise.

Bayesian Model

The Bayesian approach begins with defining prior distributions for the model parameters, reflecting initial beliefs before considering the data. As new data is introduced, the likelihood function, which represents the probability of observing the data given the parameters, is calculated. In this context, the data includes media channels X and control variables Z, which explain the dependent variable y. Using Bayes theorem, the posterior distribution is obtained by combining the prior distributions and the likelihood function.

The authors rely on Gibbs Sampling [4] due to its sampling efficiency in selecting the parameter values for each media channel (X) and control variable (Z).

Equation 6: Maximising the likelihood of the vector of parameters (Φ) given the media channels (X), control variables (Z) and the dependent variable (y)
Equation 6: Maximising the likelihood of the vector of parameters (Φ) given the media channels (X), control variables (Z) and the dependent variable (y)

Remember that the model relies less on prior distributions to estimate the parameters when the data carries strong information and has clear patterns.

Nevertheless, the authors left some guidance on how to define prior distributions for each of the parameters:

  • Retention rate (α) is constrained on [0, 1[ and should have a prior defined on [0, 1[ such as beta or uniform distribution.
Figure 5: Beta Distribution (source)
Figure 5: Beta Distribution (source)
  • Delay parameter (θ) is usually constrained on [0, L-1] and should have a prior such as uniform or scaled beta distribution.
Figure 6: Uniform Distribution (source)
Figure 6: Uniform Distribution (source)
  • Gamma (γ) and Intercept are usually modeled by a normal distribution.
Figure 7: Normal Distribution (source)
Figure 7: Normal Distribution (source)
  • Lambda (λ) is usually modeled by a gamma distribution.
Figure 8: Gamma Distribution (source)
Figure 8: Gamma Distribution (source)
  • Regression Coefficients (β) are usually modeled by nonnegative priors such as normal distribution since the media spending does not negatively affect y.
Figure 9: Half Normal Distribution (source)
Figure 9: Half Normal Distribution (source)

Bayesian Media Mix Modeling with PyMC

This section implements a Bayesian model on a public dataset from Kaggle under the License CC0: Public Domain. This dataset contains information about the spending on three different media channels (TV, Radio, and newspapers) and the sales for the same period.

The dataset is composed of the following:

  • ID – identifies a row;
  • TV Ad Budget ($) – advertisement spends on TV;
  • Radio Ad Budget ($) -advertisement spends on Radio;
  • Newspaper Ad Budget ($) – advertisement spends on Newspaper;
  • Sales ($) – the target variable.

The fitted Bayesian model will help us calculate ROAS, retention, and saturation effect per channel. Besides that, it will also help us optimize the budget allocation for future weeks.

To estimate the reliability of the model, we will assess how well it can model the dependent variable on unseen data based on the spending of each media channel and control variables. We resort to regression metrics such as Mean Absolute Error (MAE). In terms of benchmarking, we use a naive model that always predicts the average value of the training data. By the way, companies often rely on this when no MMM is available.

We start by importing the libraries:

%matplotlib inline
%load_ext autoreload
%autoreload 2
import arviz as az
import datetime
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import utils
from pymc_marketing.mmm.delayed_saturated_mmm import DelayedSaturatedMMM
from sklearn.metrics import mean_absolute_error

Then, we load the dataset and perform some basic preprocessing tasks. We simplified the column names and added a new date column based on the ID. It helps enrich the dataset with control variables such as seasonality and trend.

# load data and rename columns
df = pd.read_csv('data/data.csv')
df = df.rename(columns={'Unnamed: 0': 'id', 'TV Ad Budget ($)':'tv', 'Radio Ad Budget ($)': 'radio', 'Newspaper Ad Budget ($)': 'newspaper', 'Sales ($)': 'sales'})

# create datetime column
df['ds'] = df['id'].apply(lambda x: pd.to_datetime("2024-02-26")-datetime.timedelta(weeks=len(df)-x))

After that, we perform some exploratory data analysis to understand correlations within the data:

1. We evaluate the correlation between the dependent variable and each media channel.

  • TV is the most correlated feature with sales, while Newspaper is the least correlated.
corr_matrix = df[['sales', 'tv', 'radio', 'newspaper']].corr()
sns.heatmap(corr_matrix, annot=True, cmap='Blues')
plt.show()
Figure 10: Correlation Matrix (image by author)
Figure 10: Correlation Matrix (image by author)

2. We plot sales versus each media channel to assess if there is any lag effect between a peak in spending and a peak in sales:

  • There is no clear trend or seasonality in sales.
  • The effect of TV advertisements seem to have an immediate impact on sales.
  • The effect of Radio advertisements also seem to have an immediate impact on sales. For example, in weeks 1, 2, and 3 in 2022, where TV had lower advertisement values, we see 2 spikes in sales, which match the spikes in Radio.
  • The effect from Newspaper advertisements seem to have a 1-2 week lag, but it is hard to tell since TV and Radio advertisements ran simultaneously.
# only sales
utils.line_plot(df.copy(), ['sales'], 'Sales over Time')

# sales vs tv spends
utils.line_plot(df.copy(), ['sales', 'tv'], 'Sales vs TV over Time')

# sales vd radio spends
utils.line_plot(df.copy(), ['sales', 'radio'], 'Sales vs Radio over Time')

# sales vs newspaper spends
utils.line_plot(df.copy(), ['sales', 'newspaper'], 'Sales vs Newspaper over Time')
Figure 11: Time Series plots (image by the author)
Figure 11: Time Series plots (image by the author)

With the EDA finalized, we can start preparing the modeling part by:

1. Splitting the data into train and test:

train_df = df.sort_values(by='ds').iloc[:-5,:]
test_df = df.sort_values(by='ds').iloc[-5:,:]

2. Use the weekly data we generated before to extract control variables such as trend and seasonality.

  • We resort to the time series model Prophet from Meta to decompose our time series into trend and seasonality, and we use those as control variables.
seasonality, trend = utils.extract_trend_seasonality(train_df, 'sales', 5)
train_df.loc[:, 'seasonality'] = seasonality[:-5]
test_df.loc[:,'seasonality'] = seasonality[-5:]
train_df.loc[:,'trend'] = trend[:-5]
test_df.loc[:,'trend'] = trend[-5:]
Figure 12: Trend and Yearly seasonality (image by author)
Figure 12: Trend and Yearly seasonality (image by author)

3. Set the different hyperparameters for the model. These parameters can be defined through a traditional ML hyperparameter search. We optimize some regression metrics by changing the dist, mu, and sigma values. Remember that higher standard deviation values (sigma) give the model more freedom while searching for the optimal parameter.

my_model_config = {'intercept': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
 'beta_channel': {'dist': 'HalfNormal', 'kwargs': {'sigma': 2}},
 'alpha': {'dist': 'Beta', 'kwargs': {'alpha': 1, 'beta': 3}},
 'lam': {'dist': 'Gamma', 'kwargs': {'alpha': 3, 'beta': 1}},
 'likelihood': {'dist': 'Normal',
  'kwargs': {'sigma': {'dist': 'HalfNormal', 'kwargs': {'sigma': 2}}}},
 'gamma_control': {'dist': 'Normal', 'kwargs': {'mu': 0, 'sigma': 2}},
 'gamma_fourier': {'dist': 'Laplace', 'kwargs': {'mu': 0, 'b': 1}}}
Figure 13: Kruschke diagram of our model (image by author)
Figure 13: Kruschke diagram of our model (image by author)

In Figure 13, we present the Kruschke diagram of the model we implemented. It provides a concise overview of the definitions we made earlier. There are a few aspects to consider when reading such a diagram. Note that we find the variables and their respective distributions within each node. For instance, the circle containing α depicts a beta distribution, as defined earlier. Shaded nodes represent the observed variables. Rounded-corner boxes indicate repetition. For example, since we have three acquisition channels, we set a distinct set of α, β, and λ parameters for each channel. The arrows illustrate dependencies. In our model, two arrows point to the likelihood function: one indicates a dependency on the mu parameter and another on the sigma parameter. The mu parameter itself has three additional dependencies. Recall that we chose to model sales by integrating the carryover effect, shape effect, and control variables.

Now that we have defined the training and testing sets and the model config, we can initiate the Bayesian model and fit it into the training data.

  • Media Channels ["tv", "radio", "newspaper"]
  • Control Variables ["seasonality", "trend"]
  • From EDA, maximum adstock lag (delay parameter) does not seem higher than 2
mmm = DelayedSaturatedMMM(
    model_config=my_model_config,
    sampler_config={"progressbar": True},
    date_column="ds",
    channel_columns=["tv", "radio", "newspaper"],
    control_columns=["seasonality", "trend"],
    adstock_max_lag=2,
)

mmm.fit(X=train_df[['ds', 'tv', 'radio', "newspaper", "seasonality", "trend"]], y=train_df['sales'], target_accept=0.95, chains=4, random_seed=42)

After fitting the model, we can check how well it fits the training data by comparing the sampling predictions (blue) and the true values (black). In our case, they are well aligned.

mmm.sample_posterior_predictive(train_df[['ds', 'tv', 'radio', "newspaper", "seasonality", "trend"]], extend_idata=True, combined=True)
mmm.plot_posterior_predictive(original_scale=True);
Figure 14: Sampling predictions vs actual values (image by author)
Figure 14: Sampling predictions vs actual values (image by author)

Now, we can start interpreting the model using several approaches:

1. Checking the parameter estimation:

  • Radio seems to have the most return on investment since its coefficient (beta) is the highest (1.185), followed by TV and Newspaper.
  • The retention rate α is 3.2% for TV, 2.3% for Radio and 23.9% for Newspaper.
az.summary(data=mmm.fit_result,
    var_names=[
        "intercept",
        "likelihood_sigma",
        "beta_channel",
        "alpha",
        "lam",
        "gamma_control",
    ],
)
Table 3: Posterior distribution of model parameters
Table 3: Posterior distribution of model parameters
  • The saturation rate λ is higher on TV (3.138), which accounts for 73% of the total spending. In Figure 12, we can more easily compare the saturation rate for the 3 channels.
Figure 15: Posterior distribution of the λ parameter for each channel (image by author)
Figure 15: Posterior distribution of the λ parameter for each channel (image by author)

2. Checking channel contribution and ROAS:

  • ROAS is calculated by setting the spend of one of the media channels to zero to assess what happens to the predicted sales compared to the current sales. For example, if we set the media spend of Newspaper to zero, we are not expecting a big decline in sales. Therefore, its ROAS will be low.
  • Although TV has the highest contribution because it had higher spending, we can see that the model predicts a higher ROAS for Radio.
# channel contribution
fig = mmm.plot_channel_contribution_share_hdi(figsize=(7, 5))

# ROAS calculation
utils.plot_ROAS(mmm, train_df, ["tv", "radio", "newspaper"])
Figure 16: Channel Contribution and ROAS (image by author)
Figure 16: Channel Contribution and ROAS (image by author)

3. Finally, we can also assess what would happen if we increase in 50% the advertisement spend in each channel by taking into consideration the carryover and saturation effect.

X axis is the input channel data percentage level:

  • When =1, we have the model input spend data.
  • When =1.5, we have how much the contribution would have been if we had increased the spending by 50%.

Newspaper seems to have reached its saturation point since a 50% increase in spending would not bring much more contribution.

Radio seems much less saturated than TV based on the slope of both lines.

plt.rcParams["figure.figsize"] = (20,5)
mmm.plot_channel_contributions_grid(start=0, stop=1.5, num=12);
Figure 17: Hypothetical analysis of how much contribution each channel gains with an increase in spending (image by author)
Figure 17: Hypothetical analysis of how much contribution each channel gains with an increase in spending (image by author)

To understand if our conclusions are valid, we can use our test set to assess the ability of our model to predict future sales based on the media channels and control variables. For that, we use MAE and compare it with a naive model.

  • We had an MAE of 2.01 for a target average of 13.8.
  • We have an error 58% lower than the baseline.
y_out_of_sample = mmm.sample_posterior_predictive(X_pred=test_df[['ds', 'tv', 'radio', "newspaper", "seasonality", "trend"]], extend_idata=False)
y_pred = [np.median(x) for x in y_out_of_sample['y']]

print(f"MAE: {mean_absolute_error(test_df['sales'], y_pred)} for an average target of {test_df['sales'].mean()}")
print(f"MASE: {mean_absolute_error(test_df['sales'], y_pred)/mean_absolute_error(test_df['sales'], [train_df['sales'].mean()]*5)}")

MAE: 2.008733107943637 for an average target of 13.8 MASE: 0.41701581608539257

The regression results show that the model is reliable and does a good job of modeling sales based on the media channels and control variables.

Budget Allocation

Since we assume that the effect of spending on sales is not linear, it will saturate at some point. Therefore, we need to determine which saturation function better fits our data. We have two function options to model saturation:

  • A sigmoid function where α (alpha) is the saturation point, which means an increase in spending will not increase sales, and λ (lambda) influences the slope of the curve, where higher values make the curve steeper.
Equation 7: Sigmoid Function
Equation 7: Sigmoid Function
  • Michaelis-Menten function where α (alpha) is the maximum contribution a channel can have and λ (lambda) is the moment when the curve adjusts its direction, i.e., the slope.
Equation 8: Michaelis-Menten Function
Equation 8: Michaelis-Menten Function

To determine which curve will suit our data better, we will use our fitted MMM to calculate the parameters for each function. After that, we plot both of them and visually check which one performs a better fit.

  • For our specific use case, the sigmoid function performed better.
# plot and extracting alpha and lambda
sigmoid_response_curve_fig = mmm.plot_direct_contribution_curves(show_fit=True)
sigmoid_params = mmm.compute_channel_curve_optimization_parameters_original_scale(method='sigmoid')

mm_response_curve_fig = mmm.plot_direct_contribution_curves(show_fit = True, method='michaelis-menten')
mm_params = mmm.compute_channel_curve_optimization_parameters_original_scale(method='michaelis-menten')
Figure 18: Sigmoid and Michaelis-Menten fitting (image by author)
Figure 18: Sigmoid and Michaelis-Menten fitting (image by author)

Now that we have each media channel’s sigmoid parameters (α and λ), we know the saturation point of each channel. Thus, additional spending will not increase the returns, while the amount invested in another channel can have that desired effect.

We can use an algorithm to optimize the budget allocation based on the channel saturation, the total budget available, and the budget constraints for each channel. PyMC has an implementation of Sequential Least Squares Quadratic Programming (SLSQP). It maximizes the total contribution from all channels, taking into consideration all of those three variables:

  • The total budget limitation;
  • The minimum and maximum spending limits for each channel;
  • The saturation curve.
result_sigmoid = mmm.optimize_channel_budget_for_maximum_contribution(
    method = 'sigmoid', #define saturation function
    total_budget = 500, # total budget
    parameters = sigmoid_params, # sigmoid parameters extracted previously
    budget_bounds = {'tv': [75, 296], 'radio': [10, 300], 'newspaper': [1, 25]} # budget constraints by channel
)
Table 4: Optimised budget allocation
Table 4: Optimised budget allocation

Table 4 shows the results of our budget allocation, where Radio is the channel with the highest estimated contribution, and TV is the channel where we are advised to spend the highest budget.

Budget allocation under market uncertainty

In the current economy, we are facing a lot of uncertainty. Thus, we must design a budget allocation strategy that accommodates various scenarios.

Let’s consider three different scenarios:

  • Initial: the economy stays stable, and the budget allocation is the same as calculated in the previous section.
  • Scenario 2: the economy goes through a recession period, and our budget is cut by 40%
  • Scenario 3: the economy gets more favorable and starts growing, and our budget increases by 20%

We will use the same fitted MMM model and the same sigmoid parameters of the model to optimize the budget allocation under these different scenarios. We will use the same code as before, but we loop over the different scenarios to reduce or increase the available budget this time.

scenarios_result = []
total_budget = 500
channels = ['tv', 'radio', 'newspaper']

for scenario in np.array([0.6, 1.2]):
    scenarios_result.append(
        mmm.optimize_channel_budget_for_maximum_contribution(
            method="sigmoid",  # define saturation function
            total_budget=total_budget * scenario,
            parameters=sigmoid_params,
            budget_bounds={
                channel: [1, total_budget * scenario] for channel in channels
            },
        ).to_dict()
    )
_ = mmm.plot_budget_scenearios(
    base_data=result_sigmoid, method="sigmoid", scenarios_data=scenarios_result
)

As shown in Figure 19, under a recession scenario, the budget allocated to TV decreases significantly more than Radio compared to the initial scenario. This is expected because Radio has a higher ROAS, as we saw before. On the other hand, the budget allocated to TV and Radio increases similarly under a growth scenario.

Figure 19: Budget allocation under different scenarios (image by author)
Figure 19: Budget allocation under different scenarios (image by author)

Conclusion

AI for Media Mix Modeling can be the difference between getting a positive return on investment and acquiring valuable and loyal customers or draining our financial resources in the wrong media channel with the wrong customers.

In this article, we developed a Bayesian framework for Marketing Mix Modeling that can provide transparency and assess further potential for each company’s media channel to acquire new customers. In our approach, domain knowledge from the marketing teams can be incorporated by setting prior distributions. It contributes to improving the ability of the model to understand the relationship between the media channels and the dependent variable of interest (e.g., sales). Finally, the budget allocation strategy can be optimized depending on the capacity of the company to invest in acquiring new customers. In today’s macroeconomic scenarios, companies might be turning to profitability and, thus, reduce the available budget to invest in growth. We showed how to make a data-driven decision about where to cut with minimal impact. Conversely, we show where to invest if the scenario is more positive and the company wants to deploy more resources to grow faster.

We are currently developing and deploying new AI applications across organizations. For example, we are enhancing customer experience with generative AI and improving the planning process with time series forecasting. In this case, we demonstrate how AI can improve the efficiency of the marketing budget. From our experience, an advanced and mature organization in terms of AI adoption needs a suite of specialized AI models focused on its core activities.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Yuxue Jin, Yueqing Wang, Yunting Sun, David Chan, Jim Koehler. (2017). Bayesian Methods for Media Mix Modeling with Carryover and Shape Effects.

[2] Dominique M. Hanssens , Leonard J. Parsons , Randall L. Schultz. (2003). Market response models: econometric and time series analysis. Springer Science & Business Media.

[3] Hill, A. V. (1910). The possible effects of the aggregation of the molecules of haemoglobin on its dissociation curves. Journal of Physiology, 40 (suppl), iv–vii. doi:10.1113/jphysiol.1910. sp001386.

[4] Gelfand, A. E. & Smith, A. F. (1990). Sampling-based approaches to calculating marginal densities. Journal of the American statistical association, 85 (410), 398–409

All images are by the authors unless noted otherwise.

The post MMM: Bayesian Framework for Marketing Mix Modeling and ROAS appeared first on Towards Data Science.

]]>
Moirai: Time Series Foundation Models for Universal Forecasting https://towardsdatascience.com/moirai-time-series-foundation-models-for-universal-forecasting-dc93f74b330f/ Thu, 11 Apr 2024 05:05:21 +0000 https://towardsdatascience.com/moirai-time-series-foundation-models-for-universal-forecasting-dc93f74b330f/ The future of predictive analytics: Explore Moirai, Salesforce's new foundation model for advanced time series forecasting

The post Moirai: Time Series Foundation Models for Universal Forecasting appeared first on Towards Data Science.

]]>
This post was co-authored with Rafael Guedes.

Introduction

The development of time series foundation models has been accelerating over the last two quarters, and we have been witnessing the release of a new model nearly every month. It started with TimeGPT [1] in the last quarter of 2023, and since then, we saw the release of Lag-Llama [2], Google releasing TimesFM [3], Amazon releasing Chronos [4], and Salesforce releasing Moirai [5].

To understand the growing interest in foundation models, we should define their core capability: zero-shot inference. It refers to the ability to accurately perform tasks or make predictions on data that these models have never encountered during their training phase. This ability has been explored for models applied across various domains, such as natural language processing (NLP), computer vision, and multimodal tasks (combining text, images, etc.). The term "zero-shot" comes from the idea that the model sees "zero" examples from a specific task or data domain during training yet can "shoot" or aim at performing tasks in that area effectively. The term was introduced in the paper "Zero-Shot Learning with Semantic Output Codes," authored by Hinton et al. and presented at the NIPS conference in 2009. Since then, it has emerged as one of the most prominent research topics and is now making its way into the field of time series analysis.

In this article, we explore Moirai, a new foundation model by Salesforce for time series forecasting. It builds on our series of articles about foundation models for time series forecasting, in which we provided detailed explanations and showcased the performance of models such as TimeGPT and Chronos on real-world datasets.

We provide an in-depth explanation of the architecture behind Moirai and the main components that allow the model to perform zero-shot inference. We also summarize the differences between the Moirai and the other two foundation models we have researched so far. We compare, for example, the size of the training data, the number of model parameters, and whether they allow multivariate forecasting.

Following this theoretical overview, we apply Moirai to a specific use case and dataset. We cover the practical implementation details and thoroughly analyze the model’s performance. Finally, we compare the performance of Moirai with TiDE and Chronos using a public dataset.

Figure 1: TimeGPT vs. TiDE vs. Chronos vs Moirai (image by author with DALL-E)
Figure 1: TimeGPT vs. TiDE vs. Chronos vs Moirai (image by author with DALL-E)

As always, the code is available on our GitHub.

Background

We define key concepts in time series forecasting to make it easier to understand the time series problems Moirai proposes to address.

Univariate time series forecasting focuses on predicting the future values of a single time series variable using only its past values. The forecasting model relies on the historical data of that single variable to identify patterns, trends, and cycles that can inform future predictions. An example would be forecasting tomorrow’s temperature based solely on past temperature records.

Multivariate time series forecasting involves predicting the future values of multiple related time series variables based on historical data. In this context, the forecast model accounts for the interdependencies and interactions between multiple variables to make predictions. For example, predicting the future sales of a product might consider not only past sales but also related factors such as marketing spend, seasonal trends, and competitor prices.

Covariates in time series forecasting are variables that can influence the outcome of the prediction. These variables can be known in advance or estimated for the forecast period. In both univariate and multivariate forecasting models, covariates incorporate additional insights beyond the historical data of the target variable. Examples include factors like holidays, special events, and economic indicators. Furthermore, in multivariate forecasting, covariates extend to include related time series data – these could be variables whose future values are either known or need to be predicted (see the example above).

Time series frequency refers to the intervals at which data points in a time series are recorded or observed, representing the regularity and granularity of the data over time. This frequency can range from high-frequency data, such as minute-by-minute transactions in financial markets, to low-frequency data, like annual economic indicators. Also, different frequencies can capture various trends, patterns, and seasonalities. For example, daily sales data may reveal patterns not visible in monthly aggregates, such as weekly cycles or the impact of specific days of the week.

Probabilistic forecasts extend beyond point predictions by providing a distribution of possible future outcomes. These output distributions represent the probability of different future values occurring, allowing for more informed decision-making under uncertainty. For instance, in scenarios where observations are strictly positive, such as sales volumes or energy consumption, probabilistic forecasts might use log-normal or gamma distributions to model the range of possible outcomes. Probabilistic forecasts are particularly useful in risk management and planning, as they enable stakeholders to evaluate the likelihood of various scenarios, from the most pessimistic to the most optimistic.

Moirai: the Time Series Foundation Model by Salesforce

Moirai is a foundational model for time series forecasting developed by Salesforce. It is designed as a universal model capable of predicting a wide range of time series. To achieve this flexibility, the model addresses several challenges associated with time series data, including the ability to:

  • Handle all kinds of data frequencies (hourly, daily, weekly, etc);
  • Accommodate any number and types of covariates, whether they are unknown in the future or known;
  • Generate a probabilistic forecast using a flexible distribution that can be adapted to several cases.

The dataset is one of the core components of any foundation model. The authors built a large-scale and diverse dataset comprising 27 billion observations spanning nine distinct time series domains. Additionally, they introduced three main novel concepts: Multi Patch Size Projection Layers, Any-Variate Attention, and Mixture Distribution, each explained in detail in the following sections.

Figure 2: MOIRAI architecture (source)
Figure 2: MOIRAI architecture (source)

Multi Patch Size Projection Layers

Patching was first introduced to time series as PatchTST [7]. Its goal is to divide the time series data into patches of size P, which are shorter subsets of the original series. Then, why is patching useful in the context of foundation models in time series forecasting?

Time series forecasting aims to understand the correlation between data in each different time step. Foundation models tend to use an architecture based on transformers. While transformers work well for NLP applications, a single time step does not have semantic meaning like a word in a sentence. Hence, we need a way to extract local semantic information to apply the attention mechanism. Patching the series aggregates time steps into subseries-level components with richer semantic representations.

In simpler terms, we could say that as word embeddings represent words in a high-dimensional space, time series patches can be considered representations of segments of the series in a multidimensional space defined by their features.

This process brings numerous advantages, such as:

  • Enabling the attention mechanism to extract local semantic meaning by looking into a group of time series instead of looking at a single time step;
  • Reducing the number of tokens being fed to the encoder, consequently reducing the memory needed, allowing to feed longer input sequences to the model;
  • With longer sequences, the model has more information to process and more meaningful temporal relationships to extract, potentially producing more accurate forecasts.

The patch size used by the authors depends on the data frequency, where lower frequency data has smaller patch sizes while higher frequency data has larger patch sizes:

  • Yearly and Quarterly → Patch size 8
  • Monthly → Patch size 8, 16, 32
  • Weekly and Daily → Patch size 16, 32
  • Hourly → Patch size 32, 64
  • Minute-level → Patch size 32, 64, 128
  • Second-level → Patch size 64, 128

Regarding the model architecture, the authors used an input and output patch layer. After converting the data into patches, the input patch layer, a simple linear layer, maps the time series subset into a patch embedding to be fed to the encoder-only transformer layer. Later, a second patch layer is used to process the output of the encoder. The output tokens are then decoded via the multi-patch size output projection. Since there are five different patch sizes, the model has five different input patch layers and five different output patch layers activated according to the patch size used to process the input data.

To clarify further, let’s examine a specific example. Suppose we aim to forecast a quarterly time series. The data is segmented into P patches of size 8. These patches are subsequently processed by an input patch layer designed for patch size 8. The patch embeddings generated by this layer are then fed into an encoder-only Transformer, which processes the embeddings. Finally, the processed embeddings are output through a patch layer, again tailored for patch size 8.

Figure 3: Patching process (image by author)
Figure 3: Patching process (image by author)

Any-Variate Attention

The traditional Transformer architecture expects to receive a single sequence of target values. However, the model is expected to handle multiple sequences of target values and dynamic covariates in a multivariate time series scenario. Therefore, the authors introduced the Any-Variate Attention to allow Moirai to process multiple sequences.

The process starts by flattening the multiple time series (variates) into a single sequence of values. Then, a variate encoding is applied to allow the model to distinguish the different variates in the sequence, which is important when calculating the attention score.

Figure 4: Flattening and variate encoding (image by author)
Figure 4: Flattening and variate encoding (image by author)

Any-Variate Attention has two fundamental characteristics: it achieves permutation equivariance with respect to variate ordering and permutation invariance with respect to variate indices.

Permutation equivariance regarding variate ordering means that if the sequence of observations within a variate is permuted, the model output for that variate reflects the same permutation. This property is required since we are working with time series, and the temporal dynamics must be preserved within each variate. Therefore, the model’s understanding of time series dynamics is consistent regardless of the input order.

Permutation invariance with respect to variate indices means that the model’s output does not change if the variates are reordered. For instance, let’s consider a scenario where we are processing temperature and humidity data as two variates in a multivariate time series setup. If we decide to swap the order in which these variates are presented to the model (first humidity, then temperature instead of first temperature, then humidity), it should not affect the final output. The model treats variate indices as interchangeable, focusing instead on the encoded relationships.

To achieve permutation equivariance/invariance, Moirai uses two distinct approaches:

  1. Rotary Positional Embeddings (RoPE) [8] ensures permutation equivariance by how its encoding works. It encodes the positional information by rotating the representation of tokens in the embedding space. The rotation angle is proportional to each token’s position in the sequence. Thus, it captures the absolute position of each token while maintaining the relative distances between any pair of tokens.
  2. Binary attention bias allows the model to be invariant – treating the variates as unordered. The model dynamically adjusts its focus by applying different attention biases (learnable scalars) based on whether elements belong to the same variate (m=n) or different variates (m≠n). This enables the Any-variate Attention mechanism to handle arbitrary numbers of variates and their permutations.
Figure 5: Any-Variate Attention equation (image by the author)
Figure 5: Any-Variate Attention equation (image by the author)

Mixture Distribution

Moirai is a probabilistic forecasting model, which means it learns the parameters of a distribution rather than merely providing a single point prediction. The output, being a distribution, enables decision-makers to evaluate the uncertainty of the predictions, as wider intervals indicate greater uncertainty from the model.

Like other probabilistic models such as DeepAR [9], the objective of Moirai is to estimate the parameters of a probability distribution by minimizing a loss function, specifically, the negative log-likelihood. There are several possible distributions for optimization. For instance, DeepAR can be configured to estimate the parameters of Gaussian, Beta, Negative Binomial, or Student’s t-distributions.

Since Moirai is a foundation model, it is designed to forecast various data domains and thus cannot be limited to a single distribution. To accommodate all possible scenarios, the model learns the parameters of a mixture of distributions, each suited to different kinds of data:

  • Student’s t-distribution is a robust option for most time series due to its ability to handle outliers and data with heavier tails.
  • Negative Binomial distribution is useful for strictly positive count data, as it does not predict negative values.
  • Log-Normal distribution effectively forecasts right-skewed data, such as economic indicators or natural phenomena.
  • Low Variance Normal distribution is used for data **** clustered tightly around the mean and is suitable for high-confidence predictions.
Figure 6: Mixture of Distributions (image by author)
Figure 6: Mixture of Distributions (image by author)

TimeGPT vs Chronos vs Moirai: The Comparison

This section presents the similarities and dissimilarities between the foundation models we have studied in this and previous articles.

Table 1: Comparison of foundation models for time series forecasting.
Table 1: Comparison of foundation models for time series forecasting.

Table 1 compares the key characteristics of the foundation models. At this stage, we are not focused on comparing their performance, which will be covered in the next section. We should start by saying that Chronos and Moirai are open-source models and will benefit from community contributions. So, we recommend that in doubt, the selection should go for open-source models with greater community support and potential to improve over time. An important takeaway is that Chronos demonstrates much better data efficiency, requiring significantly less training data. Nonetheless, the model is not yet multivariate. Lastly, by looking at the number of parameters, we can see that time series models are considerably smaller than LLMs, making them more user-friendly and easier to deploy.

Moirai vs. Chronos: a comparison in a public dataset

In this section, we will use Moirai to forecast tourism visitors to Australia using a real-world dataset that is publicly available under the cc-by-4.0 license. Subsequently, we compare the forecasting performance of Moirai against Chronos (large version) and TiDE (to access the code that generated the forecast with Chronos and TiDE, please check our last article).

We enhanced the dataset with economic covariates (e.g., CPI, Inflation Rate, GDP) extracted from Trading Economics, which uses economic indicators based on official sources. We also perform some preprocessing to increase the usability of the dataset further. We stored the preprocessed dataset version here so that our experiments can be easily reproduced.

We start by importing the libraries and setting global variables. We set the date column, target column, dynamic covariates, the frequency of our series, and the forecast horizon.

Python">%load_ext autoreload
%autoreload 2
import torch
import pandas as pd
import numpy as np
import utils

from datasets import load_dataset
from gluonts.dataset.pandas import PandasDataset
from huggingface_hub import hf_hub_download

from uni2ts.model.moirai import MoiraiForecast

TIME_COL = "Date"
TARGET = "visits"
DYNAMIC_COV = ['CPI', 'Inflation_Rate', 'GDP']
SEAS_COV=['month_1', 'month_2', 'month_3', 'month_4', 'month_5', 'month_6', 'month_7','month_8', 'month_9', 'month_10', 'month_11', 'month_12']
FORECAST_HORIZON = 8 # months
FREQ = "M"

After that, we load our dataset, which already has the exogenous features mentioned in the dataset description.

# load data and exogenous features
df = pd.DataFrame(load_dataset("zaai-ai/time_series_datasets", data_files={'train': 'data.csv'})['train']).drop(columns=['Unnamed: 0'])
df[TIME_COL] = pd.to_datetime(df[TIME_COL])

# one hot encode month
df['month'] = df[TIME_COL].dt.month
df = pd.get_dummies(df, columns=['month'], dtype=int)

print(f"Distinct number of time series: {len(df['unique_id'].unique())}")
df.head()

Distinct number of time series: 304

Once the dataset is loaded, we can split the data between train and test (we decided to use the last 8 months of data for our test set).

# 8 months to test
train = df[df[TIME_COL] <= (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]
test = df[df[TIME_COL] > (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]

print(f"Months for training: {len(train[TIME_COL].unique())} from {min(train[TIME_COL]).date()} to {max(train[TIME_COL]).date()}")
print(f"Months for testing: {len(test[TIME_COL].unique())} from {min(test[TIME_COL]).date()} to {max(test[TIME_COL]).date()}")

Months for training: 220 from 1998–01–01 to 2016–04–01 Months for testing: 8 from 2016–05–01 to 2016–12–01

Finally, we need to transform the pandas data frame into a GluonTS dataset to feed the model:

  • We concatenate the training dataset (target and dynamic covariates) with only the dynamic covariates in the test set (the target in the forecast horizon will be null). We then replace the index of the pandas’ data frame with the date column.
  • We set the column that allows us to distinguish the different time series (unique_id).
  • We defined which columns represent dynamic covariates known in the future (feat_dynamic_real).
  • We define the target column (target) and the series frequency (freq).
  • Note that there is no need to scale the data since the model handles it internally.
# create GluonTS dataset from pandas
ds = PandasDataset.from_long_dataframe(
    pd.concat([train, test[["unique_id", TIME_COL]+DYNAMIC_COV+SEAS_COV]]).set_index(TIME_COL), # concatenaation with test dynamic covaraiates
    item_id="unique_id",
    feat_dynamic_real=DYNAMIC_COV+SEAS_COV,
    target=TARGET,
    freq=FREQ
)

With the dataset ready, we can forecast using Moirai. For that, we need to load the model from Hugging Face and set the following parameters:

  • _predictionlength – which is the forecast horizon we defined earlier.
  • _contextlength – how many items in the sequence the model can attend to (any positive integer).
  • _patchsize – the length of each patch. As seen previously, the authors set different patch sizes depending on the frequency. To use the pre-defined values, _patchsize should be set to ‘auto’. It can also be set to any value from {auto, 8, 16, 32, 64, 128}.
# Prepare pre-trained model by downloading model weights from huggingface hub
model = MoiraiForecast.load_from_checkpoint(
    checkpoint_path=hf_hub_download(
        repo_id="Salesforce/moirai-R-large", filename="model.ckpt"
    ),
    prediction_length=FORECAST_HORIZON,
    context_length=24,
    patch_size='auto',
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=ds.num_feat_dynamic_real,
    past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
    map_location="cuda:0" if torch.cuda.is_available() else "cpu",
)

predictor = model.create_predictor(batch_size=32)
forecasts = predictor.predict(ds)

# convert forecast into pandas
forecast_df = utils.moirai_forecast_to_pandas(forecasts, test, FORECAST_HORIZON, TIME_COL)

Once the forecast has finished, we can plot the ground truth values and the predictions.

Figure 7: MOIRAI forecast vs. actual values (image by author)
Figure 7: MOIRAI forecast vs. actual values (image by author)

Figure 7 shows that Moirai struggled to forecast our time series and did not generate a stable forecast. Instead, it predicted several consecutive jumps with a higher magnitude than expected.

Having obtained the forecast from Moirai, we can now load the forecast generated by TiDE and Chronos and compute forecasting performance metrics for comparison. For better interpretability, we have used the Mean Absolute Percentage Error (MAPE) as our comparison metric.

Figure 8: MAPE comparison between MOIRAI, Chronos, and TiDE (image by author)
Figure 8: MAPE comparison between MOIRAI, Chronos, and TiDE (image by author)

As shown in Figure 8, Moirai has the highest MAPE in the entire forecast horizon. It slightly outperformed TiDE in one of the months but never managed to outperform Chronos. We conducted similar experiments on several private datasets, and the results consistently align with the findings presented in Figure 8. This consistency is relevant when analyzing foundation models, given that the training datasets are not publicly disclosed. It is plausible that any dataset from the public domain could have been used in their training data. In such circumstances, the model could have simply overfitted the training data.

Remember that Chronos does not allow the usage of covariates and assumes independence between the time series. This shows that the approach of Chronos is significantly better and has greater potential to be improved in the future.

Conclusion

In this article, we explored Moirai, one of the most recent foundation models for Time Series Forecasting. This is one more example of a model capable of generating zero-shot inference. We have covered Chronos and TimeGPT in detail, and Moirai’s approach and model architecture are quite different. Thus, we believe it carries scientific value and appreciate that it is open-source.

Our experiments indicate that Moirai was unable to outperform both TiDE and Chronos. In the case of TiDE, they have access to the same information, and TiDE was specifically trained on this data. However, when comparing Moirai’s performance with that of Chronos, we anticipated a more comparable or even superior performance from Moirai. This is because Moirai has the advantage of accessing external information through dynamic covariates and is a multivariate time series model capable of benefiting from cross-relationships between different series.

The AI race to develop foundational models for time series forecasting is just starting, and we will closely monitor its progress. Stay tuned.

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Garza, A., & Mergenthaler-Canseco, M. (2023). TimeGPT-1. arXiv. https://arxiv.org/abs/2310.03589

[2] Rasul, K., Ashok, A., Williams, A. R., Ghonia, H., Bhagwatkar, R., Khorasani, A., Darvishi Bayazi, M. J., Adamopoulos, G., Riachi, R., Hassen, N., Biloš, M., Garg, S., Schneider, A., Chapados, N., Drouin, A., Zantedeschi, V., Nevmyvaka, Y., & Rish, I. (2024). Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting. arXiv. https://arxiv.org/abs/2310.08278

[3] Das, A., Kong, W., Sen, R., & Zhou, Y. (2024). A decoder-only foundation model for time-series forecasting. arXiv. https://arxiv.org/abs/2310.10688

[4] Ansari, A. F., Stella, L., Turkmen, C., Zhang, X., Mercado, P., Shen, H., Shchur, O., Rangapuram, S. S., Arango, S. P., Kapoor, S., Zschiegner, J., Maddix, D. C., Mahoney, M. W., Torkkola, K., Wilson, A. G., Bohlke-Schneider, M., & Wang, Y. (2024). Chronos: Learning the Language of Time Series. arXiv. https://arxiv.org/abs/2403.07815

[5] Woo, G., Liu, C., Kumar, A., Xiong, C., Savarese, S., & Sahoo, D. (2024). Unified Training of Universal Time Series Forecasting Transformers. arXiv. https://arxiv.org/abs/2402.02592

[6] Palatucci, M., Pomerleau, D., Hinton, G. E., & Mitchell, T. M. (2009). Zero-shot Learning with Semantic Output Codes. In Y. Bengio, D. Schuurmans, J. Lafferty, C. Williams, & A. Culotta (Eds.), Advances in Neural Information Processing Systems (Vol. 22). Curran Associates, Inc. Retrieved from https://proceedings.neurips.cc/paper_files/paper/2009/file/1543843a4723ed2ab08e18053ae6dc5b-Paper.pdf

[7] Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam. A Time Series is Worth 64 Words: Long-term Forecasting with Transformers. arXiv:2211.14730, 2022.

[8] Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu. RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864, 2021.

[9] David Salinas, Valentin Flunkert, Jan Gasthaus. DeepAR: Probabilistic Forecasting with Autoregressive Recurrent Networks. arXiv:1704.04110, 2017.

The post Moirai: Time Series Foundation Models for Universal Forecasting appeared first on Towards Data Science.

]]>
Chronos: The Rise of Foundation Models for Time Series Forecasting https://towardsdatascience.com/chronos-the-rise-of-foundation-models-for-time-series-forecasting-aaeba62d9da3/ Fri, 05 Apr 2024 05:22:34 +0000 https://towardsdatascience.com/chronos-the-rise-of-foundation-models-for-time-series-forecasting-aaeba62d9da3/ Exploring Chronos: How foundational AI models are setting new standards in predictive analytics

The post Chronos: The Rise of Foundation Models for Time Series Forecasting appeared first on Towards Data Science.

]]>
Chronos: The Rise of Foundation Models for Time Series Forecasting

This post was co-authored with Rafael Guedes.

Introduction

Time series forecasting has been evolving towards foundation models due to their success in other Artificial Intelligence (AI) areas. Particularly, we have been witnessing the success of such approaches in natural language processing (NLP). The cadence of the development of foundational models has been accelerating over time. A new, more powerful Large Language Model (LLM) is released every month. This is not restricted to NLP. We see a similar growing pattern in computer vision as well. Segmentation models like Meta’s Segment Anything Model (SAM) [1] can identify and accurately segment objects in unseen images. Multimodal models such as LLaVa [2] or Qwen-VL [3] can handle text and images to answer any user question. The common characteristic between these models is that they can perform accurate zero-shot inference, meaning that they do not need to be trained in your data to have an excellent performance.

Defining what a foundational model is and what makes it different from traditional approaches is probably beneficial at this point. First, a foundational model is large-scale (namely its training), which provides a broad understanding of the main patterns and important nuances we can find in the data. Secondly, it is general-purpose, i.e., the foundational model can perform various tasks without requiring task-specific training. Even though they don’t need task-specific training, they can be fine-tuned (also known as transfer learning). They are adaptable with relatively small datasets to perform better at that specific task.

Why is applying it to time series forecasting so tempting based on the above? Foremost, we design foundational models in NLP to understand and generate text sequences. Luckily, time series data are also sequential. The previous point also aligns with the fact that both problems require the model to automatically extract and learn relevant features from the sequence of the data (temporal dynamics in time series data). Additionally, the general-purpose nature of foundational models means we can adapt them to different forecasting tasks. This flexibility allows for applying a single, powerful model across various domains and forecasting challenges. Moreover, we can then fine-tune them for that specific domain and application.

TimeGPT [4] was one of the first foundation models developed for forecasting by Nixtla. Following it, other companies entered the race. They developed new models, such as MOIRAI [5] from Salesforce, Lag-Llama [6] from Morgan Stanley, ServiceNow, and a group of Canadian Universities, or TimesFM [7] from Google. More recently, Amazon joined them and developed Chronos [8], a foundational model for time series based on language model architectures.

In this article, we provide an in-depth explanation of the architecture behind Chronos. We also cover the main components that allow the model to perform zero-shot inference. Following this theoretical overview, we apply Chronos to a specific use case and dataset. We cover the practical implementation details and thoroughly analyze the model’s performance. Finally, we compare the performance of Chronos (tiny and large versions) with TiDE for a public dataset.

We show that Chronos is able to beat TiDE [9] in zero-shot inference for the public dataset. Moreover, we need to take into consideration that it performed zero-shot inference without any type of fine-tuning. The improvement Chronos brings to foundation models is unquestionable (you can check our previous article where we did an in-depth analysis of TimeGPT ). Also, we provide evidence that the difference between the tiny and large versions of Chronos is not significant.

Figure 1: TimeGPtT vs. Chronos vs. TiDE (image by author with DALL-E)
Figure 1: TimeGPtT vs. Chronos vs. TiDE (image by author with DALL-E)

As always, the code is available on our GitHub.

Chronos: Learning the Language of Time Series by Amazon

Chronos is Amazon’s most recent foundation model for time series forecasting. It consists of a probabilistic model that uses a T5 (Text-to-Text Transfer Transformer) architecture [10] to forecast future patterns.

The T5 family of models is a series of language models developed by Google. T5 approaches every NLP task as a text-to-text problem, unlike traditional models designed for specific NLP tasks such as text classification, machine translation, or text summarization. T5 is built upon the transformer architecture based on the encoder-decoder model [11]. In the context of T5, both the encoder and the decoder are made up of transformer blocks that process the input text into a continuous representation, which the decoder then uses to generate output text token by token.

The T5 model family includes several variants with different sizes. It ranges from smaller models with fewer parameters designed to be more efficient and applicable to resource-constrained environments to larger models with more parameters capable of capturing more complex patterns and nuances in the data.

The logic that led the authors of Chronos to use an approach based on T5 was thinking about the fundamental differences between an LLM that predicts the next token in a sequence and a time series model that predicts the next value in a sequence. As we discussed in the introduction, they are very similar in nature. The main difference relies on the fact that, in the LLM case, we have a finite dictionary of words that we can predict. Conversely, we have an unbounded set of continuous values in the time series case. Nevertheless, both models use the sequential structure of the data to predict future values. Therefore, the relevant question is: can we transform and make our continuous values discrete?

Figure 2: General overview of how Chronos work: pre-processing steps (left), loss function for training (center) and inference process (right) (source)
Figure 2: General overview of how Chronos work: pre-processing steps (left), loss function for training (center) and inference process (right) (source)

Adapting an LLM for Time Series Forecasting

Pre-Processing steps

As we mentioned before, the major difference between LLMs and forecasting models is that LLMs expect to handle a finite number of tokens. Therefore, we need to transform the unbounded set of continuous values, typical of time series data, into a finite set of tokens. This tokenization process requires 3 steps:

1.Scaling will be responsible for mapping the input data into a useful range of values to be used by the Quantization step (explained later). Contrary to its usual goal of facilitating the optimization of deep learning models, scaling helps us create the input tokens. The authors used mean scaling, which normalizes the input values based on the mean of the absolute values from a pre-defined context length of historical values.

Equation 1: Mean scaling where C is the context length.
Equation 1: Mean scaling where C is the context length.

2. Quantization is responsible for converting the scaled continuous values into discrete tokens through binning. The authors used uniform binning, which groups all values within a specific range to the same bin or, in other words, to the same token. In Figure 3, for simplicity, we used 4 different bins. In fact, the model uses 4096 different bins.

Figure 3: Example of the application of Uniform Binning (image by author).
Figure 3: Example of the application of Uniform Binning (image by author).

Uniform binning has a strong limitation, and therefore, we discuss some alternatives next. It suffers from the same problem as decision trees when used for time series forecasting. It cannot predict values that are outside the range of the target variable in the training set.

Another possible approach is to use quantile binning. This creates bins with the same number of samples. Nevertheless, it has its own limitations. Since it strongly assumes the distribution of values from unseen data, it can increase the heterogeneity between the training and prediction data.

3. The authors added Special Language Tokens to represent the end of the sequence (EOS). Another interesting approach was to treat missing values using a special token (PAD). These missing values could be due to a missing observation or padding. Padding is often used to create training batches by transforming time series with different lengths to a universal fixed length.

Training Strategy

Chronos uses the same loss function as any other language model – cross-entropy. It performs regression via classification and is trained to minimize the difference between the predicted and the ground truth distribution. However, cross-entropy does not have a sense of distance, which is crucial in time series data. Two consecutive values are usually more correlated than two nonconsecutive values, and their correlation tends to decrease as they are separated further in time (also known as autocorrelation). Therefore, the model is expected to learn how to associate nearby bins together based on the distribution of bin indices in the training dataset.

The authors decided to keep the same loss function as language models because it does not require any modification to the language model architecture. This makes it easier to change the backbone architecture to use other LLMs (e.g., Mixtral [12], Llama2 [13]) and their respective utilities. It also allows the model not to be limited to assuming any specific shape for the output distribution (e.g., normal distribution). Instead, it can learn to predict future values that follow any type of distribution present in the training data.

Inference

Typically, an LLM can generate a sample of token IDs as a prediction. Therefore, to obtain a probabilistic forecast, the authors draw several samples for each step in the forecast horizon. After that, they need to reverse the quantization and scaling operations performed in the pre-processing steps to feed the language model.

Considering 4 samples drawn for t+1, firstly, they perform dequantization by mapping back the token ID to the scaled value, which is the value in the center of the interval:

Figure 4: Dequantization process (image by author)
Figure 4: Dequantization process (image by author)

Then, they unscale the values by multiplying them by the average absolute value of the context length that precedes it:

Figure 5: Unscaling process (image by author)
Figure 5: Unscaling process (image by author)

Finally, the final forecast interval is generated by taking different quantiles from the sample, e.g., Q10 for the lower bound, Q50 for the mid-value, and Q90 for the upper bound:

Figure 6: Getting mid-value forecast for t+1 (image by author)
Figure 6: Getting mid-value forecast for t+1 (image by author)

Other important remarks

The authors chose to use the T5 architecture since it is available in different sizes, ranging from 16M (Tiny) to 11B (XXL) parameters. However, they reduced the vocabulary size from 32,128 to 4,096, resulting in fewer parameters (ranging from 8M to 710M models – quite small compared to the NLP counterparts). They also tested GPT-2 to highlight that, following their approach, any language model can be used as a replacement for T5.

The model does not allow external information such as static (product brand, color, etc.) or dynamic (product price, macroeconomic data, etc.) covariates. Also, it treats each time series as a simple sequence without time or frequency information (hourly, daily, weekly, or monthly data), which might become a disadvantage when modeling seasonality. Another limitation is the fact that it is a univariate model only. Additionally, it can only forecast one time series at a time, which does not allow for modeling dependencies between time series.

As in other time series foundation models like TimeGPT, zero-shot inference is achieved by training the model in datasets from various domains and frequencies. Namely, the authors used datasets from domains such as energy, transport, healthcare, retail, web, weather, and finance.

Comparing Chronos vs. TiDE

In this section, we will use Chronos to forecast tourism visitors to Australia using a real-world dataset that is publicly available under the **** cc-by-4.0 license. Subsequently, we compare the forecasting performance of Chronos (tiny and large versions) with TiDE, using its implementation from the Python library _Dart_s.

We enhanced the dataset with economic covariates (e.g., CPI, Inflation Rate, GDP) extracted from Trading Economics, which uses economic indicators based on official sources. We also perform some preprocessing to further increase the usability of the dataset. The final structure of the dataset is the following:

  • Unique ID: A combination of encoded names for States, Zones, Regions within Australia, and the purpose of the visit (e.g., business, holiday, visiting, other).
  • Time: Represents the time dimension of the dataset, dynamically adjusted for each series.
  • Target: The target variable for forecasting, specifically focusing on visits.
  • Dynamic Covariates: Economic indicators such as CPI, Inflation Rate, and GDP that vary over time.
  • Static Covariates (Static_1 to Static_4): Extracted from the unique ID, these provide additional information for analysis, including geographic and purpose-of-visit details.

We stored the new version of the dataset here so that our experiments can be easily reproduced.

We start by importing the libraries and setting global variables. We set the date column, target column, static covariates, dynamic covariates to fill with 0, dynamic covariates to fill with linear interpolation, the frequency of our series, the forecast horizon, and the scalers to use.

%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import torch
import utils
from datetime import timedelta
from chronos import ChronosPipeline
from darts import TimeSeries
from darts.dataprocessing.pipeline import Pipeline
from darts.models import TiDEModel
from darts.dataprocessing.transformers import Scaler
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression
from darts.dataprocessing.transformers import StaticCovariatesTransformer, MissingValuesFiller

TIME_COL = "Date"
TARGET = "visits"
STATIC_COV = ["static_1", "static_2", "static_3", "static_4"]
DYNAMIC_COV = ['CPI', 'Inflation_Rate', 'GDP']
FREQ = "MS"
FORECAST_HORIZON = 8 # months
SCALER = Scaler()
TRANSFORMER = StaticCovariatesTransformer()
PIPELINE = Pipeline([SCALER, TRANSFORMER])

After that, we load our dataset and enrich it with those exogenous features as mentioned in the dataset description:

# load data and exogenous features
df = pd.read_csv('data/data.csv', parse_dates=['Date']).drop(columns=['Year', 'Month']).set_index('Date')
df = utils.preprocess_dataset(df, DYNAMIC_COV, TIME_COL, TARGET)

print(f"Distinct number of time series: {len(df['unique_id'].unique())}")
df.head()

Distinct number of time series: 304

Once the dataset is loaded, we can split the data between train and test (we decided to use the last 8 months of data for our test set). And transform the pandas’ data frame into Darts TimeSeries format.

  • Using the function _TimeSeries.from_groupdataframe, we can easily define the static covariates, our target, the column with the time reference, and the frequency of the series.
  • We also used the _fill_missingdates argument to fill the target variable with 0 in case some of our series have a gap between weeks.
# 8 months to test
train = df[df[TIME_COL] <= (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]
test = df[df[TIME_COL] > (max(df[TIME_COL])-pd.DateOffset(months=FORECAST_HORIZON))]

# read train and test datasets and transform train dataset
train_darts = TimeSeries.from_group_dataframe(
      df=train,
      group_cols=STATIC_COV,
      time_col=TIME_COL,
      value_cols=TARGET,
      freq=FREQ,  
      fill_missing_dates=True,
      fillna_value=0)

# since we have several time series not all of them have the same number of weeks in the forecast set
print(f"Weeks for training: {len(train[TIME_COL].unique())} from {min(train[TIME_COL]).date()} to {max(train[TIME_COL]).date()}")
print(f"Weeks for testing: {len(test[TIME_COL].unique())} from {min(test[TIME_COL]).date()} to {max(test[TIME_COL]).date()}")

Months for training: 220 from 1998–01–01 to 2016–04–01 Months for testing: 8 from 2016–05–01 to 2016–12–01

We have the historical data in a TimeSeries format, so now it is time to create the dynamic covariates in the same format.

# create dynamic covariates for each serie in the training darts
dynamic_covariates = []
for serie in train_darts:
    # add the month and week as a covariate
    covariate = datetime_attribute_timeseries(
        serie,
        attribute="month",
        one_hot=True,
        cyclic=False,
        add_length=FORECAST_HORIZON,
    )

    static_1 = serie.static_covariates['static_1'].item()
    static_2 = serie.static_covariates['static_2'].item()
    static_3 = serie.static_covariates['static_3'].item()
    static_4 = serie.static_covariates['static_4'].item()

    # create covariates to fill with interpolation
    dyn_cov_interp = TimeSeries.from_dataframe(df[(df['static_1'] == static_1) &amp; (df['static_2'] == static_2) &amp; (df['static_3'] == static_3) &amp; (df['static_4'] == static_4)], time_col=TIME_COL, value_cols=DYNAMIC_COV, freq=FREQ, fill_missing_dates=True)
    covariate = covariate.stack(MissingValuesFiller().transform(dyn_cov_interp))

    dynamic_covariates.append(covariate)Weeks for training: 126 from 2010–02–05 to 2012–06–29 
Weeks for testing: 17 from 2012–07–06 to 2012–10–26

After splitting the data and creating the covariates, we can forecast the 304 series using TiDE:

# scale covariates
dynamic_covariates_transformed = SCALER.fit_transform(dynamic_covariates)

# scale data and transform static covariates
data_transformed = PIPELINE.fit_transform(train_darts)

TiDE_params = {
    "input_chunk_length": 12, # number of months to lookback
    "output_chunk_length": FORECAST_HORIZON,
    "num_encoder_layers": 2,
    "num_decoder_layers": 2,
    "decoder_output_dim": 1,
    "hidden_size": 15,
    "temporal_width_past": 4,
    "temporal_width_future": 4,
    "temporal_decoder_hidden": 26,
    "dropout": 0.1,
    "batch_size": 16,
    "n_epochs": 15,
    "likelihood": QuantileRegression(quantiles=[0.25, 0.5, 0.75]),
    "random_state": 42,
    "use_static_covariates": True,
    "optimizer_kwargs": {"lr": 1e-3},
    "use_reversible_instance_norm": False,
}

model = TiDEModel(**TiDE_params)
model.fit(data_transformed, future_covariates=dynamic_covariates_transformed, verbose=False)
pred = SCALER.inverse_transform(model.predict(n=FORECAST_HORIZON, series=data_transformed, future_covariates=dynamic_covariates_transformed, num_samples=50))
tide_forecast = utils.transform_predictions_to_pandas(pred, TARGET, train_darts, [0.25, 0.5, 0.75])

Once the forecast has finished, we will use the same data to forecast with Chronos. Since Chronos is not a multivariate time series model, we need to forecast each series individually. First, we load both models, the tiny and the large version, and we loop over the 304 series available in the dataset.

# load model
pipeline_tiny = ChronosPipeline.from_pretrained(
  "amazon/chronos-t5-tiny",
  device_map="cuda",
  torch_dtype=torch.bfloat16,
)

pipeline_large = ChronosPipeline.from_pretrained(
  "amazon/chronos-t5-large",
  device_map="cuda",
  torch_dtype=torch.bfloat16,
)

# run forecast
forecast_tiny = []
forecast_large = []
for ts in train_darts:
    # tiny
    lower, mid, upper = utils.chronos_forecast(pipeline_tiny, ts.pd_dataframe().reset_index(), FORECAST_HORIZON, TARGET)
    forecast_tiny.append(utils.convert_forecast_to_pandas([lower, mid, upper], test[test['unique_id'] == list(ts.static_covariates_values())[0][0]+list(ts.static_covariates_values())[0][1]+list(ts.static_covariates_values())[0][2]+list(ts.static_covariates_values())[0][3]]))

    # large
    lower, mid, upper = utils.chronos_forecast(pipeline_large, ts.pd_dataframe().reset_index(), FORECAST_HORIZON, TARGET)
    forecast_large.append(utils.convert_forecast_to_pandas([lower, mid, upper], test[test['unique_id'] == list(ts.static_covariates_values())[0][0]+list(ts.static_covariates_values())[0][1]+list(ts.static_covariates_values())[0][2]+list(ts.static_covariates_values())[0][3]]))

# convert list to data frames
forecast_tiny = pd.concat(forecast_tiny)
forecast_large = pd.concat(forecast_large)

Once the forecast has finished, we can plot the ground truth values and the predictions. We decided to check the top 3 series in terms of volume:

# get series ordered by volume in a descending way
series = test.groupby('unique_id')[TARGET].sum().reset_index().sort_values(by=TARGET, ascending=False)['unique_id'].tolist()

for ts in series[:3]:
    utils.plot_actuals_forecast(df[df["unique_id"]==ts], forecast_tiny[forecast_tiny["unique_id"] == ts], ts)
    utils.plot_actuals_forecast(df[df["unique_id"]==ts], forecast_large[forecast_large["unique_id"] == ts], ts)
Figure 9: Chronos forecast vs. actual values (image by author)
Figure 9: Chronos forecast vs. actual values (image by author)

Figure 9 shows that there are no significant differences between the large and tiny versions of Chronos. The results are impressive for both variants for ABBHol where they manage to capture perfectly the seasonality of the data and the time of peaks and drops without any external information.

Having obtained the forecast from Chronos, we can now load the forecast generated by TiDE and compute forecasting performance metrics for comparison. For better interpretability, we have used the Mean Absolute Percentage Error (MAPE) as our comparison metric.

Figure 10: MAPE comparison between Chronos and TiDE (image by author)
Figure 10: MAPE comparison between Chronos and TiDE (image by author)

As shown in Figure 10, Chronos Large beat TiDE in 6 out of the 8 months. TiDE also had a higher MAPE for 7 out of 8 months for the top 100 time series than the Tiny version of Chronos. Although we did not perform any kind of hyperparameter tuning for TiDE, it is still impressive that Chronos, in a zero-shot inference setting, managed to have better results than a powerful model like TiDE trained specifically on this dataset.

Conclusion

This article explored Chronos, the most recent foundation model for time series forecasting. This type of model promises to perform accurately in zero-shot inference, which is especially useful for organizations lacking the specialized expertise to develop SOTA models in-house. Our analysis shows Chronos beating TiDE with significant differences in MAPE. Also, the results show no significant difference between the two Chronos versions (tiny and large).

One critical methodological note we would like to emphasize is the importance of testing new models on a private dataset before drawing firm conclusions. An essential point discussed in this article is the challenge these models face, akin to that of LLMs: they struggle to generate new data outside of their training distribution. Given that model developers often do not disclose their training datasets, there is a possibility that these models were trained on the same public datasets that we are using to test them.

When running a test with a private dataset, one interesting possibility would be to compare a fine-tuned Chronos version in our own data against a model such as TiDE.

The AI race to develop foundational models for time series forecasting is just starting, and we will closely monitor its progress and report here. Stay tuned!

About me

Serial entrepreneur and leader in the AI space. I develop AI products for businesses and invest in AI-focused startups.

Founder @ ZAAI | LinkedIn | X/Twitter

References

[1] Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T., Whitehead, S., Berg, A. C., Lo, W.-Y., Dollár, P., & Girshick, R. (2023). Segment Anything. arXiv. https://arxiv.org/abs/2304.02643

[2] Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning. arXiv. https://arxiv.org/abs/2304.08485

[3] Bai, J., Bai, S., Yang, S., Wang, S., Tan, S., Wang, P., Lin, J., Zhou, C., & Zhou, J. (2023). Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond. arXiv. https://arxiv.org/abs/2308.12966

[4] Garza, A., & Mergenthaler-Canseco, M. (2023). TimeGPT-1. arXiv. https://arxiv.org/abs/2310.03589

[5] Woo, G., Liu, C., Kumar, A., Xiong, C., Savarese, S., & Sahoo, D. (2024). Unified Training of Universal Time Series Forecasting Transformers. arXiv. https://arxiv.org/abs/2402.02592

[6] Rasul, K., Ashok, A., Williams, A. R., Ghonia, H., Bhagwatkar, R., Khorasani, A., Darvishi Bayazi, M. J., Adamopoulos, G., Riachi, R., Hassen, N., Biloš, M., Garg, S., Schneider, A., Chapados, N., Drouin, A., Zantedeschi, V., Nevmyvaka, Y., & Rish, I. (2024). Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting. arXiv. https://arxiv.org/abs/2310.08278

[7] Das, A., Kong, W., Sen, R., & Zhou, Y. (2024). A decoder-only foundation model for time-series forecasting. arXiv. https://arxiv.org/abs/2310.10688

[8] Ansari, A. F., Stella, L., Turkmen, C., Zhang, X., Mercado, P., Shen, H., Shchur, O., Rangapuram, S. S., Arango, S. P., Kapoor, S., Zschiegner, J., Maddix, D. C., Mahoney, M. W., Torkkola, K., Wilson, A. G., Bohlke-Schneider, M., & Wang, Y. (2024). Chronos: Learning the Language of Time Series. arXiv. https://arxiv.org/abs/2403.07815

[9] Das, A., Kong, W., Leach, A., Mathur, S., Sen, R., & Yu, R. (2023). Long-term Forecasting with TiDE: Time-series Dense Encoder. arXiv. https://arxiv.org/abs/2304.08424

[10] Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou, Y., Li, W., & Liu, P. J. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. arXiv. https://arxiv.org/abs/1910.10683

[11] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2023). Attention Is All You Need. arXiv. https://arxiv.org/abs/1706.03762

[12] Touvron, H., Martin, L., Stone, K., Albert, P., Almahairi, A., Babaei, Y., Bashlykov, N., Batra, S., Bhargava, P., Bhosale, S., Bikel, D., Blecher, L., Ferrer, C. C., Chen, M., Cucurull, G., Esiobu, D., Fernandes, J., Fu, J., Fu, W., Fuller, B., Gao, C., Goswami, V., Goyal, N., Hartshorn, A., Hosseini, S., Hou, R., Inan, H., Kardas, M., Kerkez, V., Khabsa, M., Kloumann, I., Korenev, A., Koura, P. S., Lachaux, M.-A., Lavril, T., Lee, J., Liskovich, D., Lu, Y., Mao, Y., Martinet, X., Mihaylov, T., Mishra, P., Molybog, I., Nie, Y., Poulton, A., Reizenstein, J., Rungta, R., Saladi, K., Schelten, A., Silva, R., Smith, E. M., Subramanian, R., Tan, X. E., Tang, B., Taylor, R., Williams, A., Kuan, J. X., Xu, P., Yan, Z., Zarov, I., Zhang, Y., Fan, A., Kambadur, M., Narang, S., Rodriguez, A., Stojnic, R., Edunov, S., & Scialom, T. (2023). Llama 2: Open Foundation and Fine-Tuned Chat Models. arXiv. https://arxiv.org/abs/2307.09288

[13] Jiang, A. Q., Sablayrolles, A., Roux, A., Mensch, A., Savary, B., Bamford, C., Chaplot, D. S., de las Casas, D., Bou Hanna, E., Bressand, F., Lengyel, G., Bour, G., Lample, G., Lavaud, L. R., Saulnier, L., Lachaux, M.-A., Stock, P., Subramanian, S., Yang, S., Antoniak, S., Le Scao, T., Gervet, T., Lavril, T., Wang, T., Lacroix, T., & El Sayed, W. (2024). Mixtral of Experts. arXiv. https://arxiv.org/abs/2401.04088

The post Chronos: The Rise of Foundation Models for Time Series Forecasting appeared first on Towards Data Science.

]]>