Code Monkey home page Code Monkey logo

ditto's Introduction

Ditto: Fair and Robust Federated Learning Through Personalization

This repository contains the code and experiments for the manuscript:

Ditto: Fair and Robust Federated Learning Through Personalization

Fairness and robustness are two important concerns for federated learning systems. In this work, we identify that robustness to data and model poisoning attacks and fairness, measured as the uniformity of performance across devices, are competing constraints in statistically heterogeneous networks. To address these constraints, we propose employing a simple, general framework for personalized federated learning, Ditto, and develop a scalable solver for it. Theoretically, we analyze the ability of Ditto to achieve fairness and robustness simultaneously on a class of linear problems. Empirically, across a suite of federated datasets, we show that Ditto not only achieves competitive performance relative to recent personalization methods, but also enables more accurate, robust, and fair models relative to state-of-the-art fair or robust baselines.

We also provide Pytorch implementation

Preparation

Dataset generation

For each dataset, we provide links to downloadable datasets used in our experiments. We describe in our paper and the REAME files in separate ditto/data/$dataset folders on how these datasets are generated, and provide instructions and scripts on preprocessing and/or sampling data.

Downloading dependencies

pip3 install -r requirements.txt

Run the point estimation example

We provide a jupyter notebook that simulates the federated point estimation problem. To run that, make sure you are under the ditto folder, and

jupyter notebook

then open point_estimation.ipynb, and directly run the notebook cell by cell to reproduce the results.

Run on federated benchmarks

(A subset of) Options in run.sh:

  • dataset chosen from [femnist, fmnist, celeba, vehicle], where fmnist is short for Fashion MNIST.
  • model should be the corresponding model of that dataset. You can find it the model name under flearn/models/$dataset/$model.py, and take $model.
  • $optimizer chosen from ['ditto', 'apfl', 'ewc', 'kl', 'l2sgd', 'mapper', 'meta', 'fedavg', 'finetuning']
  • fedavg is training global models, ditto with lam=0 corresponds to training separate local models
  • $lambda is the lambda we use for ditto (can use dynamic lambdas by setting --dynamic_lam to 1)
  • num_corrupted is the number of corrupted devices (see the total number of devices in paper)
  • random_updates indicates whether we launch Attack 2 (Def 1 in paper)
  • boosting indicates whether we launch Attack 3 (Def 1 in paper)
  • If both random_updates and boosting is set to 0, then we default to Attack 1 (Def 1 paper)
  • By default, we disable all robust baselines. If you want to test any of them, set --optimizer=fedavg, and set any of the robust baselines to 1 (chosen from gradient_clipping, krum, mkrum, median, k_norm, k_loss, fedmgda in run.sh). For fedmgda, one needs to set an additional fedmgda_eps hyperparameter, chosen from the continuous range of [0, 1]. For our experiments, we pick the best fedmgda_eps among {0, 0.1, 0.5, 1} based on validation performance on benign devices.

Some example instructions on Fashion MNIST

  • Download datasets (link and instructions under ditto/data/fmnist/README.md)
  • Fashion MNIST, Ditto, without attacks, lambda=1: bash run_fashion_clean_ditto_lam1.sh
  • Fashion MNIST, Ditto, A1 (50% adversaries), lambda=1: bash run_fashion_a1_50_ditto_lam1.sh

Some example instructions on Vehicle

  • Download datasets (link and instructions under ditto/data/vehicle/README.md)
  • Vehicle, Ditto, without attacks, lambda=1: bash run_vehicle_clean_ditto_lam1.sh
  • Vehicle, Ditto, A1 (50% adversaries), lambda=1: bash run_vehicle_a1_50_ditto_lam1.sh

ditto's People

Contributors

litian96 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

ditto's Issues

learning rates, nb of iter (global vs local), local line on the comparison graphs

Good afternoon,

I am a student at EPFL and I study your paper on Ditto for a project.

There are three aspect that are not clear to me in the Ditto algorithm.

  1. are you using different learning rate for the global and the local optimization (ηg, ηl) and if yes how do choose them.
  2. Where do you specify the number of global and local iteration (r,s) and how do you choose them.
  3. On the performance graphs, to what correspond the 'local' constant line

algo

graph

Thank you very much for your repository,

Best regards,

Elie Graham

Bug Report

Hello Litian,

When I run bash run_fashion_a1_50_ditto_lam1.sh I found a bug:
flearn/trainers_MTL/ditto.py", line 53, in train
c.train_data['y'][i] = np.random.randint(0, 10, len(c.train_data['y']))
UnboundLocalError: local variable 'i' referenced before assignment
Could you tell me how to fix it? Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.