Code Monkey home page Code Monkey logo

Comments (8)

CodingBeard avatar CodingBeard commented on August 20, 2024 2

Hey, I've continued digging and managed to get the human readable graph def of that saved model. Under the hood the save method seems to be using op.SaveV2 to save the variables and you can see the function definition of StatefulPartitionedCall_2 by searching the graph def for: __inference__traced_save_719

Human readable graph: https://gist.github.com/CodingBeard/769a42d06a9b9d518e69f6c1ae41e45b

I got the graph in that format by making use of tensorflow/c/c_api_experimental.h:TF_GraphDebugString by adding the following function (I'm aware of the lack of memory management) to github.com/galone/tensorflow/go/graph.go

func (g *Graph) GetDebugString() string {
	tmp := C.ulong(1)
	graphDebugChar := C.TF_GraphDebugString(g.c, (*C.ulong)(&tmp))
	goString := C.GoString(graphDebugChar)

	return goString
}

You can figure out what's going on under the hood of a StatefulPartitionedCall (which is a just the graph of a tf.function) using that debug string method.

from tfgo.

galeone avatar galeone commented on August 20, 2024

If you're interested in using a model for inference, it's better to train the model in Pyton, export it as a SavedModel and then use it from tfgo.

If instead you're interested in training a model with tfgo (but currently we don't support model saving, hence you have to keep it in memory after training and you can't save it on disk), you must export the training graph + model in Python as a SavedModel and use it from tfgo. I wrote an article that explains this process: https://pgaleone.eu/tensorflow/go/2020/11/27/deploy-train-tesorflow-models-in-go-human-activity-recognition/

from tfgo.

CodingBeard avatar CodingBeard commented on August 20, 2024

@galeone

I read through your guide and figured out how to save a model after training in golang.

Note that I am using python TF 2.4.1 and github.com/tensorflow/tensorflow v2.0.3+incompatible. It may work on other TF versions and also with tfgo.

When I saved a model tensorflow python it had an additional two StatefulPartitionCall outputs and a saver_filename input. These are not visible using saved_model_cli but you can see them by looping through .Operations() on a loaded model in golang.

The first of the two extra outputs, when given a string value E.G. tf.NewTensor("save_dir/variables/variables") for saver_filename will save the variables (I.E. variables.index, variables.data-00000-of-00001) in save_dir/variables.

You can just then copy the original saved_model.pb into save_dir, and when you load from that dir it will have the golang trained weights.

Thanks for your article, it pointed me in the right direction.

from tfgo.

galeone avatar galeone commented on August 20, 2024

@CodingBeard thanks for reading the article and for the feedback!

I'm not aware of that method of getting the operation name, but your experience can help other readers. Would you be so kind to open a merge request to the blog repo (https://github.com/galeone/galeone.github.io/blob/master/_posts/2020-11-27-deploy-train-tesorflow-models-in-go-human-activity-recognition.md) adding some lines on how to find the node names if you can't find them using saved_model_cli?

It will be a great addition!

from tfgo.

CodingBeard avatar CodingBeard commented on August 20, 2024

I'm not using tfgo in my project, but with an older version of tensorflow's golang package it looks like this:

model, e := tf.LoadSavedModel("model_dir", []string{"serve"}, nil)
if e != nil {
    panic(e.Error())
}

for _, operation := range model.Graph.Operations() {
    fmt.Println(operation.Name())
}

It seems the save node is always the first hidden StatefulPartitionCall after the ones visible in saved_model_cli. So for example in your article where the learn signature is StatefulPartitionCall and the predict signature is StatefulPartitionCall_1 then the one to call for saving the variables will be StatefulPartitionCall_2

The full example using python TF 2.4.1 and golang tensorflow/tensorflow/go v2.0.3 is as follows:

Python:

class GolangModel(tf.Module):
    def __init__(self):
        super().__init__()

        bool_input = k.layers.Input(
            shape=(3,),
            name='bool_input',
            dtype='float32',
            batch_size=10
        )

        output = k.layers.Dense(
            1
        )(bool_input)

        self.model = Model(bool_input, output)
        self._global_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self._optimizer = k.optimizers.Adam()
        self._loss = k.losses.binary_crossentropy

    @tf.function(
        input_signature=[
            tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
        ]
    )
    def learn(self, data, labels):
        self._global_step.assign_add(1)
        with tf.GradientTape() as tape:
            loss = self._loss(labels, self.model(data))

        gradient = tape.gradient(loss, self.model.trainable_variables)
        self._optimizer.apply_gradients(zip(gradient, self.model.trainable_variables))
        return {"loss": loss}

    @tf.function(input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)])
    def predict(self, data):
        prediction = self.model(data)
        return {"prediction": prediction}

gm = GolangModel()

gm.learn(
    tf.zeros([10, 3], dtype=tf.float32),
    tf.zeros([10, 1], dtype=tf.float32),
)
gm.predict(tf.zeros((10, 3), dtype=tf.float32))

tf.saved_model.save(
    gm,
    "/data/models/gm",
    signatures={
        "learn": gm.learn,
        "predict": gm.predict,
    },
)

golang:

        gm, e := tf.LoadSavedModel("/data/models/gm", []string{"serve"}, nil)
	if e != nil {
		errorHandler.Error(e)
		return nil
	}

	boolInput, e := tf.NewTensor([][]float32{{0.5, 0.5, 0.5}, {0, 0, 0}})

	result, e := gm.Session.Run(
		map[tf.Output]*tf.Tensor{
			gm.Graph.Operation("predict_data").Output(0): boolInput,
		},
		[]tf.Output{
			gm.Graph.Operation("StatefulPartitionedCall_1").Output(0),
		},
		nil,
	)
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	floatResults, ok := result[0].Value().([][]float32)
	if !ok {
		fmt.Println("No float results")
		return nil
	}

	fmt.Println(floatResults)

	trainData, e := tf.NewTensor([][]float32{
		{0.5, 0.5, 0.5},
		{0.5, 0.5, 0.5},
		{0.5, 0.5, 0.5},
		{0.5, 0.5, 0.5},
		{0.5, 0.5, 0.5},
		{0, 0, 0},
		{0, 0, 0},
		{0, 0, 0},
		{0, 0, 0},
		{0, 0, 0},
	})
	trainLabels, e := tf.NewTensor([][]float32{
		{1},
		{1},
		{1},
		{1},
		{1},
		{0},
		{0},
		{0},
		{0},
		{0},
	})

	for i := 0; i < 1000; i++ {
		_, e := gm.Session.Run(
			map[tf.Output]*tf.Tensor{
				gm.Graph.Operation("learn_data").Output(0):   trainData,
				gm.Graph.Operation("learn_labels").Output(0): trainLabels,
			},
			[]tf.Output{
				gm.Graph.Operation("StatefulPartitionedCall").Output(0),
			},
			nil,
		)
		if e != nil {
			errorHandler.Error(e)
			return e
		}
	}

	boolTest, e := tf.NewTensor([][]float32{{0.5, 0.5, 0.5}, {0, 0, 0}})

	test, e := gm.Session.Run(
		map[tf.Output]*tf.Tensor{
			gm.Graph.Operation("predict_data").Output(0): boolTest,
		},
		[]tf.Output{
			gm.Graph.Operation("StatefulPartitionedCall_1").Output(0),
		},
		nil,
	)
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	testResults, ok := test[0].Value().([][]float32)
	if !ok {
		fmt.Println("No post training float results")
		return nil
	}

	fmt.Println(testResults)

	os.RemoveAll("gm-trained")
	os.MkdirAll("gm-trained/variables", os.ModePerm)
	savedModel, e := ioutil.ReadFile("/data/models/gm/saved_model.pb")
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	e = ioutil.WriteFile("gm-trained/saved_model.pb", savedModel, os.ModePerm)
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	filenameInput, e := tf.NewTensor("gm-trained/variables/variables")
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	_, e = gm.Session.Run(
		map[tf.Output]*tf.Tensor{
			gm.Graph.Operation("saver_filename").Output(0): filenameInput,
		},
		[]tf.Output{
			gm.Graph.Operation("StatefulPartitionedCall_2").Output(0),
		},
		nil,
	)
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	gmTrained, e := tf.LoadSavedModel("gm-trained", []string{"serve"}, nil)
	if e != nil {
		errorHandler.Error(e)
		return nil
	}

	boolTest, e = tf.NewTensor([][]float32{{0.5, 0.5, 0.5}, {0, 0, 0}})

	test, e = gmTrained.Session.Run(
		map[tf.Output]*tf.Tensor{
			gmTrained.Graph.Operation("predict_data").Output(0): boolTest,
		},
		[]tf.Output{
			gmTrained.Graph.Operation("StatefulPartitionedCall_1").Output(0),
		},
		nil,
	)
	if e != nil {
		errorHandler.Error(e)
		return e
	}

	testResults, ok = test[0].Value().([][]float32)
	if !ok {
		fmt.Println("No post training float results")
		return nil
	}

	fmt.Println(testResults)

Hopefully that makes sense and you can use it for your article, and maybe even add a method to save the model in this repo.

from tfgo.

galeone avatar galeone commented on August 20, 2024

woah, thanks for sharing! I guess I'm going to extract something from your code for tfgo 😄

from tfgo.

MIchaelFU0403 avatar MIchaelFU0403 commented on August 20, 2024

How can i save a trained model using golang? With the help of tf-go and Mr.CodingBeard 's code, i can train simple machine leanring model in an online scenario and test them by evaluating related performance metric. Could you please give me some insights to save trained model in save_model or checkpoint format.

from tfgo.

MIchaelFU0403 avatar MIchaelFU0403 commented on August 20, 2024

i tried to use privided "saver_filename" and "StatefulPartitionedCall_2" but did not work in other self-build models that are similar to Mr.CodingBeard 's gm

from tfgo.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.