Processing math: 100%

ABOUT ME

Today
Yesterday
Total
  • [백준] 7571번: 크래머의 공식
    Coding/Problem Solving & Algorithm 2022. 12. 19. 14:56

    문제 링크 : https://www.acmicpc.net/problem/7561

     

    문제

    세 변수로 이루어진 일차 방정식 세 개가 주어진다. 크래머의 공식을 이용해서 해를 구하는 프로그램을 작성하시오.

    입력

    입력은 여러 개의 테스트 케이스로 이루어져 있다.

    각 테스트 케이스는 세 줄로 이루어져 있고, 아래와 같은 순서로 주어진다.

    a11 a12 a13 b1

    a21 a22 a23 b2

    a31 a32 a33 b3

     

    모든 숫자는 -1000보다 크거나 같고, 1000보다 작거나 같은 정수이며, 공백 한 칸으로 구분되어져 있다. 

    출력

    각 테스트 케이스마다 두 줄을 출력한다.

    첫째 줄에는 행렬 A1, A2, A3, A의 행렬식을 출력한다. 

    방정식의 해가 존재하지 않는 경우에는 둘째 줄에 "No unique solution"을 출력하고, 해가 존재하는 경우에는 "Unique solution: "와 x1, x2, x3 값을 소수점 셋째자리까지 출력한다.

    방정식의 해 xi 가 -0.0005 < xi < 0.0005 인 경우에는 "-0.000" 대신에 "0.000"을 출력한다.

    각 테스트 케이스 사이에는 빈 줄을 하나 출력한다.


    우선 행렬식을 구하는 함수를 구현하자.

    주어진 행렬에서 행과 열을 제거하여 리턴하는 함수를 먼저 구현하자. 

    def rmv(M:list, x:int, y:int) -> list:
    
        rstM = []
        for i, v in enumerate(M):
            v = v.copy()
            if i == x:
                continue
            del(v[y])
            rstM.append(v)
    
        return rstM

    이를 이용해서 행렬식 함수는 아래와 같이 구현할 수 있을 것이다. 

    def det(M:list) -> int :
        
        if len(M) == 2:
            return (M[0][0] * M[1][1]) - (M[1][0] * M[0][1])
        else:
            result = 0
    
            for i, a in enumerate(M[0]):
                result += ((-1)**i) * a * det(rmv(M, 0, i))
    
            return result

     

    크래머의 공식은 아래와 같다.

    Ax=b 의 해가 존재한다면, x=[x1,x2,...,xm]일 때 xi=det(Ai(b))/det(A)이다.

    이때 Ai(b)Ai열을 b로 치환한 행렬을 말한다. 

     

    행렬의 열을 치환하는 함수를 구현하자.

    def rpl(M:list, i:int, b:list) -> list:
    
        result = M.copy()
        for c in range(len(M)):
            result[c][i] = b[c]
    
        return result

    마지막으로 지금까지 구현한 함수를 모두 이용해서 크래머의 공식을 구현하면 끝.

    .

    .

    인줄 알았는데 함수를 실행할 때마다 원본 배열의 값이 바뀐다는 사실을 알아내었다.

    copy 모듈을 import 하고 함수마다 M = deepcopyM 한 줄씩 추가해주었다.

     

    그리고 출력할 때 반올림한 뒤 0을 추가한 형식으로 숫자를 출력하고 있어서 그런 반올림 함수도 따로 구현해주었다.

    def roundT(val, digits) -> str:
        
        val = str(round(float(val), digits))
        decLen = len(val) - val.index(".") - 1
        if decLen < digits:
            val += "0" * (digits - decLen)
    
        return val

     

    최종 코드는 아래와 같다.

     

    from copy import deepcopy
    
    def rmv(M:list, x:int, y:int) -> list:
    
        M = deepcopy(M)
        rstM = []
        for i, v in enumerate(M):
            v = v.copy()
            if i == x:
                continue
            del(v[y])
            rstM.append(v)
    
        return rstM
    
    
    
    def det(M:list) -> int :
    
        M = deepcopy(M)
        if len(M) == 2:
            return (M[0][0] * M[1][1]) - (M[1][0] * M[0][1])
        else:
            result = 0
    
            for i, a in enumerate(M[0]):
                result += ((-1)**i) * a * det(rmv(M, 0, i))
    
            return result
    
    def rpl(M:list, i:int, b:list) -> list:
    
        result = deepcopy(M)
        for c in range(len(M)):
            result[c][i] = b[c]
    
        return result
    
    def roundT(val, digits) -> str:
        
        val = str(round(float(val), digits))
        decLen = len(val) - val.index(".") - 1
        if decLen < digits:
            val += "0" * (digits - decLen)
    
        return val
    
    T = int(input())
    for _ in range(T):
    
    
        A = []
        b = []
        
        
        for i in range(3):
            inp = list(map(int, input().split()))
            A.append(inp[:3])
            b.append(inp[-1])
    
        detArr = []
        for i in range(3):
            detArr.append(det(rpl(A, i, b)))
    
        detArr.append(det(A))
    
        print(" ".join(map(str,detArr)))
        
        if detArr[-1] == 0:
            print("No unique solution", end = "")
        
        else:
            print("Unique solution: ", end = "")
            for i in range(3):
                sol = detArr[i]/detArr[-1]
                sol = roundT(sol, 3)
                
                if (-0.0005 < float(sol) < 0.0005):
                    sol = "0.000"
                print(sol, end = " ")
        print("\n")

     

    쓸데없이 길어진 느낌이 있다.

    자잘하게 신경써줘야 하는게 많은 문제였다. 

    'Coding > Problem Solving & Algorithm' 카테고리의 다른 글

    댓글

Designed by Tistory.