summaryrefslogtreecommitdiffstats
path: root/net/base/nss_memio.c
blob: 5f7fd0085afe85e9d9be5d43de57142817b07917 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
// Copyright (c) 2008 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Written in NSPR style to also be suitable for adding to the NSS demo suite

/* memio is a simple NSPR I/O layer that lets you decouple NSS from
 * the real network.  It's rather like openssl's memory bio,
 * and is useful when your app absolutely, positively doesn't
 * want to let NSS do its own networking.
 */

#include <stdlib.h>
#include <string.h>

#include <prerror.h>
#include <prinit.h>
#include <prlog.h>

#include "nss_memio.h"

/*--------------- private memio types -----------------------*/

/*----------------------------------------------------------------------
 Simple private circular buffer class.  Size cannot be changed once allocated.
----------------------------------------------------------------------*/

struct memio_buffer {
    int head;     /* where to take next byte out of buf */
    int tail;     /* where to put next byte into buf */
    int bufsize;  /* number of bytes allocated to buf */
    /* TODO(port): error handling is pessimistic right now.
     * Once an error is set, the socket is considered broken
     * (PR_WOULD_BLOCK_ERROR not included).
     */
    PRErrorCode last_err;
    char *buf;
};


/* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
 * to one of these.
 * In the public header, we use struct memio_Private as a typesafe alias
 * for this.  This causes a few ugly typecasts in the private file, but
 * seems safer.
 */
struct PRFilePrivate {
    /* read requests are satisfied from this buffer */
    struct memio_buffer readbuf;

    /* write requests are satisfied from this buffer */
    struct memio_buffer writebuf;

    /* SSL needs to know socket peer's name */
    PRNetAddr peername;

    /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
    int eof;
};

/*--------------- private memio_buffer functions ---------------------*/

/* Forward declarations.  */

/* Allocate a memio_buffer of given size. */
static void memio_buffer_new(struct memio_buffer *mb, int size);

/* Deallocate a memio_buffer allocated by memio_buffer_new. */
static void memio_buffer_destroy(struct memio_buffer *mb);

/* How many bytes can be read out of the buffer without wrapping */
static int memio_buffer_used_contiguous(const struct memio_buffer *mb);

/* How many bytes can be written into the buffer without wrapping */
static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);

/* Write n bytes into the buffer.  Returns number of bytes written. */
static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);

/* Read n bytes from the buffer.  Returns number of bytes read. */
static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);

/* Allocate a memio_buffer of given size. */
static void memio_buffer_new(struct memio_buffer *mb, int size)
{
    mb->head = 0;
    mb->tail = 0;
    mb->bufsize = size;
    mb->buf = malloc(size);
}

/* Deallocate a memio_buffer allocated by memio_buffer_new. */
static void memio_buffer_destroy(struct memio_buffer *mb)
{
    free(mb->buf);
    mb->buf = NULL;
    mb->head = 0;
    mb->tail = 0;
}

/* How many bytes can be read out of the buffer without wrapping */
static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
{
    return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
}

/* How many bytes can be written into the buffer without wrapping */
static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
{
    if (mb->head > mb->tail) return mb->head - mb->tail - 1;
    return mb->bufsize - mb->tail - (mb->head == 0);
}

/* Write n bytes into the buffer.  Returns number of bytes written. */
static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
{
    int len;
    int transferred = 0;

    /* Handle part before wrap */
    len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    if (len > 0) {
        /* Buffer not full */
        memcpy(&mb->buf[mb->tail], buf, len);
        mb->tail += len;
        if (mb->tail == mb->bufsize)
            mb->tail = 0;
        n -= len;
        buf += len;
        transferred += len;

        /* Handle part after wrap */
        len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
        if (len > 0) {
            /* Output buffer still not full, input buffer still not empty */
            memcpy(&mb->buf[mb->tail], buf, len);
            mb->tail += len;
            if (mb->tail == mb->bufsize)
                mb->tail = 0;
                transferred += len;
        }
    }

    return transferred;
}


/* Read n bytes from the buffer.  Returns number of bytes read. */
static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
{
    int len;
    int transferred = 0;

    /* Handle part before wrap */
    len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    if (len) {
        memcpy(buf, &mb->buf[mb->head], len);
        mb->head += len;
        if (mb->head == mb->bufsize)
            mb->head = 0;
        n -= len;
        buf += len;
        transferred += len;

        /* Handle part after wrap */
        len = PR_MIN(n, memio_buffer_used_contiguous(mb));
        if (len) {
        memcpy(buf, &mb->buf[mb->head], len);
        mb->head += len;
            if (mb->head == mb->bufsize)
                mb->head = 0;
                transferred += len;
        }
    }

    return transferred;
}

/*--------------- private memio functions -----------------------*/

static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
{
    struct PRFilePrivate *secret = fd->secret;
    memio_buffer_destroy(&secret->readbuf);
    memio_buffer_destroy(&secret->writebuf);
    free(secret);
    fd->dtor(fd);
    return PR_SUCCESS;
}

static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
{
    /* TODO: pass shutdown status to app somehow */
    return PR_SUCCESS;
}

/* If there was a network error in the past taking bytes
 * out of the buffer, return it to the next call that
 * tries to read from an empty buffer.
 */
static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
                                  PRIntn flags, PRIntervalTime timeout)
{
    struct PRFilePrivate *secret;
    struct memio_buffer *mb;
    int rv;

    if (flags) {
        PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
        return -1;
    }

    secret = fd->secret;
    mb = &secret->readbuf;
    PR_ASSERT(mb->bufsize);
    rv = memio_buffer_get(mb, buf, len);
    if (rv == 0 && !secret->eof) {
        if (mb->last_err)
            PR_SetError(mb->last_err, 0);
        else
            PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
        return -1;
    }

    return rv;
}

static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
{
    /* pull bytes from buffer */
    return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
}

static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
                                  PRIntn flags, PRIntervalTime timeout)
{
    struct PRFilePrivate *secret;
    struct memio_buffer *mb;
    int rv;

    secret = fd->secret;
    mb = &secret->writebuf;
    PR_ASSERT(mb->bufsize);

    if (mb->last_err) {
        PR_SetError(mb->last_err, 0);
        return -1;
    }
    rv = memio_buffer_put(mb, buf, len);
    if (rv == 0) {
        PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
        return -1;
    }
    return rv;
}

static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
{
    /* append bytes to buffer */
    return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
}

static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
{
    /* TODO: fail if memio_SetPeerName has not been called */
    struct PRFilePrivate *secret = fd->secret;
    *addr = secret->peername;
    return PR_SUCCESS;
}

static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
{
    /*
     * Even in the original version for real tcp sockets,
     * PR_SockOpt_Nonblocking is a special case that does not
     * translate to a getsockopt() call
     */
    if (PR_SockOpt_Nonblocking == data->option) {
        data->value.non_blocking = PR_TRUE;
        return PR_SUCCESS;
    }
    PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
    return PR_FAILURE;
}

/*--------------- private memio data -----------------------*/

/*
 * Implement just the bare minimum number of methods needed to make ssl happy.
 *
 * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
 * PR_GetSocketOption, so we have to provide an implementation of
 * PR_GetSocketOption that just says "I'm nonblocking".
 */

static struct PRIOMethods  memio_layer_methods = {
    PR_DESC_LAYERED,
    memio_Close,
    memio_Read,
    memio_Write,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    memio_Shutdown,
    memio_Recv,
    memio_Send,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    memio_GetPeerName,
    NULL,
    NULL,
    memio_GetSocketOption,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
    NULL,
};

static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;

static PRStatus memio_InitializeLayerName(void)
{
    memio_identity = PR_GetUniqueIdentity("memio");
    return PR_SUCCESS;
}

/*--------------- public memio functions -----------------------*/

PRFileDesc *memio_CreateIOLayer(int bufsize)
{
    PRFileDesc *fd;
    struct PRFilePrivate *secret;
    static PRCallOnceType once;

    PR_CallOnce(&once, memio_InitializeLayerName);

    fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
    secret = malloc(sizeof(struct PRFilePrivate));
    memset(secret, 0, sizeof(*secret));

    memio_buffer_new(&secret->readbuf, bufsize);
    memio_buffer_new(&secret->writebuf, bufsize);
    fd->secret = secret;
    return fd;
}

void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
{
    PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    struct PRFilePrivate *secret = memiofd->secret;
    secret->peername = *peername;
}

memio_Private *memio_GetSecret(PRFileDesc *fd)
{
    PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    struct PRFilePrivate *secret =  memiofd->secret;
    return (memio_Private *)secret;
}

int memio_GetReadParams(memio_Private *secret, char **buf)
{
    struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    PR_ASSERT(mb->bufsize);

    *buf = &mb->buf[mb->tail];
    return memio_buffer_unused_contiguous(mb);
}

void memio_PutReadResult(memio_Private *secret, int bytes_read)
{
    struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    PR_ASSERT(mb->bufsize);

    if (bytes_read > 0) {
        mb->tail += bytes_read;
        if (mb->tail == mb->bufsize)
            mb->tail = 0;
    } else if (bytes_read == 0) {
        /* Record EOF condition and report to caller when buffer runs dry */
        ((PRFilePrivate *)secret)->eof = PR_TRUE;
    } else /* if (bytes_read < 0) */ {
        mb->last_err = bytes_read;
    }
}

int memio_GetWriteParams(memio_Private *secret, const char **buf)
{
    struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    PR_ASSERT(mb->bufsize);

    *buf = &mb->buf[mb->head];
    return memio_buffer_used_contiguous(mb);
}

void memio_PutWriteResult(memio_Private *secret, int bytes_written)
{
    struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    PR_ASSERT(mb->bufsize);

    if (bytes_written > 0) {
        mb->head += bytes_written;
        if (mb->head == mb->bufsize)
            mb->head = 0;
    } else if (bytes_written < 0) {
        mb->last_err = bytes_written;
    }
}

/*--------------- private memio_buffer self-test -----------------*/

/* Even a trivial unit test is very helpful when doing circular buffers. */
/*#define TRIVIAL_SELF_TEST*/
#ifdef TRIVIAL_SELF_TEST
#include <stdio.h>

#define TEST_BUFLEN 7

#define CHECKEQ(a, b) { \
    if ((a) != (b)) { \
        printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
        exit(1); \
    } \
}

int main()
{
    struct memio_buffer mb;
    char buf[100];
    int i;

    memio_buffer_new(&mb, TEST_BUFLEN);

    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
    CHECKEQ(memio_buffer_used_contiguous(&mb), 0);

    CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);

    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    CHECKEQ(memio_buffer_used_contiguous(&mb), 5);

    CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);

    CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    CHECKEQ(memio_buffer_used_contiguous(&mb), 6);

    CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
    CHECKEQ(memcmp(buf, "howdy!", 6), 0);

    CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
    CHECKEQ(memio_buffer_used_contiguous(&mb), 0);

    CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);

    CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);

    CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);

    CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    CHECKEQ(memio_buffer_used_contiguous(&mb), 1);

    /* TODO: add more cases */

    printf("Test passed\n");
    exit(0);
}

#endif