BOJ 3874:: Matrix Calculator

https://www.acmicpc.net/problem/3874

깡파싱 문제입니다. 재귀 하향 파서를 쓸 수 있도록 (+ 좀 더 간단하게) 문제에서 주어진 EBNF를 다시 정의해봅시다.

<expr>     ::= <term> (("+" | "-") <term>)*
<term>     ::= <factor> ("*" <factor>)*
<factor>   ::= <primary> | "-" <factor>
<primary>  ::= (<inum> | <var> | "(" <expr> ")" | "[" <row-seq> "]") ("'" | "(" <expr> "," <expr> ")")*
<row-seq>  ::= <row> (";" <row>)*
<row>      ::= <expr> (" " <expr>)*

이 다음부턴 노가다입니다….

#include<bits/stdc++.h>


using namespace std;

using Mat = vector<vector<int>>;
constexpr int MOD = 32768;

enum class TokenType {
    INTEGER,    // integer literal
    NOINT,      // non-integer token
    IDENT,      // variable
    ASSIGN,     // =
    PLUS,       // +
    MINUS,      // -
    ASTER,      // *
    CONCAT,     // space
    SEMICOL,    // ;
    COMMA,      // ,
    TRANS,      // '
    PAR_OPEN,   // (
    PAR_CLOSE,  // )
    BRA_OPEN,   // [
    BRA_CLOSE,  // ]
};
struct Token {
    TokenType type;
    string value;
    Token(TokenType t, string v="") : type(t), value(v) {}
};
vector<Token> tokenize(string &s) {
    TokenType cur_token_type = TokenType::NOINT;
    string::iterator cur_token_begin;
    vector<Token> res;
    for (auto it = s.begin(); it != s.end(); it++) {
        if (cur_token_type == TokenType::NOINT) {
            if (*it == '.')
                continue;
            else if (isdigit(*it)) {
                cur_token_type = TokenType::INTEGER;
                cur_token_begin = it;
            }
            else if (isalpha(*it))
                res.push_back(Token(TokenType::IDENT, string(it, it+1)));
            else switch (*it) {
                case '=': res.push_back(Token(TokenType::ASSIGN)); break;
                case '+': res.push_back(Token(TokenType::PLUS)); break;
                case '-': res.push_back(Token(TokenType::MINUS)); break;
                case '*': res.push_back(Token(TokenType::ASTER)); break;
                case ' ': res.push_back(Token(TokenType::CONCAT)); break;
                case ';': res.push_back(Token(TokenType::SEMICOL)); break;
                case ',': res.push_back(Token(TokenType::COMMA)); break;
                case '\'': res.push_back(Token(TokenType::TRANS)); break;
                case '(': res.push_back(Token(TokenType::PAR_OPEN)); break;
                case ')': res.push_back(Token(TokenType::PAR_CLOSE)); break;
                case '[': res.push_back(Token(TokenType::BRA_OPEN)); break;
                case ']': res.push_back(Token(TokenType::BRA_CLOSE)); default:;
            }
        }
        else if (!isdigit(*it)) {
            res.push_back(Token(TokenType::INTEGER, string(cur_token_begin, it)));
            cur_token_type = TokenType::NOINT;
            it--;
        }
    }
    return res;
}

struct Node;
using Vtit = vector<Token>::iterator;
using pNode = shared_ptr<Node>;
Mat variables[26];

struct Node {virtual Mat eval(void) = 0;};
struct IntLiteral : public Node {
    Mat value;
    IntLiteral(int n) : value(Mat(1, vector<int>(1, n))) {}
    virtual Mat eval(void) {return this->value;}
};
struct Variable : public Node {
    char name;
    Variable(char c) : name(c) {}
    virtual Mat eval(void) {return variables[this->name - 'A'];}
};
struct Concat : public Node {
    pNode left, right;
    Concat(pNode l, pNode r) : left(l), right(r) {}
    virtual Mat eval(void) {
        Mat a = this->left->eval(), b = this->right->eval();
        int n = a.size();
        for (int i = 0; i < n; i++) {
            a[i].reserve(a[i].size() + b[i].size());
            a[i].insert(a[i].end(), b[i].begin(), b[i].end());
        }
        return a;
    }
};
struct MatStack : public Node {
    pNode left, right;
    MatStack(pNode l, pNode r) : left(l), right(r) {}
    virtual Mat eval(void) {
        Mat a = this->left->eval(), b = this->right->eval();
        a.reserve(a.size() + b.size());
        a.insert(a.end(), b.begin(), b.end());
        return a;
    }
};
struct BinAdd : public Node {
    pNode left, right;
    BinAdd(pNode l, pNode r) : left(l), right(r) {}
    virtual Mat eval(void) {
        Mat a = this->left->eval(), b = this->right->eval();
        int n = a.size(), m = a[0].size();
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                a[i][j] = (a[i][j] + b[i][j]) % MOD;
        return a;
    }
};
struct BinSub : public Node {
    pNode left, right;
    BinSub(pNode l, pNode r) : left(l), right(r) {}
    virtual Mat eval(void) {
        Mat a = this->left->eval(), b = this->right->eval();
        int n = a.size(), m = a[0].size();
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                a[i][j] = (a[i][j] - b[i][j] + MOD) % MOD;
        return a;
    }
};
struct BinMul : public Node {
    pNode left, right;
    BinMul(pNode l, pNode r) : left(l), right(r) {}
    virtual Mat eval(void) {
        Mat a = this->left->eval(), b = this->right->eval();
        if (a.size() == 1 && a[0].size() == 1) {
            int n = b.size(), m = b[0].size();
            for (int i = 0; i < n; i++)
                for (int j = 0; j < m; j++)
                    b[i][j] = (b[i][j] * a[0][0]) % MOD;
            return b;
        }
        if (b.size() == 1 && b[0].size() == 1) {
            int n = a.size(), m = a[0].size();
            for (int i = 0; i < n; i++)
                for (int j = 0; j < m; j++)
                    a[i][j] = (a[i][j] * b[0][0]) % MOD;
            return a;
        }
        int n = a.size(), m = b.size(), r = b[0].size();
        Mat res(n, vector<int>(r, 0));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < r; j++)
                for (int k = 0; k < m; k++)
                    res[i][j] = (res[i][j] + a[i][k] * b[k][j]) % MOD;
        return res;
    }
};
struct UnaryMinus : public Node {
    pNode child;
    UnaryMinus(pNode c) : child(c) {}
    virtual Mat eval(void) {
        Mat a = this->child->eval();
        int n = a.size(), m = a[0].size();
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                a[i][j] = MOD - a[i][j];
        return a;
    }
};
struct Transpose : public Node {
    pNode child;
    Transpose(pNode c) : child(c) {}
    virtual Mat eval(void) {
        Mat a = this->child->eval();
        int n = a.size(), m = a[0].size();
        Mat res(m, vector<int>(n));
        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                res[i][j] = a[j][i];
        return res;
    }
};
struct Indexing : public Node {
    pNode matrix, idx_row, idx_col;
    Indexing(pNode m, pNode r, pNode c) : matrix(m), idx_row(r), idx_col(c) {}
    virtual Mat eval(void) {
        Mat a = this->matrix->eval(), r = this->idx_row->eval(), c = this->idx_col->eval();
        int n = r[0].size(), m = c[0].size();
        Mat res(n, vector<int>(m));
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                res[i][j] = a[r[0][i] - 1][j] - 1];
        return res;
    }
};

int parse_row(Vtit it_begin, Vtit it_end, pNode &ast);
int parse_row_seq(Vtit it_begin, Vtit it_end, pNode &ast);
int parse_primary(Vtit it_begin, Vtit it_end, pNode &ast);
int parse_factor(Vtit it_begin, Vtit it_end, pNode &ast);
int parse_term(Vtit it_begin, Vtit it_end, pNode &ast);
int parse_expr(Vtit it_begin, Vtit it_end, pNode &ast);

int parse_row(Vtit it_begin, Vtit it_end, pNode &ast) {
    int cnt = 0;
    pNode left, right;
    cnt += parse_expr(it_begin, it_end, left);
    while (it_begin + cnt != it_end && it_begin[cnt].type == TokenType::CONCAT) {
        cnt++;
        cnt += parse_expr(it_begin + cnt, it_end, right);
        left = make_shared<Concat>(left, right);
    }
    ast = left;
    return cnt;
}
int parse_row_seq(Vtit it_begin, Vtit it_end, pNode &ast) {
    int cnt = 0;
    pNode left, right;
    cnt += parse_row(it_begin, it_end, left);
    while (it_begin + cnt != it_end && it_begin[cnt].type == TokenType::SEMICOL) {
        cnt++;
        cnt += parse_row(it_begin + cnt, it_end, right);
        left = make_shared<MatStack>(left, right);
    }
    ast = left;
    return cnt;
}
int parse_primary(Vtit it_begin, Vtit it_end, pNode &ast) {
    int cnt = 0;
    pNode matrix, row, col;
    if (it_begin->type == TokenType::INTEGER) {
        cnt++;
        matrix = make_shared<IntLiteral>(stoi(it_begin->value));
    }
    else if (it_begin->type == TokenType::IDENT) {
        cnt++;
        matrix = make_shared<Variable>((it_begin->value)[0]);
    }
    else if (it_begin->type == TokenType::PAR_OPEN)
        cnt += 2 + parse_expr(it_begin + 1, it_end, matrix);
    else
        cnt += 2 + parse_row_seq(it_begin + 1, it_end, matrix);
    while (it_begin + cnt != it_end
           && (it_begin[cnt].type == TokenType::TRANS || it_begin[cnt].type == TokenType::PAR_OPEN)) {
        if (it_begin[cnt].type == TokenType::TRANS) {
            cnt++;
            matrix = make_shared<Transpose>(matrix);
        }
        else {
            cnt++;
            cnt += parse_expr(it_begin + cnt, it_end, row);
            cnt++;
            cnt += parse_expr(it_begin + cnt, it_end, col);
            cnt++;
            matrix = make_shared<Indexing>(matrix, row, col);
        }
    }
    ast = matrix;
    return cnt;
}
int parse_factor(Vtit it_begin, Vtit it_end, pNode &ast) {
    if (it_begin->type != TokenType::MINUS)
        return parse_primary(it_begin, it_end, ast);
    pNode child;
    int cnt = 1 + parse_factor(it_begin + 1, it_end, child);
    ast = make_shared<UnaryMinus>(child);
    return cnt;
}
int parse_term(Vtit it_begin, Vtit it_end, pNode &ast) {
    int cnt = 0;
    pNode left, right;
    cnt += parse_factor(it_begin, it_end, left);
    while (it_begin + cnt != it_end && it_begin[cnt].type == TokenType::ASTER) {
        cnt++;
        cnt += parse_factor(it_begin + cnt, it_end, right);
        left = make_shared<BinMul>(left, right);
    }
    ast = left;
    return cnt;
}
int parse_expr(Vtit it_begin, Vtit it_end, pNode &ast) {
    int cnt = 0;
    pNode left, right;
    cnt += parse_term(it_begin, it_end, left);
    while (it_begin + cnt != it_end
           && (it_begin[cnt].type == TokenType::PLUS || it_begin[cnt].type == TokenType::MINUS)) {
        TokenType type = it_begin[cnt].type;
        cnt++;
        cnt += parse_term(it_begin + cnt, it_end, right);
        if (type == TokenType::PLUS)
            left = make_shared<BinAdd>(left, right);
        else
            left = make_shared<BinSub>(left, right);
    }
    ast = left;
    return cnt;
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int n;
    string s;
    while (true) {
        cin >> n;
        cin.ignore(1);
        if (n == 0)
            break;
        for (int i = 0; i < 26; i++)
            variables[i].clear();
        while (n--) {
            getline(cin, s);
            vector<Token> tokens = tokenize(s);
            pNode ast;
            parse_expr(tokens.begin() + 2, tokens.end(), ast);
            char var = (tokens[0].value)[0];
            variables[var - 'A'] = ast->eval();
            int n = variables[var - 'A'].size(), m = variables[var - 'A'][0].size();
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < m; j++)
                    cout << variables[var - 'A'][i][j] << ' ';
                cout << '\n';
            }
        }
        cout << "-----" << endl;
    }

    return 0;
}