package org.apache.cassandra.index.sai.disk.v1.vector;

import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.GraphIndexBuilder;
import io.github.jbellis.jvector.graph.GraphSearcher;
import io.github.jbellis.jvector.graph.NeighborSimilarity;
import io.github.jbellis.jvector.graph.SearchResult;
import io.github.jbellis.jvector.pq.CompressedVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.IntStream;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.VectorType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.index.sai.disk.format.IndexComponent;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.io.IndexFileUtils;
import org.apache.cassandra.index.sai.disk.io.IndexOutputWriter;
import org.apache.cassandra.index.sai.disk.v1.IndexWriterConfig;
import org.apache.cassandra.index.sai.disk.v1.SAICodecUtils;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.utils.IndexIdentifier;
import org.apache.cassandra.io.util.SequentialWriter;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.lucene.util.StringHelper;
import org.cliffc.high_scale_lib.NonBlockingHashMapLong;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph.class */
public class OnHeapGraph<T> {
    private static final Logger logger;
    private final RamAwareVectorValues vectorValues;
    private final GraphIndexBuilder<float[]> builder;
    private final VectorType<?> vectorType;
    private final VectorSimilarityFunction similarityFunction;
    private final ConcurrentMap<float[], VectorPostings<T>> postingsMap;
    private final NonBlockingHashMapLong<VectorPostings<T>> postingsByOrdinal;
    private final AtomicInteger nextOrdinal;
    private volatile boolean hasDeletions;
    public static final float MAX_FLOAT32_COMPONENT = 1.0E17f;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/cassandra/index/sai/disk/v1/vector/OnHeapGraph$InvalidVectorBehavior.class */
    public enum InvalidVectorBehavior {
        IGNORE,
        FAIL
    }

    public OnHeapGraph(AbstractType<?> abstractType, IndexWriterConfig indexWriterConfig) {
        this(abstractType, indexWriterConfig, true);
    }

    public OnHeapGraph(AbstractType<?> abstractType, IndexWriterConfig indexWriterConfig, boolean z) {
        this.nextOrdinal = new AtomicInteger();
        this.vectorType = (VectorType) abstractType;
        this.vectorValues = z ? new ConcurrentVectorValues(((VectorType) abstractType).dimension) : new CompactionVectorValues((VectorType) abstractType);
        this.similarityFunction = indexWriterConfig.getSimilarityFunction();
        this.postingsMap = new ConcurrentSkipListMap(Arrays::compare);
        this.postingsByOrdinal = new NonBlockingHashMapLong<>();
        this.builder = new GraphIndexBuilder<>(this.vectorValues, VectorEncoding.FLOAT32, this.similarityFunction, indexWriterConfig.getMaximumNodeConnections(), indexWriterConfig.getConstructionBeamWidth(), 1.2f, 1.4f);
    }

    public int size() {
        return this.vectorValues.size();
    }

    public boolean isEmpty() {
        return this.postingsMap.values().stream().allMatch((v0) -> {
            return v0.isEmpty();
        });
    }

    public long add(ByteBuffer byteBuffer, T t, InvalidVectorBehavior invalidVectorBehavior) {
        if (!$assertionsDisabled && (byteBuffer == null || byteBuffer.remaining() == 0)) {
            throw new AssertionError();
        }
        float[] composeAsFloat = this.vectorType.composeAsFloat(byteBuffer);
        if (invalidVectorBehavior == InvalidVectorBehavior.IGNORE) {
            try {
                validateIndexable(composeAsFloat, this.similarityFunction);
            } catch (InvalidRequestException e) {
                logger.trace("Ignoring invalid vector during index build against existing data: {}", composeAsFloat, e);
                return 0L;
            }
        } else {
            if (!$assertionsDisabled && invalidVectorBehavior != InvalidVectorBehavior.FAIL) {
                throw new AssertionError();
            }
            validateIndexable(composeAsFloat, this.similarityFunction);
        }
        long j = 0;
        VectorPostings<T> vectorPostings = this.postingsMap.get(composeAsFloat);
        if (vectorPostings == null) {
            VectorPostings<T> vectorPostings2 = new VectorPostings<>(t);
            if (this.postingsMap.putIfAbsent(composeAsFloat, vectorPostings2) == null) {
                int andIncrement = this.nextOrdinal.getAndIncrement();
                vectorPostings2.setOrdinal(andIncrement);
                long concurrentHashMapRamUsed = 0 + RamEstimation.concurrentHashMapRamUsed(1) + (this.vectorValues instanceof ConcurrentVectorValues ? ((ConcurrentVectorValues) this.vectorValues).add(andIncrement, composeAsFloat) : ((CompactionVectorValues) this.vectorValues).add(andIncrement, byteBuffer)) + VectorPostings.emptyBytesUsed() + VectorPostings.bytesPerPosting();
                this.postingsByOrdinal.put(andIncrement, vectorPostings2);
                return concurrentHashMapRamUsed + this.builder.addGraphNode(andIncrement, this.vectorValues);
            }
            vectorPostings = this.postingsMap.get(composeAsFloat);
        }
        if (vectorPostings.add(t)) {
            j = 0 + VectorPostings.bytesPerPosting();
        }
        return j;
    }

    public static void checkInBounds(float[] fArr) {
        for (int i = 0; i < fArr.length; i++) {
            if (!Float.isFinite(fArr[i])) {
                throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + fArr[i]);
            }
            if (Math.abs(fArr[i]) > 1.0E17f) {
                throw new IllegalArgumentException("Out-of-bounds value at vector[" + i + "]=" + fArr[i]);
            }
        }
    }

    public static void validateIndexable(float[] fArr, VectorSimilarityFunction vectorSimilarityFunction) {
        try {
            checkInBounds(fArr);
            if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
                for (float f : fArr) {
                    if (f != 0.0f) {
                        return;
                    }
                }
                throw new InvalidRequestException("Zero vectors cannot be indexed or queried with cosine similarity");
            }
        } catch (IllegalArgumentException e) {
            throw new InvalidRequestException(e.getMessage());
        }
    }

    public Collection<T> keysFromOrdinal(int i) {
        return ((VectorPostings) this.postingsByOrdinal.get(i)).getPostings();
    }

    public long remove(ByteBuffer byteBuffer, T t) {
        if (!$assertionsDisabled && (byteBuffer == null || byteBuffer.remaining() == 0)) {
            throw new AssertionError();
        }
        VectorPostings<T> vectorPostings = this.postingsMap.get(this.vectorType.composeAsFloat(byteBuffer));
        if (vectorPostings == null) {
            return 0L;
        }
        this.hasDeletions = true;
        return vectorPostings.remove(t);
    }

    public PriorityQueue<T> search(float[] fArr, int i, Bits bits) {
        validateIndexable(fArr, this.similarityFunction);
        if (this.vectorValues.size() == 0) {
            return new PriorityQueue<>();
        }
        SearchResult search = new GraphSearcher.Builder(this.builder.getGraph().getView()).withConcurrentUpdates().build().search(i2 -> {
            return vectorCompareFunction(fArr, i2);
        }, (NeighborSimilarity.ReRanker) null, i, this.hasDeletions ? BitsUtil.bitsIgnoringDeleted(bits, this.postingsByOrdinal) : bits);
        Tracing.trace("ANN search visited {} in-memory nodes to return {} results", Integer.valueOf(search.getVisitedCount()), Integer.valueOf(search.getNodes().length));
        SearchResult.NodeScore[] nodes = search.getNodes();
        PriorityQueue<T> priorityQueue = new PriorityQueue<>();
        for (SearchResult.NodeScore nodeScore : nodes) {
            priorityQueue.addAll(keysFromOrdinal(nodeScore.node));
        }
        return priorityQueue;
    }

    public SegmentMetadata.ComponentMetadataMap writeData(IndexDescriptor indexDescriptor, IndexIdentifier indexIdentifier, Function<T, Integer> function) throws IOException {
        int insertsInProgress = this.builder.insertsInProgress();
        if (!$assertionsDisabled && insertsInProgress != 0) {
            throw new AssertionError(String.format("Attempting to write graph while %d inserts are in progress", Integer.valueOf(insertsInProgress)));
        }
        if (!$assertionsDisabled && this.nextOrdinal.get() != this.builder.getGraph().size()) {
            throw new AssertionError(String.format("nextOrdinal %d != graph size %d -- ordinals should be sequential", Integer.valueOf(this.nextOrdinal.get()), Integer.valueOf(this.builder.getGraph().size())));
        }
        if (!$assertionsDisabled && this.vectorValues.size() != this.builder.getGraph().size()) {
            throw new AssertionError(String.format("vector count %d != graph size %d", Integer.valueOf(this.vectorValues.size()), Integer.valueOf(this.builder.getGraph().size())));
        }
        if (!$assertionsDisabled && this.postingsMap.keySet().size() != this.vectorValues.size()) {
            throw new AssertionError(String.format("postings map entry count %d != vector count %d", Integer.valueOf(this.postingsMap.keySet().size()), Integer.valueOf(this.vectorValues.size())));
        }
        logger.debug("Writing graph with {} rows and {} distinct vectors", Integer.valueOf(this.postingsMap.values().stream().mapToInt((v0) -> {
            return v0.size();
        }).sum()), Integer.valueOf(this.vectorValues.size()));
        IndexOutputWriter openOutput = IndexFileUtils.instance.openOutput(indexDescriptor.fileFor(IndexComponent.COMPRESSED_VECTORS, indexIdentifier), true);
        try {
            IndexOutputWriter openOutput2 = IndexFileUtils.instance.openOutput(indexDescriptor.fileFor(IndexComponent.POSTING_LISTS, indexIdentifier), true);
            try {
                IndexOutputWriter openOutput3 = IndexFileUtils.instance.openOutput(indexDescriptor.fileFor(IndexComponent.TERMS_DATA, indexIdentifier), true);
                try {
                    SAICodecUtils.writeHeader(openOutput);
                    SAICodecUtils.writeHeader(openOutput2);
                    SAICodecUtils.writeHeader(openOutput3);
                    long filePointer = openOutput.getFilePointer();
                    long writePQ = writePQ(openOutput.asSequentialWriter()) - filePointer;
                    HashSet hashSet = new HashSet();
                    this.postingsMap.values().stream().filter((v0) -> {
                        return v0.isEmpty();
                    }).forEach(vectorPostings -> {
                        hashSet.add(Integer.valueOf(vectorPostings.getOrdinal()));
                    });
                    for (VectorPostings<T> vectorPostings2 : this.postingsMap.values()) {
                        vectorPostings2.computeRowIds(function);
                        if (vectorPostings2.shouldAppendDeletedOrdinal()) {
                            hashSet.add(Integer.valueOf(vectorPostings2.getOrdinal()));
                        }
                    }
                    long filePointer2 = openOutput2.getFilePointer();
                    long writePostings = new VectorPostingsWriter().writePostings(openOutput2.asSequentialWriter(), this.vectorValues, this.postingsMap, hashSet) - filePointer2;
                    this.builder.complete();
                    long filePointer3 = openOutput3.getFilePointer();
                    OnDiskGraphIndex.write(this.builder.getGraph(), this.vectorValues, openOutput3.asSequentialWriter());
                    long filePointer4 = openOutput3.getFilePointer() - filePointer3;
                    SAICodecUtils.writeFooter(openOutput);
                    SAICodecUtils.writeFooter(openOutput2);
                    SAICodecUtils.writeFooter(openOutput3);
                    SegmentMetadata.ComponentMetadataMap componentMetadataMap = new SegmentMetadata.ComponentMetadataMap();
                    componentMetadataMap.put(IndexComponent.TERMS_DATA, -1L, filePointer3, filePointer4, Map.of());
                    componentMetadataMap.put(IndexComponent.POSTING_LISTS, -1L, filePointer2, writePostings, Map.of());
                    componentMetadataMap.put(IndexComponent.COMPRESSED_VECTORS, -1L, filePointer, writePQ, Map.of("SEGMENT_ID", ByteBufferUtil.bytesToHex(ByteBuffer.wrap(StringHelper.randomId()))));
                    if (openOutput3 != null) {
                        openOutput3.close();
                    }
                    if (openOutput2 != null) {
                        openOutput2.close();
                    }
                    if (openOutput != null) {
                        openOutput.close();
                    }
                    return componentMetadataMap;
                } catch (Throwable th) {
                    if (openOutput3 != null) {
                        try {
                            openOutput3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (openOutput2 != null) {
                    try {
                        openOutput2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (openOutput != null) {
                try {
                    openOutput.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private float vectorCompareFunction(float[] fArr, int i) {
        return this.similarityFunction.compare(fArr, (float[]) this.vectorValues.vectorValue(i));
    }

    private long writePQ(SequentialWriter sequentialWriter) throws IOException {
        ProductQuantization compute;
        byte[][] bArr;
        int dimension = this.vectorValues.dimension() / 2;
        sequentialWriter.writeBoolean(this.vectorValues.size() >= 1024);
        if (this.vectorValues.size() < 1024) {
            logger.debug("Skipping PQ for only {} vectors", Integer.valueOf(this.vectorValues.size()));
            return sequentialWriter.position();
        }
        logger.debug("Computing PQ for {} vectors", Integer.valueOf(this.vectorValues.size()));
        synchronized (OnHeapGraph.class) {
            compute = ProductQuantization.compute(this.vectorValues, dimension, false);
            if (!$assertionsDisabled && this.vectorValues.isValueShared()) {
                throw new AssertionError();
            }
            bArr = (byte[][]) IntStream.range(0, this.vectorValues.size()).parallel().mapToObj(i -> {
                return compute.encode(this.vectorValues.mo932vectorValue(i));
            }).toArray(i2 -> {
                return new byte[i2];
            });
        }
        new CompressedVectors(compute, bArr).write(sequentialWriter);
        return sequentialWriter.position();
    }

    static {
        $assertionsDisabled = !OnHeapGraph.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(OnHeapGraph.class);
    }
}
