Code Monkey home page Code Monkey logo

Comments (2)

HansBambel avatar HansBambel commented on September 23, 2024

Sure. Here is the code for getting the initial prediction:

import torch
import numpy as np
from utils import data_loader_wind_nl
import plotly.graph_objects as go
import pickle
import os
import time

def get_predictions(pred_time, model_name):
    modelFolder = "models/wind_speed_nl_comparison"

    dataFile = "data/Wind_data_NL/dataset.pkl"
    # Load the min and max values
    scalerFile = "data/Wind_data_NL/scaler.pkl"
    with open(scalerFile, "rb") as f:
        scaleValues = pickle.load(f)
    minVals = scaleValues["feature_min_train"]
    maxVals = scaleValues["feature_max_train"]
    # we only care about the wind speed scaling
    minWS = minVals[0]
    maxWS = maxVals[0]

    dev = "cpu"

    test_dl_CTF = data_loader_wind_nl.get_test_loader(dataFile,
                                                  input_timesteps=6,
                                                  prediction_timestep=pred_time,
                                                  CTF=True,
                                                  batch_size=1,
                                                  num_workers=0,
                                                  pin_memory=True)


    start_time = time.time()

    loaded = torch.load(f"{modelFolder}/{model_name}")
    model = loaded["model"]
    model.load_state_dict(loaded["state_dict"])
    model.to(dev)
    model.eval()

    test_dl = test_dl_CTF 
    pred = np.zeros((len(test_dl), 7))
    target = np.zeros((len(test_dl), 7))
    with torch.no_grad():
        for i, (xb, yb) in enumerate(test_dl):
            pred[i] = (model(xb.to(dev))*(maxWS-minWS)+minWS).cpu().numpy()
            target[i] = (yb.to(dev)*(maxWS-minWS)+minWS).cpu().numpy()
    mse = np.mean((target-pred)**2, axis=0)
    mae = np.mean(abs(target-pred), axis=0)

    print(f"Model {model_name} took: {time.time()-start_time:.3f}s")
    return pred, target, mse, mae

pred_time = 1
modelFolder = "models/wind_speed_nl_comparison"
models = os.listdir(modelFolder)
models_to_compare = [f"wind_model_NL_{pred_time}h_MultidimConvNetwork_16KernelsPerLayer.pt"]
wind_models = [m for m in models if f"{pred_time}h" in m and ".pt" in m]
print(wind_models)

pred_multi, target, mse_multi, mae_multi = get_predictions(pred_time, models_to_compare[0])

And this is the code for plotting it:

start = 1337
length = 250
cityNames = ['Schiphol', 'De Bilt', 'Leeuwarden', 'Eelde', 'Rotterdam', 'Eindhoven', 'Maastricht']
figs = []
print(target.shape)
# change to matplotlib. Real solid line, Ours dashed line
for i in range(target.shape[-1]):
    plt.figure(figsize=(15,5))
    plt.plot(target[start:start+length, i], linestyle="-", label="Real")
    plt.plot(pred_multi[start:start+length, i], linestyle="--", label="Multidimensional")
    plt.title(f"{pred_time}h prediction: {cityNames[i]}, MAE: {np.mean(abs(target[:, i]-pred_multi[:, i])):.2f}", 
              fontsize=35)
    plt.ylabel("Wind speed in 0.1m/s", fontsize=30)
    plt.xlabel("Time index", fontsize=30)
    plt.yticks(fontsize=15)
    plt.legend(fontsize=30)
    plt.savefig(f"plots/{pred_time}h_prediction_{cityNames[i]}.png", dpi=300)
    plt.show()

Did that help?

from multidim_conv.

hasaki1999 avatar hasaki1999 commented on September 23, 2024

Thank you for your answer, which is really helpful, but for me as a beginner, I still need a little time to understand and learn. Thank you again!

from multidim_conv.

Related Issues (2)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.