// -*- c++ -*-
// Distributed under the BSD 2-Clause License.
// See accompanying file LICENSE for details.
#pragma once

#include <functional>
#include <variant>
#include <optional>
#include <string>
#include <type_traits>
#include <stdexcept>
#include <iostream>
#include <limits>

namespace arg
{
struct noarg {};
template<typename T>
struct Opt
{
	char shortopt;
	std::string longopt;
	std::function<int(T)> cb;
	std::string help;
	T t{};
};

template<>
struct Opt<noarg>
{
	char shortopt;
	std::string longopt;
	std::function<int()> cb;
	std::string help;
	noarg t{};
};

template<typename Callable, typename... Args>
auto call_if(Callable cb, Args... args)
{
	using Ret = std::invoke_result_t<decltype(cb), Args&&...>;
	if constexpr (std::is_same_v<Ret, void>)
	{
		if(cb)
		{
			return cb(std::forward<Args>(args)...);
		}
	}
	else
	{
		if(cb)
		{
			return cb(std::forward<Args>(args)...);
		}
		return Ret{};
	}
}

enum class error
{
	missing_arg,
	invalid_arg,
	invalid_opt,
};

template<typename... Ts>
class Parser
{
public:
	struct missing_arg{};

	Parser(int argc_, const char* const* argv_)
		: argc(argc_)
		, argv(argv_)
	{
	}

	int parse() const
	{
		bool demarcate{false};
		for(int i = 1; i < argc; ++i) // skip argv[0] which is program name
		{
			std::string_view arg{argv[i]};
			if(arg.size() == 0)
			{
				// Empty arg - This shouldn't happen
				continue;
			}

			if(arg[0] != '-' || demarcate) // positional arg
			{
				auto res = call_if(pos_cb, arg);
				if(res != 0)
				{
					return res;
				}
				continue;
			}

			if(arg == "--")
			{
				demarcate = true;
				continue;
			}

			bool was_handled{false};
			enum class state { handled, unhandled };

			if(arg.size() > 1 && arg[0] == '-' && arg[1] == '-') // long
			{
				for(const auto& option : options)
				{
					auto ret =
						std::visit([&](auto&& opt) -> std::pair<int, state>
						           {
							           if(opt.longopt != arg &&
							              !arg.starts_with(opt.longopt+'='))
							           {
								           return {0, state::unhandled};
							           }
							           try
							           {
								           using T = std::decay_t<decltype(opt)>;
								           if constexpr (std::is_same_v<T, Opt<noarg>>)
								           {
									           return {opt.cb(), state::handled};
								           }
								           else
								           {
									           return {opt.cb(convert(i, opt.t)),
									                   state::handled};
								           }
							           }
							           catch(std::invalid_argument&)
							           {
								           call_if(err_cb, error::invalid_arg, argv[i]);
								           return {1, state::handled};
							           }
							           catch(missing_arg&)
							           {
								           call_if(err_cb, error::missing_arg, argv[i]);
								           return {1, state::handled};
							           }
						           }, option);
					if(ret.second == state::handled && ret.first != 0)
					{
						return ret.first;
					}
					was_handled |= ret.second == state::handled;
					if(was_handled)
					{
						break;
					}
				}
			}
			else
			if(arg.size() > 1 && arg[0] == '-') // short
			{
				for(auto index = 1u; index < arg.size(); ++index)
				{
					was_handled = false;
					for(const auto& option : options)
					{
						auto ret =
							std::visit([&](auto&& opt) -> std::pair<int, state>
							           {
								           char c = arg[index];
								           if(opt.shortopt != c)
								           {
									           return {0, state::unhandled};
								           }
								           try
								           {
									           using T = std::decay_t<decltype(opt)>;
									           if constexpr (std::is_same_v<T, Opt<noarg>>)
									           {
										           return {opt.cb(), state::handled};
									           }
									           else
									           {
										           // Note: the rest of arg is converted to opt
										           auto idx = index;
										           // set index out of range all was eaten as arg
										           index = std::numeric_limits<int>::max();
										           return {opt.cb(convert_short(&arg[idx],
										                                        i, opt.t)),
										                   state::handled};
									           }
								           }
								           catch(std::invalid_argument&)
								           {
									           call_if(err_cb, error::invalid_arg, argv[i]);
									           return {1, state::handled};
								           }
								           catch(missing_arg&)
								           {
									           call_if(err_cb, error::missing_arg, argv[i]);
									           return {1, state::handled};
								           }
							           }, option);
						if(ret.second == state::handled && ret.first != 0)
						{
							return ret.first;
						}
						was_handled |= ret.second == state::handled;
						if(was_handled)
						{
							break;
						}
					}
				}
			}

			if(!was_handled)
			{
				call_if(err_cb, error::invalid_opt, arg);
				return 1;
			}
		}
		return 0;
	}

	template<typename T>
	void add(char shortopt,
	         const std::string& longopt,
	         std::function<int(T)> cb,
	         const std::string& help)
	{
		options.emplace_back(Opt<T>{shortopt, longopt, cb, help});
	}

	void add(char shortopt,
	         const std::string& longopt,
	         std::function<int()> cb,
	         const std::string& help)
	{
		options.emplace_back(Opt<noarg>{shortopt, longopt, cb, help});
	}

	void set_pos_cb(std::function<int(std::string_view)> cb)
	{
		pos_cb = cb;
	}

	void set_err_cb(std::function<void(error, std::string_view)> cb)
	{
		err_cb = cb;
	}

	std::string prog_name() const
	{
		if(argc < 1)
		{
			return {};
		}
		return argv[0];
	}

	void help() const
	{
		constexpr std::size_t width{26};
		constexpr std::size_t column_width{80};

		for(const auto& option : options)
		{
			std::visit(
				[&](auto&& opt)
				{
					std::string _args;
					using T = std::decay_t<decltype(opt)>;
					if constexpr (std::is_same_v<T, Opt<noarg>>)
					{
					}
					else if constexpr (std::is_same_v<T, Opt<int>>)
					{
						_args = "<int>";
					}
					else if constexpr (std::is_same_v<T, Opt<std::optional<int>>>)
					{
						_args = "[int]";
					}
					else if constexpr (std::is_same_v<T, Opt<std::string>>)
					{
						_args = "<str>";
					}
					else if constexpr (std::is_same_v<T, Opt<std::optional<std::string>>>)
					{
						_args = "[str]";
					}
					else if constexpr (std::is_same_v<T, Opt<double>>)
					{
						_args = "<real>";
					}
					else if constexpr (std::is_same_v<T, Opt<std::optional<double>>>)
					{
						_args = "[real]";
					}
					else
					{
						static_assert(std::is_same_v<T, void>, "missing");
					}

					std::string option_str;
					if(opt.shortopt != '\0' && !opt.longopt.empty())
					{
						option_str = "  -" + std::string(1, opt.shortopt) + ", " +
							opt.longopt + " " + _args;
					}
					else if(opt.shortopt != '\0')
					{
						option_str = "  -" + std::string(1, opt.shortopt) + _args;
					}
					else if(!opt.longopt.empty())
					{
						option_str = "      " + std::string(opt.longopt) + " " + _args;
					}

					std::string padding;
					if(option_str.size() < width)
					{
						padding.append(width - option_str.size(), ' ');
					}
					else
					{
						padding = "\n";
						padding.append(width, ' ');
					}

					std::cout << option_str << padding;

					auto i = width;
					for(auto c : opt.help)
					{
						if((c == '\n') || (i > column_width && (c == ' ' || c == '\t')))
						{
							std::string _padding(width, ' ');
							std::cout << '\n' << _padding;
							i = width;
							continue;
						}
						std::cout << c;
						++i;
					}
					std::cout << '\n';
				}, option);
		}
	}

private:
	template<typename T>
	T convert(int& i, T) const
	{
		auto opt = convert(i, std::optional<T>{});
		if(!opt)
		{
			throw missing_arg{};
		}
		return *opt;
	}

	template<typename T>
	std::optional<T> convert(int& i, std::optional<T>) const
	{
		std::string arg;
		bool has_arg{false};
		std::string opt = argv[i];
		if(opt.starts_with("--"))
		{
			// long opt
			auto equals_pos = opt.find('=');
			if(equals_pos != std::string::npos)
			{
				arg = opt.substr(equals_pos + 1);
				has_arg = true;
			}
			else if(i+1 < argc)
			{
				arg = argv[i+1];
				has_arg = !arg.starts_with("-");
				if(has_arg)
				{
					++i;
				}
			}
		}

		if(!has_arg)
		{
			return {};
		}

		if constexpr (std::is_same_v<T, int>)
		{
			return std::stoi(arg);
		}
		else if constexpr (std::is_same_v<T, double>)
		{
			return std::stod(arg);
		}
		else if constexpr (std::is_same_v<T, std::string>)
		{
			return arg;
		}
		else
		{
			static_assert(std::is_same_v<T, void>, "missing");
		}
		return {};
	}

	template<typename T>
	T convert_short(const char* arg_, int& i, T) const
	{
		auto opt = convert_short(arg_, i, std::optional<T>{}, false);
		if(!opt)
		{
			throw missing_arg{};
		}
		return *opt;
	}

	template<typename T>
	std::optional<T> convert_short(const char* arg_, int& i,
	                               std::optional<T>, bool optional = true) const
	{
		std::string arg;
		bool has_arg{false};
		std::string opt = arg_;
		if(opt.length() > 1)
		{
			// arg in same token
			arg = opt.substr(1);
			has_arg = true;
		}
		else if(!optional && i+1 < argc)
		{
			arg = argv[i+1];
			has_arg = true;//!arg.starts_with("-");
			if(has_arg)
			{
				++i;
			}
		}

		if(!has_arg)
		{
			return {};
		}

		if constexpr (std::is_same_v<T, int>)
		{
			return std::stoi(arg);
		}
		else if constexpr (std::is_same_v<T, double>)
		{
			return std::stod(arg);
		}
		else if constexpr (std::is_same_v<T, std::string>)
		{
			return arg;
		}
		else
		{
			static_assert(std::is_same_v<T, void>, "missing");
		}
		return {};
	}

	using Opts = std::variant<Opt<noarg>, Opt<Ts>...>;
	std::vector<Opts> options;
	std::function<int(std::string_view)> pos_cb;
	int argc;
	const char* const* argv;
	std::function<void(error, std::string_view)> err_cb;
};

} // arg::