12#include "llvm/ADT/SmallSet.h" 
   13#include "llvm/ADT/TypeSwitch.h" 
   14#include "llvm/Support/FileSystem.h" 
   15#include "llvm/Support/MathExtras.h" 
   20#include "mlir/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc" 
   21#include "mlir/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc" 
   24static constexpr uint32_t subgroupSize = 16;
 
   27LogicalResult verifyMatrixInput(
Op op) {
 
   28  static_assert(llvm::is_one_of<
Op, BlockLoad2dOp, BlockStore2dOp,
 
   29                                BlockPrefetch2dOp>::value,
 
   30                "Unexpected template parameter");
 
   34  if (pitch && width && *pitch < *width)
 
   36        "4th operand (base pitch) should be >= 2nd operand (base width)");
 
   38  uint32_t elemSize = op.getElemSizeInBits();
 
   39  if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
 
   40    return op->
emitOpError(
"expecting 'elem_size_in_bits' to be 8, 16, or 32");
 
   42  uint32_t tileHeight = op.getTileHeight();
 
   43  if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
 
   44    return op->
emitOpError(
"expecting tile_height to be 1, 2, 4, 8, 16, or 32");
 
   46  uint32_t vBlocks = op.getVBlocks();
 
   47  if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
 
   48    return op->
emitOpError(
"expecting v_blocks to be 1, 2, 4, or 8");
 
   53LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) {
 
   54  VectorType resTy = op.getRes().getType();
 
   55  if (!resTy.getElementType().isIntOrFloat())
 
   56    return op.emitOpError()
 
   57           << 
"expecting result element type to be int or float";
 
   58  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
 
   59  unsigned resSize = resTy.getNumElements() * resElemTySize;
 
   60  unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
 
   61                          op.getTileWidth() * op.getVBlocks() / subgroupSize;
 
   62  if (resSize != expectedSize)
 
   63    return op.emitOpError() << 
"result size of " << resSize
 
   64                            << 
" bits does not match the expected size of " 
   65                            << expectedSize << 
" bits";
 
   67  if (op.getTranspose() && op.getPackRegister())
 
   68    return op.emitOpError(
"transpose and pack_register are mutually exclusive");
 
   70  if (!op.getTranspose() && !op.getPackRegister()) {
 
   71    uint32_t tileHeight = op.getTileHeight();
 
   72    if (tileHeight < 1 || tileHeight > 32)
 
   73      return op.emitOpError(
"expecting tile_height to be between 1 and 32");
 
   75    uint32_t tileWidth = op.getTileWidth();
 
   76    uint32_t vBlocks = op.getVBlocks();
 
   77    switch (op.getElemSizeInBits()) {
 
   79      if (tileWidth < 4 || tileWidth > 64)
 
   80        return op.emitOpError(
"expecting tile_width to be between 4 and 64");
 
   81      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
 
   82        return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
 
   83      if (tileWidth * vBlocks > 64)
 
   84        return op.emitOpError(
 
   85            "tile_width * v_blocks should be less than or equal " 
   86            "to 64 for 8 bit elements");
 
   89      if (tileWidth < 2 || tileWidth > 32)
 
   90        return op.emitOpError(
"expecting tile_width to be between 2 and 32");
 
   91      if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
 
   92        return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
 
   93      if (tileWidth * vBlocks > 32)
 
   94        return op.emitOpError(
 
   95            "tile_width * v_blocks should be less than or equal " 
   96            "to 32 for 16 bit elements");
 
   99      if (tileWidth < 1 || tileWidth > 16)
 
  100        return op.emitOpError(
"expecting tile_width to be between 1 and 16");
 
  101      if (vBlocks != 1 && vBlocks != 2)
 
  102        return op.emitOpError(
"expecting v_blocks to be 1 or 2");
 
  103      if (tileWidth * vBlocks > 16)
 
  104        return op.emitOpError(
 
  105            "tile_width * v_blocks should be less than or equal " 
  106            "to 16 for 32 bit elements");
 
  109      if (tileWidth < 1 || tileWidth > 8)
 
  110        return op.emitOpError(
"expecting tile_width to be between 1 and 8");
 
  112        return op.emitOpError(
"expecting v_blocks to be 1");
 
  115      return op.emitOpError(
 
  116          "expecting elem_size_in_bits to be 8, 16, 32, or 64");
 
  122  if (op.getTranspose()) {
 
  123    assert(!op.getPackRegister() && 
"Expecting pack_register should be false");
 
  125    uint32_t vBlocks = op.getVBlocks();
 
  127      return op.emitOpError(
"expecting v_blocks to be 1");
 
  129    uint32_t tileHeight = op.getTileHeight();
 
  130    uint32_t tileWidth = op.getTileWidth();
 
  131    switch (op.getElemSizeInBits()) {
 
  133      if (tileHeight < 1 || tileHeight > 32)
 
  134        return op.emitOpError(
"expecting tile_height to be between 1 and 32");
 
  135      if (tileWidth < 1 || tileWidth > 8)
 
  136        return op.emitOpError(
"expecting tile_width to be between 1 and 8");
 
  140        return op.emitOpError(
 
  141            "expecting tile_height to be 8 for 64 bit elements");
 
  142      if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
 
  143        return op.emitOpError(
"expecting tile_width to be 1, 2, or 4");
 
  146      return op.emitOpError(
"transpose is only supported for 32 and 64 bit " 
  153  assert(op.getPackRegister() && !op.getTranspose() &&
 
  154         "Expecting pack_register should be true and transpose should be " 
  157  uint32_t vBlocks = op.getVBlocks();
 
  158  if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
 
  159    return op.emitOpError(
"expecting v_blocks to be 1, 2, or 4");
 
  161  uint32_t tileHeight = op.getTileHeight();
 
  162  uint32_t tileWidth = op.getTileWidth();
 
  163  switch (op.getElemSizeInBits()) {
 
  165    if (tileHeight < 4 || tileHeight > 32)
 
  166      return op.emitOpError(
"expecting tile_height to be between 4 and 32");
 
  167    if (tileWidth < 4 || tileWidth > 16)
 
  168      return op.emitOpError(
"expecting tile_width to be between 4 and 16");
 
  171    if (tileHeight < 2 || tileHeight > 32)
 
  172      return op.emitOpError(
"expecting tile_height to be between 2 and 32");
 
  173    if (tileWidth < 2 || tileWidth > 16)
 
  174      return op.emitOpError(
"expecting tile_width to be between 2 and 16");
 
  175    if (tileWidth * vBlocks > 32)
 
  176      return op.emitOpError(
 
  177          "tile_width * v_blocks should be less than or equal " 
  178          "to 32 for 16 bit elements");
 
  181    return op.emitOpError(
"pack_register is only supported for 8 and 16 bit " 
  188static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) {
 
  189  uint32_t tileHeight = op.getTileHeight();
 
  190  if (tileHeight < 1 || tileHeight > 8)
 
  191    return op.emitOpError(
"expecting tile_height to be between 1 and 8");
 
  193  uint32_t tileWidth = op.getTileWidth();
 
  194  switch (op.getElemSizeInBits()) {
 
  196    if (tileWidth < 4 || tileWidth > 64)
 
  197      return op.emitOpError(
"expecting tile_width to be between 4 and 64");
 
  200    if (tileWidth < 2 || tileWidth > 32)
 
  201      return op.emitOpError(
"expecting tile_width to be between 2 and 32");
 
  204    if (tileWidth < 1 || tileWidth > 16)
 
  205      return op.emitOpError(
"expecting tile_width to be between 1 and 16");
 
  208    if (tileWidth < 1 || tileWidth > 8)
 
  209      return op.emitOpError(
"expecting tile_width to be between 1 and 8");
 
  212    return op.emitOpError(
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
 
  215  uint32_t vBlocks = op.getVBlocks();
 
  217    return op.emitOpError(
"expecting v_blocks to be 1");
 
  223LogicalResult BlockLoad2dOp::verify() {
 
  224  if (verify2DBlockLoadRestriction(*this).failed())
 
  227  if (verifyMatrixInput(*this).failed())
 
  230  VectorType resTy = getRes().getType();
 
  231  if (!resTy.getElementType().isIntOrFloat())
 
  232    return emitOpError() << 
"expecting result element type to be int of float";
 
  233  unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
 
  234  if (getElemSizeInBits() == 32 || getPackRegister()) {
 
  235    if (resElemTySize != 32)
 
  236      return emitOpError() << 
"expecting result element type to be 32 bits";
 
  239  uint32_t tileWidth = getTileWidth();
 
  240  if (getPackRegister()) {
 
  243          "tile_width when pack_register is true should be equal " 
  244          "to subgroup size (16 elements)");
 
  251LogicalResult BlockStore2dOp::verify() {
 
  252  if (verify2DBlockStoreRestriction(*this).failed())
 
  255  if (verifyMatrixInput(*this).failed())
 
  258  uint32_t tileWidth = getTileWidth();
 
  259  switch (getElemSizeInBits()) {
 
  261    if (tileWidth != 16 && tileWidth != 32)
 
  262      return emitOpError(
"tile_width for 8 bit elements should be equal to " 
  267      return emitOpError(
"tile_width for 16 bit elements should be equal " 
  272      return emitOpError(
"tile_width for 32 bit elements should be equal " 
  276    llvm_unreachable(
"unexpected element size");
 
  282LogicalResult BlockPrefetch2dOp::verify() {
 
  283  if (verifyMatrixInput(*this).failed())
 
  286  uint32_t tileWidth = getTileWidth();
 
  287  switch (getElemSizeInBits()) {
 
  289    if (tileWidth != 16 && tileWidth != 32)
 
  290      return emitOpError(
"tile_width for 8 bit elements should be equal to " 
  295      return emitOpError(
"tile_width for 16 bit elements should be equal " 
  299    if (tileWidth != 8 && tileWidth != 16)
 
  301          "tile_width for 32 bit elements should be equal to 8 or 16");
 
  304    llvm_unreachable(
"unexpected element size");
 
  310template <
typename OpType, 
typename = std::enable_if_t<llvm::is_one_of<
 
  311                               OpType, BlockLoadOp, BlockStoreOp>::value>>
 
  314  if constexpr (std::is_same_v<OpType, BlockLoadOp>)
 
  315    srcOrDstTy = op.getResult().getType();
 
  317    srcOrDstTy = op.getVal().getType();
 
  318  VectorType vTy = dyn_cast<VectorType>(srcOrDstTy);
 
  322  int elemTySize = vTy.getElementType().getIntOrFloatBitWidth() / 8;
 
  323  if (elemTySize == 1) {
 
  324    llvm::SmallSet<int, 4> validSizes{2, 4, 8, 16};
 
  325    if (validSizes.contains(vTy.getNumElements()))
 
  328      return op.emitOpError(
 
  329          "vector size must be 2, 4, 8 or 16 for 8-bit element type");
 
  331    llvm::SmallSet<int, 3> validSizes{2, 4, 8};
 
  332    if (validSizes.contains(vTy.getNumElements()))
 
  335      return op.emitOpError(
 
  336          "vector size must be 2, 4 or 8 for element type > 8 bits");
 
 
  344LogicalResult MMAOp::verify() {
 
  347      return emitOpError(
"type of C operand must match result type");
 
  354                       StringRef triple, StringRef chip, DictionaryAttr flags,
 
  356  if (O < 0 || O > 3) {
 
  358           << 
"The optimization level must be a number between 0 and 3.";
 
  360  if (triple.empty()) {
 
  361    return emitError() << 
"The target triple cannot be empty.";
 
  364    return emitError() << 
"The target chip cannot be empty.";
 
  367    for (Attribute fileAttr : linkFiles) {
 
  368      if (
auto fileStrAttr = llvm::dyn_cast<StringAttr>(fileAttr)) {
 
  369        StringRef filePath = fileStrAttr.getValue();
 
  370        if (filePath.empty()) {
 
  371          return emitError() << 
"File paths in linkFiles cannot be empty.";
 
  373        if (!llvm::sys::fs::exists(filePath)) {
 
  374          return emitError() << 
"File '" << filePath << 
"' does not exist.";
 
  382void XeVMDialect::initialize() {
 
  385#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" 
  389#define GET_ATTRDEF_LIST 
  390#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" 
  392  declarePromisedInterface<mlir::gpu::TargetAttrInterface,
 
  393                           mlir::xevm::XeVMTargetAttr>();
 
  396#define GET_OP_CLASSES 
  397#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" 
  399#define GET_ATTRDEF_CLASSES 
  400#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" 
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
 
LogicalResult verify1DBlockArg(OpType op)
 
This class represents a diagnostic that is inflight and set to be reported.
 
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
 
This provides public APIs that all operations should have.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
Include the generated interface declarations.
 
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
 
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
 
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
 
llvm::function_ref< Fn > function_ref