package theGhastModding.midiVideoGen.renderer;

import java.awt.Color;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.util.List;

import com.aparapi.Kernel;
import com.aparapi.Range;
import com.aparapi.device.Device;
import com.aparapi.device.OpenCLDevice;

import theGhastModding.midiPlayer.midi.Note;

public class GPURenderer extends NotesRenderer {
	
	private int width;
	private int height;
	private BufferedImage backgroundImage;
	private int[] trackColors;
	private ActualRenderer r;
	private double keyLength;
	private boolean channelColoring;
	
	public GPURenderer(int width, int height, boolean channelColoring, double keyLength, List<Color> colors, BufferedImage backgroundImage) {
		this.width = width;
		this.height = height;
		this.keyLength = keyLength;
		this.channelColoring = channelColoring;
		this.backgroundImage = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
		Graphics gg = this.backgroundImage.getGraphics();
		gg.setColor(Color.BLACK);
		gg.fillRect(0, 0, width, height);
		if(backgroundImage != null) gg.drawImage(backgroundImage, 0, 0, width, height, null);
		gg.dispose();
		this.trackColors = new int[colors.size()];
		for(int i = 0; i < colors.size(); i++) {
			this.trackColors[i] = colors.get(i).getRGB();
		}
		this.r = new ActualRenderer();
		System.out.println("[DEBUG] Printing all available device info for Device.TYPE = GPU:\n");
		for(OpenCLDevice d : OpenCLDevice.listDevices(Device.TYPE.GPU)) {
			System.out.println(d.toString() + "\n");
		}
		System.out.println("\n");
	}
	
	@Override
	public void render(List<Note> notes, Graphics2D g, long tick) {
		r.keyLength = (float)keyLength;
		r.notePitches = new int[notes.size()];
		r.noteColorIndexes = new int[notes.size()];
		r.noteStarts = new long[notes.size()];
		r.noteEnds = new long[notes.size()];
		for(int i = 0; i < notes.size(); i++) {
			Note n = notes.get(i);
			r.notePitches[i] = n.getPitch();
			r.noteColorIndexes[i] = channelColoring ? n.getChannel() : n.getTrack();
			r.noteStarts[i] = n.getStart();
			r.noteEnds[i] = n.getEnd();
		}
		r.tickPos = tick;
		r.frameHeight = height;
		r.frameWidth = width;
		r.image = toIntArray(backgroundImage);
		r.colors = trackColors;
		Range range = Range.create(notes.size());
		r.execute(range);
		BufferedImage img = toImage(r.image);
		g.drawImage(img, 0, 0, width, height, null);
		System.out.println(r.getTargetDevice().toString());
	}
	
	private class ActualRenderer extends Kernel {
		
		float keyLength;
		int[] notePitches;
		int[] noteColorIndexes;
		long[] noteStarts;
		long[] noteEnds;
		long tickPos;
		int frameHeight;
		int frameWidth;
		
		int[] image;
		int[] colors;
		
		@Override
		public void run() {
			int x = getGlobalId(0);
			if(noteEnds[x] < tickPos && noteStarts[x] > tickPos + frameHeight) return;
			int endOffset = (int)(tickPos + (long)frameHeight - noteEnds[x]);
			int offset = (int)(tickPos + (long)frameHeight - noteStarts[x]);
			if(!(offset >= 0 && offset - endOffset >= 0)){
				return;
			}
			if(endOffset < 0){
				endOffset = 0;
			}
			if(offset > frameHeight){
				offset = frameHeight;
			}
			if(endOffset > frameHeight) {
				endOffset = frameHeight;
			}
			if(offset < 1) {
				offset = 1;
			}
			int pitch = notePitches[x];
			int color = colors[noteColorIndexes[x]];
			int r = (color >> 16) & 0xFF;
			int g = (color >> 8) & 0xFF;
			int b = color & 0xFF;
			int r2 = r - 118 > 0 ? r - 118 : 0;
			int g2 = g - 118 > 0 ? g - 118 : 0;
			int b2 = b - 118 > 0 ? b - 118 : 0;
			int widthHere = (int)(keyLength * (float)(pitch + 1) - keyLength * (float)pitch);
			int a = (int)(keyLength * (float)pitch);
			int l = image.length;
			for(int i = 0; i < widthHere - 1; i++) {
				if((endOffset * frameWidth + i + a) * 3 >= l || ((endOffset + offset - endOffset - 1) * frameWidth + i + a) * 3 >= l) {
					break;
				}
				image[(endOffset * frameWidth + i + a) * 3] = r2;
				image[(endOffset * frameWidth + i + a) * 3 + 1] = g2;
				image[(endOffset * frameWidth + i + a) * 3 + 2] = b2;
				image[((endOffset + offset - endOffset - 1) * frameWidth + i + a) * 3] = r2;
				image[((endOffset + offset - endOffset - 1) * frameWidth + i + a) * 3 + 1] = g2;
				image[((endOffset + offset - endOffset - 1) * frameWidth + i + a) * 3 + 2] = b2;
			}
			for(int i = 0; i < offset - endOffset - 1; i++) {
				if(((i + endOffset) * frameWidth + a) * 3 >= l || ((i + endOffset) * frameWidth + a + widthHere - 1) * 3 >= l) {
					break;
				}
				image[((i + endOffset) * frameWidth + a) * 3] = r2;
				image[((i + endOffset) * frameWidth + a) * 3 + 1] = g2;
				image[((i + endOffset) * frameWidth + a) * 3 + 2] = b2;
				image[((i + endOffset) * frameWidth + a + widthHere - 1) * 3] = r2;
				image[((i + endOffset) * frameWidth + a + widthHere - 1) * 3 + 1] = g2;
				image[((i + endOffset) * frameWidth + a + widthHere - 1) * 3 + 2] = b2;
			}
			float gradientStepSize = 90f / (float)widthHere;
			for(int llll = 2; llll < widthHere; llll++) {
				r2 = (int)(r - (90f - ((float)(llll - 1) * gradientStepSize))) > 0 ? (int)(r - (90f - ((float)(llll - 1) * gradientStepSize))) : 0;
				g2 = (int)(g - (90f - ((float)(llll - 1) * gradientStepSize))) > 0 ? (int)(g - (90f - ((float)(llll - 1) * gradientStepSize))) : 0;
				b2 = (int)(b - (90f - ((float)(llll - 1) * gradientStepSize))) > 0 ? (int)(b - (90f - ((float)(llll - 1) * gradientStepSize))) : 0;
				a = (int)(keyLength * (float)pitch + widthHere - (float)llll);
				for(int i = endOffset; i < endOffset + (offset - endOffset - 2); i++) {
					if((i * frameWidth + a) * 3 >= l) {
						break;
					}
					image[(i * frameWidth + a) * 3] = r2;
					image[(i * frameWidth + a) * 3 + 1] = g2;
					image[(i * frameWidth + a) * 3 + 2] = b2;
				}
			}
		}
		
	}
	
	private int[] toIntArray(BufferedImage img) {
		int[] toReturn = new int[width * height * 3];
		img.getData().getPixels(0, 0, width, height, toReturn);
		return toReturn;
		/*for(int i = 0; i < img.getHeight(); i++) {
			for(int j = 0; j < img.getWidth(); j++) {
				Color c = new Color(img.getRGB(j, i));
				toReturn[(i * img.getWidth() + j) * 3] = c.getRed();
				toReturn[(i * img.getWidth() + j) * 3 + 1] = c.getGreen();
				toReturn[(i * img.getWidth() + j) * 3 + 2] = c.getBlue();
			}
		}
		return toReturn;*/
	}
	
	private BufferedImage toImage(int[] data) {
		BufferedImage img = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
		//((WritableRaster)img.getData()).setPixels(0, 0, width, height, data);
		for(int i = 0; i < img.getHeight(); i++) {
			for(int j = 0; j < img.getWidth(); j++) {
				Color c = new Color(data[(i * img.getWidth() + j) * 3], data[(i * img.getWidth() + j) * 3 + 1], data[(i * img.getWidth() + j) * 3 + 2]);
				img.setRGB(j, i, c.getRGB());
			}
		}
		return img;
	}
	
}