MLIR  21.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1 //===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for unrolling XeGPU operations. It follows a
10 // similar concept and design as vector unroll patterns, serving as a complement
11 // to them.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include <numeric>
24 
25 namespace mlir {
26 namespace xegpu {
27 #define GEN_PASS_DEF_XEGPUUNROLL
28 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29 } // namespace xegpu
30 } // namespace mlir
31 
32 #define DEBUG_TYPE "xegpu-unroll"
33 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 
36 using namespace mlir;
37 
38 namespace {
39 
40 template <typename SourceOp>
41 struct UnrollPattern : public OpRewritePattern<SourceOp> {
42  UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
43  PatternBenefit benefit = 1)
44  : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
45 
46 protected:
47  /// Return the target shape for the given `op`. Return std::nullopt if the
48  /// op shouldn't be or cannot be unrolled.
49  std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
50  LDBG("");
51  LDBG("Get unroll shape for: " << *op);
52 
53  if (options.filterConstraint && failed(options.filterConstraint(op))) {
54  LDBG("--no filter constraint -> BAIL");
55  return std::nullopt;
56  }
57 
58  assert(options.nativeShape &&
59  "expects the native shape for native shape call back function.");
60  auto nativeShape = options.nativeShape(op);
61  return nativeShape;
62  }
63 
64  SmallVector<Type> getUnrolledTypes(ShapedType type,
65  ArrayRef<int64_t> tileShape) const {
66  return options.getUnrolledTypes(type, tileShape);
67  }
68 
69  /// Emulate the the unpack behavior using insert_strided_slice for VectorType
70  /// values and unrealized_conversion_cast for TensorDescType values.
71  Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
72  Location loc, PatternRewriter &rewriter) const {
73  if (auto vecTy = dyn_cast<VectorType>(destTy)) {
74  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
75  "Expecting blockSize size to match the rank of destTy.");
76  auto shape = vecTy.getShape();
77  auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
78 
79  Value result = rewriter.create<arith::ConstantOp>(
80  loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
81  for (auto [src, offsets] :
82  llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
83  SmallVector<int64_t> staticStrides(offsets.size(), 1);
84  result = rewriter.create<vector::InsertStridedSliceOp>(
85  loc, src, result, offsets, staticStrides);
86  }
87  return result;
88  }
89 
90  if (isa<xegpu::TensorDescType>(destTy)) {
91  auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
92  rewriter.getUnitAttr());
93  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
94  rewriter.getDenseI64ArrayAttr(blockSize));
95  auto castOp = rewriter.create<UnrealizedConversionCastOp>(
96  loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
97  return castOp.getResult(0);
98  }
99 
100  llvm_unreachable("Unexpected destTy.");
101  return Value();
102  }
103 
104  /// Emulate the the pack behavior using extract_strided_slice for VectorType
105  /// values and unrealized_conversion_cast for TensorDescType values.
106  SmallVector<Value> pack(Value src, TypeRange destTypes,
107  ArrayRef<int64_t> blockSize, Location loc,
108  PatternRewriter &rewriter) const {
109  if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
110  assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
111  "Expecting blockSize size to match the rank of src.");
112  auto shape = vecTy.getShape();
113  SmallVector<Value> results;
114  for (SmallVector<int64_t> offsets :
115  StaticTileOffsetRange(shape, blockSize)) {
116  SmallVector<int64_t> staticStrides(offsets.size(), 1);
117  auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
118  loc, src, offsets, blockSize, staticStrides);
119  results.push_back(slice);
120  }
121  return results;
122  }
123 
124  if (isa<xegpu::TensorDescType>(src.getType())) {
125  auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
126  rewriter.getUnitAttr());
127  auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
128  rewriter.getDenseI64ArrayAttr(blockSize));
129  auto castOp = rewriter.create<UnrealizedConversionCastOp>(
130  loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
131  return castOp.getResults();
132  }
133 
134  llvm_unreachable("Unexpected src type.");
135  return SmallVector<Value>();
136  }
137 
138 private:
139  const char *const packAttrName = "__xegpu_blocking_pack__";
140  const char *const unpackAttrName = "__xegpu_blocking_unpack__";
141  const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
142 
144 };
145 
146 struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
147  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
148  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
149  PatternRewriter &rewriter) const override {
150  Location loc = op.getLoc();
151  xegpu::TensorDescType tdescTy = op.getType();
152  int64_t rank = tdescTy.getRank();
153  ArrayRef<int64_t> shape = tdescTy.getShape();
154 
155  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
156  if (!targetShape || llvm::equal(*targetShape, shape))
157  return failure();
158 
159  auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
160 
161  auto addi = [&](OpFoldResult a, int64_t b) -> Value {
162  std::optional<int64_t> maybeInt = getConstantIntValue(a);
163  if (maybeInt) {
164  return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
165  } else {
166  auto aV = llvm::cast<Value>(a);
167  auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
168  return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
169  }
170  };
171 
172  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
173 
174  // For n-D memrefs where n > rank, we need to handle the last `rank`
175  // dimensions only, and keep the first `n-rank` dimensions as is.
176  SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
177  llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
178  auto validIdxes =
179  llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
180 
181  SmallVector<Value> newOps;
182  for (SmallVector<int64_t> offsets :
183  StaticTileOffsetRange(shape, *targetShape)) {
184 
185  for (auto [idx, oldOff, offset] :
186  llvm::zip(validIdxes, oldOffsets, offsets))
187  mixedOffsets[idx] = addi(oldOff, offset);
188 
189  auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
190  loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
191  op.getMixedStrides());
192  newOps.push_back(newOp);
193  }
194  Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
195  rewriter.replaceOp(op, castOp);
196 
197  return success();
198  }
199 };
200 
201 struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
202  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
203  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
204  PatternRewriter &rewriter) const override {
205  Location loc = op.getLoc();
206  xegpu::TensorDescType tdescTy = op.getTensorDescType();
207  ArrayRef<int64_t> shape = tdescTy.getShape();
208 
209  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
210  if (!targetShape || llvm::equal(*targetShape, shape))
211  return failure();
212 
213  SmallVector<Type> convertedTdescTypes =
214  getUnrolledTypes(tdescTy, *targetShape);
215  SmallVector<Value> convertedTdesc = pack(
216  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
217 
218  SmallVector<Value> newOps;
219  for (auto t : convertedTdesc) {
220  auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
221  loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222  newOps.push_back(newOp);
223  }
224  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
225  rewriter.replaceOp(op, castOp);
226  return success();
227  }
228 };
229 
230 struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
231  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
233  PatternRewriter &rewriter) const override {
234  Location loc = op.getLoc();
235  xegpu::TensorDescType tdescTy = op.getTensorDescType();
236  ArrayRef<int64_t> shape = tdescTy.getShape();
237 
238  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
239  if (!targetShape || llvm::equal(*targetShape, shape))
240  return failure();
241 
242  SmallVector<Type> convertedTdescTypes =
243  getUnrolledTypes(tdescTy, *targetShape);
244  SmallVector<Value> convertedTdesc = pack(
245  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
246 
247  for (auto t : convertedTdesc)
248  rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
249 
250  rewriter.eraseOp(op);
251  return success();
252  }
253 };
254 
255 struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
256  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
257  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
258  PatternRewriter &rewriter) const override {
259 
260  Location loc = op.getLoc();
261  VectorType valueTy = op.getType();
262  xegpu::TensorDescType tdescTy = op.getTensorDescType();
263  ArrayRef<int64_t> shape = tdescTy.getShape();
264 
265  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
266  if (!targetShape || llvm::equal(*targetShape, shape))
267  return failure();
268 
269  Type elemTy = tdescTy.getElementType();
270  VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
271 
272  SmallVector<Type> convertedTdescTypes =
273  getUnrolledTypes(tdescTy, *targetShape);
274  SmallVector<Value> convertedTdescs = pack(
275  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
276 
277  SmallVector<Value> newOps;
278  for (auto t : convertedTdescs) {
279  auto newOp =
280  rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
281  newOps.push_back(newOp);
282  }
283 
284  Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
285 
286  rewriter.replaceOp(op, castOp);
287  return success();
288  }
289 };
290 
291 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
292  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
293  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
294  PatternRewriter &rewriter) const override {
295  Location loc = op.getLoc();
296  VectorType valueTy = op.getValueType();
297  xegpu::TensorDescType tdescTy = op.getTensorDescType();
298  ArrayRef<int64_t> shape = tdescTy.getShape();
299 
300  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
301  if (!targetShape || llvm::equal(*targetShape, shape))
302  return failure();
303 
304  SmallVector<Type> convertedValTypes =
305  getUnrolledTypes(valueTy, *targetShape);
306  SmallVector<Type> convertedTdescTypes =
307  getUnrolledTypes(tdescTy, *targetShape);
308 
309  SmallVector<Value> convertedValues =
310  pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
311  SmallVector<Value> convertedTdescs = pack(
312  op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
313 
314  for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
315  rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
316  op.getL2HintAttr(), op.getL3HintAttr());
317 
318  rewriter.eraseOp(op);
319  return success();
320  }
321 };
322 
323 struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
324  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
325  LogicalResult matchAndRewrite(xegpu::DpasOp op,
326  PatternRewriter &rewriter) const override {
327  Location loc = op.getLoc();
328 
329  // expecting every operands is a 2D Vector
330  if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
331  auto vecTy = dyn_cast<VectorType>(type);
332  return !vecTy || vecTy.getRank() != 2;
333  }))
334  return failure();
335 
336  // A vector of 3 elements should be returned, representing M, K, N
337  // respectively.
338  std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
339  if (!targetShape || targetShape->size() != 3)
340  return failure();
341  auto M = (*targetShape)[0];
342  auto K = (*targetShape)[1];
343  auto N = (*targetShape)[2];
344 
345  int64_t aBlockSize[2] = {M, K};
346  int64_t bBlockSize[2] = {K, N};
347  int64_t cBlockSize[2] = {M, N};
348 
349  auto packWrapper = [&](TypedValue<VectorType> val,
350  ArrayRef<int64_t> blockSize) {
351  VectorType type = val.getType();
352  std::optional<SmallVector<int64_t>> grids =
353  computeShapeRatio(type.getShape(), blockSize);
354  assert(grids && "Expecting grids to be computed.");
355  auto numNewOps = computeProduct(*grids);
356  if (numNewOps == 1)
357  return SmallVector<Value>({val});
358  VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
359  SmallVector<Type> convertedTypes(numNewOps, newVecTy);
360  SmallVector<Value> values =
361  pack(val, convertedTypes, blockSize, loc, rewriter);
362  return values;
363  };
364 
365  auto a = op.getLhs();
366  auto b = op.getRhs();
367  auto c = op.getAcc();
368 
369  auto aShape = a.getType().getShape();
370  auto bShape = b.getType().getShape();
371 
372  SmallVector<Value> aVals, bVals, cVals;
373  aVals = packWrapper(a, aBlockSize);
374  bVals = packWrapper(b, bBlockSize);
375 
376  if (c)
377  cVals = packWrapper(c, cBlockSize);
378 
379  // Skip the operation if every operand has an invalid blocking size (empty)
380  // or if the original shape matches the blocking size (size == 1).
381  auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
382  : SmallVector<ValueRange>({aVals, bVals});
383  if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
384  llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
385  return failure();
386 
387  VectorType resultTy = op.getResult().getType();
388  auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
389 
390  int64_t mIters = aShape[0] / M;
391  int64_t kIters = aShape[1] / K;
392  int64_t nIters = bShape[1] / N;
393 
394  SmallVector<Value> newOps;
395  for (int64_t i = 0; i < mIters; ++i) {
396  for (int64_t j = 0; j < nIters; ++j) {
397  Value tmpC;
398  if (c)
399  tmpC = cVals[i * nIters + j]; // init with acc
400 
401  for (int64_t k = 0; k < kIters; ++k) {
402  Value aVec = aVals[i * kIters + k];
403  Value bVec = bVals[k * nIters + j];
404  SmallVector<Value> operands({aVec, bVec});
405  if (tmpC)
406  operands.push_back(tmpC);
407 
408  tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
409  op->getAttrs());
410  }
411  newOps.push_back(tmpC);
412  }
413  }
414  Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
415  rewriter.replaceOp(op, castOp);
416  return success();
417  }
418 };
419 
420 } // namespace
421 
424  patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
425  UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
426  patterns.getContext(), options);
427 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
#define LDBG(X)
Definition: XeGPUUnroll.cpp:34
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:482
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options to control the XeGPU unrolling.
Definition: Transforms.h:27
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.