Comments (2)
Sure. Here is the code for getting the initial prediction:
import torch
import numpy as np
from utils import data_loader_wind_nl
import plotly.graph_objects as go
import pickle
import os
import time
def get_predictions(pred_time, model_name):
modelFolder = "models/wind_speed_nl_comparison"
dataFile = "data/Wind_data_NL/dataset.pkl"
# Load the min and max values
scalerFile = "data/Wind_data_NL/scaler.pkl"
with open(scalerFile, "rb") as f:
scaleValues = pickle.load(f)
minVals = scaleValues["feature_min_train"]
maxVals = scaleValues["feature_max_train"]
# we only care about the wind speed scaling
minWS = minVals[0]
maxWS = maxVals[0]
dev = "cpu"
test_dl_CTF = data_loader_wind_nl.get_test_loader(dataFile,
input_timesteps=6,
prediction_timestep=pred_time,
CTF=True,
batch_size=1,
num_workers=0,
pin_memory=True)
start_time = time.time()
loaded = torch.load(f"{modelFolder}/{model_name}")
model = loaded["model"]
model.load_state_dict(loaded["state_dict"])
model.to(dev)
model.eval()
test_dl = test_dl_CTF
pred = np.zeros((len(test_dl), 7))
target = np.zeros((len(test_dl), 7))
with torch.no_grad():
for i, (xb, yb) in enumerate(test_dl):
pred[i] = (model(xb.to(dev))*(maxWS-minWS)+minWS).cpu().numpy()
target[i] = (yb.to(dev)*(maxWS-minWS)+minWS).cpu().numpy()
mse = np.mean((target-pred)**2, axis=0)
mae = np.mean(abs(target-pred), axis=0)
print(f"Model {model_name} took: {time.time()-start_time:.3f}s")
return pred, target, mse, mae
pred_time = 1
modelFolder = "models/wind_speed_nl_comparison"
models = os.listdir(modelFolder)
models_to_compare = [f"wind_model_NL_{pred_time}h_MultidimConvNetwork_16KernelsPerLayer.pt"]
wind_models = [m for m in models if f"{pred_time}h" in m and ".pt" in m]
print(wind_models)
pred_multi, target, mse_multi, mae_multi = get_predictions(pred_time, models_to_compare[0])
And this is the code for plotting it:
start = 1337
length = 250
cityNames = ['Schiphol', 'De Bilt', 'Leeuwarden', 'Eelde', 'Rotterdam', 'Eindhoven', 'Maastricht']
figs = []
print(target.shape)
# change to matplotlib. Real solid line, Ours dashed line
for i in range(target.shape[-1]):
plt.figure(figsize=(15,5))
plt.plot(target[start:start+length, i], linestyle="-", label="Real")
plt.plot(pred_multi[start:start+length, i], linestyle="--", label="Multidimensional")
plt.title(f"{pred_time}h prediction: {cityNames[i]}, MAE: {np.mean(abs(target[:, i]-pred_multi[:, i])):.2f}",
fontsize=35)
plt.ylabel("Wind speed in 0.1m/s", fontsize=30)
plt.xlabel("Time index", fontsize=30)
plt.yticks(fontsize=15)
plt.legend(fontsize=30)
plt.savefig(f"plots/{pred_time}h_prediction_{cityNames[i]}.png", dpi=300)
plt.show()
Did that help?
from multidim_conv.
Thank you for your answer, which is really helpful, but for me as a beginner, I still need a little time to understand and learn. Thank you again!
from multidim_conv.
Related Issues (2)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from multidim_conv.