diff options
| author | David Lamparter <equinox@opensourcerouting.org> | 2022-02-26 13:20:16 +0100 | 
|---|---|---|
| committer | David Lamparter <equinox@opensourcerouting.org> | 2022-02-26 16:49:12 +0100 | 
| commit | 89087f23b589b051910c26ae7772256adacc35a7 (patch) | |
| tree | c9def558726faf899559f96edd04f727fb54cf29 | |
| parent | 264c806da958e1b16619ed2a66a42f2b3ac69558 (diff) | |
lib: use iovec for checksum code
... to allow checksumming noncontiguous blurbs of data.
Signed-off-by: David Lamparter <equinox@opensourcerouting.org>
| -rw-r--r-- | lib/checksum.c | 85 | ||||
| -rw-r--r-- | lib/checksum.h | 38 | ||||
| -rw-r--r-- | tests/lib/test_checksum.c | 55 | 
3 files changed, 136 insertions, 42 deletions
diff --git a/lib/checksum.c b/lib/checksum.c index 3473370041..6c5f06de45 100644 --- a/lib/checksum.c +++ b/lib/checksum.c @@ -9,13 +9,24 @@  #include <zebra.h>  #include "checksum.h" -int /* return checksum in low-order 16 bits */ -	in_cksum(void *parg, int nbytes) +#define add_carry(dst, add)                                                    \ +	do {                                                                   \ +		typeof(dst) _add = (add);                                      \ +		dst += _add;                                                   \ +		if (dst < _add)                                                \ +			dst++;                                                 \ +	} while (0) + +uint16_t in_cksumv(const struct iovec *iov, size_t iov_len)  { -	unsigned short *ptr = parg; -	register long sum; /* assumes long == 32 bits */ -	unsigned short oddbyte; -	register unsigned short answer; /* assumes unsigned short == 16 bits */ +	const struct iovec *iov_end; +	uint32_t sum = 0; + +	union { +		uint8_t bytes[2]; +		uint16_t word; +	} wordbuf; +	bool have_oddbyte = false;  	/*  	 * Our algorithm is simple, using a 32-bit accumulator (sum), @@ -23,17 +34,42 @@ int /* return checksum in low-order 16 bits */  	 * all the carry bits from the top 16 bits into the lower 16 bits.  	 */ -	sum = 0; -	while (nbytes > 1) { -		sum += *ptr++; -		nbytes -= 2; +	for (iov_end = iov + iov_len; iov < iov_end; iov++) { +		const uint8_t *ptr, *end; + +		ptr = (const uint8_t *)iov->iov_base; +		end = ptr + iov->iov_len; +		if (ptr == end) +			continue; + +		if (have_oddbyte) { +			have_oddbyte = false; +			wordbuf.bytes[1] = *ptr++; + +			add_carry(sum, wordbuf.word); +		} + +		while (ptr + 8 <= end) { +			add_carry(sum, *(const uint32_t *)(ptr + 0)); +			add_carry(sum, *(const uint32_t *)(ptr + 4)); +			ptr += 8; +		} + +		while (ptr + 2 <= end) { +			add_carry(sum, *(const uint16_t *)ptr); +			ptr += 2; +		} + +		if (ptr + 1 <= end) { +			wordbuf.bytes[0] = *ptr++; +			have_oddbyte = true; +		}  	}  	/* mop up an odd byte, if necessary */ -	if (nbytes == 1) { -		oddbyte = 0; /* make sure top half is zero */ -		*((uint8_t *)&oddbyte) = *(uint8_t *)ptr; /* one byte only */ -		sum += oddbyte; +	if (have_oddbyte) { +		wordbuf.bytes[1] = 0; +		add_carry(sum, wordbuf.word);  	}  	/* @@ -42,26 +78,7 @@ int /* return checksum in low-order 16 bits */  	sum = (sum >> 16) + (sum & 0xffff); /* add high-16 to low-16 */  	sum += (sum >> 16);		    /* add carry */ -	answer = ~sum; /* ones-complement, then truncate to 16 bits */ -	return (answer); -} - -int in_cksum_with_ph4(struct ipv4_ph *ph, void *data, int nbytes) -{ -	uint8_t dat[sizeof(struct ipv4_ph) + nbytes]; - -	memcpy(dat, ph, sizeof(struct ipv4_ph)); -	memcpy(dat + sizeof(struct ipv4_ph), data, nbytes); -	return in_cksum(dat, sizeof(dat)); -} - -int in_cksum_with_ph6(struct ipv6_ph *ph, void *data, int nbytes) -{ -	uint8_t dat[sizeof(struct ipv6_ph) + nbytes]; - -	memcpy(dat, ph, sizeof(struct ipv6_ph)); -	memcpy(dat + sizeof(struct ipv6_ph), data, nbytes); -	return in_cksum(dat, sizeof(dat)); +	return ~sum;  }  /* Fletcher Checksum -- Refer to RFC1008. */ diff --git a/lib/checksum.h b/lib/checksum.h index 16e6945422..508c3f38a6 100644 --- a/lib/checksum.h +++ b/lib/checksum.h @@ -27,9 +27,41 @@ struct ipv6_ph {  	uint8_t next_hdr;  } __attribute__((packed)); -extern int in_cksum(void *data, int nbytes); -extern int in_cksum_with_ph4(struct ipv4_ph *ph, void *data, int nbytes); -extern int in_cksum_with_ph6(struct ipv6_ph *ph, void *data, int nbytes); + +extern uint16_t in_cksumv(const struct iovec *iov, size_t iov_len); + +static inline uint16_t in_cksum(const void *data, size_t nbytes) +{ +	struct iovec iov[1]; + +	iov[0].iov_base = (void *)data; +	iov[0].iov_len = nbytes; +	return in_cksumv(iov, array_size(iov)); +} + +static inline uint16_t in_cksum_with_ph4(const struct ipv4_ph *ph, +					 const void *data, size_t nbytes) +{ +	struct iovec iov[2]; + +	iov[0].iov_base = (void *)ph; +	iov[0].iov_len = sizeof(*ph); +	iov[1].iov_base = (void *)data; +	iov[1].iov_len = nbytes; +	return in_cksumv(iov, array_size(iov)); +} + +static inline uint16_t in_cksum_with_ph6(const struct ipv6_ph *ph, +					 const void *data, size_t nbytes) +{ +	struct iovec iov[2]; + +	iov[0].iov_base = (void *)ph; +	iov[0].iov_len = sizeof(*ph); +	iov[1].iov_base = (void *)data; +	iov[1].iov_len = nbytes; +	return in_cksumv(iov, array_size(iov)); +}  #define FLETCHER_CHECKSUM_VALIDATE 0xffff  extern uint16_t fletcher_checksum(uint8_t *, const size_t len, diff --git a/tests/lib/test_checksum.c b/tests/lib/test_checksum.c index 301078867a..0eedb96a5e 100644 --- a/tests/lib/test_checksum.c +++ b/tests/lib/test_checksum.c @@ -477,6 +477,8 @@ int main(int argc, char **argv)  		exercise += EXERCISESTEP;  		exercise %= MAXDATALEN; +		printf("\rexercising length %d\033[K", exercise); +  		for (i = 0; i < exercise; i += sizeof(long int)) {  			long int rand = frr_weak_random(); @@ -489,24 +491,67 @@ int main(int argc, char **argv)  		in_csum_res = in_cksum_optimized(buffer, exercise);  		in_csum_rfc = in_cksum_rfc(buffer, exercise);  		if (in_csum_res != in_csum || in_csum != in_csum_rfc) -			printf("verify: in_chksum failed in_csum:%x, in_csum_res:%x,in_csum_rfc %x, len:%d\n", +			printf("\nverify: in_chksum failed in_csum:%x, in_csum_res:%x,in_csum_rfc %x, len:%d\n",  			       in_csum, in_csum_res, in_csum_rfc, exercise); +		struct iovec iov[3]; +		uint16_t in_csum_iov; + +		iov[0].iov_base = buffer; +		iov[0].iov_len = exercise / 2; +		iov[1].iov_base = buffer + iov[0].iov_len; +		iov[1].iov_len = exercise - iov[0].iov_len; + +		in_csum_iov = in_cksumv(iov, 2); +		if (in_csum_iov != in_csum) +			printf("\nverify: in_cksumv failed, lens: %zu+%zu\n", +			       iov[0].iov_len, iov[1].iov_len); + +		if (exercise >= 6) { +			/* force split with byte leftover */ +			iov[0].iov_base = buffer; +			iov[0].iov_len = (exercise / 2) | 1; +			iov[1].iov_base = buffer + iov[0].iov_len; +			iov[1].iov_len = 2; +			iov[2].iov_base = buffer + iov[0].iov_len + 2; +			iov[2].iov_len = exercise - iov[0].iov_len - 2; + +			in_csum_iov = in_cksumv(iov, 3); +			if (in_csum_iov != in_csum) +				printf("\nverify: in_cksumv failed, lens: %zu+%zu+%zu, got %04x, expected %04x\n", +				       iov[0].iov_len, iov[1].iov_len, +				       iov[2].iov_len, in_csum_iov, in_csum); + +			/* force split without byte leftover */ +			iov[0].iov_base = buffer; +			iov[0].iov_len = (exercise / 2) & ~1UL; +			iov[1].iov_base = buffer + iov[0].iov_len; +			iov[1].iov_len = 2; +			iov[2].iov_base = buffer + iov[0].iov_len + 2; +			iov[2].iov_len = exercise - iov[0].iov_len - 2; + +			in_csum_iov = in_cksumv(iov, 3); +			if (in_csum_iov != in_csum) +				printf("\nverify: in_cksumv failed, lens: %zu+%zu+%zu, got %04x, expected %04x\n", +				       iov[0].iov_len, iov[1].iov_len, +				       iov[2].iov_len, in_csum_iov, in_csum); +		} +  		ospfd = ospfd_checksum(buffer, exercise + sizeof(uint16_t),  				       exercise);  		if (verify(buffer, exercise + sizeof(uint16_t))) -			printf("verify: ospfd failed\n"); +			printf("\nverify: ospfd failed\n");  		isisd = iso_csum_create(buffer, exercise + sizeof(uint16_t),  					exercise);  		if (verify(buffer, exercise + sizeof(uint16_t))) -			printf("verify: isisd failed\n"); +			printf("\nverify: isisd failed\n");  		lib = fletcher_checksum(buffer, exercise + sizeof(uint16_t),  					exercise);  		if (verify(buffer, exercise + sizeof(uint16_t))) -			printf("verify: lib failed\n"); +			printf("\nverify: lib failed\n");  		if (ospfd != lib) { -			printf("Mismatch in values at size %d\n" +			printf("\nMismatch in values at size %d\n"  			       "ospfd: 0x%04x\tc0: %d\tc1: %d\tx: %d\ty: %d\n"  			       "isisd: 0x%04x\tc0: %d\tc1: %d\tx: %d\ty: %d\n"  			       "lib: 0x%04x\n",  | 
