나는야 데이터사이언티스트/Machine Learning

[ML]K-NN 알고리즘 실습 - 기초 버전

우주먼지의하루 2020. 4. 22. 02:53
728x90

먼저 K-NN(K-Nearest Neighbor) 알고리즘에 대해 알아보자.

 

K-NN 알고리즘이란 특정공간내에서 입력과 제일 근접한 k개의 요소를 찾아, 더 많이 일치하는 것으로 분류하는 알고리즘이다. 지도 학습(Supervised Learning)의 한 종류로 레이블이 있는 데이터를 사용하여 분류 작업을한다. 알고리즘의 이름에서 볼 수 있듯이 데이터로부터 거리가 가까운 k개의 다른 데이터의 레이블을 참조하여 분류한다. 주로 거리를 측정할 때 유클리디안 거리 계산법*을 사용하여 거리를 측정하는데, 벡터의 크기가 커지면 계산이 복잡해진다. K-NN은 classification과 regression에 모두 적용할 수 있다. 

 

K-NN의 장점

  • 알고리즘이 간단하여 구현하기 쉽다
  • 수치 기반 데이터 분류 작업에서 성능이 좋다
  • training 단계가 필요 없다.
  • information loss가 없다.

K-NN의 단점

  • 학습 데이터의 양이 많으면 분류 속도가 느려진다 (사실 사전 계산을 할 수 없기 때문에 학습 과정이 따로 없기 때문에 분류 속도가 느리다)
  • 차원(벡터)의 크기가 크면 계산량이 많아진다
  • noise에 민감하다.

 

 

* 유클라디안 거리 계산법이란 ?

https://jason0425.tistory.com/74

 

[통계학] 유클리디안 거리(Euclidean Distance)

유클리디안 거리 n차원의 공간에서 두 점간의 거리를 알아내는 공식 L2 Dsitance라고 불리워진 계산 법 x축과 y축으로 구성된 2차원에 두 점이 있고, 그 두 점 사이의 거리를 측정한다. 즉, 피타고라스 정의를 이..

jason0425.tistory.com

 

* KNN에 대해 조금 더 알고 싶다면.

https://nittaku.tistory.com/275

 

1. 머신러닝 알고리즘 : kNN(k-Nearest Neighbors) 알고리즘(최근접 이웃알고리즘)

캡쳐 사진 및 글작성에 대한 도움 출저 : 유튜브 - 허민석님 kNN 알고리즘 녹색 별모양의 영화가 Activtion영화인지 / Romantic영화인지 분류하고 싶다. 액션영화와 로맨틱영화 사이에 있어서 상당히 분류하기 곤..

nittaku.tistory.com

* 코드 출처

https://doorbw.tistory.com/175

 

알고리즘 #11_ KNN 최근접 이웃 알고리즘이란?

안녕하세요. 문범우입니다. 이번 포스팅에서는 분류나 회귀에서 사용되는 KNN(K - Nearest Neighbors) 알고리즘에 대해서 알아보도록 하겠습니다. 1. KNN(K - Nearest Neighbors) KNN, K-최근접 이웃 알고리즘은..

doorbw.tistory.com

 

 

- KNN 알고리즘 실습

 

1. 먼저 데이터를 만들어줬습니다.

A그룹, B그룹을 만들고 예측할 값을 만들어 줍니다.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#데이터
A_x_list = [0,2,4,1,1,4]
A_y_list = [4,1,5,5,2,6]
A_x = np.array(A_x_list)
A_y = np.array(A_y_list)
 
B_x_list = [7,7,5,7,10,9]
B_y_list = [4,0,2,2,3,3]
B_x = np.array(B_x_list)
B_y = np.array(B_y_list)

#예측할 값
finding_point = [5,4]

 

 

2. 데이터가 어떻게 생긴지 확인해봅니다.

#차트 그려보기
plt.figure()
plt.scatter(A_x,A_y)
plt.scatter(B_x,B_y)
plt.scatter(finding_point[0],finding_point[1], marker='*')
 
plt.show()

 

3. 이제 두 점 사이의 거리를 구하는 식을 이용해 코드를 만들어 줍니다.

 

먼저 두 점 사이의 거리를 구하고 나면 가장 작은 값을 고르기 위한 함수를 만들어 놓습니다. (추후 사용 !)

# L리스트에서 c번째 작은 값 찾는 함수---> 리스트에 있는 거리값을 다 구한 후에 작은 것을 고르기 위해 만들어 놓음 !
def count_min_value(L,c):
    temp = L.copy()
    temp.sort()
    item = temp[c-1]
    return L.index(item),item

 

그 다음 두 점 사이 거리를 구하는 식을 기반으로 코드를 입력합니다.

# 입력값이 A그룹인지 B그룹인지 찾는 함수::KNN Algorithm 적용
def finding_AorB(k,x,y):
    numA = 0
    numB = 0
    A_xy = []
    B_xy = []
    
    # x,y 좌표가 따로 있는 것을 하나의 리스트로 통합
    for i in range(len(A_x_list)):
        A_xy.append([A_x_list[i],A_y_list[i]])
    for i in range(len(B_x_list)):
        B_xy.append([B_x_list[i],B_y_list[i]])
 
    A_distance = []
    B_distance = []
    
    # x,y 좌표에 대해 입력값과의 거리 산출
    for each in A_xy:
        dis = ((each[0] - x)**2 + (each[1] - y)**2)**(1/2)
        A_distance.append(dis)
    for each in B_xy:
        dis = ((each[0] - x)**2 + (each[1] - y)**2)**(1/2)
        B_distance.append(dis)
    A_result = []
    B_result = []
    
    A_min_count = 1
    B_min_count = 1
    
    while(numA + numB < k):
        min_A = 99999
        min_B = 99999
 
        _, min_A = count_min_value(A_distance,A_min_count)
        _, min_B = count_min_value(B_distance,B_min_count)
 
        if min_A < min_B:
            numA += 1
            A_min_count += 1
            A_result.append(A_xy[A_distance.index(min_A)])
            A_distance[A_distance.index(min_A)] = -1
        elif min_A > min_B:
            numB += 1
            B_min_count += 1
            B_result.append(B_xy[B_distance.index(min_B)])
            B_distance[B_distance.index(min_B)] = -1
        elif min_A == min_B:
            numA += 1
            numB += 1
            A_min_count += 1
            B_min_count += 1
            A_result.append(A_xy[A_distance.index(min_A)])
            A_distance[A_distance.index(min_A)] = -1
            B_result.append(B_xy[B_distance.index(min_B)])
            B_distance[B_distance.index(min_B)] = -1
            
    if numA > numB:
        print("RESULT: The point is A")
    elif numA < numB:
        print("RESULT, The point is B")
    elif numA == numB:
        print("I DON'T KNOW")
    print("A point is",A_result,"\nB point is",B_result,"\n")

 

K = 1 부터 4일때 예측할 값이 어디 그룹에 속하는지 알아봅니다.

 

반응형