#include <buildboxcommon_grpcretrier.h>
#include <buildboxcommon_requestmetadata.h>
#include <buildboxcommon_schedulingmetadata.h>

#include <gtest/gtest.h>

#include <google/rpc/error_details.grpc.pb.h>

#include <chrono>
#include <functional>
#include <iostream>

using buildboxcommon::GrpcRetrier;

TEST(GrpcRetrier, TestDefaultRetriableCode)
{
    const int retryLimit = 4;
    const std::chrono::milliseconds retryDelay(150);

    auto lambda = [&](grpc::ClientContext &) { return grpc::Status::OK; };

    {
        GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");
        ASSERT_EQ(r.retryableStatusCodes().size(), 1);
        ASSERT_TRUE(
            r.retryableStatusCodes().count(grpc::StatusCode::UNAVAILABLE));
    }
}

TEST(GrpcRetrier, TestDefaultOkCode)
{
    const int retryLimit = 4;
    const std::chrono::milliseconds retryDelay(150);

    auto lambda = [&](grpc::ClientContext &) { return grpc::Status::OK; };

    {
        GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");
        ASSERT_EQ(r.okStatusCodes().size(), 1);
        ASSERT_TRUE(r.okStatusCodes().count(grpc::StatusCode::OK));
    }
}

TEST(GrpcRetrier, TestGetters)
{
    const int retryLimit = 4;
    const std::chrono::milliseconds retryDelay(150);

    auto lambda = [&](grpc::ClientContext &) { return grpc::Status::OK; };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");
    EXPECT_EQ(r.retryLimit(), retryLimit);
    EXPECT_EQ(r.retryDelayBase(), retryDelay);
    EXPECT_EQ(r.requestTimeout(), std::chrono::seconds::zero()); // Default
}

TEST(GrpcRetrier, SimpleSucceedTest)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);

    int numRequests = 0;
    auto lambda = [&](grpc::ClientContext &) {
        numRequests++;
        return grpc::Status::OK;
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");

    EXPECT_TRUE(r.issueRequest());
    EXPECT_EQ(numRequests, 1);
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 0);
}

TEST(GrpcRetrier, OtherException)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);

    /* Suceed once, if called again fail */
    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        if (failures < 1) {
            failures++;
            return grpc::Status(grpc::DEADLINE_EXCEEDED, "failing in test");
        }
        else {
            return grpc::Status::OK;
        }
    };

    const GrpcRetrier::GrpcStatusCodes otherExceptions = {
        grpc::DEADLINE_EXCEEDED};
    GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()", otherExceptions);

    EXPECT_TRUE(r.issueRequest());
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 1);

    failures = -1;

    EXPECT_FALSE(r.issueRequest());
    EXPECT_EQ(r.status().error_code(), grpc::DEADLINE_EXCEEDED);
    EXPECT_EQ(r.status().error_message(), "failing in test");
    EXPECT_EQ(r.retryAttempts(), 1);
}

TEST(GrpcRetrier, MultipleException)
{
    const std::chrono::milliseconds retryDelay(100);
    const GrpcRetrier::GrpcStatusCodes otherExceptions = {
        grpc::DEADLINE_EXCEEDED, grpc::INVALID_ARGUMENT};

    unsigned int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        switch (failures) {
            case 0:
                failures++;
                return grpc::Status(grpc::DEADLINE_EXCEEDED,
                                    "failing in test");
            case 1:
                failures++;
                return grpc::Status(grpc::INVALID_ARGUMENT, "failing in test");
            case 2:
                failures++;
                return grpc::Status(grpc::UNAVAILABLE, "failing in test");
            case 3:
                return grpc::Status::OK;
        }
        return grpc::Status::OK;
    };

    {
        unsigned int retryLimit = 3;
        GrpcRetrier r(retryLimit, retryDelay, lambda, "", otherExceptions);
        EXPECT_TRUE(r.issueRequest());
        EXPECT_TRUE(r.status().ok());
        EXPECT_EQ(r.retryAttempts(), 3);
        EXPECT_EQ(r.requestTimeout(),
                  std::chrono::seconds::zero()); // Default
    }

    failures = 0;

    {
        unsigned int retryLimit = 2;
        GrpcRetrier r(retryLimit, retryDelay, lambda, "", otherExceptions);
        EXPECT_FALSE(r.issueRequest());
        EXPECT_EQ(r.status().error_code(), grpc::UNAVAILABLE);
        EXPECT_EQ(r.status().error_message(), "failing in test");
        EXPECT_EQ(r.retryAttempts(), 2);
        EXPECT_EQ(r.requestTimeout(),
                  std::chrono::seconds::zero()); // Default
    }
}

TEST(GrpcRetrier, AdditionalOkayStatus)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);

    auto lambda = [&](grpc::ClientContext &) {
        return grpc::Status(grpc::CANCELLED, "Returning CANCELLED in test");
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");
    r.addOkStatusCode(grpc::CANCELLED);

    EXPECT_TRUE(r.issueRequest());
    EXPECT_EQ(r.status().error_code(), grpc::CANCELLED);
    EXPECT_EQ(r.retryAttempts(), 0);
}

TEST(GrpcRetrier, AdditionalOkayAndRetryableStatus)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);

    /* Return ABORTED first, then ALREADY_EXISTS to test both retry
       and additional OK status codes
     */
    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        if (failures < 1) {
            failures++;
            return grpc::Status(grpc::ABORTED, "Returning ABORTED in test");
        }
        else {
            return grpc::Status(grpc::ALREADY_EXISTS,
                                "Returning ALREADY_EXISTS in test");
        }
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "lambda()");
    r.addOkStatusCode(grpc::ALREADY_EXISTS);
    r.addRetryableStatusCode(grpc::ABORTED);

    EXPECT_TRUE(r.issueRequest());
    EXPECT_EQ(r.status().error_code(), grpc::ALREADY_EXISTS);
    EXPECT_EQ(r.retryAttempts(), 1);
}

TEST(GrpcRetrier, RequestTimeoutSet)
{
    const int retryLimit = 3;
    const std::chrono::milliseconds retryDelay(100);
    const GrpcRetrier::GrpcStatusCodes otherExceptions = {
        grpc::DEADLINE_EXCEEDED, grpc::INVALID_ARGUMENT};

    const std::chrono::seconds requestTimeout(123);
    const auto testStartTime = std::chrono::system_clock::now();

    auto lambda = [&](grpc::ClientContext &context) {
        auto minDeadline = testStartTime + requestTimeout;
        auto maxDeadline = std::chrono::system_clock::now() + requestTimeout;
        EXPECT_TRUE(context.deadline() >= minDeadline);
        EXPECT_TRUE(context.deadline() <= maxDeadline);
        return grpc::Status::OK;
    };

    {
        GrpcRetrier r(retryLimit, retryDelay, lambda, "", otherExceptions,
                      nullptr, requestTimeout);
        EXPECT_TRUE(r.issueRequest());
        EXPECT_TRUE(r.status().ok());
        EXPECT_EQ(r.requestTimeout(), requestTimeout);
    }
}

TEST(GrpcRetrier, FailWithRequestTimeoutSet)
{
    const int retryLimit = 3;
    const std::chrono::milliseconds retryDelay(1000);

    const std::chrono::seconds requestTimeout(1);

    int failures = 0;
    auto lambda = [&](grpc::ClientContext &context) {
        failures++;
        return grpc::Status(grpc::UNAVAILABLE, "failing in test");
    };

    {
        GrpcRetrier r(retryLimit, retryDelay, lambda, "", {}, nullptr,
                      requestTimeout);
        EXPECT_TRUE(r.issueRequest());
        EXPECT_FALSE(r.status().ok());
        // Expect retrier to abort after 1 attempt due to the timeout.
        EXPECT_EQ(failures, 1);
    }
}

TEST(GrpcRetrier, ExceptionNotIncluded)
{
    const int retryLimit = 3;
    const std::chrono::milliseconds retryDelay(100);
    const GrpcRetrier::GrpcStatusCodes otherExceptions = {
        grpc::DEADLINE_EXCEEDED, grpc::INVALID_ARGUMENT};

    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        switch (failures) {
            case 0: // Original attempt fails => retry
                failures++;
                return grpc::Status(grpc::DEADLINE_EXCEEDED,
                                    "failing in test");
            case 1: // Fail on retry #1 => retry again
                failures++;
                return grpc::Status(grpc::INVALID_ARGUMENT, "failing in test");
            case 2: // Fail on retry #2 w/ non-retryable error => abort
                failures++;
                return grpc::Status(grpc::PERMISSION_DENIED,
                                    "failing in test");
            case 3:
                return grpc::Status::OK;
        }
        return grpc::Status::OK;
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "", otherExceptions);
    EXPECT_TRUE(r.issueRequest());
    EXPECT_EQ(r.status().error_code(), grpc::PERMISSION_DENIED);
    EXPECT_EQ(r.retryAttempts(), 2);
}

TEST(GrpcRetrier, SimpleRetrySucceedTest)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);

    /* Fail once, then succeed. */
    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        if (failures < 1) {
            failures++;
            return grpc::Status(grpc::UNAVAILABLE, "failing in test");
        }
        else {
            return grpc::Status::OK;
        }
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "");
    EXPECT_TRUE(r.issueRequest());
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 1);
}

TEST(GrpcRetrier, SimpleRetryWithStatusCodeMessagePair)
{
    const int retryLimit = 1;
    const std::chrono::milliseconds retryDelay(100);
    const std::set<std::pair<grpc::StatusCode, std::string>> retryablePairs{
        {grpc::UNKNOWN, "retry"}};

    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        return failures++ < 1
                   ? grpc::Status(grpc::UNKNOWN, "retry on this magic message")
                   : grpc::Status::OK;
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "", {}, retryablePairs);
    EXPECT_TRUE(r.issueRequest());
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 1);
}

TEST(GrpcRetrier, SimpleRetryWithStreamRemoved)
{
    const int retryLimit = 2;
    const std::chrono::milliseconds retryDelay(5);

    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        switch (failures++) {
            case 0:
                return grpc::Status(grpc::UNKNOWN, "Stream removed");
            case 1:
                return grpc::Status(grpc::UNKNOWN,
                                    "Stream removed: Truncated message");
            default:
                return grpc::Status::OK;
        }
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "", {});
    EXPECT_TRUE(r.issueRequest());
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 2);
}

TEST(GrpcRetrier, SimpleRetryFailTest)
{
    const int retryLimit = 2;
    const std::chrono::milliseconds retryDelay(100);

    /* Fail three times, then succeed. */
    int failures = 0;
    auto lambda = [&](grpc::ClientContext &) {
        if (failures < 3) {
            failures++;
            return grpc::Status(grpc::UNAVAILABLE, "failing in test");
        }
        else {
            return grpc::Status::OK;
        }
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "");
    EXPECT_FALSE(r.issueRequest());
    EXPECT_EQ(r.status().error_code(), grpc::UNAVAILABLE);
    EXPECT_EQ(r.status().error_message(), "failing in test");
    EXPECT_EQ(r.retryAttempts(), 2);
}

TEST(GrpcRetrier, ServerProvidedDelay)
{
    const int retryLimit = 2;
    const std::chrono::milliseconds retryDelay(100);

    /* Fail one time, then succeed. */
    bool firstRequest = true;
    const std::chrono::milliseconds serverSpecifiedDelay(500);
    auto lambda = [&](grpc::ClientContext &) {
        if (!firstRequest) {
            return grpc::Status::OK;
        }

        firstRequest = false;

        google::protobuf::Duration delay;
        delay.set_seconds(0);
        delay.set_nanos(static_cast<int32_t>(
            std::chrono::nanoseconds(serverSpecifiedDelay).count()));

        google::rpc::RetryInfo retryInfo;
        *retryInfo.mutable_retry_delay() = delay;
        google::rpc::Status detailed_status;
        detailed_status.set_code(grpc::UNAVAILABLE);
        detailed_status.set_message("failing in test");
        detailed_status.add_details()->PackFrom(retryInfo);

        return grpc::Status(grpc::UNAVAILABLE, "failing in test",
                            detailed_status.SerializeAsString());
    };

    GrpcRetrier r(retryLimit, retryDelay, lambda, "");
    EXPECT_TRUE(r.issueRequest());
    EXPECT_TRUE(r.status().ok());
    EXPECT_EQ(r.retryAttempts(), 1);
    EXPECT_EQ(r.retryDelayBase(), serverSpecifiedDelay); // 500 ms
}

TEST(GrpcRetrier, AttachMetadata)
{
    buildboxcommon::RequestMetadataGenerator metadata_generator(
        "testing tool name", "v0.1");
    metadata_generator.set_action_id("action1");
    buildboxcommon::SchedulingMetadataGenerator scheduling_metadata_generator;
    scheduling_metadata_generator.setLocalityHint("test-locality");

    // Automatic success, no need to retry.
    auto grpc_invocation = [&](grpc::ClientContext &) {
        return grpc::Status::OK;
    };

    int attacher_calls = 0;
    auto metadata_attacher = [&](grpc::ClientContext *context) {
        metadata_generator.attach_request_metadata(context);
        scheduling_metadata_generator.attachSchedulingMetadata(context);
        attacher_calls++;
    };

    unsigned int retryLimit = 0;
    const std::chrono::milliseconds retryDelay(0);

    GrpcRetrier r(retryLimit, retryDelay, grpc_invocation, "grpc_invocation()",
                  {}, metadata_attacher);

    ASSERT_TRUE(r.issueRequest());
    ASSERT_EQ(attacher_calls, 1);
}
