티스토리 뷰

 

 

1. Metric-based vs Model-based vs Optimization-based

Meta learning 정의: 기존의 supervised learning 에서 samples from a single distribution 으로 부터 학습하는 것을 넘어서서, distribution over tasks 를 배우는 게 목적이다. 

어떻게? 학습에 도움을 주는 Support set 이라는 것을 사용해서. 이걸 어떻게 사용하느냐에 따라 3가지 방식이 나뉘는거임.

 

기존의 supervised learning 에서는 $P_{\theta}(y|x)$ 를 구하고자 했다면 meta learning 에서는 여기에 S가 추가로 들어가는 거임. Meta learning 은 궁극적으로 $P_{\theta}(y|x, S)$ 라는 objective function 을 구하고자 함. (물론 다른 estimation 도 들어가서 더 복잡하긴 full 수식은 더 복잡하긴하다).

Conditioning the model not only on x but S to predict Y => 이것이 바로 supervised learning 과 meta learning 의 차이!

 

그리고 meta learning 에는 3가지의 categories 가 있는데, model-based, metric-based 그리고 optimization-based 가 있음. 3가지 방법이 모두 어떻게 모델을 Support set 에 condition 하는지를 다르게 정의한다 (쉽게 말하자면 Support set 을 어떻게 training 에 활용해서 fast adaptation 을 가능하게 할건지)

: optimization-based 는 gradient 를 이용하는 방법이라고 할 수 있음.  metric-based 는 similarity 를 이용해서 그리고 model-based 는 말 그대로 어떤 모델이 support set 이랑, qeuery set 데이터를 전부다 input 으로 받아서 output 으로 probability 를 냄 (some kind of network that reads in an entire (few-shot) training set

 

 

 

2. Metric-based vs Model-based vs Optimization-based

https://youtu.be/h7qyQeXKxZE?t=1134

 

여기서는 살짝 다른 이름으로 소개되는데

black-box meta learning = model-based

non-parametric meta learning = metric-based

gradient-based meta learning = optimization-based

 

우선 black box meta learning 은 보여지는 이미지와 같이 model 이 reads in all the (x, y) pairs in the support set, then reads in the test point and then output the label for the test point 라고 할 수 있음.  주로 RNN 이나 LSTM 같은 sequence model 을 써서 이 데이터 sequence 를 읽는데 관건은 어떤 model architecture 를 쓸것이냐나는 거임. (이걸 중심으로 다양한 연구가 있음)

 

이게 왜 중요하냐면, support set 이 굉장히 클 수 있기 때문에 이 데이터를 어떻게 효율적으로 model 에 입력시키는 지가 중요하기 때문. 10 - way, 10 - shot classification 을 생각해봤을 때 training data 만 100 pairs of x,y 인데 얘네들을 어떻게 효과적으로 읽는지가 중요해지는 거임.

 

non-parametric 는 similarity base 인데 meta learning part with support set 은 parametric 이지만 adaptation part  with query set 은 non-parametric 하기 때문에 이렇게 부름. Support set 을 가지고 feature representation 을 배우고, adaptation 단계에서는 query set 을 가지고 support set 과의 distance 혹은 similarity 를 통해서 prediction

ex) Matching networks, Prototypical networks ...

 

gradient-based 는 adaptation to a new task 가 finetuning with gradienet descent 라서 붙여진 이름

ex) Maml, Reptile, FOMAML ...

 

3. Metric-based meta learning 대표 example :  Prototypical network 

https://youtu.be/rHGPfl0pvLY

Support set 을 이용해서 prototype 을 계산하고, Query set으로 predict 한 다음 거기에서 생긴 loss를 이용해서 모델 업데이트 

여기서 업데이트 되는 부분이 어떤 부분이냐하면 input 들어왔을 때 encoding 하는 부분. 그래서 prototypical network 에서는 feature reprsentation 을 배운다고 하는 거임

 

Prototype 은 어떻게 정하냐? 아래 이미지의 Compute prototype from support examples 이라고 있는데 $f_{\phi}$ 는 단순히 feed forward network 라드지 아무튼 input 을 encoding 하는 function 임. 그래서 이 부분을 traning 을 통해서 학습하는 거임. 5-shot 이라고 할 때 각 클래스마다 5개의 examples 이 있을 거고, 그 input 의 mean 이 prototype 이 되는거임. 그리고 query set 의 input 이 들어왔을 때 가장 가까운 protoype 의 class 를 predict 한다.

 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함