package org.eclipse.sirius.components.graphql.ws;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import graphql.GraphQL;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.MeterRegistry;
import java.io.IOException;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.eclipse.sirius.components.graphql.ws.api.IGraphQLWebSocketHandlerListener;
import org.eclipse.sirius.components.graphql.ws.dto.IOperationMessage;
import org.eclipse.sirius.components.graphql.ws.dto.input.ConnectionInitMessage;
import org.eclipse.sirius.components.graphql.ws.dto.input.ConnectionTerminateMessage;
import org.eclipse.sirius.components.graphql.ws.dto.input.StartMessage;
import org.eclipse.sirius.components.graphql.ws.dto.input.StopMessage;
import org.eclipse.sirius.components.graphql.ws.dto.output.ConnectionErrorMessage;
import org.eclipse.sirius.components.graphql.ws.dto.output.ConnectionKeepAliveMessage;
import org.eclipse.sirius.components.graphql.ws.handlers.ConnectionInitMessageHandler;
import org.eclipse.sirius.components.graphql.ws.handlers.ConnectionTerminateMessageHandler;
import org.eclipse.sirius.components.graphql.ws.handlers.StartMessageHandler;
import org.eclipse.sirius.components.graphql.ws.handlers.StopMessageHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;

/* loaded from: input_file:BOOT-INF/lib/sirius-components-graphql-2024.1.4.jar:org/eclipse/sirius/components/graphql/ws/GraphQLWebSocketHandler.class */
public class GraphQLWebSocketHandler extends TextWebSocketHandler implements SubProtocolCapable {
    private static final Duration GRAPHQL_KEEP_ALIVE_INTERVAL = Duration.ofSeconds(28);
    private static final String COUNTER_METRIC_NAME = "siriusweb_graphql_ws_messages";
    private static final String TIMER_METRIC_NAME = "siriusweb_graphql_ws_sessions";
    private static final String MESSAGE = "message";
    private static final String GRAPHQL_WS = "graphql-ws";
    private static final String TYPE = "type";
    private final ObjectMapper objectMapper;
    private final GraphQL graphQL;
    private final Counter connectionInitCounter;
    private final Counter startMessageCounter;
    private final Counter stopMessageCounter;
    private final Counter connectionTerminateCounter;
    private final Counter connectionErrorCounter;
    private final MeterRegistry meterRegistry;
    private final IGraphQLWebSocketHandlerListener listener;
    private final Logger logger = LoggerFactory.getLogger((Class<?>) GraphQLWebSocketHandler.class);
    private final Map<WebSocketSession, List<SubscriptionEntry>> sessions2entries = new ConcurrentHashMap();
    private final Map<WebSocketSession, Disposable> sessions2keepAliveSubscriptions = new ConcurrentHashMap();

    public GraphQLWebSocketHandler(ObjectMapper objectMapper, GraphQL graphQL, MeterRegistry meterRegistry, IGraphQLWebSocketHandlerListener iGraphQLWebSocketHandlerListener) {
        this.objectMapper = (ObjectMapper) Objects.requireNonNull(objectMapper);
        this.graphQL = (GraphQL) Objects.requireNonNull(graphQL);
        this.meterRegistry = (MeterRegistry) Objects.requireNonNull(meterRegistry);
        this.listener = (IGraphQLWebSocketHandlerListener) Objects.requireNonNull(iGraphQLWebSocketHandlerListener);
        this.startMessageCounter = Counter.builder(COUNTER_METRIC_NAME).tag("message", "Start").register(meterRegistry);
        this.stopMessageCounter = Counter.builder(COUNTER_METRIC_NAME).tag("message", "Stop").register(meterRegistry);
        this.connectionInitCounter = Counter.builder(COUNTER_METRIC_NAME).tag("message", "Connection Init").register(meterRegistry);
        this.connectionErrorCounter = Counter.builder(COUNTER_METRIC_NAME).tag("message", "Connection Error").register(meterRegistry);
        this.connectionTerminateCounter = Counter.builder(COUNTER_METRIC_NAME).tag("message", "Connection Terminate").register(meterRegistry);
        Set<WebSocketSession> keySet = this.sessions2keepAliveSubscriptions.keySet();
        Objects.requireNonNull(keySet);
        Gauge.builder(TIMER_METRIC_NAME, keySet::size).register(meterRegistry);
    }

    @Override // org.springframework.web.socket.SubProtocolCapable
    public List<String> getSubProtocols() {
        return Collections.singletonList(GRAPHQL_WS);
    }

    @Override // org.springframework.web.socket.handler.AbstractWebSocketHandler
    protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage textMessage) throws Exception {
        this.listener.handleTextMessage(webSocketSession, textMessage);
        Optional<IOperationMessage> parseRequest = parseRequest(textMessage);
        if (!parseRequest.isPresent()) {
            send(webSocketSession, new ConnectionErrorMessage());
            this.connectionErrorCounter.increment();
            return;
        }
        IOperationMessage iOperationMessage = parseRequest.get();
        this.logger.trace("Message received: {}", iOperationMessage);
        if (iOperationMessage instanceof ConnectionInitMessage) {
            new ConnectionInitMessageHandler(webSocketSession, this.objectMapper).handle();
            this.connectionInitCounter.increment();
            return;
        }
        if (iOperationMessage instanceof StartMessage) {
            new StartMessageHandler(webSocketSession, this.graphQL, this.objectMapper, this.sessions2entries, this.meterRegistry).handle((StartMessage) iOperationMessage);
            this.startMessageCounter.increment();
        } else if (iOperationMessage instanceof StopMessage) {
            new StopMessageHandler(webSocketSession, this.sessions2entries).handle((StopMessage) iOperationMessage);
            this.stopMessageCounter.increment();
        } else if (iOperationMessage instanceof ConnectionTerminateMessage) {
            new ConnectionTerminateMessageHandler(webSocketSession, this.sessions2entries);
            this.connectionTerminateCounter.increment();
        } else {
            send(webSocketSession, new ConnectionErrorMessage());
            this.connectionErrorCounter.increment();
        }
    }

    private void send(WebSocketSession webSocketSession, IOperationMessage iOperationMessage) {
        try {
            TextMessage textMessage = new TextMessage(this.objectMapper.writeValueAsString(iOperationMessage));
            this.logger.trace("Message sent: {}", iOperationMessage);
            webSocketSession.sendMessage(textMessage);
        } catch (IOException e) {
            this.logger.warn(e.getMessage(), (Throwable) e);
        }
    }

    private Optional<IOperationMessage> parseRequest(TextMessage textMessage) {
        Optional empty = Optional.empty();
        try {
            JsonNode readTree = this.objectMapper.readTree(textMessage.getPayload());
            empty = getType(readTree).flatMap(str -> {
                return getOperationMessage(readTree, str);
            });
        } catch (IOException e) {
            this.logger.warn(e.getMessage(), (Throwable) e);
        }
        return empty;
    }

    private Optional<String> getType(JsonNode jsonNode) {
        return (jsonNode.has("type") && jsonNode.get("type").isTextual()) ? Optional.of(jsonNode.get("type").asText()) : Optional.empty();
    }

    private Optional<IOperationMessage> getOperationMessage(JsonNode jsonNode, String str) {
        Optional<IOperationMessage> empty = Optional.empty();
        try {
            boolean z = -1;
            switch (str.hashCode()) {
                case 3540994:
                    if (str.equals("stop")) {
                        z = 3;
                        break;
                    }
                    break;
                case 109757538:
                    if (str.equals("start")) {
                        z = 2;
                        break;
                    }
                    break;
                case 731527633:
                    if (str.equals(ConnectionInitMessage.CONNECTION_INIT)) {
                        z = false;
                        break;
                    }
                    break;
                case 1001241152:
                    if (str.equals(ConnectionTerminateMessage.CONNECTION_TERMINATE)) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    empty = Optional.of((IOperationMessage) this.objectMapper.treeToValue(jsonNode, ConnectionInitMessage.class));
                    break;
                case true:
                    empty = Optional.of((IOperationMessage) this.objectMapper.treeToValue(jsonNode, ConnectionTerminateMessage.class));
                    break;
                case true:
                    empty = Optional.of((IOperationMessage) this.objectMapper.treeToValue(jsonNode, StartMessage.class));
                    break;
                case true:
                    empty = Optional.of((IOperationMessage) this.objectMapper.treeToValue(jsonNode, StopMessage.class));
                    break;
            }
        } catch (JsonProcessingException e) {
            this.logger.warn(e.getMessage(), (Throwable) e);
        }
        return empty;
    }

    @Override // org.springframework.web.socket.handler.AbstractWebSocketHandler, org.springframework.web.socket.WebSocketHandler
    public void afterConnectionEstablished(WebSocketSession webSocketSession) throws Exception {
        this.listener.afterConnectionEstablished(webSocketSession);
        this.sessions2keepAliveSubscriptions.put(webSocketSession, Flux.interval(GRAPHQL_KEEP_ALIVE_INTERVAL).subscribe(l -> {
            send(webSocketSession, new ConnectionKeepAliveMessage());
        }));
    }

    @Override // org.springframework.web.socket.handler.AbstractWebSocketHandler, org.springframework.web.socket.WebSocketHandler
    public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) throws Exception {
        this.listener.afterConnectionClosed(webSocketSession, closeStatus);
        this.sessions2keepAliveSubscriptions.remove(webSocketSession).dispose();
        new ConnectionTerminateMessageHandler(webSocketSession, this.sessions2entries).handle();
    }
}
