MLIR  19.0.0git
BasicPtxBuilderInterface.cpp
Go to the documentation of this file.
1 //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10 // automatically. It is used by NVVM to LLVM pass.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 
17 #define DEBUG_TYPE "ptx-builder"
18 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
19 #define DBGSNL() (llvm::dbgs() << "\n")
20 
21 //===----------------------------------------------------------------------===//
22 // BasicPtxBuilderInterface
23 //===----------------------------------------------------------------------===//
24 
25 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
26 
27 using namespace mlir;
28 using namespace NVVM;
29 
30 static constexpr int64_t kSharedMemorySpace = 3;
31 
32 static char getRegisterType(Type type) {
33  if (type.isInteger(1))
34  return 'b';
35  if (type.isInteger(16))
36  return 'h';
37  if (type.isInteger(32))
38  return 'r';
39  if (type.isInteger(64))
40  return 'l';
41  if (type.isF32())
42  return 'f';
43  if (type.isF64())
44  return 'd';
45  if (auto ptr = type.dyn_cast<LLVM::LLVMPointerType>()) {
46  // Shared address spaces is addressed with 32-bit pointers.
47  if (ptr.getAddressSpace() == kSharedMemorySpace) {
48  return 'r';
49  }
50  return 'l';
51  }
52  // register type for struct is not supported.
53  llvm_unreachable("The register type could not deduced from MLIR type");
54  return '?';
55 }
56 
57 static char getRegisterType(Value v) {
58  if (v.getDefiningOp<LLVM::ConstantOp>())
59  return 'n';
60  return getRegisterType(v.getType());
61 }
62 
64  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
65  auto getModifier = [&]() -> const char * {
66  if (itype == PTXRegisterMod::ReadWrite) {
67  assert(false && "Read-Write modifier is not supported. Try setting the "
68  "same value as Write and Read seperately.");
69  return "+";
70  }
71  if (itype == PTXRegisterMod::Write) {
72  return "=";
73  }
74  return "";
75  };
76  auto addValue = [&](Value v) {
77  if (itype == PTXRegisterMod::Read) {
78  ptxOperands.push_back(v);
79  return;
80  }
81  if (itype == PTXRegisterMod::ReadWrite)
82  ptxOperands.push_back(v);
83  hasResult = true;
84  };
85 
86  llvm::raw_string_ostream ss(registerConstraints);
87  // Handle Structs
88  if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
89  if (itype == PTXRegisterMod::Write) {
90  addValue(v);
91  }
92  for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
93  if (itype != PTXRegisterMod::Write) {
94  Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
95  interfaceOp->getLoc(), v, idx);
96  addValue(extractValue);
97  }
98  if (itype == PTXRegisterMod::ReadWrite) {
99  ss << idx << ",";
100  } else {
101  ss << getModifier() << getRegisterType(t) << ",";
102  }
103  ss.flush();
104  }
105  return;
106  }
107  // Handle Scalars
108  addValue(v);
109  ss << getModifier() << getRegisterType(v) << ",";
110  ss.flush();
111 }
112 
113 LLVM::InlineAsmOp PtxBuilder::build() {
114  auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
115  LLVM::AsmDialect::AD_ATT);
116 
117  auto resultTypes = interfaceOp->getResultTypes();
118 
119  // Remove the last comma from the constraints string.
120  if (!registerConstraints.empty() &&
121  registerConstraints[registerConstraints.size() - 1] == ',')
122  registerConstraints.pop_back();
123 
124  std::string ptxInstruction = interfaceOp.getPtx();
125 
126  // Add the predicate to the asm string.
127  if (interfaceOp.getPredicate().has_value() &&
128  interfaceOp.getPredicate().value()) {
129  std::string predicateStr = "@%";
130  predicateStr += std::to_string((ptxOperands.size() - 1));
131  ptxInstruction = predicateStr + " " + ptxInstruction;
132  }
133 
134  // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
135  // Replace all % with $
136  std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
137 
138  return rewriter.create<LLVM::InlineAsmOp>(
139  interfaceOp->getLoc(),
140  /*result types=*/resultTypes,
141  /*operands=*/ptxOperands,
142  /*asm_string=*/llvm::StringRef(ptxInstruction),
143  /*constraints=*/registerConstraints.data(),
144  /*has_side_effects=*/interfaceOp.hasSideEffect(),
145  /*is_align_stack=*/false,
146  /*asm_dialect=*/asmDialectAttr,
147  /*operand_attrs=*/ArrayAttr());
148 }
149 
151  LLVM::InlineAsmOp inlineAsmOp = build();
152  LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
153  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
154  rewriter.replaceOp(interfaceOp, inlineAsmOp);
155  } else {
156  rewriter.eraseOp(interfaceOp);
157  }
158 }
static char getRegisterType(Type type)
#define DBGS()
void insertValue(Value v, PTXRegisterMod itype=PTXRegisterMod::Read)
Add an operand with the read/write input type.
LLVM::InlineAsmOp build()
Builds the inline assembly Op and returns it.
void buildAndReplaceOp()
Shortcut to build the inline assembly Op and replace or erase the original op with.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
U dyn_cast() const
Definition: Types.h:330
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
PTXRegisterMod
Register read/write modifier to build constraint string for PTX inline https://docs....
@ Write
Read register with '+' modifier.
@ ReadWrite
Read register with '=' modifier.
@ Read
Read register with no modifier.
@ kSharedMemorySpace
Shared memory space identifier.
Definition: NVVMDialect.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...