/*
 * Decompiled with CFR 0.152.
 */
package org.nuxeo.ecm.platform.categorization.categorizer.tfidf;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.io.Serializable;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.Token;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.nuxeo.ecm.platform.categorization.categorizer.tfidf.HashingVectorizer;
import org.nuxeo.ecm.platform.categorization.categorizer.tfidf.PrimitiveVectorHelper;
import org.nuxeo.ecm.platform.categorization.service.Categorizer;

public class TfIdfCategorizer
extends PrimitiveVectorHelper
implements Categorizer,
Serializable {
    private static final long serialVersionUID = 1L;
    public static final Log log = LogFactory.getLog(TfIdfCategorizer.class);
    protected final Set<String> topicNames = new TreeSet<String>();
    protected final Map<String, Object> topicTermCount = new ConcurrentHashMap<String, Object>();
    protected final Map<String, Object> cachedTopicTfIdf = new ConcurrentHashMap<String, Object>();
    protected final Map<String, Float> cachedTopicTfIdfNorm = new ConcurrentHashMap<String, Float>();
    protected long[] allTermCounts;
    protected final int dim;
    protected float[] cachedIdf;
    protected long totalTermCount = 0L;
    protected final HashingVectorizer vectorizer;
    protected transient Analyzer analyzer;
    protected Double ratioOverMedian = 3.0;
    protected boolean updateDisabled = false;

    public TfIdfCategorizer() {
        this(524288);
    }

    public TfIdfCategorizer(int dim) {
        this.dim = dim;
        this.allTermCounts = new long[dim];
        this.vectorizer = new HashingVectorizer().dimension(dim);
    }

    public HashingVectorizer getVectorizer() {
        return this.vectorizer;
    }

    public Analyzer getAnalyzer() {
        if (this.analyzer == null) {
            this.analyzer = new StandardAnalyzer();
        }
        return this.analyzer;
    }

    public synchronized void disableUpdate() {
        this.updateDisabled = true;
        this.getIdf();
        for (String topicName : this.topicNames) {
            this.tfidf(topicName);
            this.tfidfNorm(topicName);
        }
        this.topicTermCount.clear();
        this.allTermCounts = null;
    }

    public void update(String topicName, List<String> terms) {
        if (this.updateDisabled) {
            throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called");
        }
        long[] counts = this.vectorizer.count(terms);
        this.totalTermCount += TfIdfCategorizer.sum(counts);
        long[] topicCounts = (long[])this.topicTermCount.get(topicName);
        if (topicCounts == null) {
            topicCounts = new long[this.dim];
            this.topicTermCount.put(topicName, topicCounts);
            this.topicNames.add(topicName);
        }
        TfIdfCategorizer.add(topicCounts, counts);
        TfIdfCategorizer.add(this.allTermCounts, counts);
        this.invalidateCache(topicName);
    }

    public void update(String topicName, String textContent) {
        this.update(topicName, this.tokenize(textContent));
    }

    protected void invalidateCache(String topicName) {
        this.cachedTopicTfIdf.remove(topicName);
        this.cachedTopicTfIdfNorm.remove(topicName);
        this.cachedIdf = null;
    }

    protected void invalidateCache() {
        for (String topicName : this.topicNames) {
            this.invalidateCache(topicName);
        }
    }

    public Map<String, Float> getSimilarities(List<String> terms) {
        TreeMap<String, Float> similarities = new TreeMap<String, Float>();
        float[] tfidf1 = this.getTfIdf(this.vectorizer.count(terms));
        float norm1 = TfIdfCategorizer.normOf(tfidf1);
        if (norm1 == 0.0f) {
            return similarities;
        }
        for (String topicName : this.topicNames) {
            float[] tfidf2 = this.tfidf(topicName);
            float norm2 = this.tfidfNorm(topicName);
            if (norm2 == 0.0f) continue;
            similarities.put(topicName, Float.valueOf(TfIdfCategorizer.dot(tfidf1, tfidf2) / (norm1 * norm2)));
        }
        return TfIdfCategorizer.sortByDecreasingValue(similarities);
    }

    public Map<String, Float> getSimilarities(String allThePets) {
        return this.getSimilarities(this.tokenize(allThePets));
    }

    protected float tfidfNorm(String topicName) {
        Float norm = this.cachedTopicTfIdfNorm.get(topicName);
        if (norm == null) {
            norm = Float.valueOf(TfIdfCategorizer.normOf(this.tfidf(topicName)));
            this.cachedTopicTfIdfNorm.put(topicName, norm);
        }
        return norm.floatValue();
    }

    protected float[] tfidf(String topicName) {
        float[] tfidf = (float[])this.cachedTopicTfIdf.get(topicName);
        if (tfidf == null) {
            tfidf = this.getTfIdf((long[])this.topicTermCount.get(topicName));
            this.cachedTopicTfIdf.put(topicName, tfidf);
        }
        return tfidf;
    }

    protected float[] getTfIdf(long[] counts) {
        float[] idf = this.getIdf();
        float[] tfidf = new float[counts.length];
        long sum = TfIdfCategorizer.sum(counts);
        if (sum == 0L) {
            return tfidf;
        }
        for (int i = 0; i < counts.length; ++i) {
            tfidf[i] = (float)counts[i] / (float)sum * idf[i];
        }
        return tfidf;
    }

    protected float[] getIdf() {
        if (this.cachedIdf == null) {
            float[] idf = new float[this.allTermCounts.length];
            for (int i = 0; i < this.allTermCounts.length; ++i) {
                idf[i] = this.allTermCounts[i] == 0L ? 0.0f : (float)Math.log1p((float)this.totalTermCount / (float)this.allTermCounts[i]);
            }
            this.cachedIdf = idf;
        }
        return this.cachedIdf;
    }

    public int getDimension() {
        return this.dim;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void learnFiles(File folder) throws IOException {
        if (!folder.isDirectory()) {
            throw new IOException(String.format("%s is not a folder", folder.getAbsolutePath()));
        }
        for (File file : folder.listFiles()) {
            if (file.isDirectory()) continue;
            String topicName = file.getName();
            if (topicName.contains(".")) {
                topicName = topicName.substring(0, topicName.indexOf(46));
            }
            log.info((Object)String.format("About to analyze file %s", file.getAbsolutePath()));
            FileInputStream is = new FileInputStream(file);
            try {
                BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)is, Charset.forName("UTF-8")));
                String line = reader.readLine();
                int i = 0;
                while (line != null) {
                    this.update(topicName, line);
                    line = reader.readLine();
                    if (++i % 10000 != 0) continue;
                    log.info((Object)String.format("Analyzed %d lines from '%s'", i, file.getAbsolutePath()));
                }
            }
            finally {
                is.close();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void saveToFile(File file) throws IOException {
        FileOutputStream out = new FileOutputStream(file);
        try {
            this.saveToStream(out);
        }
        finally {
            out.close();
        }
    }

    public void saveToStream(OutputStream out) throws IOException {
        if (this.updateDisabled) {
            throw new IllegalStateException("model in disabled update mode cannot be saved");
        }
        this.invalidateCache();
        GZIPOutputStream gzOut = new GZIPOutputStream(out);
        ObjectOutputStream objOut = new ObjectOutputStream(gzOut);
        objOut.writeObject(this);
        gzOut.finish();
    }

    public static TfIdfCategorizer load(InputStream in) throws IOException, ClassNotFoundException {
        GZIPInputStream gzIn = new GZIPInputStream(in);
        ObjectInputStream objIn = new ObjectInputStream(gzIn);
        TfIdfCategorizer cat = (TfIdfCategorizer)objIn.readObject();
        log.info((Object)String.format("Sucessfully loaded model with %d topics, dimension %d and density %f", cat.getTopicNames().size(), cat.getDimension(), cat.getDensity()));
        return cat;
    }

    public double getDensity() {
        long sum = 0L;
        for (Object singleTopicTermCount : this.topicTermCount.values()) {
            for (long c : (long[])singleTopicTermCount) {
                sum += c != 0L ? 1L : 0L;
            }
        }
        for (long c : this.allTermCounts) {
            sum += c != 0L ? 1L : 0L;
        }
        return (double)sum / (double)((this.topicNames.size() + 1) * this.getDimension());
    }

    public Set<String> getTopicNames() {
        return this.topicNames;
    }

    public static TfIdfCategorizer load(String modelPath) throws IOException, ClassNotFoundException {
        ClassLoader loader = Thread.currentThread().getContextClassLoader();
        return TfIdfCategorizer.load(loader.getResourceAsStream(modelPath));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException {
        TfIdfCategorizer categorizer;
        File modelFile;
        if (args.length < 2 || args.length > 3) {
            System.out.println("Train a model:\nFirst argument is the model filename (e.g. my-model.gz)\nSecond argument is the path to a folder with UTF-8 text files\nThird optional argument is the dimension of the model");
            System.exit(0);
        }
        if ((modelFile = new File(args[0])).exists()) {
            log.info((Object)("Loading model from: " + modelFile.getAbsolutePath()));
            FileInputStream is = new FileInputStream(modelFile);
            try {
                categorizer = TfIdfCategorizer.load(is);
            }
            finally {
                is.close();
            }
        } else {
            categorizer = args.length == 3 ? new TfIdfCategorizer(Integer.valueOf(args[2])) : new TfIdfCategorizer();
            log.info((Object)("Initializing new model with dimension: " + categorizer.getDimension()));
        }
        categorizer.learnFiles(new File(args[1]));
        log.info((Object)("Saving trained model to: " + modelFile.getAbsolutePath()));
        categorizer.saveToFile(modelFile);
    }

    @Override
    public List<String> guessCategories(String textContent, int maxSuggestions) {
        return this.guessCategories(textContent, maxSuggestions, null);
    }

    @Override
    public List<String> guessCategories(String textContent, int maxSuggestions, Double precisionThreshold) {
        precisionThreshold = precisionThreshold == null ? this.ratioOverMedian : precisionThreshold;
        Map<String, Float> sims = this.getSimilarities(this.tokenize(textContent));
        Float median = TfIdfCategorizer.findMedian(sims);
        ArrayList<String> suggested = new ArrayList<String>();
        for (Map.Entry<String, Float> sim : sims.entrySet()) {
            double ratio;
            double d = ratio = median.floatValue() != 0.0f ? (double)(sim.getValue().floatValue() / median.floatValue()) : 100.0;
            if (suggested.size() >= maxSuggestions || ratio < precisionThreshold) break;
            suggested.add(sim.getKey());
        }
        return suggested;
    }

    public List<String> tokenize(String textContent) {
        ArrayList<String> terms = new ArrayList<String>();
        TokenStream ts = this.getAnalyzer().tokenStream(null, (Reader)new StringReader(textContent));
        Token token = new Token();
        try {
            while (ts.next(token) != null) {
                terms.add(token.termText());
            }
        }
        catch (IOException e) {
            throw new IllegalStateException(e);
        }
        return terms;
    }

    public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) {
        LinkedList<Map.Entry<String, Float>> list = new LinkedList<Map.Entry<String, Float>>(map.entrySet());
        Collections.sort(list, new Comparator<Map.Entry<String, Float>>(){

            @Override
            public int compare(Map.Entry<String, Float> e1, Map.Entry<String, Float> e2) {
                return -e1.getValue().compareTo(e2.getValue());
            }
        });
        LinkedHashMap<String, Float> result = new LinkedHashMap<String, Float>();
        for (Map.Entry entry : list) {
            result.put((String)entry.getKey(), (Float)entry.getValue());
        }
        return result;
    }

    public static Float findMedian(Map<String, Float> sortedMap) {
        int remaining = sortedMap.size() / 2;
        Float median = Float.valueOf(0.0f);
        Iterator<Float> i$ = sortedMap.values().iterator();
        while (i$.hasNext()) {
            Float value;
            median = value = i$.next();
            if (remaining-- > 0) continue;
            break;
        }
        return median;
    }
}

