Graph neural network(GNN)는 graph의 형태를 갖는 input값을 이용하는 neural network를 말한다. 화학이나 재료과학에서는 분자 및 재료(결정)를 graph로 표현하여 성질 등을 예측하는 머신러닝 모델을 연구하고 있다.
Graph는 다음과 같이 node와 edge로 구성되어 있다. 이 graph가 어떤 graph mapping function을 통과하고 나면 graph의 node feature들이 update된다. Graph의 update는 주변 환경, 즉 주변 원자들에 따라, 그리고 예측하고자 하는 target property에 따라 다르게 일어난다. Node 단위로 살펴보면, 연결되어 있는 이웃한 원자들의 node feature를 받아서 새로운 node feature로 update되는 과정으로 이해할 수 있다. 그리고 graph 학습을 반복할수록 점점 더 멀리 있는 node들도 고려하게 된다. 예를 들어 graph update를 두 번 반복하면 두 번째로 이웃한 node, 세 번 반복하면 세 번째로 이웃한 node들의 feature도 node update에 영향을 주게 된다.
이렇게 update된 graph의 node feature들을 하나의 input vector로 concatenation한 다음 pooling layer와 일반적인 fully connected layer를 통과하면 전체 GNN의 학습 및 예측이 끝난다. Pooling layer는 fully connected layer로 들어가는 input 벡터의 차원을 일정하게 하고 node index의 permutation-invariance를 확보하기 위해 넣어준다.
우리는 특정한 성질을 결정하는 재료의 feature가 무엇인지 정확히 알지 못하기 때문에, 원하는 성질을 잘 예측할 수 있는 input feature를 정의하는 것 역시 어렵다. 하지만 GNN에서는 graph의 업데이트 과정과 마지막 fully connected layer에서 동시에 학습이 진행된다. 따라서 성질을 잘 예측하게끔 graph가 update되는 것이기 때문에, 예측하고자 하는 성질에 맞는 input feature를 우리가 고민해서 결정할 필요가 없는 것이다.
그렇다면 GNN의 핵심이 되는 graph mapping function, 다시 말해 graph를 어떻게 update시키는지에는 여러 방법이 있다. 그 중에서 convolution, attention, message-passing에 관해 소개한다. 그 전에 공통적으로 사용되는 notation을 언급하자면, $h_i^t$는 $t$번째 step일 때 node $i$의 feature vector, $N(i)$는 자기 자신을 포함하여 node $i$와 결합하고 있는 node의 개수, $\phi$는 sigmoid나 ReLU 등과 같은 activation function으로, node update function을 의미한다.
1. Graph convolutional network (GCN)
일반적인 GCN에서는 다음과 같이 node feature가 update된다.
$$h_i^{t+1}=\phi\left(\sum_{j\in N(i)}c_{ij}W^th_j^t\right)\text{ where }c_{ij}=\frac{1}{\sqrt{N_iN_j}}$$
이때 $c_{ij}$는 node $i$에 대한 node $j$의 중요도를 나타낸다. 이 경우에는 $c_{ij}$가 graph의 구조에 따라 결정되는 고정된 값으로, 학습되지 않는다.
이것은 가장 기본적인 GCN의 형태이고, 다양한 방식으로 변형될 수 있다.
2. Graph attention network (GAT)
위 경우에는 $c_{ij}$가 고정된 값이므로 인접한 node 사이의 중요도가 학습되지 않지만, 이것 역시 학습시키기 위해 이를 결정하는 attention score($\alpha$)를 도입한다. 그러면 Node feature는 다음과 같이 학습된다.
$$h_i^{t+1}=\phi\left(\sum_{j\in N(i)}\alpha_{ij}W^th_j^t\right)$$
3. Message-passing network (MPNN)
Message function이라 불리는 $M_t$를 통해 edge feature($e_{ij}$)와 node feature로부터 일종의 message($m_i^{t+1}$)를 생성하고, 이 message가 node update function $U_t$에 들어가 node feature의 학습에 기여한다. MPNN은 제일 일반적인(generalized) GNN의 구조 및 형태를 제시하는 것이고, message function이나 node update function이 어떻게 정의되는지에 따라 세부적인 architecture는 달라질 수 있다.
$$\begin{align*}
m_i^{t+1}&=\sum_{j\in N(i)}M_t(h_i^t, h_j^t, e_{ij})\\
h_i^{t+1}&=U_t(h_i^t,m_i^{t+1})
\end{align*}$$
사실 위에서 언급한 GCN이나 GAT 모두 MPNN의 구조로 표현될 수 있다. 그리고, 마치 convolutional neural network(CNN)가 local한 정보에서 시작해서 convolution을 하면 할수록 점점 넓은 영역의 정보를 수집하는 것과 비슷하게, GNN도 앞서 언급한 것처럼 인접한 node의 정보에서 시작해서 graph update를 하면 할수록 점점 먼 거리의 node의 정보를 수집하게 된다. 또한 CNN에서 filter가 sharing되는 것처럼 node update에 관여하는 weight($W_t$)도 sharing된다. 즉 graph update는 일종의 convolution이고, 이런 의미에서는 GNN 자체가 convolution의 성격을 내재하고 있다고 볼 수 있을 것이다.
위에서 본 예시들은 node feature만 update되면서 GNN의 학습이 진행되었지만, 추가적으로 edge feature도 update시킬 수 있다.
참고:
https://doi.org/10.48550/arXiv.1710.10903
Graph Attention Networks
We present graph attention networks (GATs), novel neural network architectures that operate on graph-structured data, leveraging masked self-attentional layers to address the shortcomings of prior methods based on graph convolutions or their approximations
arxiv.org
https://doi.org/10.48550/arXiv.1704.01212
Neural Message Passing for Quantum Chemistry
Supervised learning on molecules has incredible potential to be useful in chemistry, drug discovery, and materials science. Luckily, several promising and closely related neural network models invariant to molecular symmetries have already been described i
arxiv.org
https://theaisummer.com/gnn-architectures/
Best Graph Neural Network architectures: GCN, GAT, MPNN and more | AI Summer
Explore the most popular gnn architectures such as gcn, gat, mpnn, graphsage and temporal graph networks
theaisummer.com
'머신러닝' 카테고리의 다른 글
Gaussian Process (1) | 2022.12.30 |
---|---|
Variational Autoencoder (0) | 2022.05.10 |
Feature Selection (0) | 2022.04.24 |
Principal Component Analysis (PCA) (0) | 2022.04.24 |
Bias & Variance (0) | 2022.04.24 |