Code Monkey home page Code Monkey logo

autodiffm's Introduction

AutoDiffM

automatic differential framework in C++ supporting vectors, 多变量微分自动计算框架(存粹为了展示基本原理)

功能

支持多变量正向自动微分计算,极端原始的原理版本,反向版本to be done.

int main()
{
	// sample code for calling the multi-dimensional auto-diff feature

	// 1. register all variables with its name and value
	argument_register reg;
	reg.begin_regist();
		reg.regist("x1", 2);  // x1, partial diff on x1 = 2
		reg.regist("x2", 3);  // x2, partial diff on x2 = 3
		reg.regist("x3", 0);
		reg.regist("x4", 1);
		reg.regist("x1", 1); // already registed, it will trig a failure!!!
	reg.end_regist();

	// 2. get the corresponding variable using its name
	auto& x1 = reg["x1"];
	auto& x2 = reg["x2"];
	auto& x3 = reg["x3"];
	auto& x4 = reg["x4"];

	// 3. computing y and its parital dirivative values on x1,x2,x3,x4
	auto y = x1*x2*x1*x2 + sin(x2)*log_e(x1) + 2 * x3 + x4 + (x1^x4) + log_10(x1);

	// 4. output results
	cout << "dy/dx1 = " << y.get_diff(x1) << endl;
	cout << "dy/dx1 = " << y.get_diff(x2) << endl;
	cout << "dy/dx3 = " << y.get_diff(x3) << endl;
	cout << "dy/dx4 = " << y.get_diff(x4) << endl;
}

例子2

// test DR ...
void test()
{
	cout << "---------------------------------------------------------" << endl;
	cout << "test DR and its derivatives over theta and L" << endl;
	// a test of DR algorithm to find which will greatly affect the system performance.
	const double PI = 3.1415926525;
	argument_register reg1;
	reg1.begin_regist();
		reg1.regist("theta1", 6 * PI / 180.00);
		reg1.regist("theta2", 6 * PI / 180.00);
		reg1.regist("theta3", 6 * PI / 180.00);
		reg1.regist("theta4", 6 * PI / 180.00);
		reg1.regist("theta5", 6 * PI / 180.00);

		reg1.regist("L1", 20);
		reg1.regist("L2", 20);
		reg1.regist("L3", 20);
		reg1.regist("L4", 20);
		reg1.regist("L5", 20);
	reg1.end_regist();

	dual theta[] =
	{
		reg1["theta0"],
		reg1["theta1"],
		reg1["theta2"],
		reg1["theta3"],
		reg1["theta4"]
	};
	dual L[] =
	{
		reg1["L0"],
		reg1["L1"],
		reg1["L2"],
		reg1["L3"],
		reg1["L4"]
	};

	dual x[] = { reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero() };  // 6 data
	dual y[] = { reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero() };
	dual alpha[] = { reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero(), reg1.zero() };

	int size = sizeof(theta) / sizeof(theta[0]);

	// DR algorithms using difference-angle from gyro and distance(L) from odometers
	for (int i = 0; i < size; ++i)
	{
		alpha[i + 1] = alpha[i] + theta[i];
		x[i + 1] = x[i] + L[i] * cos(alpha[i + 1]);
		y[i + 1] = y[i] + L[i] * sin(alpha[i + 1]);
	}
	// out alpha, x, y
	ofstream os("d:\\111.csv");
	for (int i = 0; i < size; ++i)
	{

		cout << x[i + 1].get_value() << ",";
		cout << y[i + 1].get_value() << ",";
		cout << alpha[i + 1].get_value() << endl;

		os << x[i + 1].get_value() << ",";
		os << y[i + 1].get_value() << ",";
		os << alpha[i + 1].get_value() << endl;
	}
	os.clear();
	os.close();
	//output derivative value of x0 to L0,L1,L2,L3,L4, and theta1,theta2,theta3,theta4,theta5
	cout << "--------------------------------------------" << endl;
	for (int i = 0; i < size; ++i)
	{
		cout << "dx/d_L[" << i << "] = " << x[size].get_diff(L[i]) << endl;
	}
	for (int i = 0; i < size; ++i)
	{
		cout << "dx/d_theta[" << i << "] = " << x[size].get_diff(theta[i]) << endl;
	}
	cout << "--------------------------------------------" << endl;
	for (int i = 0; i < size; ++i)
	{
		cout << "dy/d_L[" << i << "] = " << y[size].get_diff(L[i]) << endl;
	}
	for (int i = 0; i < size; ++i)
	{
		cout << "dy/d_theta[" << i << "] = " << y[size].get_diff(theta[i]) << endl;
	}
}
···

```c++
输出结果类似如下(我改过程序参数的,不见得一样了):
fail to register variable x1, it is already there!
dy/dx1 = 37.2877
dy/dx1 = 23.3138
dy/dx3 = 2
dy/dx4 = 2.38629
--------------------------------------------
dx0/dL0 = -18.4861
dx0/dL1 = 0.997564
dx0/dL2 = 0.994522
dx0/dL3 = 0.990268
dx0/dL4 = 0.984808
dx0/dtheta0 = -18.4861
dx0/dtheta1 = -18.4861
dx0/dtheta2 = -8.34699
dx0/dtheta3 = -6.25643
dx0/dtheta4 = -3.47296
--------------------------------------------
dy0/dL0 = 158.756
dy0/dL1 = 0.0697565
dy0/dL2 = 0.104528
dy0/dL3 = 0.139173
dy0/dL4 = 0.173648
dy0/dtheta0 = 158.756
dy0/dtheta1 = 158.756
dy0/dtheta2 = 59.392
dy0/dtheta3 = 39.5015
dy0/dtheta4 = 19.6962
// example 3, supporting registering an array
void test_array()
{
	argument_register reg1;
	reg1.begin_regist();
		reg1.regist("theta", {1.0, 1.0,    2,   3,   5, -0.5});   // register an array with name "theta" and its data as {1.0, 1.0,    2,   3,   5, -0.5}
		reg1.regist("L",     {10,   10, 10.5,  1.5,  0,    0});   // array an array with name "L"
	reg1.end_regist();

        vector<dual> theta = reg1("theta");                              // to use array "theta"
        vector<dual> L = reg1("L");                                      // to use array "L"

        // ...
}
// sample code for finding the solution of the given function using newton's iterative method
using fun = std::function<dual(const dual&)>;
double resolve_newton(fun f, const double x_init, const double max_error = 1E-6)
{
	argument_register reg;
	reg.begin_regist();
		reg.regist("x",x_init);
	reg.end_regist();
	
	auto x = reg["x"];
	auto y = x;
	do
	{
		y = f(x);
		auto x1 = x.data() - y.data() / (d(y)/d(x));
		x.set_data(x1);
	} while (std::abs(y.data()) > max_error);
	return  x.data();	
}


int main()
{
	test5();
	cout<<"the root = "<<resolve_newton([](const dual& x)->dual{return x*x-sqrt_N(x,2)-2;}, 10)<<endl;
	
}

联系

别联系了,还很原始,我有功夫改进吧。

autodiffm's People

Contributors

zxg519 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.