package org.apache.mahout.math.decomposer.hebbian;

import java.util.ArrayList;
import java.util.Properties;
import java.util.Random;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.decomposer.AsyncEigenVerifier;
import org.apache.mahout.math.decomposer.EigenStatus;
import org.apache.mahout.math.decomposer.SingularVectorVerifier;
import org.apache.mahout.math.function.PlusMult;
import org.apache.mahout.math.function.TimesFunction;
import org.apache.xpath.XPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/math/decomposer/hebbian/HebbianSolver.class */
public class HebbianSolver {
    private static final Logger log = LoggerFactory.getLogger(HebbianSolver.class);
    private final EigenUpdater updater;
    private final SingularVectorVerifier verifier;
    private final double convergenceTarget;
    private final int maxPassesPerEigen;
    private final Random rng;
    private int numPasses;
    private static final boolean debug = false;

    public HebbianSolver(EigenUpdater eigenUpdater, SingularVectorVerifier singularVectorVerifier, double d, int i) {
        this.rng = new Random();
        this.numPasses = 0;
        this.updater = eigenUpdater;
        this.verifier = singularVectorVerifier;
        this.convergenceTarget = d;
        this.maxPassesPerEigen = i;
    }

    public HebbianSolver(EigenUpdater eigenUpdater, SingularVectorVerifier singularVectorVerifier, double d) {
        this(eigenUpdater, singularVectorVerifier, d, Integer.MAX_VALUE);
    }

    public HebbianSolver(double d, int i) {
        this(new HebbianUpdater(), new AsyncEigenVerifier(), d, i);
    }

    public HebbianSolver(double d) {
        this(d, Integer.MAX_VALUE);
    }

    public HebbianSolver(int i) {
        this(XPath.MATCH_SCORE_QNAME, i);
    }

    public TrainingState solve(Matrix matrix, int i) {
        int numCols = matrix.numCols();
        DenseMatrix denseMatrix = new DenseMatrix(i, numCols);
        ArrayList arrayList = new ArrayList();
        log.info("Finding " + i + " singular vectors of matrix with " + matrix.numRows() + " rows, via Hebbian");
        TrainingState trainingState = new TrainingState(denseMatrix, new DenseMatrix(matrix.numRows(), i));
        for (int i2 = 0; i2 < i; i2++) {
            DenseVector denseVector = new DenseVector(numCols);
            while (hasNotConverged(denseVector, matrix, trainingState)) {
                int randomStartingIndex = getRandomStartingIndex(matrix, denseMatrix);
                Vector row = matrix.getRow(randomStartingIndex);
                trainingState.setTrainingIndex(randomStartingIndex);
                this.updater.update(denseVector, row, trainingState);
                for (int i3 = 0; i3 < matrix.numRows(); i3++) {
                    trainingState.setTrainingIndex(i3);
                    if (i3 != randomStartingIndex) {
                        this.updater.update(denseVector, matrix.getRow(i3), trainingState);
                    }
                }
                trainingState.setFirstPass(false);
            }
            double eigenValue = trainingState.getStatusProgress().get(trainingState.getStatusProgress().size() - 1).getEigenValue();
            denseVector.assign(new TimesFunction(), 1.0d / denseVector.norm(2.0d));
            denseMatrix.assignRow(i2, denseVector);
            arrayList.add(Double.valueOf(eigenValue));
            trainingState.setCurrentEigenValues(arrayList);
            log.info("Found eigenvector {}, eigenvalue: {}", Integer.valueOf(i2), Double.valueOf(eigenValue));
            trainingState.setFirstPass(true);
            trainingState.setNumEigensProcessed(trainingState.getNumEigensProcessed() + 1);
            trainingState.setActivationDenominatorSquared(XPath.MATCH_SCORE_QNAME);
            trainingState.setActivationNumerator(XPath.MATCH_SCORE_QNAME);
            trainingState.getStatusProgress().clear();
            this.numPasses = 0;
        }
        return trainingState;
    }

    private int getRandomStartingIndex(Matrix matrix, Matrix matrix2) {
        while (true) {
            int nextDouble = (int) (this.rng.nextDouble() * matrix.numRows());
            Vector row = matrix.getRow(nextDouble);
            if (row != null && row.norm(2.0d) != XPath.MATCH_SCORE_QNAME && row.getNumNondefaultElements() >= 5) {
                return nextDouble;
            }
        }
    }

    protected boolean hasNotConverged(Vector vector, Matrix matrix, TrainingState trainingState) {
        this.numPasses++;
        if (trainingState.isFirstPass()) {
            log.info("First pass through the corpus, no need to check convergence...");
            return true;
        }
        Matrix currentEigens = trainingState.getCurrentEigens();
        log.info("Have made {} passes through the corpus, checking convergence...", Integer.valueOf(this.numPasses));
        for (int i = 0; i < trainingState.getNumEigensProcessed(); i++) {
            vector.assign(currentEigens.getRow(i), new PlusMult(-trainingState.getHelperVector().get(i)));
            trainingState.getHelperVector().set(i, XPath.MATCH_SCORE_QNAME);
        }
        EigenStatus verify = verify(matrix, vector);
        if (verify.inProgress()) {
            log.info("Verifier not finished, making another pass...");
        } else {
            log.info("Has 1 - cosAngle: {}, convergence target is: {}", Double.valueOf(1.0d - verify.getCosAngle()), Double.valueOf(this.convergenceTarget));
            trainingState.getStatusProgress().add(verify);
        }
        return trainingState.getStatusProgress().size() <= this.maxPassesPerEigen && 1.0d - verify.getCosAngle() > this.convergenceTarget;
    }

    protected EigenStatus verify(Matrix matrix, Vector vector) {
        return this.verifier.verify(matrix, vector);
    }

    public static void main(String[] strArr) {
        Properties properties = new Properties();
        String str = strArr.length > 0 ? strArr[0] : "config/solver.properties";
        String property = properties.getProperty("solver.input.dir");
        String property2 = properties.getProperty("solver.output.dir");
        if (property == null || property.length() == 0 || property2 == null || property2.length() == 0) {
            log.error("{} must contain values for solver.input.dir and solver.output.dir", str);
            return;
        }
        Integer.parseInt(properties.getProperty("solver.input.bufferSize"));
        int parseInt = Integer.parseInt(properties.getProperty("solver.output.desiredRank"));
        double parseDouble = Double.parseDouble(properties.getProperty("solver.convergence"));
        int parseInt2 = Integer.parseInt(properties.getProperty("solver.maxPasses"));
        int parseInt3 = Integer.parseInt(properties.getProperty("solver.verifier.numThreads"));
        HebbianSolver hebbianSolver = new HebbianSolver(new HebbianUpdater(), new AsyncEigenVerifier(), parseDouble, parseInt2);
        if (parseInt3 <= 1) {
        }
        log.info("Solved {} eigenVectors in {} seconds.  Persisted to {}", new Object[]{Integer.valueOf(hebbianSolver.solve(null, parseInt).getCurrentEigens().size()[0]), Long.valueOf((System.currentTimeMillis() - System.currentTimeMillis()) / 1000), property2});
    }
}
