package es7

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	elastic "github.com/elastic/go-elasticsearch/v7"
	"github.com/elastic/go-elasticsearch/v7/esapi"
	"github.com/loveuer/nfflow/internal/model"
	"github.com/loveuer/nfflow/internal/opt"
	"github.com/loveuer/nfflow/internal/util"
	"github.com/sirupsen/logrus"
)

type ES7 struct {
	cli    *elastic.Client
	scroll string
	cfg    struct {
		Endpoints []string
		Username  string
		Password  string
		Size      int
		Query     map[string]any
		Source    []string
	}
}

func (e *ES7) init() error {
	var (
		err error
		cfg = elastic.Config{
			Addresses:     e.cfg.Endpoints,
			Username:      e.cfg.Username,
			Password:      e.cfg.Password,
			RetryOnStatus: []int{429},
		}
		info *esapi.Response
	)

	if e.cli, err = elastic.NewClient(cfg); err != nil {
		return err
	}

	if info, err = e.cli.Info(e.cli.Info.WithContext(util.Timeout(5))); err != nil {
		return err
	}

	if info.StatusCode != 200 {
		return fmt.Errorf("status=%d msg=%s", info.StatusCode, info.String())
	}

	return nil
}

func (e *ES7) Start(ctx context.Context, task *model.Task, rowCh chan<- model.TaskRow, errCh chan<- error) error {
	var (
		err     error
		result  *esapi.Response
		ready   = make(chan bool)
		decoder *json.Decoder

		hits = new(model.ESResponse)
	)

	if err = e.init(); err != nil {
		logrus.Debugf("ES7.Start: init err=%v", err)
		return err
	}

	qs := []func(*esapi.SearchRequest){
		e.cli.Search.WithContext(util.TimeoutCtx(ctx, opt.ES7OperationTimeout)),
		e.cli.Search.WithScroll(opt.ScrollTimeout),
		e.cli.Search.WithSize(e.cfg.Size),
	}

	if e.cfg.Query != nil && len(e.cfg.Query) > 0 {
		var bs []byte
		if bs, err = json.Marshal(e.cfg.Query); err != nil {
			logrus.Debugf("ES7.Start: marshal query err=%v", err)
			return err
		}

		qs = append(qs, e.cli.Search.WithBody(bytes.NewReader(bs)))
	}

	go func() {
		defer func() {
			if e.scroll != "" {
				var csr *esapi.Response
				if csr, err = e.cli.ClearScroll(
					e.cli.ClearScroll.WithContext(util.TimeoutCtx(ctx, 5)),
					e.cli.ClearScroll.WithScrollID(e.scroll),
				); err != nil {
					logrus.Warnf("ES7.Start: clear scroll=%s err=%v", e.scroll, err)
				} else {
					if csr.StatusCode != 200 {
						logrus.Warnf("ES7.Start: clear scroll=%s status=%d msg=%s", e.scroll, csr.StatusCode, csr.String())
					}
				}
			}

			close(rowCh)
			close(errCh)
		}()

		ready <- true

		if result, err = e.cli.Search(qs...); err != nil {
			logrus.Debugf("ES7.Start: search err=%v", err)
			errCh <- err
			return
		}

		if err = util.CheckES7Response(result); err != nil {
			logrus.Debugf("ES7.Start: search resp err=%v", err)
			errCh <- err
			return
		}

		decoder = json.NewDecoder(result.Body)

		if err = decoder.Decode(hits); err != nil {
			logrus.Debugf("ES7.Start: decode err=%v", err)
			errCh <- err
			return
		}

		if hits.TimedOut {
			err = fmt.Errorf("timeout")
			logrus.Debugf("ES7.Start: search timeout")
			errCh <- err
			return
		}

		e.scroll = hits.ScrollId

		for idx := range hits.Hits.Hits {
			rowCh <- hits.Hits.Hits[idx]
		}

		if len(hits.Hits.Hits) < e.cfg.Size {
			return
		}

		for {
			if result, err = e.cli.Scroll(
				e.cli.Scroll.WithContext(util.TimeoutCtx(ctx, opt.ES7OperationTimeout)),
				e.cli.Scroll.WithScrollID(e.scroll),
			); err != nil {
				logrus.Debugf("ES7.Start: search err=%v", err)
				errCh <- err
				return
			}

			if err = util.CheckES7Response(result); err != nil {
				logrus.Debugf("ES7.Start: search resp err=%v", err)
				errCh <- err
				return
			}

			decoder = json.NewDecoder(result.Body)
			hits = new(model.ESResponse)

			if err = decoder.Decode(hits); err != nil {
				logrus.Debugf("ES7.Start: decode err=%v", err)
				errCh <- err
				return
			}

			if hits.TimedOut {
				err = fmt.Errorf("timeout")
				logrus.Debugf("ES7.Start: search timeout")
				errCh <- err
				return
			}

			for idx := range hits.Hits.Hits {
				rowCh <- hits.Hits.Hits[idx]
			}

			if len(hits.Hits.Hits) < e.cfg.Size {
				return
			}
		}
	}()

	<-ready

	return nil
}