/*	$NetBSD$	*/

/*-
 * Copyright (c) 2015 Taylor R. Campbell
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD$");

#include <sys/errno.h>
#include <sys/ioctl_pb.h>

#include <pb.h>
#include <pb_decode.h>
#include <pb_encode.h>

int
pb_copyin_init(struct plistref *pref, struct pb_msg msg)
{
	int error;

	pb_init(msg);
	error = pb_copyin(pref, msg);
	if (error)
		pb_destroy(msg);

	return error;
}

const size_t pb_copyin_limit = 0x10000;

CTASSERT(sizeof(PB_IOCTL_MAGIC) == 0x10);

int
pb_copyin(struct plistref *pref, struct pb_msg msg)
{
	const size_t nmagic = sizeof(PB_IOCTL_MAGIC);
	unsigned char *buf;
	int error;

	if (pref->pref_len > pb_copyin_limit)
		return E2BIG;	/* XXX What error code?  */
	if (pref->pref_len < nmagic)
		return EINVAL;	/* XXX What error code?  */

	buf = kmem_alloc(pref->pref_len, KM_SLEEP);
	error = copyin(pref->pref_plist, buf, pref->pref_len);
	if (error)
		goto out;

	if (memcmp(buf, PB_IOCTL_MAGIC, nmagic) != 0) {
		/* XXX Fall back to parsing a plist.  */
		error = EINVAL;	/* XXX What error code?  */
		goto out;
	}

	error = pb_decode_from_memory(msg, buf + nmagic,
	    pref->pref_len - nmagic);
	if (error)
		goto out;

	/* Success!  */
	error = 0;

out:	kmem_free(buf, pref->pref_len);
	return error;
}

int
pb_copyout_destroy(struct plistref *pref, struct pb_msg msg)
{
	int error;

	error = pb_copyout(pref, msg);
	pb_destroy(msg);

	return error;
}

int
pb_copyout(struct plistref *pref, struct pb_msg msg)
{
	size_t len;
	void *buf, *uaddr;
	int error;

	error = pb_encoding_size(msg, &len);
	if (error)
		return error;

	/* XXX Generate a plist if the input was a plist.  */

	buf = kmem_alloc(len, KM_SLEEP);
	error = pb_encode_to_memory(msg, buf, len);
	if (error)
		goto out0;

	error = uvm_mmap_anon(curproc, &uaddr, round_page(len));
	if (error)
		goto out0;

	error = copyout(buf, uaddr, len);
	if (error)
		goto out1;

	/* Success!  */
	pref->pref_plist = uaddr;
	pref->pref_len = len;
	error = 0;

out1:	if (error) {
		/* XXX Copypasta from sys_munmap.  */
		const vaddr_t vaddr = uaddr;
		struct vm_map_entry *dead_entries = NULL;

		vm_map_lock(&curproc->p_vmspace->vm_map);
		uvm_unmap_remove(&curproc->p_vmspace->vm_map, vaddr,
		    vaddr + round_page(len), &dead_entries, 0);
		vm_map_unlock(&curproc->p_vmspace->vm_map);
		if (dead_entries != NULL)
			uvm_unmap_detach(dead_entries, 0);
	}
out0:	kmem_free(buf, len);
	return error;
}

