MLIR 22.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21
22namespace mlir {
23namespace xegpu {
24#define GEN_PASS_DEF_XEGPUUNROLL
25#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
26} // namespace xegpu
27} // namespace mlir
28
29#define DEBUG_TYPE "xegpu-unroll"
30
31using namespace mlir;
32
33namespace {
34
35template <typename SourceOp>
36struct UnrollPattern : public OpRewritePattern<SourceOp> {
37 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
38 PatternBenefit benefit = 1)
39 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
40
41protected:
42 /// Return the target shape for the given `op`. Return std::nullopt if the
43 /// op shouldn't be or cannot be unrolled.
44 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
45 LDBG() << "Get unroll shape for: " << *op;
46
47 if (options.filterConstraint && failed(options.filterConstraint(op))) {
48 LDBG() << "--no filter constraint -> BAIL";
49 return std::nullopt;
50 }
51
52 assert(options.nativeShape &&
53 "expects the native shape for native shape call back function.");
54 auto nativeShape = options.nativeShape(op);
55 return nativeShape;
56 }
57
58 SmallVector<Type> getUnrolledTypes(ShapedType type,
59 ArrayRef<int64_t> tileShape,
60 bool returnSingleType = false) const {
61 return options.getUnrolledTypes(type, tileShape, returnSingleType);
62 }
63
64 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
65 /// values and unrealized_conversion_cast for TensorDescType values.
66 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
67 Location loc, PatternRewriter &rewriter) const {
68 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
69 auto shape = vecTy.getShape();
70 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
71 }
72
73 if (isa<xegpu::TensorDescType>(destTy)) {
74 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
75 rewriter.getUnitAttr());
76 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
77 rewriter.getDenseI64ArrayAttr(blockSize));
78 auto castOp = UnrealizedConversionCastOp::create(
79 rewriter, loc, destTy, srcs,
80 ArrayRef<NamedAttribute>({attr, blkAttr}));
81 return castOp.getResult(0);
82 }
83
84 llvm_unreachable("Unexpected destTy.");
85 return Value();
86 }
87
88 /// Emulate the the pack behavior using extract_strided_slice for VectorType
89 /// values and unrealized_conversion_cast for TensorDescType values.
90 SmallVector<Value> pack(Value src, TypeRange destTypes,
91 ArrayRef<int64_t> blockSize, Location loc,
92 PatternRewriter &rewriter) const {
93 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
94 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
95 blockSize);
96 }
97
98 if (isa<xegpu::TensorDescType>(src.getType())) {
99 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
100 rewriter.getUnitAttr());
101 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
102 rewriter.getDenseI64ArrayAttr(blockSize));
103 auto castOp = UnrealizedConversionCastOp::create(
104 rewriter, loc, destTypes, src,
105 ArrayRef<NamedAttribute>({attr, blkAttr}));
106 return castOp.getResults();
107 }
108
109 llvm_unreachable("Unexpected src type.");
110 return SmallVector<Value>();
111 }
112
113private:
114 const char *const packAttrName = "__xegpu_blocking_pack__";
115 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
116 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
117
119};
120
121// Generic helper function for unrolling operations with offsets.
122//
123// Iterates over tile offsets within the tensor descriptor shape and calls
124// the provided createOp function for each computed offset. This is used by
125// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
126// have explicit offsets that need to be adjusted for each unrolled tile.
127SmallVector<Value> computeUnrolledOffsets(
128 SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
129 ArrayRef<int64_t> targetShape,
130 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
131 Location loc, PatternRewriter &rewriter) {
132 int64_t rank = tdescTy.getRank();
133 ArrayRef<int64_t> shape = tdescTy.getShape();
134
135 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
136 std::optional<int64_t> maybeInt = getConstantIntValue(a);
137 if (maybeInt) {
138 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
139 } else {
140 auto aV = llvm::cast<Value>(a);
141 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
142 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
143 }
144 };
145
146 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
147 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
148 auto validIdxes =
149 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
150
151 SmallVector<Value> newOps;
152 for (SmallVector<int64_t> offsets :
153 StaticTileOffsetRange(shape, targetShape)) {
154
155 for (auto [idx, oldOff, offset] :
156 llvm::zip(validIdxes, oldOffsets, offsets))
157 mixedOffsets[idx] = addi(oldOff, offset);
158
159 auto newOp = createOp(mixedOffsets);
160 newOps.push_back(newOp);
161 }
162 return newOps;
163}
164
165struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
166 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
167 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
168 PatternRewriter &rewriter) const override {
169 Location loc = op.getLoc();
170 xegpu::TensorDescType tdescTy = op.getType();
171
172 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
173 if (!targetShape)
174 return failure();
175
176 SmallVector<Value> newOps;
177
178 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
179 bool hasOffsets = op.getMixedOffsets().size() != 0;
180 if (!hasOffsets) {
181 auto newOp = xegpu::CreateNdDescOp::create(
182 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
183 op.getMixedStrides());
184 newOps.push_back(newOp);
185 } else {
186 auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
187 return xegpu::CreateNdDescOp::create(
188 rewriter, loc, newTdescTy, op.getSource(), offsets,
189 op.getMixedSizes(), op.getMixedStrides());
190 };
191
192 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
193 *targetShape, createOp, loc, rewriter);
194 }
195 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
196 rewriter.replaceOp(op, castOp);
197
198 return success();
199 }
200};
201
202struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
203 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
204 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
205 PatternRewriter &rewriter) const override {
206 Location loc = op.getLoc();
207 xegpu::TensorDescType tdescTy = op.getTensorDescType();
208
209 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
210 if (!targetShape)
211 return failure();
212
213 SmallVector<Type> convertedTdescTypes =
214 getUnrolledTypes(tdescTy, *targetShape);
215 SmallVector<Value> convertedTdesc = pack(
216 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
217
218 SmallVector<Value> newOps;
219 for (auto t : convertedTdesc) {
220 auto newOp = xegpu::UpdateNdOffsetOp::create(
221 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
222 newOps.push_back(newOp);
223 }
224 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
225 rewriter.replaceOp(op, castOp);
226 return success();
227 }
228};
229
230struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
231 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
232 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
233 PatternRewriter &rewriter) const override {
234 Location loc = op.getLoc();
235 xegpu::TensorDescType tdescTy = op.getTensorDescType();
236
237 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
238 if (!targetShape)
239 return failure();
240
241 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
242 if (layout)
243 layout = layout.dropInstData();
244 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
245 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
246
247 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
248 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
249
250 SmallVector<Value> convertedTdesc = pack(
251 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
252
253 if (!hasOffsets) {
254 for (auto t : convertedTdesc)
255 xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
256 op->getAttrs());
257 } else {
258 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
259 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
260 op.getL1HintAttr(), op.getL2HintAttr(),
261 op.getL3HintAttr(), layout);
262 // return dummy Value to satisfy function's signature
263 return nullptr;
264 };
265
266 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
267 createPrefetch, loc, rewriter);
268 }
269
270 rewriter.eraseOp(op);
271 return success();
272 }
273};
274
275struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
276 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
277 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
278 PatternRewriter &rewriter) const override {
279
280 Location loc = op.getLoc();
281 VectorType valueTy = op.getType();
282 xegpu::TensorDescType tdescTy = op.getTensorDescType();
283
284 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
285 if (!targetShape)
286 return failure();
287
288 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
289 if (layout)
290 layout = layout.dropInstData();
291 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
292 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
293
294 Type elemTy = tdescTy.getElementType();
295 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
296
297 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
298 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
299
300 SmallVector<Value> convertedTdescs = pack(
301 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
302 SmallVector<Value> newOps;
303
304 if (!hasOffsets) {
305 for (auto t : convertedTdescs) {
306 auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
307 op->getAttrs());
308 newOps.push_back(newOp);
309 }
310 } else {
311 auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
312 return xegpu::LoadNdOp::create(
313 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
314 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
315 op.getL2HintAttr(), op.getL3HintAttr(), layout);
316 };
317 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
318 *targetShape, createLoad, loc, rewriter);
319 }
320
321 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
322
323 rewriter.replaceOp(op, castOp);
324 return success();
325 }
326};
327
328struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
329 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
330 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
331 PatternRewriter &rewriter) const override {
332 Location loc = op.getLoc();
333 VectorType valueTy = op.getValueType();
334 xegpu::TensorDescType tdescTy = op.getTensorDescType();
335
336 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
337 if (!targetShape)
338 return failure();
339
340 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
341 if (layout)
342 layout = layout.dropInstData();
343 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
344 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
345
346 SmallVector<Type> convertedValTypes =
347 getUnrolledTypes(valueTy, *targetShape);
348 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
349 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
350
351 SmallVector<Value> convertedTdescs = pack(
352 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
353
354 SmallVector<Value> convertedValues =
355 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
356 if (!hasOffsets) {
357 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
358 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
359 op.getL2HintAttr(), op.getL3HintAttr());
360 } else {
361 size_t valueIndex = 0;
362 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
363 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
364 convertedTdescs[0], offsets,
365 op.getL1HintAttr(), op.getL2HintAttr(),
366 op.getL3HintAttr(), layout);
367 // return dummy Value to satisfy function's signature
368 return nullptr;
369 };
370
371 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
372 createStore, loc, rewriter);
373 }
374
375 rewriter.eraseOp(op);
376 return success();
377 }
378};
379
380struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
381 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
382 LogicalResult matchAndRewrite(xegpu::DpasOp op,
383 PatternRewriter &rewriter) const override {
384 Location loc = op.getLoc();
385
386 // expecting every operands is a 2D Vector
387 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
388 auto vecTy = dyn_cast<VectorType>(type);
389 return !vecTy || vecTy.getRank() != 2;
390 }))
391 return failure();
392
393 // A vector of 3 elements should be returned, representing M, K, N
394 // respectively.
395 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
396 if (!targetShape || targetShape->size() != 3)
397 return failure();
398 auto M = (*targetShape)[0];
399 auto K = (*targetShape)[1];
400 auto N = (*targetShape)[2];
401
402 int64_t aBlockSize[2] = {M, K};
403 int64_t bBlockSize[2] = {K, N};
404 int64_t cBlockSize[2] = {M, N};
405
406 auto packWrapper = [&](TypedValue<VectorType> val,
407 ArrayRef<int64_t> blockSize) {
408 VectorType type = val.getType();
409 std::optional<SmallVector<int64_t>> grids =
410 computeShapeRatio(type.getShape(), blockSize);
411 assert(grids && "Expecting grids to be computed.");
412 auto numNewOps = computeProduct(*grids);
413 if (numNewOps == 1)
414 return SmallVector<Value>({val});
415 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
416 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
417 SmallVector<Value> values =
418 pack(val, convertedTypes, blockSize, loc, rewriter);
419 return values;
420 };
421
422 auto a = op.getLhs();
423 auto b = op.getRhs();
424 auto c = op.getAcc();
425
426 auto aShape = a.getType().getShape();
427 auto bShape = b.getType().getShape();
428
429 SmallVector<Value> aVals, bVals, cVals;
430 aVals = packWrapper(a, aBlockSize);
431 bVals = packWrapper(b, bBlockSize);
432
433 if (c)
434 cVals = packWrapper(c, cBlockSize);
435
436 // Skip the operation if every operand has an invalid blocking size (empty)
437 // or if the original shape matches the blocking size (size == 1).
438 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
439 : SmallVector<ValueRange>({aVals, bVals});
440 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
441 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
442 return failure();
443
444 VectorType resultTy = op.getResult().getType();
445 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
446
447 int64_t mIters = aShape[0] / M;
448 int64_t kIters = aShape[1] / K;
449 int64_t nIters = bShape[1] / N;
450
451 SmallVector<Value> newOps;
452 for (int64_t i = 0; i < mIters; ++i) {
453 for (int64_t j = 0; j < nIters; ++j) {
454 Value tmpC;
455 if (c)
456 tmpC = cVals[i * nIters + j]; // init with acc
457
458 for (int64_t k = 0; k < kIters; ++k) {
459 Value aVec = aVals[i * kIters + k];
460 Value bVec = bVals[k * nIters + j];
461 SmallVector<Value> operands({aVec, bVec});
462 if (tmpC)
463 operands.push_back(tmpC);
464
465 tmpC = xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
466 op->getAttrs());
467 }
468 newOps.push_back(tmpC);
469 }
470 }
471 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
472 rewriter.replaceOp(op, castOp);
473 return success();
474 }
475};
476
477struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
478 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
479 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
480 PatternRewriter &rewriter) const override {
481 Location loc = op.getLoc();
482 xegpu::TensorDescType tdescTy = op.getType();
483 TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
484 VectorType indiceVecTy = indiceVec.getType();
485
486 if (!tdescTy.isScattered())
487 return failure();
488
489 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
490 if (!targetShape)
491 return failure();
492
493 SmallVector<int64_t> targetIndiceShape(*targetShape);
494 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
495 // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
496 if (originalChunkSize > 1)
497 targetIndiceShape.pop_back();
498
499 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
500 SmallVector<Type> convertedIndiceTypes =
501 getUnrolledTypes(indiceVecTy, targetIndiceShape);
502 SmallVector<Value> convertedIndiceVec =
503 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
504
505 SmallVector<Value> newOps;
506
507 // More indices is need when chunkSize > 1. Since a big load from one
508 // address could be break into multiple small loads.
509 if (originalChunkSize > 1) {
510 int64_t blockedChunkSize = targetShape->back();
511 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
512
513 for (auto [indice, indiceType] :
514 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
515 for (int64_t i = 0; i < numNewChunks; ++i) {
516 // Compute the offset
517 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
518 i * blockedChunkSize);
519 Value incVec =
520 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
521 Value offsetIndice =
522 arith::AddIOp::create(rewriter, loc, indice, incVec);
523
524 auto newOp = xegpu::CreateDescOp::create(
525 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
526
527 newOps.push_back(newOp);
528 }
529 }
530 } else {
531 for (auto indice : convertedIndiceVec) {
532 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
533 op.getSource(), indice);
534 newOps.push_back(newOp);
535 }
536 }
537
538 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
539 rewriter.replaceOp(op, castOp);
540
541 return success();
542 }
543};
544
545struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
546 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
547 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
548 PatternRewriter &rewriter) const override {
549
550 Location loc = op.getLoc();
551 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
552 xegpu::TensorDescType tdescTy = op.getTensorDescType();
553
554 // TODO: handle the unstructure source case (!tdesTy)
555 if (!tdescTy || op.getOffsets())
556 return failure();
557
558 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
559 if (!targetShape)
560 return failure();
561
562 SmallVector<int64_t> targetMaskShape(*targetShape);
563 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
564
565 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
566
567 Type elemTy = tdescTy.getElementType();
568 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
569
570 SmallVector<Type> convertedTdescTypes =
571 getUnrolledTypes(tdescTy, *targetShape);
572 SmallVector<Value> convertedTdescs = pack(
573 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
574
575 SmallVector<Type> convertedMaskTypes;
576 SmallVector<Value> convertedMasks;
577
578 if (originalChunkSize > 1) {
579 targetMaskShape.pop_back();
580 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
581 int64_t blockedChunkSize = targetShape->back();
582 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
583
584 // the mask is reused across the chunk_size dimension
585 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
586 loc, rewriter))
587 convertedMasks.append(numNewChunks, mask);
588
589 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
590 } else {
591 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
592 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
593 loc, rewriter);
594 }
595
596 SmallVector<Value> newOps;
597 for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
598 auto newOp = xegpu::LoadGatherOp::create(
599 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
600 op.getL2HintAttr(), op.getL3HintAttr());
601 newOps.push_back(newOp);
602 }
603
604 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
605 rewriter.replaceOp(op, castOp);
606 return success();
607 }
608};
609
610/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
611/// load).
612/// It unrolls the offsets and mask operands accordingly, and creates multiple
613/// LoadGatherOp with the unrolled operands.
614struct UnrollLoadGatherOpWithOffset
615 : public UnrollPattern<xegpu::LoadGatherOp> {
616 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
617 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
618 PatternRewriter &rewriter) const override {
619 Location loc = op.getLoc();
620 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
621 Value offsets = op.getOffsets();
622 Value mask = op.getMask();
623
624 // Only handle the case where offsets are present (scattered load)
625 if (!offsets)
626 return failure();
627
628 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
629 if (!targetShape)
630 return failure();
631
632 SmallVector<int64_t> targetMaskShape(*targetShape);
633 int64_t chunkSize = 1;
634 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
635 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
636 chunkSize = intAttr.getInt();
637 }
638
639 // Unroll mask and offsets with correct shape
640 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
641 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
642 Type elemTy = valueTy.getElementType();
643 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
644
645 SmallVector<Type> convertedMaskTypes;
646 SmallVector<Value> convertedMasks;
647 SmallVector<Type> convertedOffsetTypes;
648 SmallVector<Value> convertedOffsets;
649
650 if (chunkSize > 1) {
651 // For chunked loads, mask and offsets have one less dimension
652 targetMaskShape.pop_back();
653 int64_t blockedChunkSize = targetShape->back();
654 int64_t numNewChunks = chunkSize / blockedChunkSize;
655 chunkSize = blockedChunkSize;
656
657 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
658 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
659
660 SmallVector<Value> convertedMasksBase =
661 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
662 SmallVector<Value> convertedOffsetsBase =
663 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
664
665 for (auto maskVal : convertedMasksBase)
666 convertedMasks.append(numNewChunks, maskVal);
667
668 for (auto [baseOffset, offsetType] :
669 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
670 for (int64_t i = 0; i < numNewChunks; ++i) {
671 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
672 i * blockedChunkSize);
673 Value incVec =
674 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
675 Value offsetVal =
676 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
677 convertedOffsets.push_back(offsetVal);
678 }
679 }
680 } else {
681 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
682 convertedMasks =
683 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
684
685 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
686 convertedOffsets =
687 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
688 }
689
690 auto layout = op.getLayoutAttr();
691 if (layout)
692 layout = layout.dropInstData();
693
694 SmallVector<Value> newOps;
695 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
696 auto newOp = xegpu::LoadGatherOp::create(
697 rewriter, loc, newValueTy, op.getSource(), o, m,
698 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
699 op.getL2HintAttr(), op.getL3HintAttr(), layout);
700 newOps.push_back(newOp);
701 }
702
703 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
704 rewriter.replaceOp(op, castOp);
705 return success();
706 }
707};
708
709/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
710/// store).
711/// It unrolls the offsets and mask operands accordingly, and creates multiple
712/// StoreScatterOp with the unrolled operands.
713struct UnrollStoreScatterOpWithOffsets
714 : public UnrollPattern<xegpu::StoreScatterOp> {
715 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
716 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
717 PatternRewriter &rewriter) const override {
718 Location loc = op.getLoc();
719 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
720 Value offsets = op.getOffsets();
721 Value mask = op.getMask();
722
723 // Only handle the case where offsets are present (scattered store)
724 if (!offsets)
725 return failure();
726
727 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
728 if (!targetShape)
729 return failure();
730
731 int64_t chunkSize = 1;
732 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
733 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
734 chunkSize = intAttr.getInt();
735 }
736
737 SmallVector<int64_t> targetMaskShape(*targetShape);
738 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
739 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
740
741 SmallVector<Type> convertedMaskTypes;
742 SmallVector<Value> convertedMasks;
743 SmallVector<Type> convertedOffsetTypes;
744 SmallVector<Value> convertedOffsets;
745
746 if (chunkSize > 1) {
747 targetMaskShape.pop_back();
748 int64_t blockedChunkSize = targetShape->back();
749 int64_t numNewChunks = chunkSize / blockedChunkSize;
750 chunkSize = blockedChunkSize;
751
752 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
753 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
754
755 SmallVector<Value> convertedMasksBase =
756 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
757 SmallVector<Value> convertedOffsetsBase =
758 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
759
760 for (auto maskVal : convertedMasksBase)
761 convertedMasks.append(numNewChunks, maskVal);
762
763 for (auto [baseOffset, offsetType] :
764 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
765 for (int64_t i = 0; i < numNewChunks; ++i) {
766 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
767 i * blockedChunkSize);
768 Value incVec =
769 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
770 Value offsetVal =
771 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
772 convertedOffsets.push_back(offsetVal);
773 }
774 }
775 } else {
776 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
777 convertedMasks =
778 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
779
780 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
781 convertedOffsets =
782 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
783 }
784
785 SmallVector<Type> convertedValTypes =
786 getUnrolledTypes(valueTy, *targetShape);
787 SmallVector<Value> convertedValues =
788 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
789
790 auto layout = op.getLayoutAttr();
791 if (layout)
792 layout = layout.dropInstData();
793
794 for (auto [v, o, m] :
795 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
796 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
797 rewriter.getI64IntegerAttr(chunkSize),
798 op.getL1HintAttr(), op.getL2HintAttr(),
799 op.getL3HintAttr(), layout);
800 }
801
802 rewriter.eraseOp(op);
803 return success();
804 }
805};
806
807struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
808 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
809 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
810 PatternRewriter &rewriter) const override {
811 Location loc = op.getLoc();
812 xegpu::TensorDescType tdescTy = op.getTensorDescType();
813
814 // TODO: handle the unstructure source case (!tdesTy)
815 if (!tdescTy || op.getOffsets())
816 return failure();
817
818 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
819 if (!targetShape)
820 return failure();
821
822 SmallVector<Type> convertedTdescTypes =
823 getUnrolledTypes(tdescTy, *targetShape);
824 SmallVector<Value> convertedTdesc = pack(
825 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
826
827 for (auto t : convertedTdesc)
828 xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t, op->getAttrs());
829
830 rewriter.eraseOp(op);
831 return success();
832 }
833};
834
835struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
836 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
837 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
838 PatternRewriter &rewriter) const override {
839
840 Location loc = op.getLoc();
841 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
842 xegpu::TensorDescType tdescTy = op.getTensorDescType();
843
844 // TODO: handle the unstructure source case (!tdesTy)
845 if (!tdescTy || op.getOffsets())
846 return failure();
847
848 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
849 if (!targetShape)
850 return failure();
851
852 SmallVector<int64_t> targetMaskShape(*targetShape);
853 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
854
855 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
856
857 SmallVector<Type> convertedTdescTypes =
858 getUnrolledTypes(tdescTy, *targetShape);
859 SmallVector<Value> convertedTdescs = pack(
860 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
861
862 SmallVector<Type> convertedMaskTypes;
863 SmallVector<Value> convertedMasks;
864
865 if (originalChunkSize > 1) {
866 targetMaskShape.pop_back();
867 int64_t blockedChunkSize = targetShape->back();
868 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
869 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
870
871 // the mask is reused across the chunk_size dimension
872 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
873 loc, rewriter))
874 convertedMasks.append(numNewChunks, mask);
875 } else {
876 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
877 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
878 loc, rewriter);
879 }
880
881 SmallVector<Type> convertedValTypes =
882 getUnrolledTypes(valueTy, *targetShape);
883 SmallVector<Value> convertedValues =
884 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
885
886 for (size_t i = 0; i < convertedValues.size(); ++i) {
887 Value v = convertedValues[i];
888 Value t = convertedTdescs[i];
889 Value m = op.getMask() ? convertedMasks[i] : nullptr;
890 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
891 op.getL2HintAttr(), op.getL3HintAttr());
892 }
893
894 rewriter.eraseOp(op);
895 return success();
896 }
897};
898
899struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
900 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
901 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
902 PatternRewriter &rewriter) const override {
903 Location loc = op.getLoc();
904 xegpu::TensorDescType tdescTy = op.getTensorDescType();
905
906 if (!tdescTy.isScattered())
907 return failure();
908
909 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
910 if (!targetShape)
911 return failure();
912
913 SmallVector<Type> convertedTdescTypes =
914 getUnrolledTypes(tdescTy, *targetShape);
915 SmallVector<Value> convertedTdesc = pack(
916 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
917
918 TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
919 VectorType offsetVecTy = offsetVec.getType();
920 SmallVector<Type> convertedOffsetTypes;
921 SmallVector<Value> convertedOffsetVec;
922 SmallVector<Value> newOps;
923 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
924 if (originalChunkSize > 1) {
925 auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
926 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
927
928 int64_t blockedChunkSize = targetShape->back();
929 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
930 // the offset is reused across the chunk_size dimension
931 for (auto offset : pack(offsetVec, convertedOffsetTypes,
932 targetOffsetShape, loc, rewriter))
933 convertedOffsetVec.append(numNewChunks, offset);
934
935 } else {
936 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
937 convertedOffsetVec =
938 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
939 }
940
941 for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
942 auto newOp =
943 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
944 newOps.push_back(newOp);
945 }
946 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
947 rewriter.replaceOp(op, castOp);
948 return success();
949 }
950};
951
952struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
953 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
954 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
955 PatternRewriter &rewriter) const override {
956 Location loc = op.getLoc();
957 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
958 assert(valueTy && "the value type must be vector type!");
959
960 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
961 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
962 return failure();
963
964 Type elemTy = valueTy.getElementType();
965 ArrayRef<int64_t> shape = valueTy.getShape();
966 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
967
968 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
969
970 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
972 for (SmallVector<int64_t> offsets :
973 StaticTileOffsetRange(shape, *targetShape)) {
974 auto adds = xegpu::addElementwise(
975 rewriter, loc, mixedOffsets,
976 getAsIndexOpFoldResult(op.getContext(), offsets));
977 offsetsList.push_back(adds);
978 }
979
980 SmallVector<Value> newOps;
981 layout = layout.dropInstData();
982 for (SmallVector<OpFoldResult> offsets : offsetsList) {
983 auto newOp = xegpu::LoadMatrixOp::create(
984 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
985 newOps.push_back(newOp);
986 }
987 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
988 rewriter.replaceOp(op, castOp);
989 return success();
990 }
991};
992
993struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
994 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
995 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
996 PatternRewriter &rewriter) const override {
997 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
998 if (!targetShape)
999 return failure();
1000
1001 Location loc = op.getLoc();
1002 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
1003 assert(valueTy && "the value type must be vector type!");
1004 ArrayRef<int64_t> shape = valueTy.getShape();
1005 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
1006
1007 SmallVector<Type> convertedValTypes =
1008 getUnrolledTypes(valueTy, *targetShape);
1009 SmallVector<Value> convertedValues =
1010 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1011
1012 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
1014 for (SmallVector<int64_t> offsets :
1015 StaticTileOffsetRange(shape, *targetShape)) {
1016 auto adds = xegpu::addElementwise(
1017 rewriter, loc, mixedOffsets,
1018 getAsIndexOpFoldResult(op.getContext(), offsets));
1019 offsetsList.push_back(adds);
1020 }
1021
1022 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1023 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1024 layout.dropInstData());
1025
1026 rewriter.eraseOp(op);
1027 return success();
1028 }
1029};
1030
1031} // namespace
1032
1035 patterns
1036 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1037 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1038 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1039 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1040 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
1041 patterns.getContext(), options);
1042}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:27
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.