If you are seeing this, then Javascript is disabled. Please enable it to properly view this post.

Header visualization of semantic alignment across language models

LEAD: Linear Embedding Alignment across Deep Neural Network Language Models' Representations

Published December 10th, 2024
* Equal contribution

Abstract

Recent advances in Large Language Models (LLMs) have demonstrated their remarkable ability to capture semantic information. We investigate whether different language embedding models learn similar semantic representations despite variations in architecture, training data, and initialization. While previous work explored model similarity through top-k results and Centered Kernel Alignment (CKA), yielding mixed results, in the field of large language embedding models, which we focus on, there is a gap: more modern similarity quantifiation methods from Computer Vision, such as model stitching, which operationalizes the notion of "similarity" in a way that emphasizes downstream utility, are not explored. We apply stitching by training linear and nonlinear (MLP) mappings, called "stitches" between embedding spaces, which aim to biject between embeddings of the same datapoints. We define two spaces as connectivity-aligned if stitches achieve low mean squared error, indicating approximate bijectivity.

Our analysis spans 6 embedding datasets (5,000-20,000 documents), 18 models (between 20-30 layers, including both open-source and OpenAI models), and stitches ranging from linear stitches to MLPs 7 layers deep, with a focus on linear stitches. We hoped that stitching would recover the similarity between models, aligning with a strong interpretation of the platonic representation hypothesis. However, things appear to be more complicated. Our results suggest that embedding models are not linearly connectivity-aligned. Specifically, linear stitches do not perform significantly better than mean estimators. A brief foray into MLPs suggests that training shallow MLPs does not necessarily work out of the box either, but more work remains to be done on non-linear stitches, since we haven't fully maximized their potential here. Stitches are important, because their success can be used to determine operational, and therefore useful, notions of representational similarity. Our findings buttress the hypothesis that alignment metrics such as CKA are not always informative of behavior or feature overlap between models.

00 Introduction & Terminology

Modern "embedding" language models convert text into dense vectors called embeddings, which capture semantic meaning in a high-dimensional space, enabling downstream usage such as for semantic search and visualization. We investigate whether different embedding models learn similar semantic representations despite variations in architecture, training data, and initialization. If these representations are indeed (approximately) universal, we could leverage this to efficiently translate between different models' embeddings.

The paper Beyond Benchmarks: Evaluating Embedding Model Similarity for Retrieval Augmented Generation Systems explored embedding model representational similarity through techniques like top-k search retrieval results and Centered Kernel Alignment (CKA). We take their methodology and expand upon it here by systematically studying whether there exists a method to cheaply determine and translate between embeddings of models with similar characteristics. Despite our initial hopes, our approach gives further evidence that representations vary in more-than-trivial ways.

            graph TD
              %% Define the main text input
              Text(["Input Text t"])
          
              %% Define embedding spaces as subgraphs
              subgraph SpaceA["Embedding Space A"]
                  NativeA["Native Embedding A(t)"]
              end
          
              subgraph SpaceB["Embedding Space B"]
                  NativeB["Native Embedding B(t)"]
                  StitchedAB["Stitched Embedding S(A(t))"]
              end
          
              %% Define the models and transformations
              ModelA["Model A"]
              ModelB["Model B"]
              Stitch{{"Stitch S"}}
          
              %% Define the relationships
              Text --> ModelA
              Text --> ModelB
              ModelA --> NativeA
              ModelB --> NativeB
              NativeA --> Stitch
              Stitch --> StitchedAB
          
              %% Add evaluation
              MSE[["MSE(S(A(t)), B(t)) < ε"]]
              StitchedAB -.-> MSE
              NativeB -.-> MSE
          
              %% Style definitions
              classDef space fill:#f0f4f8,stroke:#a3bffa,rx:10
              classDef model fill:#ebf4ff,stroke:#63b3ed,rx:5
              classDef stitch fill:#faf5ff,stroke:#b794f4,rx:5
              classDef evaluation fill:#f0fff4,stroke:#68d391,rx:5
          
              %% Apply styles
              class SpaceA,SpaceB space
              class ModelA,ModelB model
              class Stitch stitch
              class MSE evaluation
          
Relationship between native embeddings, stitched embeddings, and embedding spaces.

We define key terminology and notation used throughout this paper:

Our key hypothesis was that embedding models of similar scale (parameters, training data size, etc.) are not only connectivity-aligned, but linearly so. We were also curious as to whether a rotation would be sufficient. The results suggest that they are in fact not linearly connectivity-aligned, and that shallow MLPs require some finesse to achieve connectivity-alignment. This finding has both theoretical implications for understanding how language models represent meaning and practical applications for efficient embedding translation.

01 Motivation

Imagine you've just spent weeks processing millions of documents through a language model to create semantic search capabilities. Then, a more powerful model is released – but recomputing all those embeddings would cost thousands of dollars and days of processing time. What if there was a better way?

Exploring a dataset with MantisAI
Exploring semantic relationships in data with MantisAI's visualization tools

The Origin Story

This research emerged from real-world challenges at MantisAI, where we help organizations understand and visualize large document collections. Our customers frequently needed to switch between different embedding models – sometimes prioritizing accuracy, other times speed or cost. But each switch required reprocessing entire datasets, creating significant computational overhead and compatibility challenges between workspaces.

The Big Picture

We spend the majority of this blogpost diving more into the theory-relevant results, rather than the cost-savings possible for semantic search systems. However, it merits remembering that there are multiple angles through which this work is useful. In the future we also imagine that deeper theoretical knowledge could drive improvements in ML algorithms, to be more interpretable, robust, efficient, or performant.

02 Relevant Work

There is ample evidence supporting the idea that neural network representations may be aligned to some degree. Some notable observations include:

Exploring a dataset with MantisAI
Diagram illustrating the platonic representaton hypothesis

At the same time Representational similarity tools are becoming mature enough to explore such questions more deeply and empirically and are percolating across both the machine learning and neuroscience communities . For example, in ICLR there was a workshop in the year of writing this (2024). Techniques such as CKA , stitching , CCA, Orthogonal Procrustes, and others, along with representational similarity analysis in neuroscience , have matured, enabling us to systematically investigate such questions.

03 Methodology

Overview

Our investigation into embedding model similarity followed a systematic experimental approach spanning multiple model architectures, embedding spaces, and datasets. The methodology consists of four main components: model selection, dataset selection, stitch architecture design, evaluation framework, and embedding parameter configuration.

  1. Embedding Model Selection

    We evaluated translation quality across several embedding models available on huggingface and via the OpenAI API. These models were selected to match the most similar prior work from Beyond Benchmarks for comparability and reproduceability. We also excluded the Cohere models since we did not have API access (which is not free except at prohibitively low request rates) and the Mistral model, since it was too big for our GPU.
    View Model Names
    
              MODEL_NAMES = [
              # MODEL_NAME                    EMBEDDING DIMENSION
              "WhereIsAI/UAE-Large-V1",                    # 1024
              "BAAI/bge-base-en-v1.5",                     # 768
              "BAAI/bge-large-en-v1.5",                    # 1024
              "BAAI/bge-small-en-v1.5",                    # 384
              "intfloat/e5-base-v2",                       # 768
              "intfloat/e5-large-v2",                      # 1024
              "intfloat/e5-small-v2",                      # 384
              "thenlper/gte-base",                         # 768
              "thenlper/gte-large",                        # 1024
              "thenlper/gte-small",                        # 384
              "sentence-transformers/gtr-t5-base",         # 768
              "sentence-transformers/gtr-t5-large",        # 768
              "mixedbread-ai/mxbai-embed-large-v1",        # 1024
              "sentence-transformers/sentence-t5-base",     # 768
              "sentence-transformers/sentence-t5-large",    # 768
              "openai/text-embedding-3-large",             # 3072
              "openai/text-embedding-3-small",             # 1536
              ]
            
  2. Detailed Model Comparison Table
    Model Family Variant Architecture Dimension Parameters Training Data
    BAAI BGE large-v1.5 DeBERTa-V3 1024 335M 330M+ text pairs
    base-v1.5 768 110M
    small-v1.5 384 33M
    E5 large-v2 DeBERTa-V3 1024 335M CCNet + web data
    base-v2 768 110M
    small-v2 384 33M
    GTE large DeBERTa-V3 1024 335M MS MARCO + public datasets
    base 768 110M
    small 384 33M
    T5-based gtr-t5-large T5 encoder 768 770M C4 + MS MARCO
    gtr-t5-base 768 110M C4 + MS MARCO
    sentence-t5-large 768 770M C4 + NLI datasets
    sentence-t5-base 768 220M C4 + NLI datasets
    UAE large-v1 RoBERTa 1024 355M Adversarial training
    MXBai embed-large-v1 DeBERTa-V3 1024 335M 700M+ pairs contrastive training, 30M+ fine tuning
    OpenAI text-embedding-3-large Proprietary 3072 - Not public
    text-embedding-3-small 1536 -
    Model Family Variant Architecture Dimension Parameters Training Data
    OpenAI text-embedding-3-large Proprietary 3072 - Not public
    OpenAI text-embedding-3-small Proprietary 1536 - Not public
    UAE large-v1 RoBERTa 1024 355M Adversarial training
    BAAI BGE large-v1.5 DeBERTa-V3 1024 335M 330M+ text pairs
    E5 large-v2 DeBERTa-V3 1024 335M CCNet + web data
    GTE large DeBERTa-V3 1024 335M MS MARCO + public datasets
    MXBai embed-large-v1 DeBERTa-V3 1024 335M 700M+ pairs contrastive training, 30M+ fine tuning
    T5-based gtr-t5-large T5 encoder 1024 770M C4 + MS MARCO
    gtr-t5-base 768 110M C4 + MS MARCO
    sentence-t5-large 1024 770M C4 + NLI datasets
    sentence-t5-base 768 220M C4 + NLI datasets
    BAAI BGE base-v1.5 DeBERTa-V3 768 110M 330M+ text pairs
    E5 base-v2 DeBERTa-V3 768 110M CCNet + web data
    GTE base DeBERTa-V3 768 110M MS MARCO + public datasets
    BAAI BGE small-v1.5 DeBERTa-V3 384 33M 330M+ text pairs
    E5 small-v2 DeBERTa-V3 384 33M CCNet + web data
    GTE small DeBERTa-V3 384 33M MS MARCO + public datasets
    Model Family Variant Architecture Dimension Parameters Training Data
    OpenAI text-embedding-3-large Proprietary 3072 ? Not public
    OpenAI text-embedding-3-small Proprietary 1536 ? Not public
    T5-based gtr-t5-large T5 encoder 768 770M C4 + MS MARCO
    T5-based sentence-t5-large T5 encoder 768 770M C4 + NLI datasets
    UAE large-v1 RoBERTa 1024 355M Adversarial training
    BAAI BGE large-v1.5 DeBERTa-V3 1024 335M 330M+ text pairs
    E5 large-v2 DeBERTa-V3 1024 335M CCNet + web data
    GTE large DeBERTa-V3 1024 335M MS MARCO + public datasets
    MXBai embed-large-v1 DeBERTa-V3 1024 335M 700M+ pairs contrastive training, 30M+ fine tuning
    T5-based gtr-t5-base T5 encoder 768 110M C4 + MS MARCO
    T5-based sentence-t5-base T5 encoder 768 220M C4 + NLI datasets
    BAAI BGE base-v1.5 DeBERTa-V3 768 110M 330M+ text pairs
    E5 base-v2 DeBERTa-V3 768 110M CCNet + web data
    GTE base DeBERTa-V3 768 110M MS MARCO + public datasets
    BAAI BGE small-v1.5 DeBERTa-V3 384 33M 330M+ text pairs
    E5 small-v2 DeBERTa-V3 384 33M CCNet + web data
    GTE small DeBERTa-V3 384 33M MS MARCO + public datasets
  3. Dataset Selection

    Each dataset we tested was one of the MTEB embedding language model datasets chosen for their relevancy to many embedding tasks. Each dataset spans thousands of short documents. Unlike the Beyond Baselines paper, we randomly shortened each dataset to a maximum of 20K documents train and 5K documents validation, since we were computationally constrained for this project.
    View Dataset Names
    
              DATASETS = [
                "arguana",    # Around 10K Short documents
                "fiqa",       # Around 50K, shortened to 20K
                "scidocs",    # Around 25K, shortened to 20K
                "nfcorpus",   # Around 5K
                "hotpotqa",   # Over 100K, shortened to 20K
                "trec-covid", # At least 20K, shortened to 20K
              ]
            

    You can read more about the datasets below:

    • ArguAna
      A dataset of ~10K argument pairs from online discussions, where each pair contains one argument in favor and one against a given topic. Useful for evaluating semantic understanding of argumentative text and stance detection.
    • FiQA
      A financial domain dataset containing ~50K text snippets from financial news and social media platforms. Includes questions, answers, and sentiment analysis annotations specific to financial topics.
    • SciDocs
      A collection of ~25K scientific documents and paper abstracts from various fields. Designed to evaluate scientific document understanding and recommendation systems.
    • NFCorpus
      A dataset of ~5K medical documents from NLM (National Library of Medicine), including clinical trials, systematic reviews, and medical news articles. Used for evaluating biomedical information retrieval.
    • HotpotQA
      A large question-answering dataset with over 100K question-answer pairs. Questions require finding and reasoning over multiple supporting documents to arrive at the answer, testing multi-hop inference capabilities.
    • TREC-COVID
      A dataset of ~20K scientific articles related to COVID-19 from the CORD-19 collection. Includes queries from biomedical researchers about COVID-19 and relevant document judgments, designed to evaluate biomedical literature search systems.
  4. Stitch Architecture

    We implemented stitch functions using Ordinary Least Squares (OLS) to find a best fitting affine function transforming the source embeddings to the target embeddings. We also trained MLPs ranging in depth from zero nonlinearities (linear) to 6 non-linearities (7 layers). Each MLP had the same width throughout: the larger of the input and output width.

  5. Evaluation Metrics

    We considered a range of different metrics to analyze the relationship between embedding spaces, settling on a mix from our reference paper, statistics on the stiching models bridging two embedding spaces, and some visualizaitons to better understand exactly what these stitches are doing.

    • Our embedding similarity metrics aimed corroborate the work in our reference paper: Beyond Benchmarks: Evaluating Embedding Model Similarity for Retrieval Augmented Generation Systems . We look primarily into CKA (Centered Kernel Alignment) , a metric to measure the similarity between feature representations between neural networks. We compare these metrics with the past results to make sure that our work is consistent.
    • Our linear connectivity-alignment metrics are the train and validation accuracy of our stitching models, which we measure by Mean Squared Error (MSE). To make this more interpretable, we also provide: mean absolute error (MAE) on the validation set, R squared between the stitched embeddings and the target embeddings, explained absolute variation between the stitched embeddings and the target embeddings (which is analogous to R squared but using MAE instead of MSE). We also plot the singular values distributions of the stitched embeddings and compare their affine shifts to the means of the embedding datasets. This helps us qualitatively understand what sort of matrix the linear stitches were.
  6. Embedding Parameters

    We used the parameters illustrated below for our embedding models. The datasets are comprised of documents and sample queries (for reproduceable top-K search results analysis). We embed each seperately, appending a prefix as visible below, like in the prior work. Unlike the prior work we use an OpenAI text splitter with a fixed model. This allows us to ensure that we compare the exact same text's emebeddings across models --- something we suspect is a subtle bug that we reported to the authors of Beyond Benchmarks though we do not find qualitatively different results before linearly connecting the embedding spaces. As previously mentioned, we reduce and split both query and document datasets to up to 20K and 5K documents respectively. Then, we chunk each document into at least one chunk and embed each chunk with each model. The default SentenceTransformers encode function is used.

    
            VECTOR_SEARCH_SENTENCE_DEFAULT_CHUNK_SIZE=256
            VECTOR_SEARCH_DISTANCE_FUNCTION="cosine"
            VECTOR_SEARCH_NORMALIZE_EMBEDDINGS="true"
            VECTOR_SEARCH_CHUNK_PREFIX="passage: "
            VECTOR_SEARCH_QUERY_PREFIX="query: "
            VECTOR_SEARCH_TEXT_SPLITTER_CHUNK_OVERLAP=25
            BATCH_SIZE=64
            CHUNK_SIZE=256
        

04 Results and Analysis

We conducted extensive experiments across different model scales and architectures. Here are our key findings:

Measuring Representational Similarity

Firstly, we reproduce some of the results from the original paper. In line with the metrics presented in Beyond Benchmarks, we examined the pairwise similarity between our various embedding models. Above you can see our CKA (Centered Kernel Alignment) matrix. This is a common metric to measure the similarity between representations. To do this, you start with a text dataset which are fed through the embedding models \(A\) and \(B\) to produce sample embeddings \(Z_A\) and \(Z_B\). These embeddings are then compared in a specific way. First they are mean-centered, then their kernel tables are respetively computed. Next, these kernels are flattened and their correlation (normalized inner product) is computed. Scores closer to \(1\) are more similar and scores closer to \(0\) are less similar. In our matrix we also introduced a control "embedding" model which was a (ArguAna Length \(\times 768\) embedding dimension) standard gaussian matrix not included in the reference paper.

CKA Matrix measuring representational similarity between embedding models with text sampled from ArguAna. Note that the larger OpenAI models were excluded from this metric due to time constraints.
Original Paper CKA Plots
Plot from the original paper of the CKA similarity matrix. As is visible, our results are not too dissimilar, showcasing the same cluster with the BGE and GTE models.

As you can hope to see, the random embedding model is extremely dissimilar from other embeddings, getting a similarity score of only \(0.10\) to \(0.16\) while the strongest similarities are all the way up to \(1.00\) (Approximately-complete similarity) between mxbai-embed-large-v1 and UAE-Large-V1, models from different companies, both with embedding dimension \(1024\). Just looking at these preliminary similarity scores, we hypothesize that these two models have stitch-connectivity.

The lowest similarity we observe is \(0.72\) between bge-small-en-v1.5 and e5-large-v2 of dimensions \(384\) and \(1024\) respectively. Still quite a high similarity in representations. While we hypothesized that most models would be linear connectivity-aligned, we also suspected that it would be easier to linearly map from larger models to smaller models. As will be visible later, this appears to be qualitatively true, which is unsurprising, but we do not provide a statistical analysis.

Measuring Stitch-Connectivity

We present the results of our stitching experiments in the table below. These are the raw mean squared errors (MSE) for each stitch as well as the MAE, R squared, and absolute variation. Each entry in the table corresponds To a pair of mdoels. Answers are provided in logarithmic (base 10) scale, since MSE is low.

Note: The axis labels are not in the same order as in the CKA matrix.

At a Distance

The plots above show the log mean errors in stitching from embeddings of the source model (x-axis) to embeddings of the target model (y-axis) using our OLS-derived affine function.

Overall, the affine stitches performed fairly well, ranging from a log MSE of \(-5.63\) (0.35% MSE) to \(-3.16\) (4.24% MSE).

The top four performing stitches were:

  1. (\(-5.63\)) mxbai-embed-large-v1 to UAE-Large-V1
  2. (\(-5.62\)) UAE-Large-V1 to mxbai-embed-large-v1
  3. (\(-5.12\)) UAE-Large-V1 to bge-large-en-v1.5
  4. (\(-5.11\)) bge-large-en-v1.5 to UAE-Large-V1

All of these models are of dimension \(1024\), corroborating the observation from the original paper that dimension is sometimes correlated with "similarity".

The bottom four performing stitches were all unsurprisingly stitched from bge-small-en-v1.5, one of our two smallest models with dimension \(384\). We had hypothesized that this would be the case. The only decently performing stitch with this native space was bge-small-en-v1.5 to gte-small, another model of dimension \(384\). This too is unsurprising.

However, what caught us off guard is that stitches from gte-small to other embedding spaces performed on-par with models that were far larger, including OpenAI's text-embedding-3-small, a model nearly 4 times its size. If text-emedding-3-small was storing a more expressive representation of semantic features than text-embedding-3-small, then how come affine stitches had similar performance when operating from their embeddings? As we observe later, these embedding spaces do not seem to be genuinely linear connectivity-aligned, so its likely that they are learning some simple, baseline strategy that does not depend, too heavily, on the additional nuance encodable by big OpenAI models.

Native Space Performance

To investigate, we needed to get a better metric than MSE and MAE, since those are in effect not interpretable. They tell us in absolute terms how geometrically distant two high-dimensional vectors are, but it is hard to use this knowledge to infer anything practical. Below we plot R squared and MAE explained which tells us what percentage of the variance present in the distribution of target embeddings is explained by the distribution of source embeddings. We provide controls with a mean estimator and a random gaussian. Clearly all the linear stitches perform better than the random gaussian, but they barely improve upon the mean estimator.

Target Space Performance

The fact that R squared is so abysmally low (and even sometimes negative) suggests that these affine transformations are actually not good in absolute terms. Simply predicting the mean (from the train-set) would perform effectively as well. We trained both with gradient descent (twice) and ordinary-least-squares based models. We tuned a ridge regression parameter to reduce overfitting. In each case, the mean estimator was just as good as the affine stitches. This result is strongly suggestive that the affine stitches are not reconstructing target embeddings. However, their training loss curves (visible in the appendix) plateaued, suggesting that they may not have much more mileage to go. In the table below we present some results for larger neural networks (MLPs) trained on the same objective via gradient descent. They do not perform significantly better.

The pattern in which larger models' embeddings are more capable of mapping onto smaller ones, but not vice versa, persists. However, these MLPs do not greatly improve upon the affine stitches. We did not get to hyperparameter sweep exhaustively, so it is possible that MLPs may yet bare fruit, but it is unlikely to work on the first try.

A natural question to ask, at this point, is whether or not these models may be simply learning to use their bias to become mean estimators. Below we plot the spectrum of the affine and MLP matrices to determine this. While MLP matrices have a fairly flat spectrum, suggesting some sort of rotation, the affine matrices have a relatively large boost on the largest singular values. Possibly they are catching on to a direction of high variance in the dataset---a question to answer with future research. Regardless, it should also be noted that the mean bias norm was very low: around 0.07. However, the norm of the mean of many of the embedding datasets was larger, often surpassing 1, but not 20. This means that the affine stitch stitches are most likely not mean estimators.

ols_nolog sampling spectra ols_log sampling spectra nonlinear_nolog sampling spectra nonlinear_log sampling spectra
Plots of the singular values of randomly sampled weight matrices from different subsets of the trained stitches. On the left are plots of the singular values and on the right are the same plots in logspace. On the top are the plots for the matrices we trained to solve the Linear MSE-minimization objective. On the bottom are the plots for the matrices that make up the layers of deeper MLPs (from 2 up to 7 layers deep) trained on the same objective. From the distribution, it's clear that these matrices do not represent rotations in the OLS-matrix case; however, they are not extremely low rank either. They seem to stretch the space significantly in a few directions. The MLP matrices seem to act more like rotations, but further investigation is required to better understand the behavior of these mappings. For aesthetic reasons, 200 samples at random were used for each plot (out of around 1600 available per plot). Use the slider to select the number of MLP layers in the stitch.

05 Discussion and Conclusion

We set out to check (1) whether or not different embedding models' embeddings were linearly mappable, (2) what degree of complexity was needed to have a good mapping between embeddings, and (3) which types of models are easier to linearly connect. Along the way we explored whether linear mappings could be rotations and whether they may be learning mean estimators. We see that the answer to (1) is likely no, and that understanding (2) requires more depth to ascertain, but seems likely to be certainly more than 2 layers and at least one non-linearity, but probably far more at reasonable widths. Our answer to (3) is that it is easier to go from larger to smaller models, and of course we found that these mappings are usually not rotations in the linear case. The spectrum of our mappings revealed that some unexpected behavior was going on and opens up new questions for research, such as whether the dataset's diversity may impact the stitch performance.

Overall, we believe that different embedding models' embeddings are not linear connectivity-aligned, and posit that while moore research is needed to try with MLPs, it is unlikely to work out of the box and will require some tuning.

06 Implications and Limitations

The key implication of this work is that for transformer-based embedding language models (like the ones we use, as one would find for semantic search), the embeddings are not linear connectivity-aligned and therefore, if the PRH and LRH are true, then either these models we are using are too small (or trained on too few data), or they are storing features non-linearly, or linearly, but in some way which is not linearly mappable (i.e. they might be using some form of sparse code which packs more information into fewer dimensions or they have the spatial relationships between related concepts be permuted relative to other such embedding models'). Unfortunately, if we want to get more evidence for the PRH and LRH we will need to try harder than this. Another implication is that cheap embedding translation for data visualization or semantic search is unlikely to work.

There are some limiations to our work. One key one is that we train using an MSE objective and evaluate using R squared and other such metrics. It's possible that these are simply not indicative of downstream tasks. With more time, we would explore in detail how linear (and non-linear) stitching affects semantic search rankings, since that is a real world usecase. We also would consider looking into alternative objects (not MSE) and we would invest more time into tuning small non-linear stitches. Another important limitation, is that due to our computation constraints, and in the interest of consistency with past work, we did not train on truly large datasets, nor truly diverse datasets. It is possible that the affine linear mappings could have far-outperfomred mean estimators if we had merged all the datasets in their entirety to creaet a large, much more diverse dataset, to train the stitch. As it stands, it's possible that the amount or type of data just wasn't sufficient. Lastly, it should be noted that our results are for embedding language models, and may not always hold for langauge models or deep neural networks at large.

📚 Cite this work

To showcase that our models were trained sufficiently observe these loss plots:

training loss for a various model on arguana; plateauing training loss for a various model on trec-covid; plateauing somewhat training loss for a single model on hotpotqa; plateauing validation loss for a single model on hotpotqa; noisier than training loss, but plateauing
Plots of the training and validation loss for various models on different datasets over the course of training. These are mostly indicative of broader patterns. Generally, our models mostly plateaued, though in some more specific cases some deeper MLPs may be possible to improve. We tried varying the Adam learning rate and batch size in a few ways for specific cases, but did not find a large improvement. This suggests that something more involved than a "train generic MLP with defaults" appraoch would be necessary to effectively map between semantic representations. We have not yet fully identified whether the objective, optimizer, dataset, architecture, or fundamental difficulties of the problem are the primary factors.
@article{culp_hernandez_embed_stitch_2024,
    title     = {LEAD: Linear Embedding Alignment across Deep Neural Network Language Models'
    Representations},
    author    = {Culp, Gatlen* and Hernandez, Adriano*}, 
    note      = {* Equal Contribution},
    journal   = {MIT Deep Learning Blogs},
    year      = {2024},
    month     = {dec},
    url       = {https://gatlenculp.github.io/embedding_translation/},
}

💻 Use our code

Our code will eventually be made public here: https://github.com/GatlenCulp/embedding_translation