MLIR 23.0.0git
XeGPUUnroll.cpp
Go to the documentation of this file.
1//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for unrolling XeGPU operations. It follows a
10// similar concept and design as vector unroll patterns, serving as a complement
11// to them.
12//
13//===----------------------------------------------------------------------===//
14
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/DebugLog.h"
22
23namespace mlir {
24namespace xegpu {
25#define GEN_PASS_DEF_XEGPUUNROLL
26#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
27} // namespace xegpu
28} // namespace mlir
29
30#define DEBUG_TYPE "xegpu-unroll"
31
32using namespace mlir;
33
34namespace {
35
36template <typename SourceOp>
37struct UnrollPattern : public OpRewritePattern<SourceOp> {
38 UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
39 PatternBenefit benefit = 1)
40 : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
41
42protected:
43 /// Return the target shape for the given `op`. Return std::nullopt if the
44 /// op shouldn't be or cannot be unrolled.
45 std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
46 LDBG() << "Get unroll shape for: " << *op;
47
48 if (options.filterConstraint && failed(options.filterConstraint(op))) {
49 LDBG() << "--no filter constraint -> BAIL";
50 return std::nullopt;
51 }
52
53 assert(options.nativeShape &&
54 "expects the native shape for native shape call back function.");
55 auto nativeShape = options.nativeShape(op);
56 return nativeShape;
57 }
58
59 SmallVector<Type> getUnrolledTypes(ShapedType type,
60 ArrayRef<int64_t> tileShape,
61 bool returnSingleType = false) const {
62 return options.getUnrolledTypes(type, tileShape, returnSingleType);
63 }
64
65 /// Emulate the the unpack behavior using insert_strided_slice for VectorType
66 /// values and unrealized_conversion_cast for TensorDescType values.
67 Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
68 Location loc, PatternRewriter &rewriter) const {
69 if (auto vecTy = dyn_cast<VectorType>(destTy)) {
70 auto shape = vecTy.getShape();
71 return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
72 }
73
74 if (isa<xegpu::TensorDescType>(destTy)) {
75 auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
76 rewriter.getUnitAttr());
77 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
78 rewriter.getDenseI64ArrayAttr(blockSize));
79 auto castOp = UnrealizedConversionCastOp::create(
80 rewriter, loc, destTy, srcs,
81 ArrayRef<NamedAttribute>({attr, blkAttr}));
82 return castOp.getResult(0);
83 }
84
85 llvm_unreachable("Unexpected destTy.");
86 return Value();
87 }
88
89 /// Emulate the the pack behavior using extract_strided_slice for VectorType
90 /// values and unrealized_conversion_cast for TensorDescType values.
91 SmallVector<Value> pack(Value src, TypeRange destTypes,
92 ArrayRef<int64_t> blockSize, Location loc,
93 PatternRewriter &rewriter) const {
94 if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
95 return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
96 blockSize);
97 }
98
99 if (isa<xegpu::TensorDescType>(src.getType())) {
100 auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
101 rewriter.getUnitAttr());
102 auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
103 rewriter.getDenseI64ArrayAttr(blockSize));
104 auto castOp = UnrealizedConversionCastOp::create(
105 rewriter, loc, destTypes, src,
106 ArrayRef<NamedAttribute>({attr, blkAttr}));
107 return castOp.getResults();
108 }
109
110 llvm_unreachable("Unexpected src type.");
111 return SmallVector<Value>();
112 }
113
114private:
115 const char *const packAttrName = "__xegpu_blocking_pack__";
116 const char *const unpackAttrName = "__xegpu_blocking_unpack__";
117 const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
118
120};
121
122// Generic helper function for unrolling operations with offsets.
123//
124// Iterates over tile offsets within the tensor descriptor shape and calls
125// the provided createOp function for each computed offset. This is used by
126// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
127// have explicit offsets that need to be adjusted for each unrolled tile.
128SmallVector<Value> computeUnrolledOffsets(
129 SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
130 ArrayRef<int64_t> targetShape,
131 const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
132 Location loc, PatternRewriter &rewriter) {
133 int64_t rank = tdescTy.getRank();
134 ArrayRef<int64_t> shape = tdescTy.getShape();
135
136 auto addi = [&](OpFoldResult a, int64_t b) -> Value {
137 std::optional<int64_t> maybeInt = getConstantIntValue(a);
138 if (maybeInt) {
139 return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
140 } else {
141 auto aV = llvm::cast<Value>(a);
142 auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
143 return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
144 }
145 };
146
147 SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
148 llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
149 auto validIdxes =
150 llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
151
152 SmallVector<Value> newOps;
153 for (SmallVector<int64_t> offsets :
154 StaticTileOffsetRange(shape, targetShape)) {
155
156 for (auto [idx, oldOff, offset] :
157 llvm::zip(validIdxes, oldOffsets, offsets))
158 mixedOffsets[idx] = addi(oldOff, offset);
159
160 auto newOp = createOp(mixedOffsets);
161 newOps.push_back(newOp);
162 }
163 return newOps;
164}
165
166struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
167 using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
168 LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
169 PatternRewriter &rewriter) const override {
170 Location loc = op.getLoc();
171 xegpu::TensorDescType tdescTy = op.getType();
172
173 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
174 if (!targetShape)
175 return failure();
176
177 SmallVector<Value> newOps;
178
179 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
180 bool hasOffsets = op.getMixedOffsets().size() != 0;
181 if (!hasOffsets) {
182 auto newOp = xegpu::CreateNdDescOp::create(
183 rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
184 op.getMixedStrides());
185 newOps.push_back(newOp);
186 } else {
187 auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
188 return xegpu::CreateNdDescOp::create(
189 rewriter, loc, newTdescTy, op.getSource(), offsets,
190 op.getMixedSizes(), op.getMixedStrides());
191 };
192
193 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
194 *targetShape, createOp, loc, rewriter);
195 }
196 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
197 rewriter.replaceOp(op, castOp);
198
199 return success();
200 }
201};
202
203struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
204 using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
205 LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
206 PatternRewriter &rewriter) const override {
207 Location loc = op.getLoc();
208 xegpu::TensorDescType tdescTy = op.getTensorDescType();
209
210 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
211 if (!targetShape)
212 return failure();
213
214 SmallVector<Type> convertedTdescTypes =
215 getUnrolledTypes(tdescTy, *targetShape);
216 SmallVector<Value> convertedTdesc = pack(
217 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
218
219 SmallVector<Value> newOps;
220 for (auto t : convertedTdesc) {
221 auto newOp = xegpu::UpdateNdOffsetOp::create(
222 rewriter, loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
223 newOps.push_back(newOp);
224 }
225 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
226 rewriter.replaceOp(op, castOp);
227 return success();
228 }
229};
230
231struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
232 using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
233 LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
234 PatternRewriter &rewriter) const override {
235 Location loc = op.getLoc();
236 xegpu::TensorDescType tdescTy = op.getTensorDescType();
237
238 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
239 if (!targetShape)
240 return failure();
241
242 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
243 if (layout)
244 layout = layout.dropInstData();
245 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
246 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
247
248 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
249 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
250
251 SmallVector<Value> convertedTdesc = pack(
252 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
253
254 if (!hasOffsets) {
255 for (auto t : convertedTdesc)
256 xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
257 xegpu::dropInstDataOnAttrs(op->getAttrs()));
258 } else {
259 auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
260 xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
261 op.getL1HintAttr(), op.getL2HintAttr(),
262 op.getL3HintAttr(), layout);
263 // return dummy Value to satisfy function's signature
264 return nullptr;
265 };
266
267 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
268 createPrefetch, loc, rewriter);
269 }
270
271 rewriter.eraseOp(op);
272 return success();
273 }
274};
275
276struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
277 using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
278 LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
279 PatternRewriter &rewriter) const override {
280
281 Location loc = op.getLoc();
282 VectorType valueTy = op.getType();
283 xegpu::TensorDescType tdescTy = op.getTensorDescType();
284
285 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
286 if (!targetShape)
287 return failure();
288
289 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
290 if (layout)
291 layout = layout.dropInstData();
292 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
293 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
294
295 Type elemTy = tdescTy.getElementType();
296 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
297
298 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
299 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
300
301 SmallVector<Value> convertedTdescs = pack(
302 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
303 SmallVector<Value> newOps;
304
305 if (!hasOffsets) {
306 for (auto t : convertedTdescs) {
307 auto newOp =
308 xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
309 xegpu::dropInstDataOnAttrs(op->getAttrs()));
310 newOps.push_back(newOp);
311 }
312 } else {
313 auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
314 return xegpu::LoadNdOp::create(
315 rewriter, loc, newValueTy, convertedTdescs[0], offsets,
316 op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
317 op.getL2HintAttr(), op.getL3HintAttr(), layout);
318 };
319 newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
320 *targetShape, createLoad, loc, rewriter);
321 }
322
323 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
324
325 rewriter.replaceOp(op, castOp);
326 return success();
327 }
328};
329
330struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
331 using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
332 LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
333 PatternRewriter &rewriter) const override {
334 Location loc = op.getLoc();
335 VectorType valueTy = op.getValueType();
336 xegpu::TensorDescType tdescTy = op.getTensorDescType();
337
338 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
339 if (!targetShape)
340 return failure();
341
342 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
343 if (layout)
344 layout = layout.dropInstData();
345 int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
346 bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
347
348 SmallVector<Type> convertedValTypes =
349 getUnrolledTypes(valueTy, *targetShape);
350 SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
351 tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
352
353 SmallVector<Value> convertedTdescs = pack(
354 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
355
356 SmallVector<Value> convertedValues =
357 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
358 if (!hasOffsets) {
359 for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
360 xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
361 op.getL2HintAttr(), op.getL3HintAttr());
362 } else {
363 size_t valueIndex = 0;
364 auto createStore = [&](SmallVector<OpFoldResult> offsets) {
365 xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
366 convertedTdescs[0], offsets,
367 op.getL1HintAttr(), op.getL2HintAttr(),
368 op.getL3HintAttr(), layout);
369 // return dummy Value to satisfy function's signature
370 return nullptr;
371 };
372
373 computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
374 createStore, loc, rewriter);
375 }
376
377 rewriter.eraseOp(op);
378 return success();
379 }
380};
381
382struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
383 using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
384 LogicalResult matchAndRewrite(xegpu::DpasOp op,
385 PatternRewriter &rewriter) const override {
386 Location loc = op.getLoc();
387
388 // expecting every operands is a 2D Vector
389 if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
390 auto vecTy = dyn_cast<VectorType>(type);
391 return !vecTy || vecTy.getRank() != 2;
392 }))
393 return failure();
394
395 // A vector of 3 elements should be returned, representing M, K, N
396 // respectively.
397 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
398 if (!targetShape || targetShape->size() != 3)
399 return failure();
400 auto M = (*targetShape)[0];
401 auto K = (*targetShape)[1];
402 auto N = (*targetShape)[2];
403
404 int64_t aBlockSize[2] = {M, K};
405 int64_t bBlockSize[2] = {K, N};
406 int64_t cBlockSize[2] = {M, N};
407
408 auto packWrapper = [&](TypedValue<VectorType> val,
409 ArrayRef<int64_t> blockSize) {
410 VectorType type = val.getType();
411 std::optional<SmallVector<int64_t>> grids =
412 computeShapeRatio(type.getShape(), blockSize);
413 assert(grids && "Expecting grids to be computed.");
414 auto numNewOps = computeProduct(*grids);
415 if (numNewOps == 1)
416 return SmallVector<Value>({val});
417 VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
418 SmallVector<Type> convertedTypes(numNewOps, newVecTy);
419 SmallVector<Value> values =
420 pack(val, convertedTypes, blockSize, loc, rewriter);
421 return values;
422 };
423
424 auto a = op.getLhs();
425 auto b = op.getRhs();
426 auto c = op.getAcc();
427
428 auto aShape = a.getType().getShape();
429 auto bShape = b.getType().getShape();
430
431 SmallVector<Value> aVals, bVals, cVals;
432 aVals = packWrapper(a, aBlockSize);
433 bVals = packWrapper(b, bBlockSize);
434
435 if (c)
436 cVals = packWrapper(c, cBlockSize);
437
438 // Skip the operation if every operand has an invalid blocking size (empty)
439 // or if the original shape matches the blocking size (size == 1).
440 auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
441 : SmallVector<ValueRange>({aVals, bVals});
442 if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
443 llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
444 return failure();
445
446 VectorType resultTy = op.getResult().getType();
447 auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
448
449 int64_t mIters = aShape[0] / M;
450 int64_t kIters = aShape[1] / K;
451 int64_t nIters = bShape[1] / N;
452
453 SmallVector<Value> newOps;
454 for (int64_t i = 0; i < mIters; ++i) {
455 for (int64_t j = 0; j < nIters; ++j) {
456 Value tmpC;
457 if (c)
458 tmpC = cVals[i * nIters + j]; // init with acc
459
460 for (int64_t k = 0; k < kIters; ++k) {
461 Value aVec = aVals[i * kIters + k];
462 Value bVec = bVals[k * nIters + j];
463 SmallVector<Value> operands({aVec, bVec});
464 if (tmpC)
465 operands.push_back(tmpC);
466
467 tmpC =
468 xegpu::DpasOp::create(rewriter, loc, vecTy, operands,
469 xegpu::dropInstDataOnAttrs(op->getAttrs()));
470 }
471 newOps.push_back(tmpC);
472 }
473 }
474 Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
475 rewriter.replaceOp(op, castOp);
476 return success();
477 }
478};
479
480struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
481 using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
482 LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
483 PatternRewriter &rewriter) const override {
484 Location loc = op.getLoc();
485 xegpu::TensorDescType tdescTy = op.getType();
486 TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
487 VectorType indiceVecTy = indiceVec.getType();
488
489 if (!tdescTy.isScattered())
490 return failure();
491
492 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
493 if (!targetShape)
494 return failure();
495
496 SmallVector<int64_t> targetIndiceShape(*targetShape);
497 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
498 // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
499 if (originalChunkSize > 1)
500 targetIndiceShape.pop_back();
501
502 auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
503 SmallVector<Type> convertedIndiceTypes =
504 getUnrolledTypes(indiceVecTy, targetIndiceShape);
505 SmallVector<Value> convertedIndiceVec =
506 pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
507
508 SmallVector<Value> newOps;
509
510 // More indices is need when chunkSize > 1. Since a big load from one
511 // address could be break into multiple small loads.
512 if (originalChunkSize > 1) {
513 int64_t blockedChunkSize = targetShape->back();
514 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
515
516 for (auto [indice, indiceType] :
517 llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
518 for (int64_t i = 0; i < numNewChunks; ++i) {
519 // Compute the offset
520 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
521 i * blockedChunkSize);
522 Value incVec =
523 vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
524 Value offsetIndice =
525 arith::AddIOp::create(rewriter, loc, indice, incVec);
526
527 auto newOp = xegpu::CreateDescOp::create(
528 rewriter, loc, newTdescTy, op.getSource(), offsetIndice);
529
530 newOps.push_back(newOp);
531 }
532 }
533 } else {
534 for (auto indice : convertedIndiceVec) {
535 auto newOp = xegpu::CreateDescOp::create(rewriter, loc, newTdescTy,
536 op.getSource(), indice);
537 newOps.push_back(newOp);
538 }
539 }
540
541 Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
542 rewriter.replaceOp(op, castOp);
543
544 return success();
545 }
546};
547
548struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
549 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
550 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
551 PatternRewriter &rewriter) const override {
552
553 Location loc = op.getLoc();
554 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
555 xegpu::TensorDescType tdescTy = op.getTensorDescType();
556
557 // TODO: handle the unstructure source case (!tdesTy)
558 if (!tdescTy || op.getOffsets())
559 return failure();
560
561 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
562 if (!targetShape)
563 return failure();
564
565 SmallVector<int64_t> targetMaskShape(*targetShape);
566 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
567
568 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
569
570 Type elemTy = tdescTy.getElementType();
571 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
572
573 SmallVector<Type> convertedTdescTypes =
574 getUnrolledTypes(tdescTy, *targetShape);
575 SmallVector<Value> convertedTdescs = pack(
576 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
577
578 SmallVector<Type> convertedMaskTypes;
579 SmallVector<Value> convertedMasks;
580
581 if (originalChunkSize > 1) {
582 targetMaskShape.pop_back();
583 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
584 int64_t blockedChunkSize = targetShape->back();
585 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
586
587 // the mask is reused across the chunk_size dimension
588 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
589 loc, rewriter))
590 convertedMasks.append(numNewChunks, mask);
591
592 newValueTy = valueTy.cloneWith(*targetShape, elemTy);
593 } else {
594 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
595 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
596 loc, rewriter);
597 }
598
599 SmallVector<Value> newOps;
600 for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
601 auto newOp = xegpu::LoadGatherOp::create(
602 rewriter, loc, newValueTy, t, m, op.getL1HintAttr(),
603 op.getL2HintAttr(), op.getL3HintAttr());
604 newOps.push_back(newOp);
605 }
606
607 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
608 rewriter.replaceOp(op, castOp);
609 return success();
610 }
611};
612
613/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
614/// load).
615/// It unrolls the offsets and mask operands accordingly, and creates multiple
616/// LoadGatherOp with the unrolled operands.
617struct UnrollLoadGatherOpWithOffset
618 : public UnrollPattern<xegpu::LoadGatherOp> {
619 using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
620 LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
621 PatternRewriter &rewriter) const override {
622 Location loc = op.getLoc();
623 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
624 Value offsets = op.getOffsets();
625 Value mask = op.getMask();
626
627 // Only handle the case where offsets are present (scattered load)
628 if (!offsets)
629 return failure();
630
631 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
632 if (!targetShape)
633 return failure();
634
635 SmallVector<int64_t> targetMaskShape(*targetShape);
636 int64_t chunkSize = 1;
637 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
638 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
639 chunkSize = intAttr.getInt();
640 }
641
642 // Unroll mask and offsets with correct shape
643 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
644 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
645 Type elemTy = valueTy.getElementType();
646 VectorType newValueTy = VectorType::get(*targetShape, elemTy);
647
648 SmallVector<Type> convertedMaskTypes;
649 SmallVector<Value> convertedMasks;
650 SmallVector<Type> convertedOffsetTypes;
651 SmallVector<Value> convertedOffsets;
652
653 if (chunkSize > 1) {
654 // For chunked loads, mask and offsets have one less dimension
655 targetMaskShape.pop_back();
656 int64_t blockedChunkSize = targetShape->back();
657 int64_t numNewChunks = chunkSize / blockedChunkSize;
658 chunkSize = blockedChunkSize;
659
660 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
661 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
662
663 SmallVector<Value> convertedMasksBase =
664 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
665 SmallVector<Value> convertedOffsetsBase =
666 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
667
668 for (auto maskVal : convertedMasksBase)
669 convertedMasks.append(numNewChunks, maskVal);
670
671 for (auto [baseOffset, offsetType] :
672 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
673 for (int64_t i = 0; i < numNewChunks; ++i) {
674 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
675 i * blockedChunkSize);
676 Value incVec =
677 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
678 Value offsetVal =
679 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
680 convertedOffsets.push_back(offsetVal);
681 }
682 }
683 } else {
684 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
685 convertedMasks =
686 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
687
688 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
689 convertedOffsets =
690 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
691 }
692
693 auto layout = op.getLayoutAttr();
694 if (layout)
695 layout = layout.dropInstData();
696
697 SmallVector<Value> newOps;
698 for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
699 auto newOp = xegpu::LoadGatherOp::create(
700 rewriter, loc, newValueTy, op.getSource(), o, m,
701 rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
702 op.getL2HintAttr(), op.getL3HintAttr(), layout);
703 newOps.push_back(newOp);
704 }
705
706 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
707 rewriter.replaceOp(op, castOp);
708 return success();
709 }
710};
711
712/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
713/// store).
714/// It unrolls the offsets and mask operands accordingly, and creates multiple
715/// StoreScatterOp with the unrolled operands.
716struct UnrollStoreScatterOpWithOffsets
717 : public UnrollPattern<xegpu::StoreScatterOp> {
718 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
719 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
720 PatternRewriter &rewriter) const override {
721 Location loc = op.getLoc();
722 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
723 Value offsets = op.getOffsets();
724 Value mask = op.getMask();
725
726 // Only handle the case where offsets are present (scattered store)
727 if (!offsets)
728 return failure();
729
730 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
731 if (!targetShape)
732 return failure();
733
734 int64_t chunkSize = 1;
735 if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
736 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
737 chunkSize = intAttr.getInt();
738 }
739
740 SmallVector<int64_t> targetMaskShape(*targetShape);
741 VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
742 VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
743
744 SmallVector<Type> convertedMaskTypes;
745 SmallVector<Value> convertedMasks;
746 SmallVector<Type> convertedOffsetTypes;
747 SmallVector<Value> convertedOffsets;
748
749 if (chunkSize > 1) {
750 targetMaskShape.pop_back();
751 int64_t blockedChunkSize = targetShape->back();
752 int64_t numNewChunks = chunkSize / blockedChunkSize;
753 chunkSize = blockedChunkSize;
754
755 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
756 convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);
757
758 SmallVector<Value> convertedMasksBase =
759 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
760 SmallVector<Value> convertedOffsetsBase =
761 pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);
762
763 for (auto maskVal : convertedMasksBase)
764 convertedMasks.append(numNewChunks, maskVal);
765
766 for (auto [baseOffset, offsetType] :
767 llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
768 for (int64_t i = 0; i < numNewChunks; ++i) {
769 Value inc = arith::ConstantIndexOp::create(rewriter, loc,
770 i * blockedChunkSize);
771 Value incVec =
772 vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
773 Value offsetVal =
774 arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
775 convertedOffsets.push_back(offsetVal);
776 }
777 }
778 } else {
779 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
780 convertedMasks =
781 pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
782
783 convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
784 convertedOffsets =
785 pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
786 }
787
788 SmallVector<Type> convertedValTypes =
789 getUnrolledTypes(valueTy, *targetShape);
790 SmallVector<Value> convertedValues =
791 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
792
793 auto layout = op.getLayoutAttr();
794 if (layout)
795 layout = layout.dropInstData();
796
797 for (auto [v, o, m] :
798 llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
799 xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
800 rewriter.getI64IntegerAttr(chunkSize),
801 op.getL1HintAttr(), op.getL2HintAttr(),
802 op.getL3HintAttr(), layout);
803 }
804
805 rewriter.eraseOp(op);
806 return success();
807 }
808};
809
810struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
811 using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
812 LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
813 PatternRewriter &rewriter) const override {
814 Location loc = op.getLoc();
815 xegpu::TensorDescType tdescTy = op.getTensorDescType();
816
817 // TODO: handle the unstructure source case (!tdesTy)
818 if (!tdescTy || op.getOffsets())
819 return failure();
820
821 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
822 if (!targetShape)
823 return failure();
824
825 SmallVector<Type> convertedTdescTypes =
826 getUnrolledTypes(tdescTy, *targetShape);
827 SmallVector<Value> convertedTdesc = pack(
828 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
829
830 for (auto t : convertedTdesc)
831 xegpu::PrefetchOp::create(rewriter, loc, TypeRange(), t,
832 xegpu::dropInstDataOnAttrs(op->getAttrs()));
833
834 rewriter.eraseOp(op);
835 return success();
836 }
837};
838
839struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
840 using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
841 LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
842 PatternRewriter &rewriter) const override {
843
844 Location loc = op.getLoc();
845 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
846 xegpu::TensorDescType tdescTy = op.getTensorDescType();
847
848 // TODO: handle the unstructure source case (!tdesTy)
849 if (!tdescTy || op.getOffsets())
850 return failure();
851
852 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
853 if (!targetShape)
854 return failure();
855
856 SmallVector<int64_t> targetMaskShape(*targetShape);
857 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
858
859 VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
860
861 SmallVector<Type> convertedTdescTypes =
862 getUnrolledTypes(tdescTy, *targetShape);
863 SmallVector<Value> convertedTdescs = pack(
864 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
865
866 SmallVector<Type> convertedMaskTypes;
867 SmallVector<Value> convertedMasks;
868
869 if (originalChunkSize > 1) {
870 targetMaskShape.pop_back();
871 int64_t blockedChunkSize = targetShape->back();
872 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
873 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
874
875 // the mask is reused across the chunk_size dimension
876 for (auto mask : pack(op.getMask(), convertedMaskTypes, targetMaskShape,
877 loc, rewriter))
878 convertedMasks.append(numNewChunks, mask);
879 } else {
880 convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
881 convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
882 loc, rewriter);
883 }
884
885 SmallVector<Type> convertedValTypes =
886 getUnrolledTypes(valueTy, *targetShape);
887 SmallVector<Value> convertedValues =
888 pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
889
890 for (size_t i = 0; i < convertedValues.size(); ++i) {
891 Value v = convertedValues[i];
892 Value t = convertedTdescs[i];
893 Value m = op.getMask() ? convertedMasks[i] : nullptr;
894 xegpu::StoreScatterOp::create(rewriter, loc, v, t, m, op.getL1HintAttr(),
895 op.getL2HintAttr(), op.getL3HintAttr());
896 }
897
898 rewriter.eraseOp(op);
899 return success();
900 }
901};
902
903struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
904 using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
905 LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
906 PatternRewriter &rewriter) const override {
907 Location loc = op.getLoc();
908 xegpu::TensorDescType tdescTy = op.getTensorDescType();
909
910 if (!tdescTy.isScattered())
911 return failure();
912
913 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
914 if (!targetShape)
915 return failure();
916
917 SmallVector<Type> convertedTdescTypes =
918 getUnrolledTypes(tdescTy, *targetShape);
919 SmallVector<Value> convertedTdesc = pack(
920 op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
921
922 TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
923 VectorType offsetVecTy = offsetVec.getType();
924 SmallVector<Type> convertedOffsetTypes;
925 SmallVector<Value> convertedOffsetVec;
926 SmallVector<Value> newOps;
927 int64_t originalChunkSize = tdescTy.getChunkSizeAsInt();
928 if (originalChunkSize > 1) {
929 auto targetOffsetShape = ArrayRef<int64_t>(*targetShape).drop_back();
930 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, targetOffsetShape);
931
932 int64_t blockedChunkSize = targetShape->back();
933 int64_t numNewChunks = originalChunkSize / blockedChunkSize;
934 // the offset is reused across the chunk_size dimension
935 for (auto offset : pack(offsetVec, convertedOffsetTypes,
936 targetOffsetShape, loc, rewriter))
937 convertedOffsetVec.append(numNewChunks, offset);
938
939 } else {
940 convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
941 convertedOffsetVec =
942 pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
943 }
944
945 for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
946 auto newOp =
947 xegpu::UpdateOffsetOp::create(rewriter, loc, t.getType(), t, o);
948 newOps.push_back(newOp);
949 }
950 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
951 rewriter.replaceOp(op, castOp);
952 return success();
953 }
954};
955
956struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
957 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
958 LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
959 PatternRewriter &rewriter) const override {
960 Location loc = op.getLoc();
961 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
962 assert(valueTy && "the value type must be vector type!");
963
964 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
965 if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
966 return failure();
967
968 Type elemTy = valueTy.getElementType();
969 ArrayRef<int64_t> shape = valueTy.getShape();
970 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
971
972 VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
973
974 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
976 for (SmallVector<int64_t> offsets :
977 StaticTileOffsetRange(shape, *targetShape)) {
978 auto adds = xegpu::addElementwise(
979 rewriter, loc, mixedOffsets,
980 getAsIndexOpFoldResult(op.getContext(), offsets));
981 offsetsList.push_back(adds);
982 }
983
984 SmallVector<Value> newOps;
985 layout = layout.dropInstData();
986 for (SmallVector<OpFoldResult> offsets : offsetsList) {
987 auto newOp = xegpu::LoadMatrixOp::create(
988 rewriter, op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
989 newOps.push_back(newOp);
990 }
991 Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
992 rewriter.replaceOp(op, castOp);
993 return success();
994 }
995};
996
997struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
998 using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
999 LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
1000 PatternRewriter &rewriter) const override {
1001 std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
1002 if (!targetShape)
1003 return failure();
1004
1005 Location loc = op.getLoc();
1006 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
1007 assert(valueTy && "the value type must be vector type!");
1008 ArrayRef<int64_t> shape = valueTy.getShape();
1009 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
1010
1011 SmallVector<Type> convertedValTypes =
1012 getUnrolledTypes(valueTy, *targetShape);
1013 SmallVector<Value> convertedValues =
1014 pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
1015
1016 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
1018 for (SmallVector<int64_t> offsets :
1019 StaticTileOffsetRange(shape, *targetShape)) {
1020 auto adds = xegpu::addElementwise(
1021 rewriter, loc, mixedOffsets,
1022 getAsIndexOpFoldResult(op.getContext(), offsets));
1023 offsetsList.push_back(adds);
1024 }
1025
1026 for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
1027 xegpu::StoreMatrixOp::create(rewriter, loc, v, op.getMemDesc(), offsets,
1028 layout.dropInstData());
1029
1030 rewriter.eraseOp(op);
1031 return success();
1032 }
1033};
1034
1035} // namespace
1036
1039 patterns
1040 .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
1041 UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
1042 UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
1043 UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
1044 UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
1045 patterns.getContext(), options);
1046}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
UnitAttr getUnitAttr()
Definition Builders.cpp:102
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, const UnrollOptions &options)
Collect a set of patterns to unroll xegpu operations to a smaller shapes.
SmallVector< NamedAttribute > dropInstDataOnAttrs(ArrayRef< NamedAttribute > attrs)
Updates the NamedAttribute sequence by dropping inst-data information from any DistributeLayoutAttr f...
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options to control the XeGPU unrolling.
Definition Transforms.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.