#include "./light.hpp"

Light::Light(const Color& color, double power): color(color), power(power){
	
}

Light::~Light(){
	
}

bool Light::isLight() const{
	return true;
}

bool ShapeLight::isLight() const{
	return true;
}

bool RectangleLight::isLight() const{
	return true;
}

Color Light::emitted() const{
	return color* power;
}

Vec3 Light::getPosition() const {
	return Vec3(0.0,0.0,0.0);
}

RectangleLight::RectangleLight(const Point& pos, const Vec3& side1, const Vec3& side2, const Color& color, double power): Light(color, power), origin(pos), side1(side1), side2(side2) {
	
}

RectangleLight::~RectangleLight() {
	
}

Vec3 RectangleLight::getPosition() const {
	return origin;
}

bool RectangleLight::intersect(Intersection& intersection){
	
	Vec3 normal = cross(side1, side2).normalized();
	double nDotD = dot(normal, intersection.ray.direction);
	
	if (nDotD == 0.0f){
		return false;
	}

	double t = dot(origin - intersection.ray.origin, normal) / nDotD;

	if (t >= intersection.t || t < RAY_T_MIN){
		return false;
	}

	Vec3 side1Norm = side1;
	Vec3 side2Norm = side2;
	double side1Length = side1Norm.normalize();
	double side2Length = side2Norm.normalize();

	Point worldPoint = intersection.ray.calculate(t);
	Point relativePoint = worldPoint - origin;
	Point localPoint = Point(dot(relativePoint, side1Norm), dot(relativePoint, side2Norm), 0.0);

	if (localPoint.x < 0.0 || localPoint.x > side1Length || localPoint.y < 0.0 || localPoint.y > side2Length){
		return false;
	}

	intersection.t = t;
	intersection.pShape = this;
	//intersection.pMaterial = &material;
	intersection.normal = normal;

	if (dot(intersection.normal, intersection.ray.direction) > 0.0) {
		intersection.normal *= -1.0;
	}

	return true;
}

bool RectangleLight::doesIntersect(const Ray& ray){
	
	Vec3 normal = cross(side1, side2).normalized();
	double nDotD = dot(normal, ray.direction);
	
	if (nDotD == 0.0){
		return false;
	}

	double t = dot(origin - ray.origin, normal) / nDotD;

	if (t >= ray.tMax || t < RAY_T_MIN){
		return false;
	}

	Vec3 side1Norm = side1;
	Vec3 side2Norm = side2;
	double side1Length = side1Norm.normalize();
	double side2Length = side2Norm.normalize();

	Point worldPoint = ray.calculate(t);
	Point relativePoint = worldPoint - origin;
	Point localPoint = Point(dot(relativePoint, side1Norm), dot(relativePoint, side2Norm), 0.0);

	if (localPoint.x < 0.0 || localPoint.x > side1Length || localPoint.y < 0.0 || localPoint.y > side2Length){
		return false;
	}

	return true;
}

bool RectangleLight::sampleSurface(const Point& surfPosition, const Vec3& surfNormal, double u1, double u2, double u3, Point& outPosition, Vec3& outNormal, double& outPdf) {
	(void)u3;
	(void)surfNormal;
	outPosition = origin + side1 * u1 + side2 * u2;
	Vec3 outgoing = surfPosition - outPosition;
	double dist = outgoing.normalize();
	outNormal = cross(side1, side2);
	double area = outNormal.normalize();

	if (dot(outNormal, outgoing) < 0.0) {
		outNormal *= -1.0;
	}
	
	outPdf = squared(dist) / (area * std::fabs(dot(outNormal, outgoing)));

	if (outPdf > RAY_T_MAX) {
		
		outPdf = 0.0;
		return false;
	}

	return true;
}

double RectangleLight::intersectPdf(const Intersection& isect) {
	
	if (isect.pShape == this) {
		double pdf = squared(isect.t) / (std::fabs(dot(isect.normal, -isect.ray.direction)) * cross(side1, side2).length());

		if (pdf > RAY_T_MAX){
			return 0.0;
		}
		
		return pdf;
	}

	return 0.0;
}

bool ShapeLight::intersect(Intersection& intersection) {
	
	if (pShape->intersect(intersection)) {
		
		//intersection.pMaterial = &material;
		intersection.pShape = this;
		return true;
	}

	return false;
}

bool ShapeLight::doesIntersect(const Ray& ray) {
	return pShape->doesIntersect(ray);
}

bool ShapeLight::sampleSurface(const Point& surfPosition, const Point& surfNormal, double u1, double u2, double u3, Point& outPosition, Vec3& outNormal, double& outPdf) {
	
	if (!pShape->sampleSurface(surfPosition, surfNormal, u1, u2, u3, outPosition, outNormal, outPdf)) {
		outPdf = 0.0;
		return false;
	}

	//In back of the light, no points will be sampled.
	if (dot(outNormal, surfPosition - outPosition) < 0.0){
		return false;
	}

	return true;
}

double ShapeLight::intersectPdf(const Intersection& isect) {
	
	if (isect.pShape == this) {
		
		return pShape->pdfSA(isect.ray.origin, isect.ray.direction, isect.position(), isect.normal);
	}
	
	return 0.0;
	
}

ShapeLight::ShapeLight(Shape *pShape, const Color& color, double power): Light(color, power), pShape(pShape) {
	
}

ShapeLight::~ShapeLight(){
	delete pShape;
}

Vec3 ShapeLight::getPosition() const {
	return pShape->getPosition();
}