Medium-Range Weather Forecasting with Time- and Space-aware Deep Learning Part 3: Graph Neural Network Forecasts
AI-based forecasting models can come in many flavors. We'll sample a couple of prominent approaches which depend on a graph neural network architecture in this post.
This is the third post in a series on deep learning-based methods in medium-range weather forecasting. Part one discussed the history and foundational theory of numerical weather prediction (NWP). Part two discussed some of the issues with modern NWP approaches and how both the problem of medium-range weather forecasting itself, and also its data, aligns well with the strengths of machine learning. If you’d like to read the entire article in its full form, it is posted here.
Machine Learning Approaches to NWP
While machine learning methods have long been used for climatological and weather forecasting applications (Schultz et al., 2021), competitive AI-NWP approaches for global, medium-range weather prediction have only begun to emerge in the past five years. The earliest approaches to medium-range AI-NWP like those in Dueben & Bauer (2018), Weyn, Durran, & Caruana (2020), or Rasp & Thuerey (2021) were hampered by low resolution and data quality, resulting in models which were only marginally more skillful than forecasts based on climatological averages. To reach performance competitive with the ECMWF’s IFS ENS forecasts, AI-NWP approaches required the release of the 0.25°-resolution ERA5 dataset to provide the data fidelity to compete with IFS HRES. As we will see, these new methods also require the integration of the spatial, temporal, and physical inductive biases which more closely align with the fluid and thermodynamic laws which underlie the behavior of medium-range weather. At the moment, the implementation of these inductive biases in AI-NWP models come in two architectural flavors: Graph Neural Networks and Transformers. We’ll cover the former in this post and the latter in the next.
Graph Neural Network Approaches
Graph Neural Networks (GNNs) are a convenient architecture for encoding spatial inductive biases given their dependence on the connectivity of an underlying network structure upon which their functional form is defined. These deep learning algorithms learn to represent complex interactions on an underlying network by performing parameterized message passing between adjacent nodes, updating parameters so as to minimize the discrepancy between the representations computed at each node, and the representations at each node under the true data. This message passing can take many structural forms, but nearly all of these forms may be interpreted as the composition of a messaging operation between adjacent nodes, an aggregation of inbound messages into each node, and an update to the node representations based on these aggregated local messages (Veličković 2022).
By stacking multiple message passing layers in sequence, the locally-defined GNN architecture begins to integrate more global information, as data from more distant nodes diffuses into the representations of each node’s immediate neighbors with each message, aggregate, and update iteration. Generally speaking, a GNN will require k message passing layers in order for the architecture to account for interactions among nodes within a k-hop neighborhood. By this logic, one can infer that in order to guarantee each node representation learned by a GNN contains information from every other node in the graph, one would need as many message passing layers as the longest path in the graph, a value known as the diameter1. The functional definition of GNNs imbues them with an implicit spatial inductive bias, assigning similar representation to nodes which are nearby (measured by number of hops along the underlying graph) within the graph.
Given the overwhelming success of grid-based NWP since Richardson’s first pioneering efforts, it is reasonable to expect that a spatial inductive bias, a predisposition to process each grid square based on the information contained within the square’s neighbors, would lead to a family of predictive functions which are aligned with the spatially-determined dynamics which underlying weather patterns. Instead of solving for each grid point’s forecasted weather given its neighbors as in traditional NWP, GNN-based methods seek to predict the weather in the next time step given the node representations of the weather encoded within the neighborhood of each grid point.
Forecasting Global Weather with Graph Neural Networks
The first GNN-based AI-NWP method to achieve performance nearing that of production-quality NWP was presented in Keisler (2022). A herculean single-author effort, Keisler differentiates itself from prior work primarily due to its substantial up-sampling of the underlying grid data fidelity used during training and forecasting. The model is based on a grid size of approximately 110km (1°) and incorporates interpolated ERA5 weather observation data for six variables across 13 vertical atmospheric pressure levels at each of the grid’s 65,160 nodes, resulting in orders of magnitude more fidelity than prior AI-NWP approaches up to that point.
Keisler aggregates hourly reanalysis data from 1950-2021 covering temperature, geopotential height, specific humidity, and three directional wind components across 13 pressure levels. After mapping all of this data to a 1° globe-spanning grid, one could at this point apply a GNN to this raw data to produce weather change forecasts by performing message passing between neighboring grid points, augmenting across each layer the (13 x 6 = 78)-dimensional node representations (plus pre-computed data like solar radiation, orography, land-sea mask, the day-of-year, and sine and cosine of latitude and longitude) derived from weather observations in recent time steps.
While this naive approach would likely yield some amount of predictive skill, it faces a major structural inefficiency. Because the Earth is approximately a sphere, a 1° (110km) grid measured at the equator would result in substantially smaller coverage areas as one moves north or south towards the poles. This grid irregularity biases the model’s predictive capacity towards the poles, as there would be many more weather observations measured by the poles resulting in, by extension, an outsized influence of the poles on an RMSE-like loss function on global weather prediction. Because human population density is biased towards the equator and weather forecasts are primarily for human consumption, such a poleward bias would result in particularly sub-optimal performance in a production setting.
Keisler resolves this structural inefficiency with a clever architectural decision. Instead of performing message passing directly on the original 1° latitude/longitude grid, Keisler first subsamples the grid down to a ~3° (~330km) icosahedral mesh resulting in a ~6,000-node grid which is distributed uniformly across the globe. To map raw weather observations on the original grid to the mesh, Keisler introduces an encoder GNN which connects each grid point to its closest spatial icosahedral mesh node in a bipartite manner. This encoder essentially provides a parameterized downsampling operation from the pole-biased latitude/longitude grid to a uniform icosahedral mesh. The architecture then performs a number of message passing operations using the latent node representations on the mesh before passing these representations through a decoder, a parameterized upsampling function which undoes the original encoder operations, mapping mesh information back onto the 1° grid resolution.
The architecture makes predictions about changes in weather from the input observation data 6 hours into the future. However, 6-hour forecasts are relatively easy to make, especially for traditional NWP systems. Ideally, one would like the architecture to also be able to produce competitive forecasts for days in advance. To achieve this, Keisler formulates a loss function which penalizes model predictions for each 6-hour time rolled out to 3 days, gaining longer-horizon forecasting skill to the slight detriment of near-term skill.
This combination of keen architectural choices and extensive training data results in a model which was able to achieve forecasting performance noticeably more skilled and comprehensive than its predecessors. Although Keisler’s GNN approach took five and a half days to train, this training can be done on a single NVIDIA A100 GPU and, once trained, can produce a 5-day forecast in less than a second. Despite this success, this first attempt at a GNN-based AI-NWP model still lags behind the skill and resolution of the ECMWF’s IFS across multiple variables, pressure levels, and lead times.
GraphCast
In a section discussing data preprocessing, Keisler (2022) makes the following observation:
One useful feature of using message-passing GNNs is that we can encode the relative positions between nodes into the messages, so that a single model can learn from data at different resolutions. We took advantage of this by first training on 2-degree data for the first round of training and then switching to training on 1-degree data for the last two rounds. For reasons we do not understand, this produced better results than training on 1-degree data throughout.
The GraphCast model (Lam et al., 2023) capitalizes on this observation that multiple data resolutions benefit AI-NWP performance by scaling up the resolution and performance of Keisler’s GNN approach while adding additional layers to the icosahedral mesh. GraphCast is structured similarly to its predecessor, an encoder-GNN-decoder architecture, but builds on this model in a few crucial ways. These improvements have resulted in GraphCast being one of the most accurate medium-range AI-NWP models proposed to date, outperforming ECMWF IFS across a number of surface-level and atmospheric forecasting tasks, especially for forecasts within a 5-day time horizon.
The first major improvement introduced by the GraphCast model is its operational data scale. The DeepMind and Google Research-affiliated team were able to scale the observational grid resolution from Keisler’s earlier 1° (110km) resolution to ERA5’s minimum 0.25° (28km) resolution, resulting in a base grid with over a million nodes and approaching the resolution of the highest-fidelity NWP model ECMWF HRES (0.1°). In addition to upsampling the data resolution, GraphCast also trains on 5 surface-level variables (2m temperature, 10m wind components, mean sea-level pressure, total precipitation) in addition to the atmospheric variables used in Keisler (2022). GraphCast also incorporates the atmospheric variables at 37 pressure levels, resulting in (5 + 6 * 37 = 227) weather observation variables per grid point. Note that while the model is trained on this collection of 227 variables, its performance is evaluated on a 69-variable subset corresponding to the variables covered in the WeatherBench and ECMWF Scorecard benchmarks.
Like in previous approaches, the output of GraphCast is a forecast of the change to each weather variable six hours in advance. GraphCast differs by not only taking the current weather observation as input, but also the weather from the previous forecast time, six hours prior. In other words, GraphCast uses information about the current weather and the weather state six hours prior to forecast the weather six hours into the future. These forecasts can then be rolled out to generate arbitrarily long weather state trajectories, two input states at a time.
Perhaps the most innovative architectural feature introduced within GraphCast is the expansion of the icosahedral mesh which uniformly spans the globe into an icosahedral multi-mesh. That is, instead of performing the GNN’s message-passing operations on a single mesh whose data points uniformly span the globe at a single spacing scale, the GraphCast model performs simultaneous message-passing operations on a collection of seven icosahedral meshes of increasing granularity. The coarsest mesh has 12 nodes, while the finest mesh, obtained from repeated refinement of each triangular face into four smaller faces and an additional node at the center, contains 40,962 nodes. The subset nodes of the finer meshes are connected to the nodes of the coarser mesh one level above in a canonical way (each node in the finer mesh is a subset of the nodes in the coarser parent), thus creating a pathway for information to flow between each mesh level, in addition to information flow amongst edges within each level. This means nodes at the coarser mesh levels can serve as hubs for gathering and transmitting longer-range information within finer levels. This multi-mesh architecture’s success provides further empirical evidence towards Keisler’s original observation that weather prediction is better facilitated by including variable information at multiple spatial scales.
The necessity of multi-scale message passing for accurate AI-NWP weather forecasting makes sense when one considers the peculiar influence of global weather patterns on local forecasts. Large-scale weather patterns like Rossby waves, ENSO, or atmospheric rivers all play a conditional roll when making local weather forecasts, even if these anomalies are measured thousands of kilometers away from the forecasted location of interest. While such teleconnections are generally resolvable locally, their global influence would be much more difficult to capture at a fine mesh resolution without an extensive number of message-passing operations. By introducing a multi-mesh which provides a shortcut route for integrating global weather information, the extra computational requirements, parameter costs, and oversmoothing risks introduced by excessive message passing may be avoided. Lam et al. (2023) provide some evidence towards this hypothesis by evaluating the performance of the GraphCast model without a multi-mesh, ablating each mesh level but the finest. They find that the multi-mesh architecture outperforms the single mesh model across all variables and lead times apart from the 50hPa level at greater than 5 days2, emphasizing the importance of multi-scale spatial interactions in AI-NWP modeling.
We’ll compare these GNN-based approaches to a few Transformer-based weather forecasting architectures in the next post.
References
Schultz, M. G., Betancourt, C., Gong, B., Kleinert, F., Langguth, M., Leufen, L. H., … Stadtler, S. (2021). Can deep learning beat numerical weather prediction? Philosophical Transactions of the Royal Society A, 379(2194), 20200097.
Dueben, P. D., & Bauer, P. (2018). Challenges and design choices for global weather and climate models based on machine learning. Geoscientific Model Development, 11(10), 3999–4009.
Weyn, J. A., Durran, D. R., & Caruana, R. (2020). Improving data-driven global weather prediction using deep convolutional neural networks on a cubed sphere. Journal of Advances in Modeling Earth Systems, 12(9), e2020MS002109.
Rasp, S., & Thuerey, N. (2021). Data-driven medium-range weather prediction with a resnet pretrained on climate simulations: A new model for weatherbench. Journal of Advances in Modeling Earth Systems, 13(2), e2020MS002405.
Veličković, P. (2022). Message passing all the way up. ICLR 2022 Workshop on Geometrical and Topological Representation Learning.
Keisler, R. (2022). Forecasting global weather with graph neural networks. arXiv Preprint arXiv:2202.07575.
Lam, R., Sanchez-Gonzalez, A., Willson, M., Wirnsberger, P., Fortunato, M., Alet, F., … others. (2023). Learning skillful medium-range global weather forecasting. Science, 382(6677), 1416–1421.
In practice, one typically does not need or even want to perform this many message passing operations, as node representations in densely-connected regions of the network become oversmoothed due to the significant number of message-passing and update operations performed within these dense neighborhoods. This constant communication results in learned representations at each node being highly similar to their neighbors. This trend towards learning similar representations is an inherent feature of GNNs (although it is often treated as a bug in the literature).
It would be interesting to investigate this hypothesis further in future work by determining the extent to which the latent representations learned at coarser resolutions of GraphCast’s multi-mesh reflect global weather oscillations.