/*
 * Decompiled with CFR 0.152.
 */
package org.carrot2.text.vsm;

import org.carrot2.attrs.AttrComposite;
import org.carrot2.attrs.AttrObject;
import org.carrot2.math.mahout.matrix.DoubleMatrix2D;
import org.carrot2.math.mahout.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.math.matrix.IterativeMatrixFactorizationFactory;
import org.carrot2.math.matrix.MatrixFactorization;
import org.carrot2.math.matrix.MatrixFactorizationFactory;
import org.carrot2.math.matrix.MatrixUtils;
import org.carrot2.math.matrix.NonnegativeMatrixFactorizationEDFactory;
import org.carrot2.text.vsm.ReducedVectorSpaceModelContext;
import org.carrot2.text.vsm.VectorSpaceModelContext;

public class TermDocumentMatrixReducer
extends AttrComposite {
    public MatrixFactorizationFactory factorizationFactory;

    public TermDocumentMatrixReducer() {
        this.attributes.register("factorizationFactory", ((AttrObject.Builder)AttrObject.builder(MatrixFactorizationFactory.class).label("Term-document matrix factorization method")).getset(() -> this.factorizationFactory, v -> {
            this.factorizationFactory = v;
        }).defaultValue(NonnegativeMatrixFactorizationEDFactory::new));
    }

    public void reduce(ReducedVectorSpaceModelContext context, int dimensions) {
        VectorSpaceModelContext vsmContext = context.vsmContext;
        if (vsmContext.termDocumentMatrix.columns() == 0 || vsmContext.termDocumentMatrix.rows() == 0) {
            context.baseMatrix = new DenseDoubleMatrix2D(vsmContext.termDocumentMatrix.rows(), vsmContext.termDocumentMatrix.columns());
            return;
        }
        if (this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) {
            ((IterativeMatrixFactorizationFactory)this.factorizationFactory).estimateIterationsNumber(dimensions, vsmContext.termDocumentMatrix);
        }
        MatrixUtils.normalizeColumnL2(vsmContext.termDocumentMatrix, null);
        MatrixFactorization factorization = this.factorizationFactory.factorize(vsmContext.termDocumentMatrix);
        context.baseMatrix = factorization.getU();
        context.coefficientMatrix = factorization.getV();
        context.baseMatrix = this.trim(this.factorizationFactory, factorization.getU(), dimensions);
        context.coefficientMatrix = this.trim(this.factorizationFactory, factorization.getV(), dimensions);
    }

    private final DoubleMatrix2D trim(MatrixFactorizationFactory factorizationFactory, DoubleMatrix2D matrix, int dimensions) {
        if (!(factorizationFactory instanceof IterativeMatrixFactorizationFactory) && matrix.columns() > dimensions) {
            return matrix.viewPart(0, 0, matrix.rows(), dimensions);
        }
        return matrix;
    }
}

