#include "imgload.h"
#include "sys_log.h"
#include <Windows.h>
#include <wincodec.h>
#include <Shlwapi.h>
#pragma comment(lib, "windowscodecs.lib")
#pragma comment(lib, "ole32.lib")
#pragma comment(lib, "shlwapi.lib")

static IWICImagingFactory* factory = NULL;
static int coinit = 0;

int yrImgLoad_init(void)
{
	HRESULT hr;
	factory = NULL;
	hr = CoInitializeEx(NULL, 0);
	if(FAILED(hr)) {yrLog(0, "CoInitializeEx failed, hresult: %x", hr); return -1;}
	coinit = 1;
	hr = CoCreateInstance(&CLSID_WICImagingFactory1, NULL, CLSCTX_INPROC_SERVER, &IID_IWICImagingFactory, &factory);
	if(FAILED(hr)) {yrLog(0, "Could not create WICImagingFactory, hresult: %x", hr); return -1;}
	return 0;
}

void yrImgLoad_shutdown(void)
{
	if(factory) factory->lpVtbl->Release(factory);
	factory = NULL;
	if(coinit) CoUninitialize();
	coinit = 0;
}

static const char* errmsgs[] = {
	"An error occured while decoding the image.",
	"An error occured while encoding the image.",
	"Not enough memory to decode the image.",
	"Not enough memory to encode the image.",
};

/****************************************
* IStream compatible memory reader/writer
*****************************************/
struct yrISM
{
	struct IStreamVtbl *lpVtbl;
	volatile LONG _refcount;
	unsigned char* data;
	size_t ptr;
	size_t len;
	size_t cap;
};

ULONG STDMETHODCALLTYPE yrISM_AddRef(__RPC__in IStream * This)
{
	struct yrISM* ism =(struct yrISM*) This;
	return (ULONG) InterlockedIncrement(&ism->_refcount);
}

ULONG STDMETHODCALLTYPE yrISM_Release( __RPC__in IStream * This)
{
	struct yrISM* ism =(struct yrISM*) This;
	ULONG ref = (ULONG) InterlockedDecrement(&ism->_refcount);
	if(ref == 0) free(ism);
	return ref;
}

HRESULT	STDMETHODCALLTYPE yrISM_QueryInterface( __RPC__in IStream * This, __RPC__in REFIID riid, _COM_Outptr_  void **ppvObject)
{
	if(riid == &IID_IUnknown
	   || riid == &IID_IStream
	   || riid == &IID_ISequentialStream)
	{
		*ppvObject = This;
		yrISM_AddRef(This);
		return S_OK;
	} else
		return E_NOINTERFACE;
}
HRESULT	STDMETHODCALLTYPE yrISM_Read( IStream * This, _Out_writes_bytes_to_(cb, *pcbRead)  void *pv,_In_  ULONG cb,_Out_opt_  ULONG *pcbRead)
{
	struct yrISM* ism =(struct yrISM*) This;
	size_t readable = ism->len - ism->ptr;
	size_t to_read = (cb <= readable) ? cb : readable;
	memcpy(pv, ism->data + ism->ptr, to_read);
	ism->ptr += to_read;
	*pcbRead = (ULONG) to_read;
	return (cb <= readable) ? S_OK : S_FALSE;
}
HRESULT	STDMETHODCALLTYPE yrISM_Write( IStream * This,_In_reads_bytes_(cb)  const void *pv,_In_  ULONG cb,_Out_opt_  ULONG *pcbWritten)
{
	struct yrISM* ism =(struct yrISM*) This;
	//check capacity and try to increase if needed
	size_t writeable = ism->cap - ism->ptr;
	while(cb > writeable) {
		size_t newcap = ism->cap + (ism->cap >> 1);
		unsigned char* newdata = realloc(ism->data, newcap);
		if(newdata) {
			ism->data = newdata;
			ism->cap = newcap;
		} else break;
		writeable = ism->cap - ism->ptr;
	}

	//do write
	size_t to_write = (cb <= writeable) ? cb : writeable;
	memcpy(ism->data + ism->ptr, pv, to_write);
	ism->ptr += to_write;
	if(ism->ptr > ism->len) ism->len = ism->ptr;
	*pcbWritten = (ULONG) to_write;
	return (cb <= writeable) ? S_OK : STG_E_MEDIUMFULL;
}
HRESULT	STDMETHODCALLTYPE yrISM_Seek( IStream * This, LARGE_INTEGER dlibMove, DWORD dwOrigin, _Out_opt_  ULARGE_INTEGER *plibNewPosition)
{
	struct yrISM* ism =(struct yrISM*) This;
	if(plibNewPosition) plibNewPosition->QuadPart = (ULONGLONG) ism->ptr;
	size_t base;
	switch(dwOrigin) {
		case STREAM_SEEK_SET: base = 0; break;
		case STREAM_SEEK_END: base = ism->len; break;
		case STREAM_SEEK_CUR: base = ism->ptr; break;
		default: return STG_E_INVALIDFUNCTION;
	}
	if(dlibMove.QuadPart < 0)
		if((size_t)-dlibMove.QuadPart > base)
			return STG_E_INVALIDFUNCTION;
	size_t newptr = base + dlibMove.QuadPart;
	if(newptr > ism->len) { yrLog(0, "Seek beyond end");  return STG_E_INVALIDFUNCTION;} //TODO: actually the stream should just extend here
	ism->ptr = newptr;
	if(plibNewPosition) plibNewPosition->QuadPart = (ULONGLONG) ism->ptr;
	return S_OK;
}
HRESULT	STDMETHODCALLTYPE yrISM_SetSize( __RPC__in IStream * This, ULARGE_INTEGER libNewSize) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_CopyTo( IStream * This,_In_  IStream *pstm,ULARGE_INTEGER cb,_Out_opt_  ULARGE_INTEGER *pcbRead,_Out_opt_  ULARGE_INTEGER *pcbWritten) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_Commit( __RPC__in IStream * This,DWORD grfCommitFlags) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_Revert( __RPC__in IStream * This) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_LockRegion( __RPC__in IStream * This,ULARGE_INTEGER libOffset,ULARGE_INTEGER cb,DWORD dwLockType) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_UnlockRegion( __RPC__in IStream * This,ULARGE_INTEGER libOffset,ULARGE_INTEGER cb,DWORD dwLockType) { yrLog(0, "Don't use this"); return E_NOTIMPL; }
HRESULT	STDMETHODCALLTYPE yrISM_Stat( __RPC__in IStream * This,__RPC__out STATSTG *pstatstg,DWORD grfStatFlag)
{
	struct yrISM* ism =(struct yrISM*) This;
	memset(pstatstg, 0, sizeof(STATSTG));
	pstatstg->type = STGTY_STREAM;
	pstatstg->cbSize.QuadPart = (ULONGLONG) ism->len;
	return S_OK;
}
HRESULT	STDMETHODCALLTYPE yrISM_Clone( __RPC__in IStream * This,__RPC__deref_out_opt IStream **ppstm) { yrLog(0, "Don't use this"); return E_NOTIMPL; }

static struct IStreamVtbl yrISM_vtable = {
	yrISM_QueryInterface,
	yrISM_AddRef,
	yrISM_Release,
	yrISM_Read,
	yrISM_Write,
	yrISM_Seek,
	yrISM_SetSize,
	yrISM_CopyTo,
	yrISM_Commit,
	yrISM_Revert,
	yrISM_LockRegion,
	yrISM_UnlockRegion,
	yrISM_Stat,
	yrISM_Clone,
};

/************************
* Image encode and decode
*************************/
wchar_t* widen(const char* str); //in platform_win64
const char* yrImgLoad_get_dims(const char* file, unsigned* w, unsigned* h)
{
	HRESULT hr;
	IWICBitmapDecoder* decoder = NULL;
	IWICBitmapFrameDecode* frame = NULL;
	wchar_t* wname = widen(file);
	if(!wname) return errmsgs[2];
	hr = factory->lpVtbl->CreateDecoderFromFilename(factory, wname, NULL, GENERIC_READ, WICDecodeMetadataCacheOnDemand, &decoder);
	free(wname);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapDecoder, hresult: %x", hr); goto onerror;}
	hr = decoder->lpVtbl->GetFrame(decoder, 0, &frame);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapFrameDecode, hresult: %x", hr); goto onerror;}
	hr = frame->lpVtbl->GetSize(frame, w, h);
	if(FAILED(hr)) {yrLog(0, "GetSize failed, hresult: %x", hr); goto onerror;}

	frame->lpVtbl->Release(frame);
	decoder->lpVtbl->Release(decoder);
	return NULL;
onerror:
	if(frame) frame->lpVtbl->Release(frame);
	if(decoder) decoder->lpVtbl->Release(decoder);
	return errmsgs[0];
}

const char*	yrImgLoad_decode(size_t len , void* data, unsigned* w, unsigned* h, void** img)
{
	int err = 0;
	HRESULT hr;
	IWICBitmapDecoder* decoder = NULL;
	IWICBitmapFrameDecode* frame = NULL;
	IWICBitmapSource* source = NULL;
	IWICFormatConverter* converter = NULL;

	//init memory stream interface
	struct yrISM* ism = NULL;
	ism = malloc(sizeof(struct yrISM)); 
	IStream* stream = (IStream*) ism;
	if(!ism) {yrLog(0, "Out of memory"); err = 2; goto onerror;}
	ism->lpVtbl = &yrISM_vtable;
	ism->_refcount = 1;
	ism->data = data;
	ism->ptr = 0;
	ism->len = len;
	ism->cap = ism->len;
	//init com components
	hr = factory->lpVtbl->CreateDecoderFromStream(factory, stream, NULL, WICDecodeMetadataCacheOnDemand, &decoder);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapDecoder, hresult: %x", hr); goto onerror;}
	hr = decoder->lpVtbl->GetFrame(decoder, 0, &frame);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapFrameDecode, hresult: %x", hr); goto onerror;}
	hr = frame->lpVtbl->QueryInterface(frame, &IID_IWICBitmapSource, &source);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapSource, hresult: %x", hr); goto onerror;}
	hr = factory->lpVtbl->CreateFormatConverter(factory, &converter);
	if(FAILED(hr)) {yrLog(0, "Could not create WICFormatConverter, hresult: %x", hr); goto onerror;}
	converter->lpVtbl->Initialize(converter, source, &GUID_WICPixelFormat32bppRGBA, WICBitmapDitherTypeErrorDiffusion, NULL, 0.0f, WICBitmapPaletteTypeCustom);
	if(FAILED(hr)) {yrLog(0, "Could not initialize format converter, hresult: %x", hr); goto onerror;}

	//get image data
	hr = converter->lpVtbl->GetSize(converter, w, h);
	if(FAILED(hr)) {yrLog(0, "GetSize failed, hresult: %x", hr); goto onerror;}
	*img = malloc((*w) * (*h) * 4);
	if(!*img) {yrLog(0, "Out of memory"); err = 2; goto onerror;}
	hr = converter->lpVtbl->CopyPixels(converter, NULL, (*w) * 4, (*w) * (*h) * 4, *img);
	if(FAILED(hr)) {yrLog(0, "CopyPixels failed, hresult: %x", hr); goto onerror;}

	//done
	converter->lpVtbl->Release(converter);
	source->lpVtbl->Release(source);
	frame->lpVtbl->Release(frame);
	decoder->lpVtbl->Release(decoder);
	stream->lpVtbl->Release(stream);
	return NULL;
onerror:
	if(converter) converter->lpVtbl->Release(converter);
	if(source) source->lpVtbl->Release(source);
	if(frame) frame->lpVtbl->Release(frame);
	if(decoder) decoder->lpVtbl->Release(decoder);
	if(stream) stream->lpVtbl->Release(stream);
	return errmsgs[err];
}

const char*	yrImgLoad_encode(size_t* len , void** data, unsigned w, unsigned h, void* img)
{
	int err = 1;
	HRESULT hr;
	IWICBitmapEncoder* encoder = NULL;
	IWICBitmapFrameEncode* frame = NULL;
	IWICBitmap* bitmap = NULL;
	IWICBitmapSource* source =  NULL;
	IWICFormatConverter* converter = NULL;
	IWICBitmapSource* convsource =  NULL;

	//init memory stream interface
	struct yrISM* ism = NULL;
	ism = malloc(sizeof(struct yrISM)); 
	IStream* stream = (IStream*) ism;
	if(!ism) {yrLog(0, "Out of memory"); err = 3; goto onerror;}
	ism->lpVtbl = &yrISM_vtable;
	ism->_refcount = 1;
	ism->data = malloc(1024);
	ism->ptr = 0;
	ism->len = 0;
	ism->cap = 1024;
	if(!ism->data) {yrLog(0, "Out of memory"); err = 3; goto onerror;}
	//init com components
	hr = factory->lpVtbl->CreateEncoder(factory, &GUID_ContainerFormatPng, NULL, &encoder);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapEncoder, hresult: %x", hr); goto onerror;}
	encoder->lpVtbl->Initialize(encoder, stream, WICBitmapEncoderNoCache);
	if(FAILED(hr)) {yrLog(0, "Could not initialize encoder, hresult: %x", hr); goto onerror;}
	hr = encoder->lpVtbl->CreateNewFrame(encoder, &frame, NULL);
	if(FAILED(hr)) {yrLog(0, "Could not create WICBitmapFrameEncode, hresult: %x", hr); goto onerror;}
	hr = frame->lpVtbl->Initialize(frame, NULL);
	if(FAILED(hr)) {yrLog(0, "Could not initialize frame encode, hresult: %x", hr); goto onerror;}
	hr = factory->lpVtbl->CreateBitmapFromMemory(factory, w, h, &GUID_WICPixelFormat32bppRGBA, w*4, w*h*4, img, &bitmap);
	if(FAILED(hr)) {yrLog(0, "Could not create bitmap, hresult: %x", hr); goto onerror;}
	hr = bitmap->lpVtbl->QueryInterface(bitmap, &IID_IWICBitmapSource, &source);
	if(FAILED(hr)) {yrLog(0, "Could not create bitmap source, hresult: %x", hr); goto onerror;}

	//set size and format, handle conversion
	hr = frame->lpVtbl->SetSize(frame, w, h);
	if(FAILED(hr)) {yrLog(0, "SetSize failed, hresult: %x", hr); goto onerror;}
	WICPixelFormatGUID fmt = GUID_WICPixelFormat32bppRGBA;
	hr = frame->lpVtbl->SetPixelFormat(frame, &fmt);
	if(FAILED(hr)) {yrLog(0, "SetPixelFormat failed, hresult: %x", hr); goto onerror;}
	hr = factory->lpVtbl->CreateFormatConverter(factory, &converter);
	if(FAILED(hr)) {yrLog(0, "Could not create WICFormatConverter, hresult: %x", hr); goto onerror;}
	converter->lpVtbl->Initialize(converter, source, &fmt, WICBitmapDitherTypeErrorDiffusion, NULL, 0.0f, WICBitmapPaletteTypeCustom);
	if(FAILED(hr)) {yrLog(0, "Could not initialize format converter, hresult: %x", hr); goto onerror;}
	hr = converter->lpVtbl->QueryInterface(converter, &IID_IWICBitmapSource, &convsource);
	if(FAILED(hr)) {yrLog(0, "Could not create converter bitmap source, hresult: %x", hr); goto onerror;}

	//write data
	hr = frame->lpVtbl->WriteSource(frame, convsource, NULL);
	if(FAILED(hr)) {yrLog(0, "WriteSource failed, hresult: %x", hr); goto onerror;}
	hr = frame->lpVtbl->Commit(frame);
	if(FAILED(hr)) {yrLog(0, "Frame Commit failed, hresult: %x", hr); goto onerror;}
	hr = encoder->lpVtbl->Commit(encoder);
	if(FAILED(hr)) {yrLog(0, "Encoder Commit failed, hresult: %x", hr); goto onerror;}

	//done
	convsource->lpVtbl->Release(convsource);
	converter->lpVtbl->Release(converter);
	source->lpVtbl->Release(source);
	bitmap->lpVtbl->Release(bitmap);
	frame->lpVtbl->Release(frame);
	encoder->lpVtbl->Release(encoder);
	*len = ism->len;
	*data = ism->data;
	stream->lpVtbl->Release(stream);
	return NULL;
onerror:
	if(convsource) convsource->lpVtbl->Release(convsource);
	if(converter) converter->lpVtbl->Release(converter);
	if(source) source->lpVtbl->Release(source);
	if(bitmap) bitmap->lpVtbl->Release(bitmap);
	if(frame) frame->lpVtbl->Release(frame);
	if(encoder) encoder->lpVtbl->Release(encoder);
	if(stream) {
		free(ism->data);
		stream->lpVtbl->Release(stream);
	}
	return errmsgs[err];
	//mentioned for easy searching:	WINCODEC_ERR_CODECTOOMANYSCANLINES
}