MLIR  21.0.0git
XeGPUBlocking.cpp
Go to the documentation of this file.
1 //===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Pass/PassManager.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 namespace mlir {
24 namespace xegpu {
25 #define GEN_PASS_DEF_XEGPUBLOCKING
26 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27 } // namespace xegpu
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "xegpu-blocking"
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 // reslove the unrealized conversion cast ops generated when doing SCF
39 // Structural Type Conversion. It will have two formats, N:1 vector
40 // cast and 1:N vector cast. vector::insert_strided_slice ops will be
41 // used for the first case, and vector::extract_strided_slice ops will be
42 // used for the second case.
43 static void
44 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
45  ValueRange inputs = castOp.getInputs();
46  ValueRange outputs = castOp.getOutputs();
47 
48  auto hasIdenticalVectorTypes = [](ValueRange values) {
49  auto types = values.getTypes();
50  return llvm::all_of(types, [&](Type type) {
51  return isa<VectorType>(type) && type == types.front();
52  });
53  };
54 
55  // We only interest in the case where all inputs and outputs have the
56  // identical VectorTypes
57  if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
58  LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
59  return;
60  }
61 
62  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
63  OpBuilder builder(castOp);
64  if (inputs.size() > 1 && outputs.size() == 1) {
65  // the castOp is emulating an unpack op
66  ArrayRef<int64_t> shape = outputTy.getShape();
68  builder, castOp.getLoc(), inputs, shape);
69  castOp->replaceAllUsesWith(ValueRange(result));
70  castOp->erase();
71  } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
72  // the castOp is emulating a pack op
73  ArrayRef<int64_t> tileShape = outputTy.getShape();
75  builder, castOp.getLoc(), inputs[0], tileShape);
76  castOp->replaceAllUsesWith(results);
77  castOp->erase();
78  }
79 }
80 
81 //===------------------------------------------------------------------------===//
82 // The XeGPUBlockingPass leverages the unroll patterns for XeGPU and Vector ops
83 // to partition operations that process large shapes into multiple operations on
84 // smaller shapes, as specified by the inst_data in the layout attribute. This
85 // enables each resulting operation to be efficiently mapped to a hardware
86 // instruction.
87 //===------------------------------------------------------------------------===//
88 
89 class XeGPUBlockingPass final
90  : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
91 public:
92  void runOnOperation() override;
93 
94 private:
95  // Get the tile shape for a given OpOperand or OpResult by examining the
96  // corresponding layout attribute. If layout is not present or is not a
97  // subgroup level layout, it returns std::nullopt.
98  template <typename T,
99  typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
100  std::is_same_v<T, OpResult>>>
101  std::optional<SmallVector<int64_t>>
102  getTileShape(const T &operandOrResult) const;
103 
104  // Get the tile shape for a given operation.
105  std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
106 
107  // Determine if the operation requires unrolling. Return false if all operands
108  // and results have tile shapes identical to their original types. Otherwise,
109  // return true.
110  bool needsUnroll(Operation *op) const;
111 };
112 } // namespace
113 
114 template <typename T, typename>
115 std::optional<SmallVector<int64_t>>
116 XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
117  Value value;
118  if constexpr (std::is_same_v<T, OpOperand>)
119  value = operandOrResult.get();
120  else
121  value = (Value)operandOrResult;
122 
123  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
124  if (layout && layout.isSgLayout()) {
125  if (auto inst_data = layout.getInstData())
126  return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
127 
128  if (auto type = dyn_cast<ShapedType>(value.getType()))
129  return llvm::to_vector(type.getShape());
130  }
131  LDBG("failed to getTileShape for: " << value);
132  return std::nullopt;
133 }
134 
135 std::optional<SmallVector<int64_t>>
137  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
138  xegpu::UpdateOffsetOp>(op))
139  return getTileShape(op->getOpResult(0));
140  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
141  xegpu::LoadGatherOp>(op))
142  return getTileShape(op->getOpOperand(0));
143  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
144  return getTileShape(op->getOpOperand(1));
145 
146  if (isa<xegpu::DpasOp>(op)) {
147  std::optional<SmallVector<int64_t>> aTile =
148  getTileShape(op->getOpOperand(0));
149  std::optional<SmallVector<int64_t>> bTile =
150  getTileShape(op->getOpOperand(1));
151 
152  if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
153  return std::nullopt;
154 
155  // semantic check for A and B
156  if ((*aTile)[1] != (*bTile)[0])
157  return std::nullopt;
158 
159  // semantic check for C
160  if (op->getNumOperands() == 3) {
161  std::optional<SmallVector<int64_t>> cTile =
162  getTileShape(op->getOpOperand(2));
163  int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
164  if (!cTile || !llvm::equal(*cTile, expectedCTile))
165  return std::nullopt;
166  }
167 
168  return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
169  }
170 
172  return getTileShape(op->getOpResult(0));
173 
174  if (isa<vector::MultiDimReductionOp>(op))
175  return getTileShape(op->getOpOperand(0));
176 
177  if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
178  return getTileShape(op->getOpResult(0));
179 
180  return std::nullopt;
181 }
182 
183 bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
184  // skip the op if any of its operands or results has workgroup level layouts
185  bool hasWgLayoutOperands =
186  llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
187  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
188  return layout && layout.isWgLayout();
189  });
190  bool hasWgLayoutResults =
191  llvm::any_of(op->getOpResults(), [](OpResult result) {
192  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
193  return layout && layout.isWgLayout();
194  });
195  if (hasWgLayoutOperands || hasWgLayoutResults) {
196  LDBG("skip unrolling for op with workgroup level layout: " << *op);
197  return false;
198  }
199 
200  auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
201  Type valTy = value.getType();
202  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
203  xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
204  return layout && layout.getInstData();
205  }
206  auto shapedType = dyn_cast<ShapedType>(valTy);
207  return shapedType && !llvm::equal(tileShape, shapedType.getShape());
208  };
209 
210  bool hasUnrollableOperands =
211  llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
212  std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
213  return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
214  });
215  bool hasUnrollableResults =
216  llvm::any_of(op->getOpResults(), [&](OpResult result) {
217  std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
218  return tileShape.has_value() && isUnrollable(result, *tileShape);
219  });
220  return hasUnrollableOperands || hasUnrollableResults;
221 }
222 
223 void XeGPUBlockingPass::runOnOperation() {
224  MLIRContext *ctx = &getContext();
225  Operation *op = getOperation();
226 
227  // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
228  // This ensures that the LayoutAttr remains accessible even if the defining
229  // operation is replaced.
230  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
231 
232  auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
233  xegpu::LayoutAttr layout) {
234  int count = 1;
235  SmallVector<int64_t> tileShape(shape);
236  if (layout && layout.getInstData()) {
237  DenseI32ArrayAttr instData = layout.getInstData();
238  tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
239  count = computeProduct(shape) / computeProduct(tileShape);
240  }
241  return std::make_pair(tileShape, count);
242  };
243 
244  // Perform type conversion for SCF control folow ops
245  TypeConverter converter;
246  converter.addConversion([](Type type) -> Type { return type; });
247  converter.addConversion(
248  [&](RankedTensorType type,
249  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
250  Type elemTy = type.getElementType();
251  ArrayRef<int64_t> shape = type.getShape();
252 
253  auto layout =
254  llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
255  if (layout && layout.isWgLayout())
256  return failure();
257 
258  int count;
259  SmallVector<int64_t> subShape;
260  std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
261  auto newTy = VectorType::get(subShape, elemTy);
262  result.append(count, newTy);
263  return success();
264  });
265  converter.addConversion(
266  [&](xegpu::TensorDescType type,
267  SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
268  Type elemTy = type.getElementType();
269  ArrayRef<int64_t> shape = type.getShape();
270 
271  xegpu::LayoutAttr layout = type.getLayoutAttr();
272  if (layout && layout.isWgLayout())
273  return failure();
274 
275  int count;
276  SmallVector<int64_t> subShape;
277  std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
278 
279  if (layout)
280  layout = layout.dropInstData();
281 
282  auto newTy = xegpu::TensorDescType::get(
283  type.getContext(), subShape, elemTy, type.getEncoding(), layout);
284  result.append(count, newTy);
285  return success();
286  });
287 
289 
291  options.setFilterConstraint(
292  [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
293 
294  options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
295 
296  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
297  Type elemTy = type.getElementType();
298  Type newTy;
299 
300  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
301 
302  Attribute encoding = tdescTy.getEncoding();
303  // If the encoding is a ScatterTensorDescAttr, we need to
304  // potentially adjust the chunk size based on the inst_data.
305  if (tdescTy.isScattered()) {
306  auto scatterAttr =
307  llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
308  int64_t chunkSize = scatterAttr.getChunkSize().getInt();
309 
310  if (chunkSize > 1) {
311  int64_t blockedChunkSize = chunkSize;
312  auto instData = tdescTy.getLayoutAttr().getInstData();
313  if (!instData.empty())
314  blockedChunkSize = instData.asArrayRef().back();
315 
316  // To create a new attribute with a different chunk_size:
317  auto newEncoding = xegpu::ScatterTensorDescAttr::get(
318  ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
319 
320  encoding = newEncoding;
321  }
322  }
323 
324  newTy =
325  xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
326  tdescTy.getLayoutAttr().dropInstData());
327  } else {
328  newTy = type.clone(tileShape, elemTy);
329  }
330 
331  std::optional<SmallVector<int64_t>> ratio =
332  computeShapeRatio(type.getShape(), tileShape);
333  assert(ratio && "The shape of the type must be a multiple of tileShape.");
334  return SmallVector<Type>(computeProduct(*ratio), newTy);
335  });
336 
338 
339  vector::UnrollVectorOptions vectorOptions;
340  vectorOptions.setNativeShapeFn(options.nativeShape);
341 
343  vector::populateVectorUnrollPatterns(patterns, vectorOptions);
344 
345  (void)applyPatternsGreedily(op, std::move(patterns));
346 
347  op->walk([](Operation *op) {
348  // Remove the layout attributes cached per operands.
349  for (OpOperand &opr : op->getOpOperands()) {
350  std::string name = xegpu::getLayoutName(opr);
351  if (op->hasAttrOfType<xegpu::LayoutAttr>(name))
352  op->removeAttr(name);
353  }
354 
355  // Update the layout attributes per result.
356  for (OpResult result : op->getOpResults()) {
357  std::string name = xegpu::getLayoutName(result);
358  if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
359  op->removeAttr(name);
360  if (!isa<LoopLikeOpInterface>(op))
361  xegpu::setLayoutAttr(result, layout.dropInstData());
362  }
363  }
364 
365  // Resolve unrealized conversion cast ops emulating pack/unpack
366  if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
367  resolveUnrealizedConversionCastOp(castOp);
368  });
369 }
static MLIRContext * getContext(OpFoldResult val)
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:39
static llvm::ManagedStatic< PassManagerOptions > options
#define LDBG(X)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:388
bool hasAttrOfType(NameT &&name)
Definition: Operation.h:575
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getOpResults()
Definition: Operation.h:420
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:600
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
Definition: XeGPUUtils.cpp:208
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
LayoutAttr getLayoutAttr(const Value value)
Retrieves the LayoutAttr associated with a given Value.
Definition: XeGPUUtils.cpp:115
void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout)
Sets the LayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictionary attri...
Definition: XeGPUUtils.cpp:156
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach LayoutAttr.
Definition: XeGPUUtils.cpp:104
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
Definition: XeGPUUtils.cpp:233
void setLayoutAttrs(Operation *op, function_ref< LayoutAttr(Value)> getLayoutImpl)
Set the LayoutAttr for each OpOperand and OpResult of the given operation.
Definition: XeGPUUtils.cpp:173
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
Definition: XeGPUUtils.cpp:188
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Options to control the XeGPU unrolling.
Definition: Transforms.h:27