replicate / cog-sdxl Goto Github PK
View Code? Open in Web Editor NEWStable Diffusion XL training and inference as a cog model
Home Page: https://replicate.com/stability-ai/sdxl
License: Apache License 2.0
Stable Diffusion XL training and inference as a cog model
Home Page: https://replicate.com/stability-ai/sdxl
License: Apache License 2.0
When trying to run the SDXL preprocessor it can't find the model that's in the sample code (i.e. "replicate/sdxl_preprocess:bd1158a5052ed46176da900ad7e2a80ea04a3c46196d93f9e1db879fd1ce7f29"). The link to the model in the documentation is broken as well. Any chance to get access to this?
I know that this template, as of now, doesn't support training multiple concepts. I would love to help out with a PR, but would need some guidance as to how this could be implemented.
We're consistently getting crashes in the training script if we submit a comma in the token_string
Failing training model
https://replicate.com/p/e76k0ze9a9rgm0cf1cftefcr0m?input=form
Traceback (most recent call last):
File "/usr/local/lib/python3.9/site-packages/cog/server/worker.py", line 217, in _predict
result = predict(**payload)
File "train.py", line 133, in train
n_tok = int(token.split(":")[1])
IndexError: list index out of range
Reproducible, crashes every time.
The lora model does not load when run with the code in the following blog.
blog: https://replicate.com/blog/fine-tune-sdxl
When I changed the lora model loading to the load_attn_procs method as shown below, the model loaded correctly.
import torch
from diffusers import DiffusionPipeline
from safetensors import safe_open
# need this command : git clone https://github.com/replicate/cog-sdxl cog_sdxl
from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler
from diffusers.models import AutoencoderKL
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
pipe.unet.load_attn_procs("/content/lora.safetensors") # should take < 2 seconds
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings("/content/embeddings.pti")
Thank you for the great service and easy fine tuning of SDXL.
in the dreambooth api, the model is automatically pushed without the need to create a model on replicate
Img2img + mask is called inpainting and is achieved by a black and white mask to instruct where SDXL should "inpaint" and preserve original image.
Currently the mask colors:
If the input image contains white colors, or large portions of white colors (unknown), SDXL inpainting with a mask does not work.
Example: https://replicate.com/p/rh7cbvtbclfkon7nirofjr6z2u
On replicate.com, fine-tuned cog-sdxl versions result in a version which will fail setup with error:
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: 'URLPath' object has no attribute 'encode'
this has no effect on replicate.com, where we can run fine-tuned versions against the base model by hotswapping weights, but I suspect that it might have broken cog predict
for fine-tuned models?
Full setup run log below:
Loading safety checker...
downloading url: https://weights.replicate.delivery/default/sdxl/safety-1.0.tar
downloading to: ./safety-cache
downloading took: 1.6426427364349365
downloading url: https://weights.replicate.delivery/default/sdxl/sdxl-vae-fix-1.0.tar
downloading to: ./sdxl-cache
downloading took: 8.76220154762268
Loading sdxl txt2img pipeline...
Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]
Loading pipeline components...: 14%|█▍ | 1/7 [00:00<00:05, 1.17it/s]
Loading pipeline components...: 43%|████▎ | 3/7 [00:01<00:01, 2.75it/s]
Loading pipeline components...: 71%|███████▏ | 5/7 [00:01<00:00, 4.76it/s]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00, 5.79it/s]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00, 4.35it/s]
Traceback (most recent call last):
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/worker.py", line 185, in _setup
run_setup(self._predictor)
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/predictor.py", line 98, in run_setup
predictor.setup(weights=weights)
File "predict.py", line 184, in setup
self.load_trained_weights(weights, self.txt2img_pipe)
File "predict.py", line 77, in load_trained_weights
local_weights_cache = self.weights_cache.ensure(weights)
File "/src/weights.py", line 78, in ensure
path = self.weights_path(url)
File "/src/weights.py", line 98, in weights_path
hashed_url = hashlib.sha256(url.encode()).hexdigest()
AttributeError: 'URLPath' object has no attribute 'encode'
Traceback (most recent call last):
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/runner.py", line 292, in setup
for event in worker.setup():
File "/root/.pyenv/versions/3.9.18/lib/python3.9/site-packages/cog/server/worker.py", line 126, in _wait
raise FatalWorkerException(raise_on_error + ": " + done.error_detail)
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: 'URLPath' object has no attribute 'encode'
I would like to suggest considering the support for some commonly used styles, as mentioned in the reference link: https://huggingface.co/spaces/google/sdxl. This could potentially enhance the user experience. If it is feasible, I am willing to submit a pull request (PR) for the implementation.
I am using this model for some specific use cases via the Replicate API. I know that I can fork this and create my own version, but I only have a small request that I think that everyone could benefit from it, so instead of writing this small enhancement myself, thought that perhaps it would make sense for to you just implement it?
Right now, the images generated via the Replicate API does not include any parameter info, making it hard to know the parameters used after the images are saved locally. I would like to see if you would consider adding these metadata into the PNG header, similar to how Automatic1111 does it.
You can read the A1111 source, but potentially an even easier way is to see how image viewers read that info. You can view the implementation by SD Prompt Reader here:
Essentially, the prompt and other meta data is stored inside the info field:
There‘s a very good reason to store inside the info field instead of EXIF — with one of the primary reason being that it won’t be stripped by Discord. This thus allows PNGs to be shared on Discord server and the prompts can then be shared and evolved among friends.
I hope that you would consider this feature as I believe that it would be beneficial to all.
I can explain my particular use case if necessary, but even merely for personal usage, not having prompt information makes these PNGs somewhat less useful. I know that I could possibly insert them into the PNG myself after receiving it from the replicate server, but in some ways it seems that it would be fantastic if all the PNGs already have that info.
Thanks very much!
I pulled the repo and ran the following command:
cog train -i input_images=@example_datasets/zeke.zip -i use_face_detection_instead=True
However, I encountered the following error:
...
Running prediction...
{"logger": "uvicorn.error", "timestamp": "2024-06-26T02:08:26.505009Z", "exception": "Traceback (most recent call last):\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/uvicorn/protocols/http/httptools_impl.py\", line 399, in run_asgi\n result = await app( # type: ignore[func-returns-value]\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/uvicorn/middleware/proxy_headers.py\", line 70, in __call__\n return await self.app(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/applications.py\", line 284, in __call__\n await super().__call__(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/applications.py\", line 122, in __call__\n await self.middleware_stack(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/middleware/errors.py\", line 184, in __call__\n raise exc\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/middleware/errors.py\", line 162, in __call__\n await self.app(scope, receive, _send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/middleware/exceptions.py\", line 79, in __call__\n raise exc\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/middleware/exceptions.py\", line 68, in __call__\n await self.app(scope, receive, sender)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py\", line 20, in __call__\n raise e\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py\", line 17, in __call__\n await self.app(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/routing.py\", line 718, in __call__\n await route.handle(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/routing.py\", line 276, in handle\n await self.app(scope, receive, send)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/starlette/routing.py\", line 66, in app\n response = await func(request)\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/applications.py\", line 239, in openapi\n return JSONResponse(self.openapi())\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/applications.py\", line 214, in openapi\n self.openapi_schema = get_openapi(\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/openapi/utils.py\", line 421, in get_openapi\n definitions = get_model_definitions(\n File \"/root/.pyenv/versions/3.9.19/lib/python3.9/site-packages/fastapi/utils.py\", line 64, in get_model_definitions\n model_name = model_name_map[model]\nKeyError: <enum 'lr_scheduler'>", "severity": "ERROR", "message": "Exception in ASGI application\n"}
ⅹ Failed to get OpenAPI schema: 500
Could you please provide some insight into what might be going wrong?
When setting refiner=base_image_refiner
, the refine_steps
argument promises to control how many denoising steps will be performed with the refiner:
However, in practice, this is only true if you additional set prompt_strength=1.0
. The reason is that:
cog-sdxl
sets num_inference_steps=refiner_steps
here, anddiffusers
computes the init_step
for refinement based both on num_inference_steps
and strength
here.This results in surprising behavior, where you set refiner_steps=N
only to see that it ran for fewer steps. I think you could either:
strength=1
hereDeploying a custom model works, but for sites with low traffic the cold-boot can be painfully long.
I'm curious if we could add an optional parameter that'd point to a lora stored on huggingface (or civit.ai) that we can load and un-load after the generation.
I'm happy to contribute if you give me a thumbs up.
Right now we generate image captions with BLIP. Our training input is just a collection of images. It would be nice if we could accept a list of images and their captions.
Using the built-in upcast fixed vae, fine details like hair come out fuzzy and pixelated. This is particularly present if cog-sdxl is used to generate a full fine tune and then loras are created from the full fine tune from images generated from the model.
Changing to use madebyollin/sdxl-vae-fp16-fix
for fp16/bf16 training fixes the issue.
Generation from full fine tune:
Generation from LoRA trained on full fine tune as base model:
The captioning step is done before the segmentation step + cropping of is face model
- meaning you will get a caption for the whole image but crop out to focus on the face - which could change what's visible on the image and invalidate the caption.
https://github.com/replicate/cog-sdxl/blob/main/preprocess.py#L500-L531
Even with the same settings & seed img2img always has significant randomness & changes in output. Occasionally output is the same as expected.
This behavior makes it extremely difficult to experiment with settings and see their effects. As it's impossible to attribute the changes to the settings or inherent randomness in the output.
img2img generates consistently the same output with same input.
After 2 days of experiments I think I traced it back to the model being loaded. If the img2img is generated on the same instance with the following message in the log
.. weights already loaded
then the output is consistent with images generated prior.
If the im2img happens on 2 different model instances with the message
Ensuring enough disk space... Loading fine-tuned model
in the log - the outputs are different.
Example:
*the differences are more pronounced at higher resolutions like 1024
Missing Leg | Big Face, 3 Ears |
---|---|
downloading url: weights
downloading to: ./trained-model
Error downloading file: Head "weights": unsupported protocol scheme ""
Traceback (most recent call last):
File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/cog/server/worker.py", line 185, in _setup
run_setup(self._predictor)
File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/cog/predictor.py", line 98, in run_setup
predictor.setup(weights=weights)
File "predict.py", line 186, in setup
self.load_trained_weights(weights, self.txt2img_pipe)
File "predict.py", line 83, in load_trained_weights
download_weights(weights, local_weights_cache)
File "predict.py", line 68, in download_weights
result.check_returncode()
File "/root/.pyenv/versions/3.9.17/lib/python3.9/subprocess.py", line 460, in check_returncode
raise CalledProcessError(self.returncode, self.args, self.stdout,
subprocess.CalledProcessError: Command '['pget', '-x', 'weights', './trained-model']' returned non-zero exit status 1.
ⅹ Model setup failed
Currently loading of weights from weights.replicate.delivery is failing periodically in production
Loading safety checker...
downloading url: https://weights.replicate.delivery/default/sdxl/safety-1.0.tar
downloading to: ./safety-cache
Traceback (most recent call last):
File "/root/.pyenv/versions/3.11.1/lib/python3.11/site-packages/cog/server/worker.py", line 185, in _setup
run_setup(self._predictor)
File "/root/.pyenv/versions/3.11.1/lib/python3.11/site-packages/cog/predictor.py", line 98, in run_setup
predictor.setup(weights=weights)
File "/src/predict.py", line 99, in setup
download_weights(SAFETY_URL, SAFETY_CACHE)
File "/src/predict.py", line 56, in download_weights
subprocess.check_output(["pget", "-x", url, dest])
File "/root/.pyenv/versions/3.11.1/lib/python3.11/subprocess.py", line 466, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.pyenv/versions/3.11.1/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['pget', '-x', 'https://weights.replicate.delivery/default/sdxl/safety-1.0.tar', './safety-cache']' returned non-zero exit status 1.
Traceback (most recent call last):
File "/root/.pyenv/versions/3.11.1/lib/python3.11/site-packages/cog/server/runner.py", line 292, in setup
for event in worker.setup():
File "/root/.pyenv/versions/3.11.1/lib/python3.11/site-packages/cog/server/worker.py", line 126, in _wait
raise FatalWorkerException(raise_on_error + ": " + done.error_detail)
cog.server.exceptions.FatalWorkerException: Predictor errored during setup: Command '['pget', '-x', 'https://weights.replicate.delivery/default/sdxl/safety-1.0.tar', './safety-cache']' returned non-zero exit status 1.
We are still debugging what layer this occurs.
Currently executing pget
is using check_output
which is causing the output to not be displayed in the "setupruns"
To help with debugging, doing check_call
should let stdout/stderr be captured, which could let us see the underlying error
Hey @cloneofsimo what's the license for the code in this repo?
Situation:
Input captioning text: a photo of TOK 0%| | 0/10 [00:00<?, ?it/s] 0%| | 0/10 [00:00<?, ?it/s] Traceback (most recent call last): File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/cog/server/worker.py", line 217, in _predict result = predict(**payload) File "train.py", line 138, in train input_dir = preprocess( File "/src/preprocess.py", line 78, in preprocess load_and_save_masks_and_captions( File "/src/preprocess.py", line 424, in load_and_save_masks_and_captions captions = blip_captioning_dataset( File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/src/preprocess.py", line 221, in blip_captioning_dataset inputs = processor(image, text=text, return_tensors="pt").to("cuda") File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/transformers/feature_extraction_utils.py", line 224, in to new_data[k] = v.to(*args, **kwargs) File "/root/.pyenv/versions/3.9.17/lib/python3.9/site-packages/torch/cuda/__init__.py", line 247, in _lazy_init torch._C._cuda_init() RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
Seems like it's not recognizing the Nividia driver on the device, what could be the issue? I'm unable to debug the device where the docker image is being run as I can't ssh
This issue is created for stability-ai/sdxl.
When I use the source image and its mask of dimension 1024x768, the output image is 1024x1024.
For a reference, andreasjansson/stable-diffusion-inpainting does a great job of in-painting with SD 1.5. I'm expecting similar output with SDXL.
Just a heads up that I have started work on a webui for cog training similar to the one for kohya_ss.
If any of you guys are interested to contribute code you are welcome.
Link to the repo if you are interested. Hoping to have a MVP by end of day: https://github.com/bmaltais/cog-sdxl
Hello @daanelson @cloneofsimo
I have been trying to get the captions.csv files working but when checking the run logs on replicate.com the code just auto-captions everything. I tested the example caption file from test_remote_train.py
"https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar"
but that run still got autocaptioned.
Can you fix this issue soon, please
The filename required for custom captions is called "caption.csv" (code)
However, there is an error message that refers to it in the plural form as "captions.csv" (code)
It is also referred to as "captions.csv" on the Replicate blog (which is where I originally saw it)
Additionally this page does not mention the caption.csv which I think would be another good place to document it: https://replicate.com/stability-ai/sdxl#training-inputs
https://github.com/huggingface/hf_transfer
Just needs two lines of code
- "hf_transfer"
to cog.yaml
python_packages
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
to the top of train.py
and/or predict.py
to help support a similar flow to https://replicate.com/cloneofsimo/lora, would it be possible to add an unload method for a lora?
It appears you can call load multiple times and it will overwrite the last used lora, but it would be useful to be able to unload back to default SDXL weights, so that you could use the public SDXL model to do lora inference
In predict.py
feature extractor is loaded from a cache folder at line 154
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
But that cache folder is never downloaded.
It should be added in /script/download_weigths.py
This issue is for stability-ai/sdxl model. When I use the model for img2img predictions and set custom width and height, the resulting output still follows the dimensions of the original image. For example I set the width to 1024 and height to 768 and the parameters are ignored. Can we add the width and height args to the img2img_pipe?
if image and mask:
print("inpainting mode")
sdxl_kwargs["image"] = self.load_image(image)
sdxl_kwargs["mask_image"] = self.load_image(mask)
sdxl_kwargs["strength"] = prompt_strength
sdxl_kwargs["width"] = width
sdxl_kwargs["height"] = height
pipe = self.inpaint_pipe
elif image:
print("img2img mode")
sdxl_kwargs["image"] = self.load_image(image)
sdxl_kwargs["strength"] = prompt_strength
pipe = self.img2img_pipe
else:
print("txt2img mode")
sdxl_kwargs["width"] = width
sdxl_kwargs["height"] = height
pipe = self.txt2img_pipe
On localhost, it is working fine, i am currently using free version of the api.
Instance can think it has loaded a fine-tune even though it hasn't
replicate_weights
Because load_lora_weights
sets self.tuned_weights
to the passed in URL before actually loading the weights, it means if a prediction is canceled while the download/... is happening - (Eg before finishing) - you can end up in an invalid state
def load_trained_weights(self, weights, pipe):
from no_init import no_init_or_tensor
# weights can be a URLPath, which behaves in unexpected ways
weights = str(weights)
if self.tuned_weights == weights:
print("skipping loading .. weights already loaded")
return
self.tuned_weights = weights
### SNIP - now we actually load the weights ###
We need to ensure that load_lora_weights
leaves the model in a recoverable/correct state even if canceled during a prediction
Current code is case sensitive:
Lines 404 to 410 in de1a389
This will quietly ignore .JPG
, .JPEG
and .PNG
files.
Hi,
I have been trying to load fine-tuned SDXL models to Automatic1111. When running the inference on Replicate, it works correctly.
I downloaded the weights of the model and uploaded them to the folder /stable-diffusion-webui/models/lora and
load into the prompts with the Lora tab which adds "lora:lora:1".
When generating the image, it doesn't work and generate other random things
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.