/*
 * Decompiled with CFR 0.152.
 */
package net.jsign.bouncycastle.pqc.crypto.cmce;

import net.jsign.bouncycastle.crypto.digests.SHAKEDigest;
import net.jsign.bouncycastle.pqc.crypto.cmce.BENES;
import net.jsign.bouncycastle.pqc.crypto.cmce.BENES12;
import net.jsign.bouncycastle.pqc.crypto.cmce.BENES13;
import net.jsign.bouncycastle.pqc.crypto.cmce.GF;
import net.jsign.bouncycastle.pqc.crypto.cmce.GF12;
import net.jsign.bouncycastle.pqc.crypto.cmce.GF13;
import net.jsign.bouncycastle.pqc.crypto.cmce.Utils;

class CMCEEngine {
    private int SYS_N;
    private int SYS_T;
    private int GFBITS;
    private int IRR_BYTES;
    private int COND_BYTES;
    private int PK_NROWS;
    private int PK_NCOLS;
    private int PK_ROW_BYTES;
    private int SYND_BYTES;
    private int GFMASK;
    private int[] poly;
    private final int defaultKeySize;
    private GF gf;
    private BENES benes;
    private boolean usePadding;
    private boolean countErrorIndices;
    private boolean usePivots;

    public int getPublicKeySize() {
        if (this.usePadding) {
            return this.PK_NROWS * (this.SYS_N / 8 - (this.PK_NROWS - 1) / 8);
        }
        return this.PK_NROWS * this.PK_NCOLS / 8;
    }

    public CMCEEngine(int n, int n2, int n3, int[] nArray, boolean bl, int n4) {
        this.usePivots = bl;
        this.SYS_N = n2;
        this.SYS_T = n3;
        this.GFBITS = n;
        this.poly = nArray;
        this.defaultKeySize = n4;
        this.IRR_BYTES = this.SYS_T * 2;
        this.COND_BYTES = (1 << this.GFBITS - 4) * (2 * this.GFBITS - 1);
        this.PK_NROWS = this.SYS_T * this.GFBITS;
        this.PK_NCOLS = this.SYS_N - this.PK_NROWS;
        this.PK_ROW_BYTES = (this.PK_NCOLS + 7) / 8;
        this.SYND_BYTES = (this.PK_NROWS + 7) / 8;
        this.GFMASK = (1 << this.GFBITS) - 1;
        if (this.GFBITS == 12) {
            this.gf = new GF12(this.GFBITS);
            this.benes = new BENES12(this.SYS_N, this.SYS_T, this.GFBITS);
        } else {
            this.gf = new GF13(this.GFBITS);
            this.benes = new BENES13(this.SYS_N, this.SYS_T, this.GFBITS);
        }
        this.usePadding = this.SYS_T % 8 != 0;
        this.countErrorIndices = 1 << this.GFBITS > this.SYS_N;
    }

    public byte[] generate_public_key_from_private_key(byte[] byArray) {
        byte[] byArray2 = new byte[this.getPublicKeySize()];
        short[] sArray = new short[1 << this.GFBITS];
        long[] lArray = new long[]{0L};
        int[] nArray = new int[1 << this.GFBITS];
        byte[] byArray3 = new byte[this.SYS_N / 8 + (1 << this.GFBITS) * 4];
        int n = byArray3.length - 32 - this.IRR_BYTES - (1 << this.GFBITS) * 4;
        SHAKEDigest sHAKEDigest = new SHAKEDigest(256);
        sHAKEDigest.update((byte)64);
        sHAKEDigest.update(byArray, 0, 32);
        sHAKEDigest.doFinal(byArray3, 0, byArray3.length);
        for (int i = 0; i < 1 << this.GFBITS; ++i) {
            nArray[i] = Utils.load4(byArray3, n + i * 4);
        }
        this.pk_gen(byArray2, byArray, nArray, sArray, lArray);
        return byArray2;
    }

    private int mov_columns(byte[][] byArray, short[] sArray, long[] lArray) {
        long l;
        int n;
        long l2;
        int n2;
        int n3;
        long[] lArray2 = new long[64];
        long[] lArray3 = new long[32];
        long l3 = 1L;
        byte[] byArray2 = new byte[9];
        int n4 = this.PK_NROWS - 32;
        int n5 = n4 / 8;
        int n6 = n4 % 8;
        if (this.usePadding) {
            for (n3 = 0; n3 < 32; ++n3) {
                for (n2 = 0; n2 < 9; ++n2) {
                    byArray2[n2] = byArray[n4 + n3][n5 + n2];
                }
                for (n2 = 0; n2 < 8; ++n2) {
                    byArray2[n2] = (byte)((byArray2[n2] & 0xFF) >> n6 | byArray2[n2 + 1] << 8 - n6);
                }
                lArray2[n3] = Utils.load8(byArray2, 0);
            }
        } else {
            for (n3 = 0; n3 < 32; ++n3) {
                lArray2[n3] = Utils.load8(byArray[n4 + n3], n5);
            }
        }
        lArray[0] = 0L;
        for (n3 = 0; n3 < 32; ++n3) {
            long l4;
            l2 = lArray2[n3];
            for (n2 = n3 + 1; n2 < 32; ++n2) {
                l2 |= lArray2[n2];
            }
            if (l2 == 0L) {
                return -1;
            }
            int n7 = CMCEEngine.ctz(l2);
            lArray3[n3] = n7;
            lArray[0] = lArray[0] | l3 << (int)lArray3[n3];
            for (n2 = n3 + 1; n2 < 32; ++n2) {
                l4 = lArray2[n3] >> n7 & 1L;
                int n8 = n3;
                lArray2[n8] = lArray2[n8] ^ lArray2[n2] & --l4;
            }
            n2 = n3 + 1;
            while (n2 < 32) {
                l4 = lArray2[n2] >> n7 & 1L;
                l4 = -l4;
                int n9 = n2++;
                lArray2[n9] = lArray2[n9] ^ lArray2[n3] & l4;
            }
        }
        for (n2 = 0; n2 < 32; ++n2) {
            for (n = n2 + 1; n < 64; ++n) {
                l = sArray[n4 + n2] ^ sArray[n4 + n];
                int n10 = n4 + n2;
                sArray[n10] = (short)((long)sArray[n10] ^ (l &= CMCEEngine.same_mask64((short)n, (short)lArray3[n2])));
                int n11 = n4 + n;
                sArray[n11] = (short)((long)sArray[n11] ^ l);
            }
        }
        for (n3 = 0; n3 < this.PK_NROWS; ++n3) {
            if (this.usePadding) {
                for (n = 0; n < 9; ++n) {
                    byArray2[n] = byArray[n3][n5 + n];
                }
                for (n = 0; n < 8; ++n) {
                    byArray2[n] = (byte)((byArray2[n] & 0xFF) >> n6 | byArray2[n + 1] << 8 - n6);
                }
                l2 = Utils.load8(byArray2, 0);
            } else {
                l2 = Utils.load8(byArray[n3], n5);
            }
            for (n2 = 0; n2 < 32; ++n2) {
                l = l2 >> n2;
                l ^= l2 >> (int)lArray3[n2];
                l2 ^= (l &= 1L) << (int)lArray3[n2];
                l2 ^= l << n2;
            }
            if (this.usePadding) {
                Utils.store8(byArray2, 0, l2);
                byArray[n3][n5 + 8] = (byte)((byArray[n3][n5 + 8] & 0xFF) >>> n6 << n6 | (byArray2[7] & 0xFF) >>> 8 - n6);
                byArray[n3][n5 + 0] = (byte)((byArray2[0] & 0xFF) << n6 | (byArray[n3][n5] & 0xFF) << 8 - n6 >>> 8 - n6);
                for (n = 7; n >= 1; --n) {
                    byArray[n3][n5 + n] = (byte)((byArray2[n] & 0xFF) << n6 | (byArray2[n - 1] & 0xFF) >>> 8 - n6);
                }
                continue;
            }
            Utils.store8(byArray[n3], n5, l2);
        }
        return 0;
    }

    private static int ctz(long l) {
        int n = 0;
        int n2 = 0;
        for (int i = 0; i < 64; ++i) {
            int n3 = (int)(l >> i & 1L);
            n2 += ((n |= n3) ^ 1) & (n3 ^ 1);
        }
        return n2;
    }

    private static long same_mask64(short s, short s2) {
        long l = s ^ s2;
        --l;
        l >>>= 63;
        l = -l;
        return l;
    }

    private int pk_gen(byte[] byArray, byte[] byArray2, int[] nArray, short[] sArray, long[] lArray) {
        block26: {
            int n;
            int n2;
            int n3;
            short[] sArray2 = new short[this.SYS_T + 1];
            sArray2[this.SYS_T] = 1;
            for (n3 = 0; n3 < this.SYS_T; ++n3) {
                sArray2[n3] = Utils.load_gf(byArray2, 40 + n3 * 2, this.GFMASK);
            }
            long[] lArray2 = new long[1 << this.GFBITS];
            n3 = 0;
            while (n3 < 1 << this.GFBITS) {
                lArray2[n3] = nArray[n3];
                int n4 = n3;
                lArray2[n4] = lArray2[n4] << 31;
                int n5 = n3;
                lArray2[n5] = lArray2[n5] | (long)n3;
                int n6 = n3++;
                lArray2[n6] = lArray2[n6] & Long.MAX_VALUE;
            }
            CMCEEngine.sort64(lArray2, 0, lArray2.length);
            for (n3 = 1; n3 < 1 << this.GFBITS; ++n3) {
                if (lArray2[n3 - 1] >> 31 != lArray2[n3] >> 31) continue;
                return -1;
            }
            short[] sArray3 = new short[this.SYS_N];
            for (n3 = 0; n3 < 1 << this.GFBITS; ++n3) {
                sArray[n3] = (short)(lArray2[n3] & (long)this.GFMASK);
            }
            for (n3 = 0; n3 < this.SYS_N; ++n3) {
                sArray3[n3] = Utils.bitrev(sArray[n3], this.GFBITS);
            }
            short[] sArray4 = new short[this.SYS_N];
            this.root(sArray4, sArray2, sArray3);
            for (n3 = 0; n3 < this.SYS_N; ++n3) {
                sArray4[n3] = this.gf.gf_inv(sArray4[n3]);
            }
            byte[][] byArray3 = new byte[this.PK_NROWS][this.SYS_N / 8];
            for (n3 = 0; n3 < this.PK_NROWS; ++n3) {
                for (n2 = 0; n2 < this.SYS_N / 8; ++n2) {
                    byArray3[n3][n2] = 0;
                }
            }
            for (n3 = 0; n3 < this.SYS_T; ++n3) {
                for (n2 = 0; n2 < this.SYS_N; n2 += 8) {
                    for (n = 0; n < this.GFBITS; ++n) {
                        byte by = (byte)(sArray4[n2 + 7] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 6] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 5] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 4] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 3] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 2] >>> n & 1);
                        by = (byte)(by << 1);
                        by = (byte)(by | sArray4[n2 + 1] >>> n & 1);
                        by = (byte)(by << 1);
                        byArray3[n3 * this.GFBITS + n][n2 / 8] = by = (byte)(by | sArray4[n2 + 0] >>> n & 1);
                    }
                }
                for (n2 = 0; n2 < this.SYS_N; ++n2) {
                    sArray4[n2] = this.gf.gf_mul(sArray4[n2], sArray3[n2]);
                }
            }
            for (n3 = 0; n3 < (this.PK_NROWS + 7) / 8; ++n3) {
                int n7;
                for (n2 = 0; n2 < 8 && (n7 = n3 * 8 + n2) < this.PK_NROWS; ++n2) {
                    int n8;
                    byte by;
                    if (this.usePivots && n7 == this.PK_NROWS - 32 && this.mov_columns(byArray3, sArray, lArray) != 0) {
                        return -1;
                    }
                    for (n = n7 + 1; n < this.PK_NROWS; ++n) {
                        by = (byte)(byArray3[n7][n3] ^ byArray3[n][n3]);
                        by = (byte)(by >> n2);
                        by = (byte)(by & 1);
                        by = -by;
                        for (n8 = 0; n8 < this.SYS_N / 8; ++n8) {
                            byte[] byArray4 = byArray3[n7];
                            int n9 = n8;
                            byArray4[n9] = (byte)(byArray4[n9] ^ byArray3[n][n8] & by);
                        }
                    }
                    if ((byArray3[n7][n3] >> n2 & 1) == 0) {
                        return -1;
                    }
                    for (n = 0; n < this.PK_NROWS; ++n) {
                        if (n == n7) continue;
                        by = (byte)(byArray3[n][n3] >> n2);
                        by = (byte)(by & 1);
                        by = -by;
                        for (n8 = 0; n8 < this.SYS_N / 8; ++n8) {
                            byte[] byArray5 = byArray3[n];
                            int n10 = n8;
                            byArray5[n10] = (byte)(byArray5[n10] ^ byArray3[n7][n8] & by);
                        }
                    }
                }
            }
            if (byArray == null) break block26;
            if (this.usePadding) {
                int n11 = 0;
                int n12 = this.PK_NROWS % 8;
                for (n3 = 0; n3 < this.PK_NROWS; ++n3) {
                    for (n2 = (this.PK_NROWS - 1) / 8; n2 < this.SYS_N / 8 - 1; ++n2) {
                        byArray[n11++] = (byte)((byArray3[n3][n2] & 0xFF) >>> n12 | byArray3[n3][n2 + 1] << 8 - n12);
                    }
                    byArray[n11++] = (byte)((byArray3[n3][n2] & 0xFF) >>> n12);
                }
            } else {
                for (n3 = 0; n3 < this.PK_NROWS; ++n3) {
                    n = 0;
                    for (n2 = 0; n2 < (this.SYS_N - this.PK_NROWS + 7) / 8; ++n2) {
                        byArray[n3 * ((this.SYS_N - this.PK_NROWS + 7) / 8) + n] = byArray3[n3][n2 + this.PK_NROWS / 8];
                        ++n;
                    }
                }
            }
        }
        return 0;
    }

    private short eval(short[] sArray, short s) {
        short s2 = sArray[this.SYS_T];
        for (int i = this.SYS_T - 1; i >= 0; --i) {
            s2 = this.gf.gf_mul(s2, s);
            s2 = this.gf.gf_add(s2, sArray[i]);
        }
        return s2;
    }

    private void root(short[] sArray, short[] sArray2, short[] sArray3) {
        for (int i = 0; i < this.SYS_N; ++i) {
            sArray[i] = this.eval(sArray2, sArray3[i]);
        }
    }

    private static void sort64(long[] lArray, int n, int n2) {
        int n3 = n2 - n;
        if (n3 < 2) {
            return;
        }
        for (int i = 1; i < n3 - i; i += i) {
        }
        for (int i = i; i > 0; i >>>= 1) {
            long l;
            int n4;
            for (n4 = 0; n4 < n3 - i; ++n4) {
                if ((n4 & i) != 0) continue;
                l = lArray[n + n4 + i] - lArray[n + n4];
                l >>>= 63;
                l = -l;
                int n5 = n + n4;
                lArray[n5] = lArray[n5] ^ (l &= lArray[n + n4] ^ lArray[n + n4 + i]);
                int n6 = n + n4 + i;
                lArray[n6] = lArray[n6] ^ l;
            }
            n4 = 0;
            for (int j = i; j > i; j >>>= 1) {
                while (n4 < n3 - j) {
                    if ((n4 & i) == 0) {
                        l = lArray[n + n4 + i];
                        for (int k = j; k > i; k >>>= 1) {
                            long l2 = lArray[n + n4 + k] - l;
                            l2 >>>= 63;
                            l2 = -l2;
                            l ^= (l2 &= l ^ lArray[n + n4 + k]);
                            int n7 = n + n4 + k;
                            lArray[n7] = lArray[n7] ^ l2;
                        }
                        lArray[n + n4 + i] = l;
                    }
                    ++n4;
                }
            }
        }
    }
}

