Code Monkey home page Code Monkey logo

Comments (10)

jukofyork avatar jukofyork commented on June 18, 2024

It looks like the example from the original LoRD code has similar problems due to 2 extra token being added:

image

image

https://github.com/thomasgauthier/LoRD/blob/main/LoRD.ipynb
https://colab.research.google.com/github/thomasgauthier/LoRD/blob/main/LoRD.ipynb

https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B

  "vocab_size": 32002

https://huggingface.co/thomasgauthier/OpenSnark-LoRD

Is this really working as a real LoRA would or has some information been lost from the lm_head.weight and embed_tokens.weight tensors of OpenSnark-LoRD?

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

So I've looked more into this and found: vllm-project/vllm#2816

with a real example model using this here:

https://huggingface.co/yard1/llama-2-7b-sql-lora-test/tree/main?show_file_info=adapter_model.safetensors

adapter_model.safetensors

base_model.model.model.embed_tokens.lora_embedding_A 	[8, 32 004] 	BF16
base_model.model.model.embed_tokens.lora_embedding_B 	[4 096, 8] 	BF16
.
.
.
base_model.model.lm_head.lora_A.weight 	[8, 4 096] 	BF16
base_model.model.lm_head.lora_B.weight 	[32 004, 8] 	BF16

new_embeddings.safetensors

input_embeddings 	[4, 4 096] 	BF16
output_embeddings 	[4, 4 096] 	BF16

I don't really understand why that is needed to be 32004 in both files - surely the point of adding this is so that the new token embedding weights can be trained full-rank?

I'm also a bit confused about how it's called input_embeddings and output_embeddings but applies to the embed_tokens\ and lm_head` weights?

This is how I see some code using these:

# Model Loading
def add_embeddings(model, embed_path, tokenizer):
    old_embeddings = model.get_input_embeddings()
    old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
    new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
    new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
    model._init_weights(new_embeddings)
    embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
    vocab_size = tokenizer.vocab_size
    new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
    new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.to(
        new_embeddings.weight.dtype
    ).to(new_embeddings.weight.device)
    model.set_input_embeddings(new_embeddings)
    model.tie_weights()

So the final shape of new_embeddings.weight will be (vocab_size + new_num_tokens, old_embedding_dim), but this doesn't seem to match with the real example in the llama-2-7b-sql-lora-test repo...

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

So it looks like the names input_embeddings and output_embeddings are just LoRA's renaming and correct:

https://github.com/vllm-project/vllm/blob/main/tests/lora/test_lora_manager.py

EMBEDDING_MODULES = {
    "embed_tokens": "input_embeddings",
    "lm_head": "output_embeddings",
}

I'm still unsure though why llama-2-7b-sql-lora-test has 32004 in both files.

The from_lora_tensors() and from_local_checkpoint() here might help to understand:

https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py

but I don't really get what's going on there... :/

I think for now I'm just going to try truncating these and then also seeing if I can save any 1D vectors (like norms) too as if these were targetted with the LoRA fine-tune then they too should be saved for archival purposes. If I can get that working then I'll return to this and try to figure it out.

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

I think this is almost working now:

import json
import logging
import os

from typing import Any, Dict, List, Optional, Tuple

import bitsandbytes as bnb
import click
import torch
from peft.tuners.lora import QuantLinear
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import PreTrainedModel

from mergekit.card import generate_card_lora
from mergekit.common import ModelReference
from mergekit.io import LazyTensorLoader

def _low_rank_decomposition(
    weight: torch.Tensor,
    reduced_rank: int = 16,
    min_fraction_variance: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Decompose a 2D matrix into low-rank matrices A and B using SVD.

    :param weight: The matrix to decompose, of shape (H, W)
    :param reduced_rank: The final rank of the decomposition
    :param min_fraction_variance: Minimum fraction of variance explained
    :return: A tuple of tensors (A, B, final_rank)
    """
    if weight.dim() != 2:
        raise ValueError(
            f"Only support 2D matrix, but your input has {weight.dim()} dimensions."
        )

    dtype = weight.dtype

    # SVD Decomposition
    U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)

    # Square the singular values
    S_squared = S ** 2

    # Compute the total sum of squares
    total_sum_of_squares = S_squared.sum().item()

    # Compute the cumulative variance explained
    cumulative_variance_explained = torch.cumsum(S_squared, dim=0) / total_sum_of_squares
    
    # Set the final rank based on the value we were given.
    final_rank = reduced_rank;
    
    # Reduce further if this is too large for this specific matrix.
    final_rank = min(final_rank, min(weight.shape))
    
    # Reduce the rank if we have enough variance explained
    if min_fraction_variance < 1:
        rank_based_on_variance = torch.searchsorted(cumulative_variance_explained, min_fraction_variance).item() + 1
        final_rank = min(final_rank, rank_based_on_variance)

    # Truncate matrices based on the final rank
    A = Vh[:final_rank, :]
    B = U[:, :final_rank] @ torch.diag(S[:final_rank])
            
    return A.to(dtype), B.to(dtype), final_rank


def decompose_delta_weight(
    new_weight: torch.Tensor,
    base_weight: torch.Tensor,
    reduced_rank: int,
    min_fraction_variance: float,
    device: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    new_weight = new_weight.to(device)
    base_weight = base_weight.to(device)
    
    # Truncate new_weight to match the dimensions of base_weight
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # @@@@ TODO: FIX TO USE ' new_embeddings.safetensors' @@@@
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    if new_weight.shape != base_weight.shape:
        new_weight = new_weight[:base_weight.shape[0], :base_weight.shape[1]]

    """
    Decompose the delta weight into low-rank matrices A and B.

    :param new_weight: The updated weight matrix after applying LoRA.
    :param base_weight: The original weight matrix before LoRA.
    :param reduced_rank: The rank for the low-rank decomposition.
    :param min_fraction_variance: Minimum fraction of variance explained.
    :param device: The device to perform computation on.
    :return: A tuple of tensors (A, B)
    """
    delta_weight = new_weight - base_weight

    assert (
        0 < min_fraction_variance <= 1
    ), f"The min_fraction_variance ({min_fraction_variance}) must be in the range (0, 1]."
    
    A, B, final_rank = _low_rank_decomposition(delta_weight, reduced_rank=reduced_rank, min_fraction_variance=min_fraction_variance)

    return A, B, final_rank


def find_all_layer_names(model: PreTrainedModel) -> Tuple[List[str], List[str], List[str]]:
    
    embedding_names = []
    linear_names = []
    other_names = []
    
    for name, module in model.named_modules():
        
        # Check if an embedding layer
        if isinstance(module, (torch.nn.Embedding)):
            embedding_names.append(name)
            
        # Check if a linear layer
        elif (
            isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, QuantLinear))
            or "Linear" in module.__class__.__name__
            and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
        ):
            linear_names.append(name)
            
        # Else other type of layer we are insterested in.
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        # @@@@ TODO: FIX TO USE OTHER 1D-LAYER TYPES @@@@
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        elif module.__class__.__name__.lower().endswith("norm".lower()):
            other_names.append(name)
    
    return embedding_names, linear_names, other_names
    
    
def get_all_module_names(model_id: str) -> Tuple[List[str], List[str], List[str]]:
    model = AutoModelForCausalLM.from_pretrained(
        model_id, state_dict={}, device_map="meta"
    )  # avoid loading weights as we won't need them
    
    embedding_names, linear_names, other_names = find_all_layer_names(model)

    return  embedding_names, linear_names, other_names


def create_peft_config(
    base_model_name_or_path: str, rank: int, rank_pattern: Dict[str, int], alpha: int, target_modules: List[str], modules_to_save: List[str]
) -> Dict[str, Any]:
    return {
        "alpha_pattern": {},
        "auto_mapping": None,
        "base_model_name_or_path": base_model_name_or_path,
        "bias": "none",
        "fan_in_fan_out": False,
        "inference_mode": True,
        "init_lora_weights": True,
        "layers_pattern": None,
        "layers_to_transform": None,
        "loftq_config": {},
        "lora_alpha": alpha,
        "lora_dropout": 0,
        "megatron_config": None,
        "megatron_core": "megatron.core",
        "modules_to_save": modules_to_save,
        "peft_type": "LORA",
        "r": rank,
        "rank_pattern": rank_pattern,
        "revision": None,
        "target_modules": target_modules,
        "task_type": "CAUSAL_LM",
        "use_rslora": False,
    }


def reconstruct_invocation(args):
    """
    Reconstructs the command-line invocation string based on the given arguments stored in a dictionary.

    Parameters:
    - args: A dictionary containing the command arguments with keys matching the parameter names.
      Expected keys are 'base_model', 'finetuned_model', 'out_path', 'no_lazy_unpickle', 'desired_rank', 'model_name' and 'device'.

    Returns:
    - The reconstructed command-line invocation string.
    """
    # Provide a default value for out_path if it's not in the dictionary
    out_path = args.get("out_path", "OUTPUT_PATH")

    invocation = f"mergekit-extract-lora {args['base_model']} {args['finetuned_model']} {out_path}"
    if args.get("no_lazy_unpickle"):
        invocation += " --no-lazy-unpickle"
    if args.get("desired_rank"):
        invocation += f" --rank={args['desired_rank']}"
    if args.get("min_fraction_variance"):
        invocation += f" --min_fraction_variance={args['min_fraction_variance']}"
    if args.get("model_name"):
        invocation += f" --model_name={args['model_name']}"
    if args.get("device"):
        invocation += f" --device={args['device']}"

    return invocation


@click.command("mergekit-extract-lora")
@click.argument("finetuned_model", type=str)
@click.argument("base_model", type=str)
@click.argument("out_path", type=click.Path())
@click.option(
    "--no-lazy-unpickle",
    is_flag=True,
    help="Disable lazy unpickler (more stable, higher memory usage)",
)
@click.option(
    "--rank",
    "desired_rank",
    type=int,
    default=32,
    help="Rank for the low-rank decomposition",
)
@click.option(
    "--min_fraction_variance",
    type=float,
    default=1.0,
    help="Minimum fraction of variance explained (default is 1.0, which means all variance)",
)
@click.option(
    "--model_name",
    type=str,
    default=None,
    help="Name of the resulting model (shown in the model card)",
)
@click.option(
    "--device",
    type=str,
    default=None,
    help="PyTorch device to perform SVD computation on",
)
def main(
    finetuned_model: str,
    base_model: str,
    out_path: str,
    no_lazy_unpickle: bool,
    desired_rank: int,
    min_fraction_variance: float,
    model_name: str,
    device: str,
) -> None:
    """
    Decomposes delta weights between a base model and a finetuned model, saving a PEFT model to the specified output path.

    \b
    Arguments:
    FINETUNED_MODEL - the model ID or path to use as the PEFT extraction target model.
    BASE_MODEL - the model ID or path to use as the base model.
    OUT_PATH - the output path where the PEFT model will be saved.
    """

    invocation_args = {
        "base_model": base_model,
        "finetuned_model": finetuned_model,
        "desired_rank": desired_rank,
        "min_fraction_variance": min_fraction_variance,
        "device": device,
        "out_path": out_path,
        "model_name": model_name,
        "no_lazy_unpickle": no_lazy_unpickle,
    }

    os.makedirs(out_path, exist_ok=True)

    base_model_ref = ModelReference.parse(base_model)
    finetuned_model_ref = ModelReference.parse(finetuned_model)

    embedding_names, linear_names, other_names = get_all_module_names(base_model_ref.model.path)
    finetuned_embedding_names, finetuned_linear_names, finetuned_other_names = get_all_module_names(finetuned_model_ref.model.path)

    assert set(embedding_names) == set(finetuned_embedding_names), "Model architecture mismatch"
    assert set(linear_names) == set(finetuned_linear_names), "Model architecture mismatch"
    assert set(other_names) == set(finetuned_other_names), "Model architecture mismatch"

    base_loader = LazyTensorLoader(
        base_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle)
    )
    finetuned_loader = LazyTensorLoader(
        finetuned_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle)
    )

    lora_weights = {}
    rank_pattern = {}

    # Embedding weights
    for layer_name in tqdm(embedding_names):
        
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        # @@@ WHY DOES THIS NEED TRANSPOSING TO WORK??? @@@
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        base_weight = base_loader.get_tensor(f"{layer_name}.weight").T
        finetuned_weight = finetuned_loader.get_tensor(f"{layer_name}.weight").T
        
        lora_A, lora_B, final_rank = decompose_delta_weight(
            finetuned_weight, base_weight, desired_rank, min_fraction_variance, device=device
        )
        
        lora_weights[f"base_model.model.{layer_name}.lora_embedding_A"] = lora_A.to(
            "cpu"
        ).contiguous()
        lora_weights[f"base_model.model.{layer_name}.lora_embedding_B"] = lora_B.to(
            "cpu"
        ).contiguous()
        
        # Add the final rank to the rank_pattern dictionary only if it is different from the desired_rank
        if final_rank != desired_rank:
            rank_pattern[layer_name] = final_rank
            
        print(f"E--> {layer_name} : {final_rank}")
    
    # Linear weights
    for layer_name in tqdm(linear_names):
                
        base_weight = base_loader.get_tensor(f"{layer_name}.weight")
        finetuned_weight = finetuned_loader.get_tensor(f"{layer_name}.weight")
        
        lora_A, lora_B, final_rank = decompose_delta_weight(
            finetuned_weight, base_weight, desired_rank, min_fraction_variance, device=device
        )
        
        lora_weights[f"base_model.model.{layer_name}.lora_A.weight"] = lora_A.to(
            "cpu"
        ).contiguous()
        lora_weights[f"base_model.model.{layer_name}.lora_B.weight"] = lora_B.to(
            "cpu"
        ).contiguous()
        
        # Add the final rank to the rank_pattern dictionary only if it is different from the desired_rank
        if final_rank != desired_rank:
            rank_pattern[layer_name] = final_rank
            
        print(f"L--> {layer_name} : {final_rank}")

    # Other wweights
    for layer_name in tqdm(other_names):
        
        finetuned_weight = finetuned_loader.get_tensor(f"{layer_name}.weight")
        
        lora_weights[f"base_model.model.{layer_name}.weight"] = finetuned_weight.to(
            "cpu"
        ).contiguous()
        
        print(f"O--> {layer_name}")
        
    lora_config = create_peft_config(
        base_model_name_or_path=base_model_ref.model.path,
        alpha=desired_rank,  # Setting the alpha to the reduced rank value as `peft` will scale the LoRA weights by alpha/r when applying the adapter
        rank=desired_rank,
        rank_pattern=rank_pattern,
        target_modules=list(
            set([module_name.split(".")[-1] for module_name in embedding_names])
            .union(set([module_name.split(".")[-1] for module_name in linear_names]))
        ),
        modules_to_save=list(
            set([module_name.split(".")[-1] for module_name in other_names])   # Used to replace completely when LoRA is applied...
        )
    )

    with open(os.path.join(out_path, "adapter_config.json"), "w") as f:
        json.dump(lora_config, f, indent=2)

    save_file(lora_weights, os.path.join(out_path, "adapter_model.safetensors"))

    invocation_args.pop("out_path")  # don't include out_path for privacy
    invocation = reconstruct_invocation(invocation_args)

    card_md = generate_card_lora(
        base_model_ref=base_model_ref,
        finetuned_model_ref=finetuned_model_ref,
        invocation=invocation,
        name=model_name,
    )

    with open(os.path.join(out_path, "README.md"), "w", encoding="utf-8") as fp:
        fp.write(card_md)

    logging.info(f"PEFT LoRA adapters saved to {out_path}")


if __name__ == "__main__":
    main()

I'm not quite cure about lora_embedding_A and lora_embedding_B as I don't know why they need transposing to match the dimensions like that, but I have successfully taken a 1024 rank LoRA from Mistral-7B-Instruct-v0.2 with Mistral-7B-v0.1 as the base, creating this adapter_config.json:

{
  "alpha_pattern": {},
  "auto_mapping": null,
  "base_model_name_or_path": "./Mistral-7B-v0.1",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 1024,
  "lora_dropout": 0,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": [
    "post_attention_layernorm",
    "norm",
    "input_layernorm"
  ],
  "peft_type": "LORA",
  "r": 1024,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "down_proj",
    "gate_proj",
    "k_proj",
    "o_proj",
    "embed_tokens",
    "up_proj",
    "lm_head",
    "v_proj",
    "q_proj"
  ],
  "task_type": "CAUSAL_LM",
  "use_rslora": false
}

and a 5.1GB adapter_model.safetensors file.

I then re-merged the LoRA back onto Mistral-7B-v0.1 .

It is working and definitely following instructions but has a bit of a repeat problem which could be due to the low-rank decomposition or possibly some bug in the code.


I also ran an extra sanity test to put everything modules_to_save like so:

{
  "alpha_pattern": {},
  "auto_mapping": null,
  "base_model_name_or_path": "./Mistral-7B-v0.1",
  "bias": "none",
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 1024,
  "lora_dropout": 0,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": [
    "post_attention_layernorm",
    "norm",
    "input_layernorm",
    "down_proj",
    "gate_proj",
    "k_proj",
    "o_proj",
    "embed_tokens",
    "up_proj",
    "lm_head",
    "v_proj",
    "q_proj"
  ],
  "peft_type": "LORA",
  "r": 1024,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [ ],
  "task_type": "CAUSAL_LM",
  "use_rslora": false
}

to confirm the correct naming of the un-decomposed weights really is:

        lora_weights[f"base_model.model.{layer_name}.weight"] = finetuned_weight.to(
            "cpu"
        ).contiguous()

and not:

        lora_weights[f"model.model.{layer_name}.weight"] = finetuned_weight.to(
            "cpu"
        ).contiguous()

and also to check if these should be added to the adapter_model.safetensors or a separate model.safetensors file.

This did run almost the same as the stock Mistral-7B-Instruct-v0.2 , but it had a very slight deviation in the tokens chosen.

I think this is because of the bfloat16 --> float16 conversion of the lora-cache creation (that Mergekit won't let you change from float16), but it could be something I have missed from modules_to_save... I'm 90% sure it's not though:

https://github.com/arcee-ai/mergekit/blob/main/mergekit/_data/architectures/llama.json

as clearly all these are listed in modules_to_save above.


Anyway I will have a look more on Monday and see if I can get the new_embeddings.safetensors stuff working for the models with added tokens and then it's mostly done.

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

I have successfully extracted and replied, a full-rank LoRA from Mistral-7B-Instruct-v0.2 with Mistral-7B-v0.1 as the base.

There is an eventual divergence but I'm pretty sure it's just down to numerical error in the SVD or the bfloat16 --> float16 conversion (as even putting everything in modules_to_save does cause a slight divergence too...):

Mistral-7B-Instruct-v0.2:

USER

Write me the opening chapter of a 'Grimdark' stories trilogy in the style of Joe Abercrombie and Rob J Hayes

ASSISTANT

Title: The Crimson Reckoning

Chapter One: A Harvest of Sorrow

The sun dipped below the horizon, casting long shadows over the desolate landscape. The wind howled through the barren trees, their skeletal branches scraping against each other like banshees in mourning. In this forsaken place, where hope was a rare commodity and despair wore a permanent smile, a lone figure trudged through the mud. His name was Morgrath Ironhand, and he was no stranger to sorrow.

Morgrath's face was etched with the lines of a thousand battles - each scar, each wrinkle, a testament to his past. His eyes held the cold, hard glint of a man who had seen too much death and destruction. He clutched an ancient warhammer in one hand, its crimson handle stained with the blood of countless enemies. The other hand cradled a leather pouch filled with seeds - seeds that represented his last hope.

The land around him was a desolate wasteland, ravaged by war and plague. The once-proud cities lay in ruins, their inhabitants either dead or scattered to the winds. Morgrath's people, the Ironborn, had been among the hardest hit. Their strongholds had been sacked, their warriors slain or enslaved. But Morgrath refused to give in to despair. He believed that even in this darkest of times, there was a glimmer of hope - a chance for redemption and renewal.

As he walked, Morgrath's thoughts were interrupted by the sound of hooves on the muddy ground. He turned to see a group of riders approaching, their faces hidden beneath hooded cloaks. They were the Blackhand Brotherhood, a notorious band of mercenaries known for their brutality and treachery. Their leader, a hulking figure with a scarred face and cold, dead eyes, rode at the front.

"Morgrath Ironhand," he sneered, "the legendary warrior of the Ironborn. I've heard tales of your exploits - how you've managed to survive in this forsaken world. Impressive, for one so old."

Morgrath glared back at him, his hand tightening around the haft of his warhammer. "I've survived longer than you ever will, Blackhand."

The mercenary leader chuckled, a sound that sent shivers down Morgrath's spine. "We have come to offer you a deal, old man. Join us in our next campaign, and we'll ensure that your people are well-compensated. Or we can make an example of you here and now."

Morgrath's eyes narrowed, his grip on the warhammer growing tighter. "I serve no master but my own."

The mercenaries laughed, their laughter echoing through the desolate landscape. But Morgrath knew that they were bluffing - they needed him as much as he needed them. And he had a plan.

As the sun set and the stars began to twinkle in the night sky, Morgrath made his move. With a roar, he charged at the mercenaries, his warhammer swinging through the air like a blade of crimson justice. The battle that ensued was brutal and savage - a dance of death that left the ground slick with blood and tears.

But Morgrath's greatest weapon was not his warhammer or his strength, but his determination and belief in a better future. And as the last mercenary fell, he knew that he had won more than just a battle - he had won a harvest of hope.

With renewed purpose, Morgrath continued on his journey, planting the seeds as he went. He would rebuild his people's strongholds, restore their pride and dignity. And though the road ahead was long and fraught with danger, he knew that he would not walk it alone.

For in this grimdark world, even a glimmer of hope could be a beacon of light.

USER

Can you tell me about Joe Abercrombie and Rob J Hayes?

ASSISTANT

Joe Abercrombie and Rob J. Hayes are both acclaimed fantasy authors known for their gritty, dark, and realistic storytelling. They have gained a significant following in the fantasy genre with their intricate plots, complex characters, and vivid world-building.

Joe Abercrombie is a British author who has published several bestselling novels in the Last Argument of Kings series (also known as The First Law Trilogy), as well as standalone novels and short stories. His writing style is characterized by its gritty realism, dark humor, and intricate character development. Abercrombie's works often explore themes of power, morality, and the human condition in a harsh and unforgiving world.

Rob J. Hayes is an American author who has published several fantasy novels, including the Ties That Bind series (also known as The Chronicler's Quill), the Children of Rage series, and standalone novels. His writing style is similar to Abercrombie's in its gritty realism and dark themes, but he also incorporates elements of magic and the supernatural into his stories. Hayes' works often explore complex moral dilemmas, the consequences of power, and the nature of heroism in a brutal world.

Both Abercrombie and Hayes have been praised for their ability to create compelling characters that are flawed, complex, and morally ambiguous. Their writing is known for its vivid descriptions, intricate plotting, and exploration of themes that challenge traditional fantasy tropes. Their works often contain graphic violence, dark humor, and mature themes, making them popular among readers who enjoy a more adult take on the fantasy genre.

Mistral-7B-v0.1 +Mistral-7B-Instruct-v0.2-LoRA:

USER

Write me the opening chapter of a 'Grimdark' stories trilogy in the style of Joe Abercrombie and Rob J Hayes

ASSISTANT

Title: The Crimson Reckoning

Chapter One: A Harvest of Sorrow

The sun dipped below the horizon, casting long shadows over the desolate landscape. The wind howled through the barren trees, their skeletal branches scraping against each other like banshees in mourning. In this forsaken place, where hope was a rare commodity and despair wore a permanent smile, a lone figure trudged through the mud. His name was Morgrath Ironhand, and he was no stranger to sorrow.

Morgrath's face was etched with the lines of a thousand battles - each scar, each wrinkle, a testament to his past. His eyes held the cold, hard glint of a man who had seen too much death and destruction. He clutched an ancient warhammer in one hand, its crimson handle stained with the blood of countless enemies. The other hand cradled a leather pouch filled with seeds - seeds that represented his last hope.

The land around him was a desolate wasteland, ravaged by war and plague. The once-proud cities lay in ruins, their inhabitants either dead or scattered to the winds. Morgrath's people, the Ironborn, had been among the hardest hit. Their strongholds had been sacked, their warriors slain or enslaved. But Morgrath refused to give in to despair. He believed that even in this darkest of times, there was a glimmer of hope - a chance for redemption and renewal.

As he walked, Morgrath's thoughts were interrupted by the sound of hooves on the muddy ground. He turned to see a group of riders approaching, their faces hidden beneath hooded cloaks. They were the Blackhand Brotherhood, a notorious band of mercenaries known for their brutality and treachery. Their leader, a hulking figure with a scarred face and cold, dead eyes, rode at the front.

"Morgrath Ironhand," he sneered, "the legendary warrior of the Ironborn. I've heard tales of your exploits - how you've managed to survive in this forsaken world. Impressive, for one so old."

Morgrath glared back at him, his hand tightening around the haft of his warhammer. "I've survived longer than you ever will, Blackhand."

The mercenary leader chuckled, a sound that sent shivers down Morgrath's spine. "We have come to offer you a deal, old man. Join us in our next campaign, and we'll ensure that your people are well-compensated. Or we can make an example of you here and now."

Morgrath's eyes narrowed, his grip on the warhammer growing tighter. "I serve no master but my own."

The mercenaries laughed, their laughter echoing through the desolate landscape. But Morgrath knew that they were bluffing - they needed him as much as he needed them. And he had a plan.

As the sun set and the stars began to twinkle in the night sky, Morgrath made his move. With a roar, he charged at the mercenaries, his warhammer swinging through the air like a blade of crimson justice. The battle that ensued was brutal and savage - a dance of death that left the ground slick with blood and tears.

But Morgrath's greatest weapon was not his warhammer or his strength, but his determination and belief in a better future. And as the last mercenary fell, he knew that he had won more than just a battle - he had won a harvest of hope.

With renewed purpose, Morgrath continued on his journey, planting the seeds as he went. He knew that it would be a long and arduous process, but he believed in the power of redemption and renewal. And so, amidst the ruins of a broken world, Morgrath Ironhand began to sow the seeds of a new beginning.

The Crimson Reckoning had begun.

USER

Can you tell me about Joe Abercrombie and Rob J Hayes?

ASSISTANT

Joe Abercrombie and Rob J. Hayes are both renowned fantasy authors known for their gritty, dark, and realistic storytelling. They have gained a significant following in the fantasy genre with their intricate plots, complex characters, and vivid world-building.

Joe Abercrombie is a British author who has written several critically acclaimed fantasy series. His work is characterized by its gritty realism and moral ambiguity. Some of his most notable works include the First Law Trilogy (The Blade Itself, Before They Are Hanged, and Last Argument of Kings), the Shattered Sea Trilogy (Half a King, Half the World, and Half a War), and the Age of Madness series (A Little Hatred, The Trouble with Peace, and The Wisdom of Crowds).

Rob J. Hayes is an American author who has written several fantasy novels and short stories. His work shares similarities with Abercrombie's in terms of its dark and gritty tone. Some of his most notable works include the Ties That Bind series (Bloodrune, Bloodrune: The Price of Betrayal, and Bloodrune: The Reckoning), the Chronicler's Tale series (The Heresy Within, The Heresy's End, and The Heretic's Revenge), and the Seventh Thunders series (The Seventh Thunders, The Fifth Empire of Man, and The Last War).

Both Abercrombie and Hayes have been praised for their ability to create complex characters that inhabit intricately detailed worlds. Their stories often explore themes of morality, power, and the human condition in a dark and unforgiving world. Their writing styles are similar in that they do not shy away from depicting violence, death, and the harsh realities of life. However, they also offer moments of hope and redemption amidst the darkness.


It was the same right up to the last few words of the last paragraph and the follow-up question didn't say "Rob J. Hayes is an Australian author" or "Rob J. Hayes is an Irish author" like the 1024-rank reconstruction.


I originally forgot to set the alpha_pattern dictionary for any that had rank_pattern and got the spookiest thing I've ever seen from a LLM:

USER

Write me the opening chapter of a 'Grimdark' stories trilogy in the style of Joe Abercrombie and Rob J Hayes

ASSISTANT

Title: A Bercromian Tale of M Вла, H, and the unquenchd Ius: Cum Garam, Ias or this in a (un❶ shrines, elde, Ere or O N ethere all at once, 22 at M Mosen, 11 in the Schrines of the Gods of M Mosen, and my uncomely-ed-and-aesist I part : At this hallowed place of the aforemented and more , _ this is at the heart of Old Gosan, amongst all its malign relics, devotional relics and the last remnants of long-dead traditions.
_ A myriad multitude gathered in one fated place, Thas in the 22's _ I and _ E aeons.
_ And at the shrine of M Neros, a city not yet recovered by Gracchans or Romans, but young men who had seen better days.
_ and among these were those who still clung to the old ways,their faces seared with the sun's light-paining and fair complexion of Aeon, a man of great beauty and refined bearing.
_ And at the shrine of M Neros,a young man whose eyes shone withthe sun.
_ and Aemon, a beautiful woman who had seen better days, the priestesses, the scholars, and the old men who still tended to their cities and their people.
_ and Aemon, a man of great beauty and refined bearing,-,_ A fair woman with a fine figure.
_ And Aemon, a beautiful woman whose face was seared by the sun.
_ and the other priests, who were not only revered for their beauty and their ability to rule men.
_ and Aemon, the priestesses of M Neros,
_ Aeon_, a man of great beauty,and a beautiful woman whose face was seared by the sun.
_ and Aethon,
_ Aeion_,
_ and Aethon.

"and Aemon,
"a fair man,"
the son of the god Aeon, who dwelt in M Neros at the shrine of M Neros.
"And a beautiful woman,
_ and a beautiful woman.
_ and Aethon,

a fair priestess,

and Aeeton,

a man of great beauty and refined bearing.
"and the other priests,

the sons of the god Aeon,

and Aethon and Iam.

"And a beautiful woman

in the face of_ Aethon_,

a fair woman

WTF! 😮


I'll look at the new_embeddings.safetensors stuff next and then this should be creating LoRAs that are 100% correct.

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

I think I'm just going to not bother with the new_embeddings.safetensors stuff as it looks like it could be quite problematic:

The only examples I can see on hugginface that use it are these:

https://huggingface.co/yard1/llama-2-7b-sql-lora-test/tree/main
https://huggingface.co/alpindale/l2-lora-test/tree/main

adapter_model.safetensors

base_model.model.model.embed_tokens.lora_embedding_A 	[8, 32 004] 	BF16
base_model.model.model.embed_tokens.lora_embedding_B 	[4 096, 8] 	BF16
.
.
.
base_model.model.lm_head.lora_A.weight 	[8, 4 096] 	BF16
base_model.model.lm_head.lora_B.weight 	[32 004, 8] 	BF16

new_embeddings.safetensors

input_embeddings 	[4, 4 096] 	BF16
output_embeddings 	[4, 4 096] 	BF16

which made me think I could just do this:

base_weight = torch.nn.functional.pad(base_weight, (0, 0, 0, padding_size), "constant", 0)
extra_weight = finetuned_weight[-padding_size:, :]
new_embedding_weights["input_embeddings"] = extra_weight

to pad the embed_tokens (and lm_head) tensors to match the above, and then save the new_embedding_weights dictionary into a file called new_embeddings.safetensors... But it doesn't work:

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.embed_tokens.lora_embedding_A.default: copying a param with shape torch.Size([4096, 32768]) from checkpoint, the shape in current model is torch.Size([4096, 32000]).
	size mismatch for base_model.model.lm_head.lora_B.default.weight: copying a param with shape torch.Size([32768, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]).

and if I just truncate and and then save the new_embeddings.safetensors:

finetuned_weight = finetuned_weight[:base_weight.shape[0]]
extra_weight = finetuned_weight[-padding_size:, :]
new_embedding_weights["input_embeddings"] = extra_weight

the stuff saved in new_embeddings.safetensors doesn't get used when the LoRA is merged back and the final dimensions stay as 32000 rather than 32768.

I'm definitely saving the correctly specified new_embeddings.safetensors tensors too:

{
  "input_embeddings": {
    "dtype": "BF16",
    "shape": [
      768,
      4096
    ],
    "data_offsets": [
      0,
      6291456
    ]
  },
  "output_embeddings": {
    "dtype": "BF16",
    "shape": [
      768,
      4096
    ],
    "data_offsets": [
      6291456,
      12582912
    ]
  }
}

So I also noticed that the two examples above have a file called added_tokens.json:

{
  "[/assistant]": 32003,
  "[/user]": 32001,
  "[assistant]": 32002,
  "[user]": 32000
}

which made me think I could just make either a fake one of these with "token_32000", "token_32001", ... as the names, or even parse the fine-tuned tokenizer_config.json file to get the proper values out...

BUT: Then I saw in Mistral-7B-Instruct-v0.3 it seems to prepend the extra tokens (!!!):

{
  "added_tokens_decoder": {
    "0": {
      "content": "<unk>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "1": {
      "content": "<s>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "2": {
      "content": "</s>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "3": {
      "content": "[INST]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "4": {
      "content": "[/INST]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "5": {
      "content": "[TOOL_CALLS]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "6": {
      "content": "[AVAILABLE_TOOLS]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },
    "7": {
      "content": "[/AVAILABLE_TOOLS]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    },

    "770": {
      "content": "[control_768]",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    }
  }
}

and there appears to be more of them than the 768 that were added (???).


It's then not even 100% clear what Mergekit with do with this when you use the +LoRA syntax anyway, so overall this seems like way too much error-prone code to try to add for these cases....

I'll have a good tidy up of the code tomorrow, add some extra command line options to choose about saving the norms, etc and put in a proper PR.


I should also say that I tried to see if it were possible to save the 1D norm vectors as a (1, 1) lora_A and (1, n) lora_B (with the hope that the mismatched dimensions might get squeezed down). This would be good as then we could store the norm's delta-values, rather than the actual norms, to match the rest of the code... But sadly the PERF code doesn't like it:

    def _create_new_module(lora_config, adapter_name, target, **kwargs):
        # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters,
        # because the first match is always used. Therefore, the default layers should be checked last.
        dispatchers = []

        # avoid eager bnb import
        if is_bnb_available():
            from .bnb import dispatch_bnb_8bit

            dispatchers.append(dispatch_bnb_8bit)

        if is_bnb_4bit_available():
            from .bnb import dispatch_bnb_4bit

            dispatchers.append(dispatch_bnb_4bit)

        dispatchers.extend(
            [
                dispatch_eetq,
                dispatch_aqlm,
                dispatch_awq,
                dispatch_gptq,
                dispatch_hqq,
                dispatch_megatron,
                dispatch_default,
            ]
        )

        new_module = None
        for dispatcher in dispatchers:
            new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
            if new_module is not None:  # first match wins
                break

        if new_module is None:
            # no module could be matched
            raise ValueError(
                f"Target module {target} is not supported. Currently, only the following modules are supported: "
                "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`."
            )

        return new_module

@cg123, I don't think there is any way in Mergekit to use the task-arithmetic to do this either, eg: take the delta between a fine-tuned model and a base model, and then add this delta onto a different base model.

This would be useful for trying to transplant delta's between models fine-tuned on the base llama-2-70b to miqu-1-70b as an alternative to Midnight-Miqu-70B-v1.0's [0, 0, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0, 0] SLERP merge method...

from mergekit.

cg123 avatar cg123 commented on June 18, 2024

Thanks for investigating this! I know LoRA with added tokens can be a minefield.

Mistral-v0.3 is uniquely crazy in having prepended the new tokens. I've never seen that anywhere before. Mergekit will do the right thing if you use tokenizer_source: union, I'm 90% sure.

RE: task arithmetic, you can do something equivalent to what you want:

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

This will give you llama2 + (miqu-llama2)*1 + (finetune-llama2)*k == miqu + (finetune-llama2)*k, up to floating point precision anyway. This will still only really work for similar base models though.

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

Thanks for investigating this! I know LoRA with added tokens can be a minefield.

Mistral-v0.3 is uniquely crazy in having prepended the new tokens. I've never seen that anywhere before. Mergekit will do the right thing if you use tokenizer_source: union, I'm 90% sure.

Yeah, I just left it truncating the tenors for now as it looks way too complex and error prone to do anything else!

For experiments like these:

https://huggingface.co/ChuckMcSneed/Euryale-1.3-L2-70B-LORA
https://huggingface.co/ChuckMcSneed/Xwin-LM-70B-V0.1-LORA

Then I think truncation is probably fine, but for high-rank archiving it would probably be better to not truncate.

RE: task arithmetic, you can do something equivalent to what you want:

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

This will give you llama2 + (miqu-llama2)*1 + (finetune-llama2)*k == miqu + (finetune-llama2)*k, up to floating point precision anyway. This will still only really work for similar base models though.

Ah, thanks! I wouldn't have thought of doing that but makes sense now! :)

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

Thanks for investigating this! I know LoRA with added tokens can be a minefield.

Mistral-v0.3 is uniquely crazy in having prepended the new tokens. I've never seen that anywhere before. Mergekit will do the right thing if you use tokenizer_source: union, I'm 90% sure.

RE: task arithmetic, you can do something equivalent to what you want:

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

This will give you llama2 + (miqu-llama2)*1 + (finetune-llama2)*k == miqu + (finetune-llama2)*k, up to floating point precision anyway. This will still only really work for similar base models though.

Just in case anybody searches and finds this, it makes sense to ignore the norms :

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight:
        - filter: norm  # Should match: 'input_layernorm', 'post_attention_layernorm' and 'norm'.
          value: 0.0
        - value: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

as their weight represents a non-negative standard deviation initialized to 1 during training.

The other weights should (hopefully) be approximately mean centred, so by adding a proportion of the fine-tuned weights it should stay approximately mean centred.

It's an interesting alternative to linear interpolation which (conceptually) standardises the scale too (and hence can make sense to perform on the norms unlike the above).

I'm just testing it as an alternative to the sophosympatheia's SLERP-merge method:

models:
  - model: /home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf
  - model: /home/llm/mergequant/models/mr-70b-v2.0.3
merge_method: slerp
base_model: /home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf
parameters:
  t:
    - value: [0, 0, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0, 0]
  embed_slerp: true
dtype: float16
tokenizer_source: model:/home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf

which has proven useful for merging miqu-1-70b with older 4k-context models, whilst keeping miqu-1-70b's 32k context.

from mergekit.

jukofyork avatar jukofyork commented on June 18, 2024

Thanks for investigating this! I know LoRA with added tokens can be a minefield.
Mistral-v0.3 is uniquely crazy in having prepended the new tokens. I've never seen that anywhere before. Mergekit will do the right thing if you use tokenizer_source: union, I'm 90% sure.
RE: task arithmetic, you can do something equivalent to what you want:

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

This will give you llama2 + (miqu-llama2)*1 + (finetune-llama2)*k == miqu + (finetune-llama2)*k, up to floating point precision anyway. This will still only really work for similar base models though.

Just in case anybody searches and finds this, it makes sense to ignore the norms :

base_model: llama-2-70b
merge_method: task_arithmetic
models:
  - model: llama_2_finetune
    parameters:
      weight:
        - filter: norm  # Should match: 'input_layernorm', 'post_attention_layernorm' and 'norm'.
          value: 0.0
        - value: <whatever>
  - model: miqu-1-70n
    parameters:
      weight: 1.0

as their weight represents a non-negative standard deviation initialized to 1 during training.

The other weights should (hopefully) be approximately mean centred, so by adding a proportion of the fine-tuned weights it should stay approximately mean centred.

It's an interesting alternative to linear interpolation which (conceptually) standardises the scale too (and hence can make sense to perform on the norms unlike the above).

I'm just testing it as an alternative to the sophosympatheia's SLERP-merge method:

models:
  - model: /home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf
  - model: /home/llm/mergequant/models/mr-70b-v2.0.3
merge_method: slerp
base_model: /home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf
parameters:
  t:
    - value: [0, 0, 0.2, 0.3, 0.4, 0.5, 0.4, 0.3, 0.2, 0, 0]
  embed_slerp: true
dtype: float16
tokenizer_source: model:/home/llm/mergequant/models/BASE/152334H_miqu-1-70b-sf

which has proven useful for merging miqu-1-70b with older 4k-context models, whilst keeping miqu-1-70b's 32k context.

Also found for the specific case of miqu-1 you need to make sure the correct "rope_theta": 1000000 is used (by copying over the miqu-1-70b-sf's config.json, etc) otherwise you get the llama-2-70b's "rope_theta": 10000 value.

from mergekit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.