import React, { useState, useCallback } from 'react';
import PropTypes from 'prop-types';
import { Button, Card, CardContent, CircularProgress, Typography, TextField, Box } from '@mui/material';
import ReactMarkdown from 'react-markdown';
import useApi from "../ApiService";

// Hook: useRagGeneration
const useRagGeneration = () => {
    const { apiCall } = useApi();
    const [isLoading, setIsLoading] = useState(false);
    const [error, setError] = useState(null);

    const generateContent = useCallback(async (requestBody, ids) => {
        setIsLoading(true);
        setError(null);
        try {
            const endpoint = ids ? `/ai/rag/source?ids=${ids.join(',')}` : '/ai/rag/content';
            const response = await apiCall(endpoint, {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json',
                },
                data: JSON.stringify(requestBody)
            });
            setIsLoading(false);
            return response.data.generation;
        } catch (error) {
            console.error('Error fetching deeper content:', error);
            setError('Failed to load content. Please try again.');
            setIsLoading(false);
            return null;
        }
    }, [apiCall]);

    return { generateContent, isLoading, error };
};

// Component: RAGGeneration
const RAGGeneration = ({
                           contentSupplier,
                           contextSupplier,
                           specialInstructionsSupplier,
                           querySupplier,
                           buttonText = 'Go Deeper',
                           userQueryLabel = "Enter your query",
                           userContentLabel = "Enter your content",
                           userContextLabel = "Enter your context",
                           userSpecialInstructionsLabel = "Enter special instructions",
                           contentIsViaSourceIdList = false
                       }) => {
    const [deeperContent, setDeeperContent] = useState(null);
    const [isContentVisible, setIsContentVisible] = useState(false);
    const [userQuery, setUserQuery] = useState('');
    const [userContent, setUserContent] = useState('');
    const [userContext, setUserContext] = useState('');
    const [userSpecialInstructions, setUserSpecialInstructions] = useState('');
    const [inputsChanged, setInputsChanged] = useState(false);

    const { generateContent, isLoading, error } = useRagGeneration();

    const handleInputChange = (setter) => (e) => {
        setter(e.target.value);
        setInputsChanged(true);
    };

    const handleGoDeeper = useCallback(async () => {
        if (!inputsChanged && deeperContent) {
            setIsContentVisible(!isContentVisible);
            return;
        }

        const content = contentSupplier ? await contentSupplier() : userContent;
        const query = querySupplier ? await querySupplier() : userQuery;
        const context = contextSupplier ? await contextSupplier() : userContext;
        const specialInstructions = specialInstructionsSupplier ? await specialInstructionsSupplier() : userSpecialInstructions;

        const requestBody = {
            query,
            context,
            specialInstructions
        };

        if (!contentIsViaSourceIdList) {
            requestBody.content = content;
        }

        setDeeperContent(null);
        setIsContentVisible(false);
        setInputsChanged(false);

        const result = await generateContent(requestBody, contentIsViaSourceIdList ? content : null);
        if (result) {
            setDeeperContent(result);
            setIsContentVisible(true);
        }
    }, [contentSupplier, contextSupplier, specialInstructionsSupplier, querySupplier,
        userContent, userQuery, userContext, userSpecialInstructions,
        deeperContent, generateContent, inputsChanged, setIsContentVisible,
        setDeeperContent, setInputsChanged, contentIsViaSourceIdList]);

    return (
        <Box sx={{ display: 'flex', flexDirection: 'column', gap: 1 }}>
            {!querySupplier && (
                <TextField
                    fullWidth
                    variant="outlined"
                    label={userQueryLabel}
                    value={userQuery}
                    onChange={handleInputChange(setUserQuery)}
                    multiline
                    rows={2}
                    size="small"
                    sx={{ mb: 1 }}
                    aria-label="User query input"
                />
            )}
            {!contentSupplier && (
                <TextField
                    fullWidth
                    variant="outlined"
                    label={userContentLabel}
                    value={userContent}
                    onChange={handleInputChange(setUserContent)}
                    multiline
                    rows={3}
                    size="small"
                    sx={{ mb: 1 }}
                    aria-label="User content input"
                />
            )}
            {!contextSupplier && (
                <TextField
                    fullWidth
                    variant="outlined"
                    label={userContextLabel}
                    value={userContext}
                    onChange={handleInputChange(setUserContext)}
                    multiline
                    rows={2}
                    size="small"
                    sx={{ mb: 1 }}
                    aria-label="User context input"
                />
            )}
            {!specialInstructionsSupplier && (
                <TextField
                    fullWidth
                    variant="outlined"
                    label={userSpecialInstructionsLabel}
                    value={userSpecialInstructions}
                    onChange={handleInputChange(setUserSpecialInstructions)}
                    multiline
                    rows={2}
                    size="small"
                    sx={{ mb: 1 }}
                    aria-label="User special instructions input"
                />
            )}
            <Box sx={{ display: 'flex', gap: 1 }}>
                <Button
                    variant="contained"
                    color="primary"
                    onClick={handleGoDeeper}
                    startIcon={isLoading ? <CircularProgress size={20} color="inherit" /> : null}
                    disabled={isLoading}
                    size="small"
                    aria-label="Generate deeper content"
                >
                    {isLoading ? 'Loading...' : buttonText}
                </Button>
                {deeperContent && (
                    <Button
                        variant="outlined"
                        color="primary"
                        onClick={() => setIsContentVisible(!isContentVisible)}
                        size="small"
                        aria-label={isContentVisible ? 'Hide content' : 'Show content'}
                    >
                        {isContentVisible ? 'Hide' : 'Show'}
                    </Button>
                )}
            </Box>

            {error && (
                <Typography color="error" variant="body2">
                    Error: {error}
                </Typography>
            )}

            {isContentVisible && deeperContent && (
                <Card>
                    <CardContent>
                        <Typography component="div">
                            <ReactMarkdown>{deeperContent}</ReactMarkdown>
                        </Typography>
                    </CardContent>
                </Card>
            )}
        </Box>
    );
};

RAGGeneration.propTypes = {
    contentSupplier: PropTypes.func,
    contextSupplier: PropTypes.func,
    specialInstructionsSupplier: PropTypes.func,
    querySupplier: PropTypes.func,
    buttonText: PropTypes.string,
    userQueryLabel: PropTypes.string,
    userContentLabel: PropTypes.string,
    userContextLabel: PropTypes.string,
    userSpecialInstructionsLabel: PropTypes.string,
    contentIsViaSourceIdList: PropTypes.bool
};

export default RAGGeneration;