// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package account

import (
	"context"
	"encoding/csv"
	"errors"
	"fmt"
	"mime/multipart"

	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
	"github.com/superseriousbusiness/gotosocial/internal/gtserror"
	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
	"github.com/superseriousbusiness/gotosocial/internal/log"
)

func (p *Processor) ImportData(
	ctx context.Context,
	requester *gtsmodel.Account,
	data *multipart.FileHeader,
	importType string,
	overwrite bool,
) gtserror.WithCode {
	switch importType {

	case "following":
		return p.importFollowing(
			ctx,
			requester,
			data,
			overwrite,
		)

	case "blocks":
		return p.importBlocks(
			ctx,
			requester,
			data,
			overwrite,
		)

	default:
		const text = "import type not yet supported"
		return gtserror.NewErrorUnprocessableEntity(errors.New(text), text)
	}
}

func (p *Processor) importFollowing(
	ctx context.Context,
	requester *gtsmodel.Account,
	followingData *multipart.FileHeader,
	overwrite bool,
) gtserror.WithCode {
	file, err := followingData.Open()
	if err != nil {
		err := fmt.Errorf("error opening following data file: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}
	defer file.Close()

	// Parse records out of the file.
	records, err := csv.NewReader(file).ReadAll()
	if err != nil {
		err := fmt.Errorf("error reading following data file: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}

	// Convert the records into a slice of barebones follows.
	//
	// Only TargetAccount.Username, TargetAccount.Domain,
	// and ShowReblogs will be set on each Follow.
	follows, err := p.converter.CSVToFollowing(ctx, records)
	if err != nil {
		err := fmt.Errorf("error converting records to follows: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}

	// Do remaining processing of this import asynchronously.
	f := importFollowingAsyncF(p, requester, follows, overwrite)
	p.state.Workers.Processing.Queue.Push(f)

	return nil
}

func importFollowingAsyncF(
	p *Processor,
	requester *gtsmodel.Account,
	follows []*gtsmodel.Follow,
	overwrite bool,
) func(context.Context) {
	return func(ctx context.Context) {
		// Map used to store wanted
		// follow targets (if overwriting).
		var wantedFollows map[string]struct{}

		if overwrite {
			// If we're overwriting, we need to get current
			// follow(-req)s owned by requester *before*
			// making any changes, so that we can remove
			// unwanted follows after we've created new ones.
			prevFollows, err := p.state.DB.GetAccountFollows(ctx, requester.ID, nil)
			if err != nil {
				log.Errorf(ctx, "db error getting following: %v", err)
				return
			}

			prevFollowReqs, err := p.state.DB.GetAccountFollowRequesting(ctx, requester.ID, nil)
			if err != nil {
				log.Errorf(ctx, "db error getting follow requesting: %v", err)
				return
			}

			// Initialize new follows map.
			wantedFollows = make(map[string]struct{}, len(follows))

			// Once we've created (or tried to create)
			// the required follows, go through previous
			// follow(-request)s and remove unwanted ones.
			defer func() {

				// AccountIDs to unfollow.
				toRemove := []string{}

				// Check previous follows.
				for _, prev := range prevFollows {
					username := prev.TargetAccount.Username
					domain := prev.TargetAccount.Domain

					_, wanted := wantedFollows[username+"@"+domain]
					if !wanted {
						toRemove = append(toRemove, prev.TargetAccountID)
					}
				}

				// Now any pending follow requests.
				for _, prev := range prevFollowReqs {
					username := prev.TargetAccount.Username
					domain := prev.TargetAccount.Domain

					_, wanted := wantedFollows[username+"@"+domain]
					if !wanted {
						toRemove = append(toRemove, prev.TargetAccountID)
					}
				}

				// Remove each discovered
				// unwanted follow.
				for _, accountID := range toRemove {
					if _, errWithCode := p.FollowRemove(
						ctx,
						requester,
						accountID,
					); errWithCode != nil {
						log.Errorf(ctx, "could not unfollow account: %v", errWithCode.Unwrap())
						continue
					}
				}
			}()
		}

		// Go through the follows parsed from CSV
		// file, and create / update each one.
		for _, follow := range follows {
			var (
				// Username of the target.
				username = follow.TargetAccount.Username

				// Domain of the target.
				// Empty for our domain.
				domain = follow.TargetAccount.Domain

				// Show reblogs on
				// the new follow.
				showReblogs = follow.ShowReblogs

				// Notify when new
				// follow posts.
				notify = follow.Notify
			)

			if overwrite {
				// We'll be overwriting, so store
				// this new follow in our handy map.
				wantedFollows[username+"@"+domain] = struct{}{}
			}

			// Get the target account, dereferencing it if necessary.
			targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain(
				ctx,
				requester.Username,
				username,
				domain,
			)
			if err != nil {
				log.Errorf(ctx, "could not retrieve account: %v", err)
				continue
			}

			// Use the processor's FollowCreate function
			// to create or update the follow. This takes
			// account of existing follows, and also sends
			// the follow to the FromClientAPI processor.
			if _, errWithCode := p.FollowCreate(
				ctx,
				requester,
				&apimodel.AccountFollowRequest{
					ID:      targetAcct.ID,
					Reblogs: showReblogs,
					Notify:  notify,
				},
			); errWithCode != nil {
				log.Errorf(ctx, "could not follow account: %v", errWithCode.Unwrap())
				continue
			}
		}
	}
}

func (p *Processor) importBlocks(
	ctx context.Context,
	requester *gtsmodel.Account,
	blocksData *multipart.FileHeader,
	overwrite bool,
) gtserror.WithCode {
	file, err := blocksData.Open()
	if err != nil {
		err := fmt.Errorf("error opening blocks data file: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}
	defer file.Close()

	// Parse records out of the file.
	records, err := csv.NewReader(file).ReadAll()
	if err != nil {
		err := fmt.Errorf("error reading blocks data file: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}

	// Convert the records into a slice of barebones blocks.
	//
	// Only TargetAccount.Username and TargetAccount.Domain,
	// will be set on each Block.
	blocks, err := p.converter.CSVToBlocks(ctx, records)
	if err != nil {
		err := fmt.Errorf("error converting records to blocks: %w", err)
		return gtserror.NewErrorBadRequest(err, err.Error())
	}

	// Do remaining processing of this import asynchronously.
	f := importBlocksAsyncF(p, requester, blocks, overwrite)
	p.state.Workers.Processing.Queue.Push(f)

	return nil
}

func importBlocksAsyncF(
	p *Processor,
	requester *gtsmodel.Account,
	blocks []*gtsmodel.Block,
	overwrite bool,
) func(context.Context) {
	return func(ctx context.Context) {
		// Map used to store wanted
		// block targets (if overwriting).
		var wantedBlocks map[string]struct{}

		if overwrite {
			// If we're overwriting, we need to get current
			// blocks owned by requester *before* making any
			// changes, so that we can remove unwanted blocks
			// after we've created new ones.
			var (
				prevBlocks []*gtsmodel.Block
				err        error
			)

			prevBlocks, err = p.state.DB.GetAccountBlocks(ctx, requester.ID, nil)
			if err != nil {
				log.Errorf(ctx, "db error getting blocks: %v", err)
				return
			}

			// Initialize new blocks map.
			wantedBlocks = make(map[string]struct{}, len(blocks))

			// Once we've created (or tried to create)
			// the required blocks, go through previous
			// blocks and remove unwanted ones.
			defer func() {
				for _, prev := range prevBlocks {
					username := prev.TargetAccount.Username
					domain := prev.TargetAccount.Domain

					_, wanted := wantedBlocks[username+"@"+domain]
					if wanted {
						// Leave this
						// one alone.
						continue
					}

					if _, errWithCode := p.BlockRemove(
						ctx,
						requester,
						prev.TargetAccountID,
					); errWithCode != nil {
						log.Errorf(ctx, "could not unblock account: %v", errWithCode.Unwrap())
						continue
					}
				}
			}()
		}

		// Go through the blocks parsed from CSV
		// file, and create / update each one.
		for _, block := range blocks {
			var (
				// Username of the target.
				username = block.TargetAccount.Username

				// Domain of the target.
				// Empty for our domain.
				domain = block.TargetAccount.Domain
			)

			if overwrite {
				// We'll be overwriting, so store
				// this new block in our handy map.
				wantedBlocks[username+"@"+domain] = struct{}{}
			}

			// Get the target account, dereferencing it if necessary.
			targetAcct, _, err := p.federator.Dereferencer.GetAccountByUsernameDomain(
				ctx,
				// Provide empty request user to use the
				// instance account to deref the account.
				//
				// It's pointless to make lots of calls
				// to a remote from an account that's about
				// to block that account.
				"",
				username,
				domain,
			)
			if err != nil {
				log.Errorf(ctx, "could not retrieve account: %v", err)
				continue
			}

			// Use the processor's BlockCreate function
			// to create or update the block. This takes
			// account of existing blocks, and also sends
			// the block to the FromClientAPI processor.
			if _, errWithCode := p.BlockCreate(
				ctx,
				requester,
				targetAcct.ID,
			); errWithCode != nil {
				log.Errorf(ctx, "could not block account: %v", errWithCode.Unwrap())
				continue
			}
		}
	}
}