diff --git a/payloads/libpayload/drivers/usb/usbmsc.c b/payloads/libpayload/drivers/usb/usbmsc.c index d1e33022c3..ac1284129b 100644 --- a/payloads/libpayload/drivers/usb/usbmsc.c +++ b/payloads/libpayload/drivers/usb/usbmsc.c @@ -70,9 +70,11 @@ static const char *msc_protocol_strings[0x51] = { static void usb_msc_destroy (usbdev_t *dev) { - if (usbdisk_remove) - usbdisk_remove (dev); - free (dev->data); + if (dev->data) { + if (MSC_INST (dev)->usbdisk_created && usbdisk_remove) + usbdisk_remove (dev); + free (dev->data); + } dev->data = 0; } @@ -417,6 +419,9 @@ usb_msc_init (usbdev_t *dev) { int i, timeout; + /* init .data before setting .destroy */ + dev->data = NULL; + dev->destroy = usb_msc_destroy; dev->poll = usb_msc_poll; @@ -451,6 +456,7 @@ usb_msc_init (usbdev_t *dev) MSC_INST (dev)->protocol = interface->bInterfaceSubClass; MSC_INST (dev)->bulk_in = 0; MSC_INST (dev)->bulk_out = 0; + MSC_INST (dev)->usbdisk_created = 0; for (i = 1; i <= dev->num_endp; i++) { if (dev->endpoints[i].endpoint == 0) @@ -465,10 +471,14 @@ usb_msc_init (usbdev_t *dev) MSC_INST (dev)->bulk_out = &dev->endpoints[i]; } - if (MSC_INST (dev)->bulk_in == 0) - fatal ("couldn't find bulk-in endpoint"); - if (MSC_INST (dev)->bulk_out == 0) - fatal ("couldn't find bulk-out endpoint"); + if (MSC_INST (dev)->bulk_in == 0) { + printf("couldn't find bulk-in endpoint"); + return; + } + if (MSC_INST (dev)->bulk_out == 0) { + printf("couldn't find bulk-out endpoint"); + return; + } debug (" using endpoint %x as in, %x as out\n", MSC_INST (dev)->bulk_in->endpoint, MSC_INST (dev)->bulk_out->endpoint); @@ -514,6 +524,8 @@ usb_msc_init (usbdev_t *dev) } debug ("\n"); - if ((read_capacity (dev) == MSC_COMMAND_OK) && usbdisk_create) + if ((read_capacity (dev) == MSC_COMMAND_OK) && usbdisk_create) { usbdisk_create (dev); + MSC_INST (dev)->usbdisk_created = 1; + } } diff --git a/payloads/libpayload/include/usb/usbmsc.h b/payloads/libpayload/include/usb/usbmsc.h index b8a8ec13c9..fafa6f5979 100644 --- a/payloads/libpayload/include/usb/usbmsc.h +++ b/payloads/libpayload/include/usb/usbmsc.h @@ -35,6 +35,7 @@ typedef struct { unsigned int protocol; endpoint_t *bulk_in; endpoint_t *bulk_out; + int usbdisk_created; } usbmsc_inst_t; #define MSC_INST(dev) ((usbmsc_inst_t*)(dev)->data)