Source code for pyprocar.plotter.ebs_plot

__author__ = "Pedram Tavadze and Logan Lang"
__maintainer__ = "Pedram Tavadze and Logan Lang"
__email__ = "petavazohi@mail.wvu.edu, lllang@mix.wvu.edu"
__date__ = "March 31, 2020"

import os 
import yaml
from typing import List

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib as mpl
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator

from pyprocar.core import ElectronicBandStructure, KPath


[docs]class EBSPlot: """ A class to plot an electronic band structure. Parameters ---------- ebs : ElectronicBandStructure An electronic band structure object pyprocar.core.ElectronicBandStructure. kpath : KPath, optional A kpath object pyprocar.core.KPath. The default is None. ax : mpl.axes.Axes, optional A matplotlib Axes object. If provided the plot will be located at that ax. The default is None. spin : List[int], optional A list of the spins The default is None. Returns ------- None. """
[docs] def __init__(self, ebs:ElectronicBandStructure, kpath:KPath=None, ax:mpl.axes.Axes=None, spins:List[int]=None, kdirect:bool=True, config=None): self.config=config self.ebs = ebs self.kpath = kpath self.spins = spins self.kdirect=kdirect if self.spins is None: self.spins = range(self.ebs.nspins) self.nspins = len(self.spins) if self.ebs.is_non_collinear: self.spins = [0] self.handles = [] figsize=tuple(self.config.figure_size) if ax is None: self.fig = plt.figure(figsize=figsize) self.ax = self.fig.add_subplot(111) else: self.fig = plt.gcf() self.ax = ax # need to initiate kpath if kpath is not defined. self.x = self._get_x() self._initiate_plot_args() return None
def _initiate_plot_args(self): """Helper method to initialize the plot options """ self.set_xticks() self.set_yticks() self.set_xlabel() self.set_ylabel() self.set_xlim() self.set_ylim() def _get_x(self): """ Provides the x axis data of the plots Returns ------- np.ndarray x-axis data. """ pos = 0 if self.kpath is not None and self.kpath.nsegments == len(self.kpath.ngrids): for isegment in range(self.kpath.nsegments): kstart, kend = self.kpath.special_kpoints[isegment] if self.kdirect is False: kstart=np.dot(self.ebs.reciprocal_lattice,kstart) kend=np.dot(self.ebs.reciprocal_lattice,kend) distance = np.linalg.norm(kend - kstart) if isegment == 0: x = np.linspace(pos, pos + distance, self.kpath.ngrids[isegment]) else: x = np.append( x, np.linspace(pos, pos + distance, self.kpath.ngrids[isegment]), axis=0, ) pos += distance else : x = np.arange(0, self.ebs.kpoints.shape[0]) return np.array(x).reshape(-1,)
[docs] def plot_bands(self): """ Plot the plain band structure. Returns ------- None. """ for ispin in self.spins: if len(self.spins)==1: color=self.config.color else: color=self.config.spin_colors[ispin] for iband in range(self.ebs.nbands): handle = self.ax.plot( self.x, self.ebs.bands[:, iband, ispin], color=color, alpha=self.config.opacity[ispin], linestyle=self.config.linestyle[ispin], label=self.config.label[ispin], linewidth=self.config.linewidth[ispin], ) self.handles.append(handle)
[docs] def plot_scatter(self, width_mask:np.ndarray=None, color_mask:np.ndarray=None, spins:List[int]=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, ): """A method to plot a scatter plot Parameters ---------- width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None spins : List[int], optional A list of spins, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None """ if spins is None: spins = range(self.ebs.nspins) if self.ebs.is_non_collinear: spins = [0] if width_weights is None: width_weights = np.ones_like(self.ebs.bands) markersize = self.config.markersize else: markersize =[l*30 for l in self.config.markersize] if width_mask is not None or color_mask is not None: if width_mask is not None: mbands = np.ma.masked_array( self.ebs.bands, np.abs(width_weights) < width_mask) if color_mask is not None: mbands = np.ma.masked_array( self.ebs.bands, np.abs(color_weights) < color_mask) else: # Faking a mask, all elemtnet are included mbands = np.ma.masked_array(self.ebs.bands, False) if color_weights is not None: vmin=self.config.clim[0] vmax=self.config.clim[1] if vmin is None: # only the actual spin values are to be used (i.e. we # are plotting the density, then negative values from # spin projections are nonsense ) vmin = color_weights[:,:,spins].min() if vmax is None: vmax = color_weights[:,:,spins].max() for ispin in spins: for iband in range(self.ebs.nbands): if len(self.spins)==1: color=self.config.color else: color=self.config.spin_colors[ispin] if color_weights is None: sc = self.ax.scatter( self.x, mbands[:, iband, ispin], c=color, s=width_weights[:, iband, ispin].round( 2)*markersize[ispin], # edgecolors="none", linewidths=self.config.linewidth[ispin], cmap=self.config.cmap, vmin=vmin, vmax=vmax, marker=self.config.marker[ispin], alpha=self.config.opacity[ispin], ) else: sc = self.ax.scatter( self.x, mbands[:, iband, ispin], c=color_weights[:, iband, ispin].round(2), s=width_weights[:, iband, ispin].round(2)*markersize[ispin], # edgecolors="none", linewidths=self.config.linewidth[ispin], cmap=self.config.cmap, vmin=vmin, vmax=vmax, marker=self.config.marker[ispin], alpha=self.config.opacity[ispin], ) if self.config.plot_color_bar and color_weights is not None: self.cb = self.fig.colorbar(sc, ax=self.ax)
[docs] def plot_parameteric( self, spins:List[int]=None, width_mask:np.ndarray=None, color_mask:np.ndarray=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, elimit:List[float]=None ): """A method to plot a scatter plot Parameters ---------- spins : List[int], optional A list of spins, by default None width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None elimit : List[float], optional Energy range to plot. Only useful if the band index is written """ # if there is only a single k-point the method for atomic # levels will be called to fake another kpoint and then # exit. `plot_atomic_levels` will invoke this method again to # get the actual plot if len(self.ebs.kpoints) == 1: self.plot_atomic_levels(color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins, elimit=elimit) return if width_weights is None: width_weights = np.ones_like(self.ebs.bands) linewidth = self.config.linewidth else: linewidth = [l*5 for l in self.config.linewidth] if spins is None: spins = range(self.ebs.nspins) if self.ebs.is_non_collinear: spins = [0] if width_mask is not None or color_mask is not None: if width_mask is not None: mbands = np.ma.masked_array( self.ebs.bands, np.abs(width_weights) < width_mask) if color_mask is not None: mbands = np.ma.masked_array( self.ebs.bands, np.abs(color_weights) < color_mask) else: # Faking a mask, all elemtnet are included mbands = np.ma.masked_array(self.ebs.bands, False) if color_weights is not None: vmin=self.config.clim[0] vmax=self.config.clim[1] if vmin is None: vmin = color_weights[:,:,spins].min() if vmax is None: vmax = color_weights[:,:,spins].max() norm = mpl.colors.Normalize(vmin, vmax) for ispin in spins: for iband in range(self.ebs.nbands): if len(self.spins)==1: color=self.config.color else: color=self.config.spin_colors[ispin] points = np.array( [self.x, mbands[:, iband, ispin]]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) # this is to delete the segments on the high sym points x = self.x # segments = np.delete( # segments, np.where(x[1:] == x[:-1])[0], axis=0) if color_weights is None: lc = LineCollection( segments, colors=color, linestyle=self.config.linestyle[ispin]) else: lc = LineCollection( segments, cmap=plt.get_cmap(self.config.cmap), norm=norm) lc.set_array(color_weights[:, iband, ispin]) lc.set_linewidth( width_weights[:, iband, ispin]*linewidth[ispin]) lc.set_linestyle(self.config.linestyle[ispin]) handle = self.ax.add_collection(lc) # if color_weights is not None: # handle.set_color(color_map[iweight][:-1].lower()) handle.set_linewidth(linewidth) self.handles.append(handle) if self.config.plot_color_bar and color_weights is not None: self.cb = self.fig.colorbar(lc, ax=self.ax)
[docs] def plot_parameteric_overlay(self, spins:List[int]=None, weights:np.ndarray=None, ): """A method to plot the parametric overlay Parameters ---------- spins : List[int], optional A list of spins, by default None weights : np.ndarray, optional The weights of each point, by default None """ linewidth = [l*7 for l in self.config.linewidth] if type(self.config.cmap) is str: color_map = ['Reds', "Blues", "Greens", "Purples", "Oranges", "Greys"] else: color_map=self.config.cmap if spins is None: spins = range(self.ebs.nspins) if self.ebs.is_non_collinear: spins = [0] for iweight, weight in enumerate(weights): vmin=self.config.clim[0] vmax=self.config.clim[1] if vmin is None: vmin = 0 if vmax is None: vmax = 1 norm = mpl.colors.Normalize(vmin, vmax) for ispin in spins: # plotting for iband in range(self.ebs.nbands): points = np.array( [self.x, self.ebs.bands[:, iband, ispin]]).T.reshape(-1, 1, 2) segments = np.concatenate( [points[:-1], points[1:]], axis=1) # this is to delete the segments on the high sym points x = self.x segments = np.delete( segments, np.where(x[1:] == x[:-1])[0], axis=0) lc = LineCollection( segments, cmap=plt.get_cmap(color_map[iweight]), norm=norm, alpha=self.config.opacity[ispin]) lc.set_array(weight[:, iband, ispin]) lc.set_linewidth(weight[:, iband, ispin]*linewidth[ispin]) handle = self.ax.add_collection(lc) handle.set_color(color_map[iweight][:-1].lower()) handle.set_linewidth(linewidth) self.handles.append(handle) if self.config.plot_color_bar: self.cb = self.fig.colorbar(lc, ax=self.ax)
[docs] def plot_atomic_levels(self, spins:List[int]=None, width_mask:np.ndarray=None, color_mask:np.ndarray=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, elimit:List[float]=None ): """A method to plot a scatter plot Parameters ---------- spins : List[int], optional A list of spins, by default None width_mask : np.ndarray, optional The width mask, by default None color_mask : np.ndarray, optional The color mask, by default None width_weights : np.ndarray, optional The width weight of each point, by default None color_weights : np.ndarray, optional The color weights at each point, by default None elimit : List[float], optional The energy range to plot. """ self.ebs.bands = np.vstack((self.ebs.bands, self.ebs.bands)) self.ebs.projected = np.vstack((self.ebs.projected, self.ebs.projected)) self.ebs.kpoints = np.vstack((self.ebs.kpoints, self.ebs.kpoints)) self.ebs.kpoints[0][-1] += 1 self.x=self._get_x() print("Atomic plot: bands.shape :", self.ebs.bands.shape) print("Atomic plot: spd.shape :", self.ebs.projected.shape) print("Atomic plot: kpoints.shape:", self.ebs.kpoints.shape) self.ax.xaxis.set_major_locator(plt.NullLocator()) # labels on each band if elimit: emin, emax = elimit[0], elimit[1] else: emin, emax = np.min(self.ebs.bands), np.max(self.ebs.bands) # print('Energy range', emin, emax) if spins is None: spins = range(self.ebs.nspins) if self.ebs.is_non_collinear: spins = [0] # cointainers for the bounding boxes of the text elements Nspin = len(spins) texts = [] for ispin in spins: for i in range(len(self.ebs.bands[0,:,ispin])): energy = self.ebs.bands[0,i,ispin] if energy > emin and energy < emax: txt = [0, energy, f"s-{ispin} : "+"b-"+str(i + 1)] texts.append(txt) # sorting the texts texts.sort(key=lambda x: x[1]) # I need to set the energy limits self.set_ylim(elimit) self.set_xlim() # knowing the text size txt = texts[-1] txt = plt.text(*txt) bbox = txt.get_window_extent() bbox_data = self.ax.transData.inverted().transform_bbox(bbox) w, h = bbox_data.width, bbox_data.height txt.remove() # print('Width, ', w, '. Height,', h) shift = 0 txt = texts[0] self.ax.text(*txt) for i in range(1, len(texts)): txt = texts[i] last = texts[i-1] y, y0 = txt[1], last[1] # if there there is vertical overlap if y < y0 + h: # print('overlap', y, y0+h) # I shift it laterally (the shift can be 0) shift +=1 if shift == 2: shift = 0 txt[0] = txt[0] + w*1.5*shift else: shift = 0 # print(txt) self.ax.text(*txt) self.plot_parameteric(color_weights=color_weights, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, spins=spins)
[docs] def set_xticks(self, tick_positions:List[int]=None, tick_names:List[str]=None, color:str="black"): """A method to set the x ticks Parameters ---------- tick_positions : List[int], optional A list of tick positions, by default None tick_names : List[str], optional A list of tick names, by default None color : str, optional A color for the ticks, by default "black" """ if self.kpath is not None: if tick_positions is None: tick_positions = self.kpath.tick_positions if tick_names is None: tick_names = self.kpath.tick_names for ipos in tick_positions: self.ax.axvline(self.x[ipos], color=color) self.ax.set_xticks(self.x[tick_positions]) self.ax.set_xticklabels(tick_names) self.ax.tick_params( which='major', axis='x', direction='in')
[docs] def set_yticks(self, major:float=None, minor:float=None, interval:List[float]=None): """A method to set the y ticks Parameters ---------- major : float, optional A float to set the major tick locator, by default None minor : float, optional A float to set the the minor tick Locator, by default None interval : List[float], optional The interval of the ticks, by default None """ if (major is None or minor is None): if interval is None: interval = (self.ebs.bands.min()-abs(self.ebs.bands.min()) * 0.1, self.ebs.bands.max()*1.1) interval = abs(interval[1] - interval[0]) if interval < 30 and interval >= 20: major = 5 minor = 1 elif interval < 20 and interval >= 10: major = 4 minor = 0.5 elif interval < 10 and interval >= 5: major = 2 minor = 0.2 elif interval < 5 and interval >= 3: major = 1 minor = 0.1 elif interval < 3 and interval >= 1: major = 0.5 minor = 0.1 else: pass if major is not None and minor is not None: self.ax.yaxis.set_major_locator(MultipleLocator(major)) self.ax.yaxis.set_minor_locator(MultipleLocator(minor)) self.ax.tick_params( which='major', axis="y", direction="inout", width=1, length=5, labelright=False, right=True, left=True) self.ax.tick_params( which='minor', axis="y", direction="in", left=True, right=True)
[docs] def set_xlim(self, interval:List[float]=None): """A method to set the x limit Parameters ---------- interval : List[float], optional A list containing the begining and the end of the interval, by default None """ if interval is None: interval = (self.x[0], self.x[-1]) self.ax.set_xlim(interval)
[docs] def set_ylim(self, interval:List[float]=None): """A method to set the y limit Parameters ---------- interval : List[float], optional A list containing the begining and the end of the interval, by default None """ if interval is None: interval = (self.ebs.bands.min()-abs(self.ebs.bands.min()) * 0.1, self.ebs.bands.max()*1.1) self.ax.set_ylim(interval)
[docs] def set_xlabel(self, label:str="K vector"): """A method to set the x label Parameters ---------- label : str, optional String fo the x label name, by default "K vector" """ self.ax.set_xlabel(label)
[docs] def set_ylabel(self, label:str=r"E - E$_F$ (eV)"): """A method to set the y label Parameters ---------- label : str, optional String fo the y label name, by default r"E - E$ (eV)" """ self.ax.set_ylabel(label)
[docs] def set_title(self, title:str="Band Structure"): """A method to set the title Parameters ---------- title : str, optional String for the title, by default "Band Structure" """ if self.config.title: self.ax.set_title(label=self.config.title)
[docs] def set_colorbar_title(self, title:str=None): """A method to set the title of the color bar Parameters ---------- title : str, optional String for the title, by default "Atomic Orbital Projections" """ if title: title=title else: title=self.config.colorbar_title self.cb.ax.tick_params(labelsize=self.config.colorbar_tick_labelsize) self.cb.set_label(title, size=self.config.colorbar_title_size, rotation=270, labelpad=self.config.colorbar_title_padding)
[docs] def legend(self, labels:List[str]=None): """A methdo to plot the legend Parameters ---------- labels : List[str], optional A list of strings for the labels of each element for the legend, by default None """ if labels == None: labels = self.config.label if self.config.legend: self.ax.legend(self.handles, labels)
[docs] def draw_fermi(self, fermi_level:float=0, ): """A method to draw the fermi line Parameters ---------- fermi_level : str, optional The energy level to draw the line """ self.ax.axhline(y=fermi_level, color=self.config.fermi_color, linestyle=self.config.fermi_linestyle, linewidth=self.config.fermi_linewidth)
[docs] def grid(self): """A method to plot a grid """ if self.config.grid: self.ax.grid( self.config.grid, which=self.config.grid_which, color=self.config.grid_color, linestyle=self.config.grid_linestlye, linewidth=self.config.grid_linewidth)
[docs] def show(self): """A method to show the plot """ plt.show()
[docs] def save(self, filename:str='bands.pdf'): """A method to save the plot Parameters ---------- filename : str, optional A string for the file name, by default 'bands.pdf' """ plt.savefig(filename, dpi=self.config.dpi, bbox_inches="tight") plt.clf()