/**
 * Cleversafe open-source code header - Version 1.1 - December 1, 2006
 *
 * Cleversafe Dispersed Storage(TM) is software for secure, private and
 * reliable storage of the world's data using information dispersal.
 *
 * Copyright (C) 2005-2007 Cleversafe, Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
 * USA.
 *
 * Contact Information: Cleversafe, 10 W. 35th Street, 16th Floor #84,
 * Chicago IL 60616
 * email licensing@cleversafe.org
 *
 * Author: Greg Dhuse <gdhuse@cleversafe.com>
 *
 */
#include "dsd.h"
#include "dsdnet.h"
#include "ksocket.h"

static FAST_MUTEX tagMutex;
static int DsdGetRequestTag();
static WDFREQUEST DsdFindTaggedRequest( IN WDFQUEUE queue, IN int tag );
static void DsdDeviceThread( IN PVOID data );


/**
 * Called when the driver is loaded
 */
NTSTATUS
DriverEntry( IN PDRIVER_OBJECT driverObject, 
             IN PUNICODE_STRING registryPath )
{
   NTSTATUS status;
   WDF_DRIVER_CONFIG config;
   WDF_OBJECT_ATTRIBUTES attributes;
   
   KdPrint(( DSD_TAG "--> DriverEntry\n" ));
   
   // Initialize static variables
   deviceNum = 0;
   ExInitializeFastMutex( &tagMutex );

   // Set driver attributes:
   WDF_OBJECT_ATTRIBUTES_INIT( &attributes );

   WDF_DRIVER_CONFIG_INIT( &config, EvtDeviceAdd );
   status = WdfDriverCreate( driverObject,
                             registryPath,
                             &attributes,
                             &config,
                             WDF_NO_HANDLE );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "WdfDriverCreate failed with status 0x%08x\n",
               status ));
   }
   
   KdPrint(( DSD_TAG "<-- DriverEntry\n" ));
   return status;
}

/**
 * Called when a new device stack that involves this driver is
 * being constructed
 */
NTSTATUS
EvtDeviceAdd( IN WDFDRIVER driver, 
              IN PWDFDEVICE_INIT deviceInit )
{
   PDSD_PDO pdo;
   NTSTATUS status;
   WDFDEVICE device;
   PDEVICE_OBJECT pdoDevice;
   PDSD_DEV dsd = NULL;
   WDF_OBJECT_ATTRIBUTES attributes, wlAttributes;
   WDF_IO_QUEUE_CONFIG ioQConfig, pendingQConfig;

   DECLARE_UNICODE_STRING_SIZE( deviceNameUS, DSD_DEVICE_NAME_LENGTH );
   DECLARE_CONST_UNICODE_STRING( linkName, L"\\DosDevices\\G:" );

   UNREFERENCED_PARAMETER( driver );

   KdPrint(( DSD_TAG "--> EvtDeviceAdd\n" ));

   // Name device
   deviceNum++;
   status = RtlUnicodeStringPrintf( &deviceNameUS, L"%ws%d", 
      DSD_DEVICE_NAME, deviceNum );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "Printf failed with status 0x%08x\n", status ));
      return status;
   }
   
   status = WdfDeviceInitAssignName( deviceInit, &deviceNameUS );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "Unable to assign tag\n" ));
      return status;
   }
   WdfDeviceInitSetIoType( deviceInit, WdfDeviceIoDirect );
   WdfDeviceInitSetDeviceType( deviceInit, FILE_DEVICE_DISK );
   WdfDeviceInitSetExclusive( deviceInit, FALSE );

   /** 
    * Get PDO information
    * FIXME: Is accessing the extension of our PDO completely kosher?
    */
   pdoDevice = WdfFdoInitWdmGetPhysicalDevice( deviceInit );
   if( !pdoDevice )
   {
      // FIXME
      return STATUS_INVALID_PARAMETER;
   }
   pdo = (PDSD_PDO)pdoDevice->DeviceExtension;
   if( !pdo )
   {
      // FIXME
      return STATUS_INVALID_PARAMETER;
   }

   // Initialize context storage for the DSD_DEV object
   WDF_OBJECT_ATTRIBUTES_INIT_CONTEXT_TYPE( &attributes, DSD_DEV );
   attributes.EvtCleanupCallback = EvtDeviceCleanup;

   // Create device
   status = WdfDeviceCreate( &deviceInit, &attributes, &device );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "WdfDeviceCreate failed with status 0x%08x\n",
                status ));
      return status;
   }

   // Set device parameters
   dsd = GetDsd( device );
   dsd->numBlocks = pdo->numBlocks;
   dsd->blockSize = pdo->blockSize;
   dsd->socket    = pdo->socket;
   dsd->thread    = NULL;
   InterlockedExchange( &dsd->pendingWork, 0 );
   InterlockedExchange( &dsd->error, (LONG)STATUS_SUCCESS );

   // Initialize socket mutex
   WDF_OBJECT_ATTRIBUTES_INIT( &wlAttributes );
   wlAttributes.ParentObject = device;
   status = WdfWaitLockCreate( &wlAttributes, &dsd->socketLock );
   
   // Save device name
   status = WdfStringCreate( &deviceNameUS, WDF_NO_OBJECT_ATTRIBUTES,
      &dsd->deviceName );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "WdfStringCreate failed with status 0x%08x\n",
         status ));
      return status;
   }

   // Initialize IO queue
   WDF_IO_QUEUE_CONFIG_INIT_DEFAULT_QUEUE( &ioQConfig, 
      WdfIoQueueDispatchParallel );
   ioQConfig.EvtIoDeviceControl  = EvtDeviceIoDeviceControl;
   ioQConfig.EvtIoRead           = EvtDeviceIoRead;
   ioQConfig.EvtIoWrite          = EvtDeviceIoWrite;

   status = WdfIoQueueCreate( device, &ioQConfig, 
                              WDF_NO_OBJECT_ATTRIBUTES, 
                              &dsd->defaultQueue );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG "WdfIoQueueCreate failed with status 0x%08x\n",
                status ));
      return status;
   }

   // Initialize manual pending request queue
   WDF_IO_QUEUE_CONFIG_INIT( &pendingQConfig, WdfIoQueueDispatchManual );
   pendingQConfig.PowerManaged = WdfFalse;
   status = WdfIoQueueCreate( device, &pendingQConfig, 
                              WDF_NO_OBJECT_ATTRIBUTES, 
                              &dsd->pendingQueue );
   if( !NT_SUCCESS(status) )
   {
      // FIXME
      return status;
   }

   // Initialize interface
   status = WdfDeviceCreateDeviceInterface( device, 
      &GUID_DEVINTERFACE_DISK, NULL );
   status = WdfDeviceCreateDeviceInterface( device, 
      &MOUNTDEV_MOUNTED_DEVICE_GUID, NULL );
   if( !NT_SUCCESS(status) )
   {
      KdPrint(( DSD_TAG 
                "WdfDeviceCreateInterface failed with status 0x%08x\n",
                status ));
      return status;
   }

   // FIXME: This hard-coded link (drive letter) should be
   // dynamically assigned by the mount manager
   IoCreateSymbolicLink( (PUNICODE_STRING)&linkName, &deviceNameUS );

   // Create device thread
   status = PsCreateSystemThread( &dsd->thread,
                                  (ACCESS_MASK) 0L,
                                  NULL,
                                  NULL,
                                  NULL,
                                  DsdDeviceThread,
                                  (PVOID)device );
   if( !NT_SUCCESS(status) )
   {
      return status;
   }

   KdPrint(( DSD_TAG "Created device with name: %wZ\n", &deviceNameUS ));
   KdPrint(( DSD_TAG "<-- EvtDeviceAdd\n" ));
   return status;
}

/**
 * Called by the framework at PASSIVE_LEVEL before the device is destroyed
 */
VOID
EvtDeviceCleanup( IN WDFOBJECT deviceObject )
{
   PVOID threadObj;
   NTSTATUS status;
   PDSD_DEV dsd;
   WDFDEVICE device;

   KdPrint(( DSD_TAG "Destroying device - start\n" ));
   VERIFY_IS_IRQL_PASSIVE_LEVEL();
   
   device = (WDFDEVICE)deviceObject;
   dsd = GetDsd( device );

   // Set error state
   InterlockedExchange( &dsd->error, (LONG)STATUS_INVALID_DEVICE_STATE );

   /**
    * Wait for all in-flight requests 
    * FIXME: Is it necessary to use an interlocked operation to read this value?
    */
   while( InterlockedOr( &dsd->pendingWork, 0x0 ) > 0 )
   {
      // Sleep for 10ms
      LARGE_INTEGER sleepTime;
      sleepTime.QuadPart = -100000;
      KeDelayExecutionThread( KernelMode, TRUE, &sleepTime );
   }

   // Clean up network socket(s)
   DsdNetClose( device );

   // Wait for device thread to exit
   // Once this thread has exited, all pending requests have been purged
   if( dsd->thread )
   {
      status = ObReferenceObjectByHandle( dsd->thread, 
                                          (ACCESS_MASK)0, 
                                          NULL, 
                                          KernelMode, 
                                          &threadObj, 
                                          NULL );
      if( NT_SUCCESS(status) )
      {
         KeWaitForSingleObject( threadObj, Executive, KernelMode, FALSE, NULL );
      }
   }

   KdPrint(( DSD_TAG "Destroying device - done\n" ));
}

/**
 * Process a read request to the device
 */
VOID
EvtDeviceIoRead( IN WDFQUEUE queue, 
                 IN WDFREQUEST request, 
                 IN size_t length )
{
   PMDL mdl;
   PVOID buffer;
   NTSTATUS status;
   PDSD_DEV dsd;
   WDFDEVICE device;
   WDF_REQUEST_PARAMETERS params;

   WDF_REQUEST_PARAMETERS_INIT( &params );
   WdfRequestGetParameters( request, &params );
   device = WdfIoQueueGetDevice( queue );
   dsd = GetDsd( device );

   KdPrint(( DSD_TAG "--> EvtDeviceIoRead(fb=%I64d,sz=%u,irql=%d)\n", 
      params.Parameters.Read.DeviceOffset, length, KeGetCurrentIrql() ));

   // Device is in an error state
   status = (NTSTATUS)InterlockedOr( &dsd->error, 0x0 );
   if( !NT_SUCCESS(status) )
   {
      goto error;
   }

   // Send read request to daemon
   status = WdfRequestRetrieveOutputWdmMdl( request, &mdl );
   if( NT_SUCCESS(status) ) 
   {
      buffer = MmGetSystemAddressForMdlSafe( mdl, NormalPagePriority );
      if( buffer )
      {
         struct dsd_msg_read_sectors* msg;
         PDSD_REQUEST_CONTEXT requestContext;
         WDF_OBJECT_ATTRIBUTES requestAttributes;
         int msg_size = sizeof( struct dsd_msg_read_sectors );

         // Tag this request
         WDF_OBJECT_ATTRIBUTES_INIT_CONTEXT_TYPE( &requestAttributes, 
                                                  DSD_REQUEST_CONTEXT );
         status = WdfObjectAllocateContext( request, 
                                            &requestAttributes, 
                                            &requestContext );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }
         requestContext->tag = DsdGetRequestTag();

         // Move the request to the pending queue
         status = WdfRequestForwardToIoQueue( request, dsd->pendingQueue );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }

         // Send request to the grid
         msg = ExAllocatePoolWithTag( NonPagedPool, msg_size, '1erG' );
         if( !msg )
         {
            status = STATUS_NO_MEMORY;
            goto error;
         }
         RtlZeroMemory( msg, msg_size );

         msg->request_tag  = htonl( requestContext->tag );
         msg->first_sector 
            = htonll( params.Parameters.Read.DeviceOffset / dsd->blockSize );
         msg->num_sectors  = (uint8_t)(length / dsd->blockSize);

         status = DsdNetSendAtomic( device, 
                                    DSD_MSG_READ_SECTORS, 
                                    (uint8_t*)msg, 
                                    msg_size );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }
      }
   }
   else
   {
      KdPrint(( DSD_TAG "MDL error: %d\n", status ));
      goto error;
   }

   KdPrint(( DSD_TAG "<-- EvtDeviceIoRead\n" ));
   return;

error:
   KdPrint(( DSD_TAG "<-- EvtDeviceIoRead(FAILURE:0x%08x)\n", status ));
   WdfRequestComplete( request, status );
   return;
}

/**
 * Process a write request to the device
 */
VOID
EvtDeviceIoWrite( IN WDFQUEUE queue, 
                  IN WDFREQUEST request, 
                  IN size_t length )
{
   PMDL mdl;
   PVOID buffer;
   NTSTATUS status;
   PDSD_DEV dsd;
   WDFDEVICE device;
   WDF_REQUEST_PARAMETERS params;

   WDF_REQUEST_PARAMETERS_INIT( &params );
   WdfRequestGetParameters( request, &params );
   device = WdfIoQueueGetDevice( queue );
   dsd = GetDsd( device );

   KdPrint(( DSD_TAG "--> EvtDeviceIoWrite(fb=%I64d,sz=%u)\n", 
      params.Parameters.Write.DeviceOffset, length ));

   // Device is in an error state
   status = (NTSTATUS)InterlockedOr( &dsd->error, 0x0 );
   if( !NT_SUCCESS(status) )
   {
      goto error;
   }

   // Send write request to the daemon
   status = WdfRequestRetrieveInputWdmMdl( request, &mdl );
   if( NT_SUCCESS(status) ) 
   {
      buffer = MmGetSystemAddressForMdlSafe( mdl, NormalPagePriority );
      if( buffer )
      {
         struct dsd_msg_write_sectors* msg;
         PDSD_REQUEST_CONTEXT requestContext;
         WDF_OBJECT_ATTRIBUTES requestAttributes;
         int msg_size = sizeof( struct dsd_msg_write_sectors ) + length;
         
         // Tag this request
         WDF_OBJECT_ATTRIBUTES_INIT_CONTEXT_TYPE( &requestAttributes, 
                                                  DSD_REQUEST_CONTEXT );
         status = WdfObjectAllocateContext( request, 
                                            &requestAttributes, 
                                            &requestContext );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }
         requestContext->tag = DsdGetRequestTag();

         // Move the request to the pending queue
         status = WdfRequestForwardToIoQueue( request, dsd->pendingQueue );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }

         // Send request to the grid
         msg = ExAllocatePoolWithTag( NonPagedPool, msg_size, '0erG' );
         if( !msg )
         {
            status = STATUS_NO_MEMORY;
            goto error;
         }
         RtlZeroMemory( msg, msg_size );

         msg->request_tag  = htonl( requestContext->tag );
         msg->first_sector 
            = htonll( params.Parameters.Read.DeviceOffset / dsd->blockSize );
         msg->num_sectors  = (uint8_t)(length / dsd->blockSize);
         RtlCopyMemory( msg->data, buffer, length );

         status = DsdNetSendAtomic( device, 
                                    DSD_MSG_WRITE_SECTORS, 
                                    (uint8_t*)msg, 
                                    msg_size );
         if( !NT_SUCCESS(status) )
         {
            goto error;
         }
      }
   }
   else
   {
      KdPrint(( DSD_TAG "MDL error: %d\n", status ));
   }

   KdPrint(( DSD_TAG "<-- EvtDeviceIoWrite\n" ));
   return;

error:
   KdPrint(( DSD_TAG "<-- EvtDeviceIoWrite(FAILURE:0x%08x)\n", status ));
   WdfRequestComplete( request, status );
   return;
}

/**
 * Thread function - runs while device is alive processing
 * incoming network events
 *
 * @param data WDFDEVICE handle
 */
static void 
DsdDeviceThread( IN PVOID data )
{
   NTSTATUS status;
   WDFDEVICE device;
   PDSD_DEV dsd;
   WDFREQUEST request;

   device = data;
   dsd = GetDsd( device );

   for(;;)
   {
      uint8_t msgType;
      uint8_t* msg;

      // Blocking recv
      status = DsdNetRecv( device, &msg, &msgType );
      if( !NT_SUCCESS(status) )
      {
         KdPrint(( DSD_TAG "Device communication error: 0x%08x\n", status ));
         goto error;
      }

      switch( msgType )
      {
         // Request for device removal
         case DSD_MSG_REMOVE_DEVICE:  
         {
            KdPrint(( DSD_TAG "Got: DSD_MSG_REMOVE_DEVICE\n" ));

            // FIXME: Handle this event
         }
         break;

         // Write response
         case DSD_MSG_WRITE_SECTORS_RSP:
         {
            WDF_REQUEST_PARAMETERS params;
            struct dsd_msg_write_sectors_rsp* response = 
               (struct dsd_msg_write_sectors_rsp*)msg;

            response->request_tag = ntohl( response->request_tag );

            KdPrint(( DSD_TAG "Got: DSD_MSG_WRITE_SECTORS_RSP\n" )); 

            // Find tagged request
            request = DsdFindTaggedRequest( dsd->pendingQueue, 
                                            response->request_tag );
            if( !request )
            {
               // FIXME
            }
            WDF_REQUEST_PARAMETERS_INIT( &params );
            WdfRequestGetParameters( request, &params );

            // Complete request          
            WdfRequestCompleteWithInformation( 
               request, 
               STATUS_SUCCESS, 
               (ULONG_PTR)params.Parameters.Write.Length );
         }
         break;

         // Read response
         case DSD_MSG_READ_SECTORS_RSP:
         {
            PMDL mdl;
            PVOID buffer;
            WDF_REQUEST_PARAMETERS params;
            struct dsd_msg_read_sectors_rsp* response = 
               (struct dsd_msg_read_sectors_rsp*)msg;

            response->request_tag = ntohl( response->request_tag );
            response->bytes = ntohll( response->bytes );

            KdPrint(( DSD_TAG "Got: DSD_MSG_READ_SECTORS_RSP\n" )); 

            // Find tagged request
            request = DsdFindTaggedRequest( dsd->pendingQueue, 
                                            response->request_tag );
            if( !request )
            {
               // FIXME
            }
            WDF_REQUEST_PARAMETERS_INIT( &params );
            WdfRequestGetParameters( request, &params );

            // Complete request
            status = WdfRequestRetrieveOutputWdmMdl( request, &mdl );
            if( !NT_SUCCESS(status) ) 
            {
               // FIXME
            }

            buffer = MmGetSystemAddressForMdlSafe( mdl, NormalPagePriority );
            if( !buffer )
            {
               // FIXME
            }
            
            RtlCopyMemory( buffer, response->data, (size_t)response->bytes );
            WdfRequestCompleteWithInformation( 
               request, 
               STATUS_SUCCESS, 
               (ULONG_PTR)params.Parameters.Read.Length );
         }
         break;

         default: break;
      }

      ExFreePool( msg );
   }

error:
   KdPrint(( DSD_TAG "Device thread beginning cleanup\n" ));

   // Fail device, current, and future requests
   InterlockedExchange( &dsd->error, (LONG)STATUS_IO_DEVICE_ERROR );
   WdfDeviceSetFailed( device, WdfDeviceFailedNoRestart );

   // FIXME: There may be a race condition here for Read/Write requests 
   //        that are in-flight between the dsd->error test and being
   //        forwarded to the pending queue
   do
   {
      status = WdfIoQueueRetrieveNextRequest( dsd->pendingQueue, &request );
      if( NT_SUCCESS(status) )
      {
         WdfRequestComplete( request, STATUS_IO_DEVICE_ERROR );
      }
   }
   while( NT_SUCCESS(status) );

   KdPrint(( DSD_TAG "Device thread exiting\n" ));
   return;
}

/**
 * Get a unique tag for request/response association
 */
static int
DsdGetRequestTag()
{
   static uint32_t nextTag = 0;
   uint32_t tag;

   ExAcquireFastMutex( &tagMutex );

   tag = nextTag++;
   if( nextTag == 0xffffffff )
   {
      nextTag = 0;
   }

   ExReleaseFastMutex( &tagMutex );

   return (int)tag;
}

/**
 * Find a request in the provided queue with a tag matching 'tag'
 */
static WDFREQUEST 
DsdFindTaggedRequest( IN WDFQUEUE queue, IN int tag )
{
   NTSTATUS status;
   WDFREQUEST request;
   WDFREQUEST current = NULL;
   PDSD_REQUEST_CONTEXT context = NULL;

   while( !context || (context->tag != tag) )
   {
      status = WdfIoQueueFindRequest( queue, current, NULL, NULL, &request );
      if( !NT_SUCCESS(status) )
      {
         return NULL;
      }

      context = GetRequestContext( request );
      if( current )
      {
         WdfObjectDereference( current );
      }
      current = request;
   }

   // Request was found
   status = WdfIoQueueRetrieveFoundRequest( queue, current, &request );
   WdfObjectDereference( current );
   if( !NT_SUCCESS(status) )
   {
      return NULL;
   }

   return request;
}
