The Clinical Trials Assistant is a multi-agent generative AI application whose goal is to help users find completed, published results of clinical trials for treatment or prevention of a disease that the user is interested in. Large Language Models (LLMs) are used in validating user input, summarizing complex text for people without a medical background, and evaluating the summarizations.

This is my capstone project for the Google Kaggle 5-day Generative AI Course. The application demonstrates a few uses of LLMs in a pipeline to produce easy to understand summarizations of the trial results. My Kaggle application is here.

This blog (1) introduces LLMs, (2) highlights some of the application's lifecycle, and (3) presents highlights from the application. References used in the blog are at the bottom of the article.

skip to lifecycle
skip to the application

(1) LLMs

LLMs are able to accomplish complex goals through abilities built upon the ability to predict the next word in a sequence of words. Predicting the next word is a difficult task. The space of possible events for all combinations of words is far larger than the amount of data that could be given to a model, so the problem of learning the probabilities of word combinations is severely under-determined.

Note that in these algorithms "word" might be partial words or even byte-pairs.

To predict the combination of words, we define the joint probability and marginal of the words. Let x=[ x1 , x2 , xn ] , then, by the chain rule of factorization, we have p(x)= p( x1) · p( x2 | x1 ) · p( xn | xn-1, xn-2, x1 ) . Predicting the last word given the words preceding it, that is, p( xn | xn-1, xn-2, x1 ) has a complexity of ( (number of possible values x can be) n-1 -1 ) ) and so that calculation is not tractable. To make the problem tractable while still preserving the autoregressive property of using the previous steps to predict the current step, many different algorithms have been made. The transformer and uses of transformers in LLMs are the latest of these algorithms.

[view/hide] Details of Transformers and their components.

Milestones in LLM components

YearConstructs
2015Attention
2017Transformer Architecture
2018Contextual Word Embeddings and Pretraining
2019Prompting

LLMs learn the probabilities of sequences of words by training on a very large corpus of words to predict the next word. They're built from transformers which are generative models that include autoregressive properties with non-linear weights. The transformer was created by Google in 2017. Their network architecture was less complex than the state-of-the-art (SOTA) RNN and CNN models, was parallelizable, required less training time, and demonstrated higher quality results on language translations.


Briefly, definitions of a few constructs:

Vector embeddings

are ways to represent data in a compact, binned (grouped) manner. For instance, if a language had a million words, a vector would be 1 million elements in length if each word was an element. The vector would have 1 million indices. Words can be grouped by meanings or some unknown aspects to make far fewer than a million indices. Efficient embeddings may have latent groupings that are not easy to interpret, and they may share the same embedded vector space with objects of different types (modalities, e.g. image and text and audio embeddings sharing the same vector embedding space).

Distances

as differences or similarities between embeddings are part of what is needed to calculate model objectives efficiently. Objectives such as minimizing cross-entropy losses, use of K-L divergence, Jenson-Shannon divergence, constrastive learning, etc. can be used in training Neural Network (NN) models. (TODO: follow-up on training multimodal embeddings.) There are

Attention from Embedding Distances

The transformer architecture was influenced by noticing that "end-to-end" memory networks use a recurrent attention mechanism to perform well on simple question and answer tasks and language modeling tasks. Details of the "end-to-end" memory network help to understand the Transformer architecture.

The "end-to-end" memory network is a query answer model.
{x_i} is the input set of data to be stored in memory.
The set {x_i} is converted to embedded memory vectors {m_i} by embedding matrix A, and they have dimension d.
A query q is converted to embedding u by an embedding matrix B.
The probability of matching u to any m_i is the softmax of their inner products:
p i = softmax ( u T m i ) where softmax ( zi ) = e zi / j ( e zj )
x_i is converted to the output embedding vector c_i by embedding matrix C.
Then the response vector from memory o is o = i ( pi · ci ) The function is smooth, enabling gradients and back-propagation. So far, this is not so different from a "Two-Tower" DNN model construction before defining the number of layers.


The single layer "end-to-end" has final prediction: a_est = softmax(W(o+u)) where W is a final weight matrix. All 3 embedding matrices A, B, and C as well as W are jointly learned during training by minimizing a standard cross-entropy loss between a_est and the true label a.


The multiple layer "end-to-end" with K layers gives the input query embedding to the first layer, and thereafter, the output of each layer is given as input to the next layer. u of (layer k+1) = u of (layer k) + o of (layer k)
each layer k has its own embedding matrices A of layer k, C of layer k though they are constrained to keep number of parameters small. The final prediction is: a_est = Softmax(W * u of (layer k+1)) = Softmax(W * (o of (layer k) + u of (layer k)). The end-to-end memory network authors found that some layers concentrate only on nearby words while other layers have broad attention over all memory locations.

The Transformer continues with far more positional attention in its architecture.

The transformer architecture

There is an encoder and a decoder.

The encoder maps the input representation sequence { x1 , x2 , xn } to a sequence of { z1 , z2 , zn } which are continuous representations. The z are inputs to the decoder which generates symbols { y1 , y2 , ym } one at a time.

Encoder:

The encoder is composed of a stack of N = 6 identical layers. Each layer has two sub-layers. The first sub-layer is a multi-head self-attention mechanism The second sub-layer is a simple, position-wise fully connected feed-forward network. The output of each sub-layer is LayerNorm(x + Sublayer(x)) where Sublayer(x) is the function implemented by that specific sub-layer and LayerNorm is a normalization function. The output dimension is the same size for all embedding layers and sub-layers.

Decoder:

The decoder is composed of a stack of N = 6 identical layers. The same 2 sub-layers of the encoder. A third sub-layer performs multi-head attention over the output of the encoder stack. The output of each sub-layer is LayerNorm(x + Sublayer(x)) where Sublayer(x). Masking: the self-attention sub-layer is modified to prevent use of subsequent positions (attention is paid to past positions only). Like the end-to-end memory model, the output embeddings are all offset by 1 position from one another, so combined with masking, produces predictions for i that depend upon known inputs from positions less than i.

Attention:

query, keys, and values are embedding vectors. The attention function maps a query and set of key-value pairs to an output. The output is a weighted sum of the values. The weight for each value is calculated by compatability of query with key.

Scaled Dot-Product Attention:

The query and key embeddings have dimension dk The value embeddings have dimension dv For a given q attention = i ( softmax(q· ki ÷ sqrt ( dk ) ) · v ki ) In matrix form attention = softmax(Q· transpose(K) ÷sqrt( dk ) ) · V

Multi-Head Attention:

replaces the single attention function in the 3rd sub-layer by this algorithm. For h number of times, perform the following: a linear projection of the queries, keys and values to learned projections having dimensions dk , dk , and dv , respectively. One can perform the attention function for each q, k, v in parallel. The outputs are each of dimension dv . These h output values are concatenated and then projected. MultiHead(Q, K, V ) = Concat( head1, headh ) · WQ
where headi = Attention( Q· Wi Q , K· Wi K , V· Wi V ) The projections are parameter matrices
Wi Q dmodel × dk
Wi K dmodel × dk
Wi V dmodel × dk
and
Wi O dmodel × dk
In the paper they use h=8 parallel attention layers (a.k.a. heads), and dk = dv = dmodel÷h =64. Because of the reduced dimension of each parallel attention layer, the total computational cost is similar to single-layer attention with full dimensionality. The original transformer's use of Multi-head Attention:

(A) in "encoder-decoder" attention layers:

☇ Every position in the decoder attends over all positions in input sequence. (mimics sequence-to-sequence models)
query (q): previous decoder layer output
key (k): encoder output
value (v): encoder output

(B) in encoder self-attention layers:

self-attention means query, keys and values are all from same place which is the output of the previous layer in the encoder..
☇ Each position in the encoder can attend to all positions in the previous layer of the encoder.
query (q): previous encoder layer output
key (k): same as q and v, they are output of encoder's previous layer
value (v): same as q and k

(C) in decoder self-attention layers:

self-attention means query, keys and values are all from same place which is the output of the previous layer in the decoder.. to preserve auto-regressive property, leftward information flow is prevented by masking out (set to − ∞) all values in the input of the softmax that are illegal connections.
☇ Each position in decoder attends to all positions in the decoder's previous layer.

Each layer in the network encoder and decoder also has a fully connected feed-forward network applied to each position separately and identically to provide non-linearity. FFN(x)=max( 0, x·W1+b1) ·W2+b2
Positional encodings of dimension dmodel , are made and put at bottom of encoder and decoder stacks. each position pos and dimension i is represented in the positional encoding 2*i+1 a sine function and the adjacent dimension (2*i+1) by a cosine function.

Characteristics that Vaswanis et al improved with their Transformer architecture:

  1. reduced the total computational complexity per model layer
  2. increased the amount of computation that can be parallelized, as measured by the minimum number of sequential operations required
  3. increased the context length, that is, the path length between long-range dependencies in the network, measured by the maximum path lengths between any 2 input and output positions.

(2) LifeCycle

The first steps are defining the goal of the project, finding a stable, robust source of data for it, and building the smallest working version of pieces to explore the feasibility of it. With a confirmed feasible goal, one can then further define the application. B.T.W. the contest requirement was to use generative AI as a uniquely valuable tool in an application which would otherwise be less capable of accomplishing its real world goals. For this I created a clinical trials assistant from multiple gen-AI agents. First results from hand picked articles which had complex medical terms showed that text summarization by the Gemini-2-*, Gemini-1-* and Gamma-3-* models gave very good results that did not require a medical background to understand. The prompts I used were simple and instructive zero-shot prompts. I used deterministice generative settings (temperature=0) with all agents. Neither model fine-tuning, nor parameter efficient tuning were deemed to be necessary for this prototype.

Standard software requirements, analysis, design and implementation were followed with additional elements added for generative AI architecture.

[view/hide] Details of the gen AI lifecycle here (also in docs/lifecycle in the github repository)

  1. Requirements (functional and non-functional):
    • Capstone capabilities: At least 3 of the get AI capstone capabilities must be included.
    • Agents: text capable LLMs. All need to be able to reason, one needs to be able to perform document summarization, one (preferably a different model than the summarizer, needs to be able to evaluation the summarization, and one needs to be able to recognize valid disease names.
    • Data: robust, stable API sources
    • Deployment/Serving/Client:
      • Cloud hosted LLMs supplied by Google.
      • Stubs for cloud based logging.
      • The code must run from start to finish within a Kaggle notebook. The assistant must be an interactive application, and so it needs to optionally run automatically too.
      • Consideration for ease of porting to mobile environments is kept in mind.
    • Prompts
      • Zero-shot instructions for all tasks
    • Logging, Monitoring:
      • latencies for all API and LLM invocations
      • size of data: number of input and output tokens
      • drift of data: watching for APIs returning data different than expected
      • errors
      • user feedback
    • Protect User from harmful content
    • Ensure regulations for privacy and other guidelines are followed
    • Latencies: Each response to the user should be less than 3 seconds ideally.
    • QPS:
      • Rate limiting for API requests. The requests are directly from the client Kaggle notebooks to the APIs. Rate limiting on client-side should be made. Model choices that handle scale well for the budget should be made.
  2. Analysis and Design
    • Data I/O assessment
      • data to and from the NIH APIs is small
      • data to and from the Google LLM APIs is small
    • Choice of LLMs w/ preference for smaller models, no need for image or audio in this prototype
      • Gemini-1.5-flash and smaller models: SOTA, handles long context well, ability to follow complex instructions and complete complex tasks, document understanding, can take system instructions specifically, can output results in a structured format, can scale well, use is free up to rate and data limits, then increases.
      • Gamma context window constraint of 128 k, ability to follow complex instructions and complete complex tasks. document understanding. it's an open source model that performs very well, though has fewer abilities than the Gemini models. The costs for rates and data sizes are free, but the model might have scale limits.
    • Function Calling: client side methods to build and test
      • query to user for disease name w/ option to exit
        • Google LLM agent to validate the disease name
      • retrieve clinical trials
        • API request to clinicaltrials.gov, parsing, logging
      • query to user for trial selection w/ option to exit
      • query to user for citation selection w/ option to exit
        • API request to NIHML's PubMed, parsing, logging
      • article results summarization
        • Google LLM agent to summarize text
        • Google LLM agent to evaluate summarization
        • parsing, logging
    • Gen AI orchestration layer: langgraph
    • Integrity of data: best practices are followed in using APIs. The APIs themselves and Kaggle follow secure practices.
    • Integrity of logs: Kaggle environment is session based. cloud logging stubs are made but not implemented so no concerns there.
    • Protection of User from Harmful content
      • User Feedback is requested and logged. Additionally, the Kaggle notebooks have a messaging environment where users can ask questions or leave comments.
    • Regulations: GDPR CCPA, and other guidelines are implicitly followed because no PII is requested nor stored.
    • Prompts
      • Versions: prompts are stored in language and version directories to allow mixing of components while experimenting with improvements for the application.
    • Logging, Monitoring
      • implemented in client locally w/ stubs for remote aggregation in the cloud
    • Source version control in github
    • Development tools were the Kaggle notebook and the JetBrains Pycharm IDE
  3. Implementation (see The Gen AI Application below)

The Gen AI Application

The application is hosted in a Kaggle notebook here. Langraph was used for orchestration of the function calls (a.k.a. tools) by client-side invocations (client-side rather than LLM invocations to reduce token use). A sequential planner pattern with conditional cycles was used for the workflow. Nodes were created for each function, and conditional edges from each node to the next in the sequence or to exit the application by user request.
  1. node: user_input_disease
  2. node: fetch_trials
  3. node: user_choose_trial_number
  4. node: user_choose_citation_number
  5. node: fetch_abstract
  6. node: llm_summarization
  7. node: feedback_query
When the application starts, the user is presented with a welcome message and then a query for the disease name.

[to user]
"This librarian searches clinical trials for completed, published results and summarizes the results for you." "Please enter a disease to search for (q to quit at any time):"


The disease name checker is an instance of ChatGoogleGenerativeAI with a chosen LLM and a deterministic temperature of 0. The disease name checker is given this prompt with the disease name and returns a response which includes meta data such as the number of input and output tokens:

[to LLM agent]
"You are a librarian at the National Institutes of Health. Do you recognize the words {disease_name} as a valid disease name? Answer yes or no."


The trials are retrieved from the US National Library of Medicine's Clinical Trials database and presented to the user.
The user selects a trial and published trial result citations are then presented to them. The user selects a result from the list and the summary of that result is retrieved from the US National Library of Medicine's PubMed database. A prompt with text summarization instructions is given to the document summary agent which has been configured for a deterministic response.

[to LLM agent]
"Summarize this text in simple terms in a serious tone."


The agent response is presented to the user. The agent response is asynchronously evaluated by another agent which is given a somewhat lengthy prompt of instructions. and the response evaluation is logged. The user is asked if they would like to submit feedback and if so, they are presented with a couple of questions with itemized choices. Their responses are logged. Lastly, the user is asked if they would like to make another query.

And that is the prototype. Thanks to Google for sponsoring this 5-day intensive course on Generative AI and providing great examples and resources for us to use!

References