"""
gfx.py

all graphics displaying and loading is here
"""

import sys, os
import pygame


Surf = None

def init():
    global Surf
    Surf = pygame.display.set_mode((640, 480))


def flip():
    pygame.display.flip()


def draw(surf, pos):
    Surf.blit(surf, pos)

def drawmulti(objs):
    blit = Surf.blit
    for surf, pos in objs:
        blit(surf, pos)

def clear(img):
    Surf.blit(img, (0, 0))


Fontcache = {}
def getfont(name, size):
    key = name, size
    if not Fontcache.has_key(key):
        if not name:
            font = pygame.font.Font(None, size)
        else:
            font = pygame.font.SysFont(name, size)
        Fontcache[key] = font
    else:
        font = Fontcache[key]
    return font
        


def text(size, color, bgd, text, **pos):
    font = getfont(None, size)
    if size < 20:
        img = font.render(text, 1, color, bgd)
        img = img.convert()
    else:
        shadcolor = [c/2 for c in color]
        img1 = font.render(text, 1, color)
        img2 = font.render(text, 1, shadcolor, bgd)
        img = pygame.Surface((img1.get_width()+4, img1.get_height()+4))
        img.fill(bgd)
        if size < 30:
            offset = 1
        else:
            offset = 2
        img.blit(img2, (offset, offset))
        img.blit(img1, (0, 0))
    r = img.get_rect()
    if pos:
        setattr(r, *pos.popitem())
    return [img, r]


def texta(size, color, text, **pos):
    font = getfont(None, size)
    img = font.render(text, 1, color)
    r = img.get_rect()
    if pos:
        setattr(r, *pos.popitem())
    return [img, r]


    
Imgcache = {}
def getimg(name):
    if not Imgcache.has_key(name):
        fullname = os.path.join('data', name+'.pcx')
        img = pygame.image.load(fullname)
        Imgcache[name] = img    
    else:
        img = Imgcache[name]
    return img

def load(name):
    img = getimg(name)
    return img.convert()

def loadcolored(name, color):
    img = getimg(name)
    pal = img.get_palette()
    newpal = []
    for r,g,b in pal:
        newpal.append((r*color[0]>>8, r*color[1]>>8, r*color[2]>>8))
    img.set_palette(newpal)
    newimg = img.convert()
    img.set_palette(pal)
    return newimg

def loadcolored(name, color):
    img = getimg(name)
    pal = img.get_palette()
    newpal = []
    for r,g,b in pal:
        newpal.append((r*color[0]>>8, r*color[1]>>8, r*color[2]>>8))
    img.set_palette(newpal)
    newimg = img.convert()
    img.set_palette(pal)
    return newimg

def loadcoloredkey(name, color):
    img = getimg(name)
    pal = img.get_palette()
    img.set_colorkey((0,255,0))
    newpal = []
    for r,g,b in pal:
        newpal.append((r*color[0]>>8, g*color[1]>>8, b*color[2]>>8))
    img.set_palette(newpal)
    newimg = img.convert()
    img.set_palette(pal)
    img.set_colorkey()
    return newimg

def key(img, color=(255,255,255)):
    img.set_colorkey(color, pygame.RLEACCEL)
    return img