/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.NoopNormalizer;
import org.tribuo.math.util.VectorNormalizer;

public class Hinge
implements LabelObjective {
    @Config(description="The classification margin.")
    private double margin = 1.0;

    public Hinge(double margin) {
        this.margin = margin;
    }

    public Hinge() {
        this(1.0);
    }

    @Override
    @Deprecated
    public Pair<Double, SGDVector> valueAndGradient(int truth, SGDVector prediction) {
        return this.lossAndGradient(truth, prediction);
    }

    @Override
    public Pair<Double, SGDVector> lossAndGradient(Integer truth, SGDVector prediction) {
        prediction.add(truth.intValue(), -this.margin);
        int predIndex = prediction.indexOfMax();
        if (truth == predIndex) {
            return new Pair((Object)0.0, (Object)SparseVector.createSparseVector((int)prediction.size(), (int[])new int[0], (double[])new double[0]));
        }
        int[] indices = new int[2];
        double[] values = new double[2];
        if (truth < predIndex) {
            indices[0] = truth;
            values[0] = this.margin;
            indices[1] = predIndex;
            values[1] = -this.margin;
        } else {
            indices[0] = predIndex;
            values[0] = -this.margin;
            indices[1] = truth;
            values[1] = this.margin;
        }
        SparseVector output = SparseVector.createSparseVector((int)prediction.size(), (int[])indices, (double[])values);
        double loss = prediction.get(truth.intValue()) - prediction.get(predIndex);
        return new Pair((Object)loss, (Object)output);
    }

    @Override
    public VectorNormalizer getNormalizer() {
        return new NoopNormalizer();
    }

    @Override
    public boolean isProbabilistic() {
        return false;
    }

    public String toString() {
        return "Hinge(margin=" + this.margin + ")";
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "LabelObjective");
    }
}

