/*
 * Decompiled with CFR 0.152.
 */
package openmods.calc.types.multi;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import openmods.calc.BinaryOperator;
import openmods.calc.Environment;
import openmods.calc.ExecutionErrorException;
import openmods.calc.Frame;
import openmods.calc.FrameFactory;
import openmods.calc.ICallable;
import openmods.calc.IExecutable;
import openmods.calc.ISymbol;
import openmods.calc.LocalSymbolMap;
import openmods.calc.SymbolCall;
import openmods.calc.SymbolMap;
import openmods.calc.UnaryOperator;
import openmods.calc.Value;
import openmods.calc.parsing.BinaryOpNode;
import openmods.calc.parsing.ICompilerState;
import openmods.calc.parsing.IExprNode;
import openmods.calc.parsing.ISymbolCallStateTransition;
import openmods.calc.parsing.SameStateSymbolTransition;
import openmods.calc.parsing.SymbolCallNode;
import openmods.calc.parsing.SymbolGetNode;
import openmods.calc.types.multi.BindPatternTranslator;
import openmods.calc.types.multi.ClosureCompilerHelper;
import openmods.calc.types.multi.Code;
import openmods.calc.types.multi.Cons;
import openmods.calc.types.multi.IBindPattern;
import openmods.calc.types.multi.ScopeModifierNode;
import openmods.calc.types.multi.Symbol;
import openmods.calc.types.multi.TypeDomain;
import openmods.calc.types.multi.TypedCalcUtils;
import openmods.calc.types.multi.TypedValue;
import openmods.utils.OptionalInt;
import openmods.utils.Stack;

public class LetExpressionFactory {
    private final TypeDomain domain;
    private final TypedValue nullValue;
    private final BinaryOperator<TypedValue> colonOperator;
    private final BinaryOperator<TypedValue> assignOperator;
    private final BinaryOperator<TypedValue> lambdaOperator;
    private final ClosureCompilerHelper closureCompiler;

    public LetExpressionFactory(TypeDomain domain, TypedValue nullValue, BinaryOperator<TypedValue> colonOperator, BinaryOperator<TypedValue> assignOperator, BinaryOperator<TypedValue> lambdaOperator, UnaryOperator<TypedValue> varArgMarker) {
        this.domain = domain;
        this.nullValue = nullValue;
        this.colonOperator = colonOperator;
        this.assignOperator = assignOperator;
        this.lambdaOperator = lambdaOperator;
        this.closureCompiler = new ClosureCompilerHelper(domain, varArgMarker);
    }

    public ISymbolCallStateTransition<TypedValue> createLetStateTransition(ICompilerState<TypedValue> parentState) {
        return new LetStateTransition("let", parentState);
    }

    public ISymbolCallStateTransition<TypedValue> createLetSeqStateTransition(ICompilerState<TypedValue> parentState) {
        return new LetStateTransition("letseq", parentState);
    }

    public ISymbolCallStateTransition<TypedValue> createLetRecStateTransition(ICompilerState<TypedValue> parentState) {
        return new LetStateTransition("letrec", parentState);
    }

    private static void copySymbols(Set<String> names, SymbolMap<TypedValue> from, SymbolMap<TypedValue> to) {
        for (String bindName : names) {
            ISymbol<TypedValue> outputSymbol = from.get(bindName);
            Preconditions.checkState((outputSymbol != null ? 1 : 0) != 0, (String)"Symbol not defined: %s", (Object[])new Object[]{bindName});
            to.put(bindName, (TypedValue)((Object)outputSymbol));
        }
    }

    private static void fillPlaceholders(Set<String> bindNames, SymbolMap<TypedValue> symbols) {
        for (String bindName : bindNames) {
            symbols.put(bindName, (TypedValue)((Object)new PlaceholderSymbol()));
        }
    }

    private static Set<String> extractBindNames(IBindPattern pattern) {
        HashSet bindNames = Sets.newHashSet();
        pattern.listBoundVars(bindNames);
        return bindNames;
    }

    private static TypedValue executeForSingleResult(Frame<TypedValue> frame, Code expr) {
        expr.execute(frame);
        return frame.stack().popAndExpectEmptyStack();
    }

    public void registerSymbol(Environment<TypedValue> env) {
        env.setGlobalSymbol("let", (TypedValue)((Object)new LetSymbol()));
        env.setGlobalSymbol("letseq", (TypedValue)((Object)new LetSeqSymbol()));
        env.setGlobalSymbol("letrec", (TypedValue)((Object)new LetRecSymbol()));
    }

    private class LetRecSymbol
    extends LetSymbolBase {
        private LetRecSymbol() {
        }

        @Override
        protected void prepareFrame(SymbolMap<TypedValue> outputSymbols, SymbolMap<TypedValue> callSymbols, Cons vars) {
            final HashSet bindNames = Sets.newHashSet();
            final ArrayList varsToExecute = Lists.newArrayList();
            vars.visit(new ArgPairVisitor(){

                @Override
                protected void acceptVar(IBindPattern pattern, Code expr) {
                    pattern.listBoundVars(bindNames);
                    varsToExecute.add(new PatternInitializerCodePair(pattern, expr));
                }
            });
            LocalSymbolMap<TypedValue> placeholderSymbols = new LocalSymbolMap<TypedValue>(callSymbols);
            LetExpressionFactory.fillPlaceholders(bindNames, placeholderSymbols);
            for (PatternInitializerCodePair e : varsToExecute) {
                Frame<TypedValue> executionFrame = FrameFactory.newLocalFrame(placeholderSymbols);
                TypedValue result = LetExpressionFactory.executeForSingleResult(executionFrame, e.code);
                TypedCalcUtils.matchPattern(e.pattern, executionFrame, outputSymbols, result);
            }
            LetExpressionFactory.copySymbols(bindNames, outputSymbols, placeholderSymbols);
        }
    }

    private static class PatternInitializerCodePair {
        public final IBindPattern pattern;
        public final Code code;

        public PatternInitializerCodePair(IBindPattern pattern, Code code) {
            this.pattern = pattern;
            this.code = code;
        }
    }

    private class LetSeqSymbol
    extends LetSymbolBase {
        private LetSeqSymbol() {
        }

        @Override
        protected void prepareFrame(final SymbolMap<TypedValue> outputSymbols, SymbolMap<TypedValue> callSymbols, Cons vars) {
            vars.visit(new ArgPairVisitor(){

                @Override
                protected void acceptVar(IBindPattern pattern, Code expr) {
                    Set bindNames = LetExpressionFactory.extractBindNames(pattern);
                    LetExpressionFactory.fillPlaceholders(bindNames, outputSymbols);
                    Frame<TypedValue> executionFrame = FrameFactory.symbolsToFrame(outputSymbols);
                    TypedValue result = LetExpressionFactory.executeForSingleResult(executionFrame, expr);
                    TypedCalcUtils.matchPattern(pattern, executionFrame, outputSymbols, result);
                    LetExpressionFactory.copySymbols(bindNames, outputSymbols, executionFrame.symbols());
                }
            });
        }
    }

    private class LetSymbol
    extends LetSymbolBase {
        private LetSymbol() {
        }

        @Override
        protected void prepareFrame(final SymbolMap<TypedValue> outputSymbols, final SymbolMap<TypedValue> callSymbols, Cons vars) {
            vars.visit(new ArgPairVisitor(){

                @Override
                protected void acceptVar(IBindPattern pattern, Code expr) {
                    Frame<TypedValue> executionFrame = FrameFactory.newLocalFrame(callSymbols);
                    SymbolMap executionSymbols = executionFrame.symbols();
                    Set bindNames = LetExpressionFactory.extractBindNames(pattern);
                    LetExpressionFactory.fillPlaceholders(bindNames, executionSymbols);
                    TypedValue result = LetExpressionFactory.executeForSingleResult(executionFrame, expr);
                    TypedCalcUtils.matchPattern(pattern, executionFrame, outputSymbols, result);
                    LetExpressionFactory.copySymbols(bindNames, outputSymbols, executionSymbols);
                }
            });
        }
    }

    private abstract class ArgPairVisitor
    extends Cons.ListVisitor {
        public ArgPairVisitor() {
            super(LetExpressionFactory.this.nullValue);
        }

        @Override
        public void value(TypedValue value, boolean isLast) {
            IBindPattern pattern;
            if (!value.is(Cons.class)) {
                throw new InvalidArgsException();
            }
            Cons pair = value.as(Cons.class);
            TypedValue patternValue = pair.car;
            if (patternValue.is(IBindPattern.class)) {
                pattern = patternValue.as(IBindPattern.class);
            } else if (patternValue.is(Symbol.class)) {
                pattern = BindPatternTranslator.createPatternForVarName(patternValue.as(Symbol.class).value);
            } else {
                throw new IllegalArgumentException("Invalid bind pattern: " + patternValue);
            }
            if (!pair.cdr.is(Code.class)) {
                throw new InvalidArgsException();
            }
            Code valueExpr = pair.cdr.as(Code.class);
            this.acceptVar(pattern, valueExpr);
        }

        protected abstract void acceptVar(IBindPattern var1, Code var2);

        @Override
        public void end(TypedValue terminator) {
        }

        @Override
        public void begin() {
        }
    }

    private static abstract class LetSymbolBase
    implements ICallable<TypedValue> {
        private LetSymbolBase() {
        }

        @Override
        public void call(Frame<TypedValue> currentFrame, OptionalInt argumentsCount, OptionalInt returnsCount) {
            TypedCalcUtils.expectExactArgCount(argumentsCount, 2);
            Frame<TypedValue> letFrame = FrameFactory.newLocalFrameWithSubstack(currentFrame, 2);
            Stack<TypedValue> letStack = letFrame.stack();
            Code code = letStack.pop().as(Code.class, "second (code) 'let' parameter");
            Cons vars = letStack.pop().as(Cons.class, "first (var list) 'let'  parameter");
            try {
                this.prepareFrame(letFrame.symbols(), currentFrame.symbols(), vars);
            }
            catch (InvalidArgsException e) {
                throw new IllegalArgumentException("Expected list of name:value pairs on second 'let' parameter, got " + vars, e);
            }
            code.execute(letFrame);
            TypedCalcUtils.expectExactReturnCount(returnsCount, letStack.size());
        }

        protected abstract void prepareFrame(SymbolMap<TypedValue> var1, SymbolMap<TypedValue> var2, Cons var3);
    }

    private static class InvalidArgsException
    extends RuntimeException {
        private InvalidArgsException() {
        }
    }

    private static class PlaceholderSymbol
    implements ISymbol<TypedValue> {
        private PlaceholderSymbol() {
        }

        @Override
        public void call(Frame<TypedValue> frame, OptionalInt argumentsCount, OptionalInt returnsCount) {
            throw new ExecutionErrorException("Cannot call symbol during definition");
        }

        @Override
        public TypedValue get() {
            throw new ExecutionErrorException("Cannot reference symbol during definition");
        }
    }

    private class LetStateTransition
    extends SameStateSymbolTransition<TypedValue> {
        private final String letState;

        public LetStateTransition(String letState, ICompilerState<TypedValue> parentState) {
            super(parentState);
            this.letState = letState;
        }

        @Override
        public IExprNode<TypedValue> createRootNode(List<IExprNode<TypedValue>> children) {
            return new LetNode(this.letState, children);
        }
    }

    private class LetNode
    extends ScopeModifierNode {
        public LetNode(String letSymbol, List<IExprNode<TypedValue>> children) {
            super(LetExpressionFactory.this.domain, letSymbol, LetExpressionFactory.this.colonOperator, LetExpressionFactory.this.assignOperator, children);
        }

        @Override
        protected void handlePairOp(List<IExecutable<TypedValue>> output, BinaryOpNode<TypedValue> opNode) {
            if (opNode.operator != LetExpressionFactory.this.lambdaOperator) {
                throw new UnsupportedOperationException("Expected '=', ':' or '->' as pair separators, got " + opNode.operator);
            }
            this.flattenLambdaDefinition(output, opNode);
        }

        @Override
        protected void flattenNameAndValue(List<IExecutable<TypedValue>> output, IExprNode<TypedValue> bindPattern, IExprNode<TypedValue> value) {
            this.flattenBindPattern(output, bindPattern);
            output.add(Value.create(Code.flattenAndWrap(LetExpressionFactory.this.domain, value)));
        }

        private void flattenLambdaDefinition(List<IExecutable<TypedValue>> output, BinaryOpNode<TypedValue> opNode) {
            ImmutableList lambdaArgs;
            TypedValue varName;
            IExprNode nameNode = opNode.left;
            IExprNode<TypedValue> lambdaBody = opNode.right;
            if (nameNode instanceof SymbolCallNode) {
                String symbolName = ((SymbolCallNode)nameNode).symbol();
                varName = Symbol.get(LetExpressionFactory.this.domain, symbolName);
                lambdaArgs = nameNode.getChildren();
            } else if (nameNode instanceof SymbolGetNode) {
                varName = Symbol.get(LetExpressionFactory.this.domain, ((SymbolGetNode)nameNode).symbol());
                lambdaArgs = ImmutableList.of();
            } else {
                throw new IllegalArgumentException("Cannot extract value name from " + nameNode);
            }
            output.add(Value.create(varName));
            output.add(Value.create(this.createLambdaWrapperCode((Iterable<IExprNode<TypedValue>>)lambdaArgs, lambdaBody)));
        }

        private TypedValue createLambdaWrapperCode(Iterable<IExprNode<TypedValue>> args, IExprNode<TypedValue> body) {
            ArrayList result = Lists.newArrayList();
            LetExpressionFactory.this.closureCompiler.compile(result, args, body);
            return Code.wrap(LetExpressionFactory.this.domain, result);
        }

        private void flattenBindPattern(List<IExecutable<TypedValue>> output, IExprNode<TypedValue> bindPattern) {
            if (bindPattern instanceof SymbolGetNode) {
                SymbolGetNode var = (SymbolGetNode)bindPattern;
                output.add(Value.create(Symbol.get(LetExpressionFactory.this.domain, var.symbol())));
            } else {
                output.add(Value.create(Code.flattenAndWrap(LetExpressionFactory.this.domain, bindPattern)));
                output.add(new SymbolCall("pattern", 1, 1));
            }
        }
    }
}

