MLIR 23.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
15
24#include "mlir/Pass/Pass.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28#include <algorithm>
29#include <optional>
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38namespace {
39
40// Return true if value represents a zero constant.
41static bool isZeroConstant(Value val) {
42 auto constant = val.getDefiningOp<arith::ConstantOp>();
43 if (!constant)
44 return false;
45
46 return TypeSwitch<Attribute, bool>(constant.getValue())
47 .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
48 .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
49 .Default(false);
50}
51
52static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53 Operation *op, VectorType vecTy,
54 MemRefType memTy) {
55 // Validate only vector as the basic vector store and load ops guarantee
56 // XeGPU-compatible memref source.
57 unsigned vecRank = vecTy.getRank();
58 if (!(vecRank == 1 || vecRank == 2))
59 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
60
61 if (!vecTy.getElementType().isIntOrFloat())
62 return rewriter.notifyMatchFailure(
63 op, "Expected scalar type with known bitwidth");
64
65 // XeGPU requires the memref to have a scalar integer or float element type.
66 // Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
67 // supported because createNdDescriptor computes byte offsets using
68 // getElementTypeBitWidth(), which asserts on non-integer/float types.
69 if (!memTy.getElementType().isIntOrFloat())
70 return rewriter.notifyMatchFailure(
71 op, "Unsupported memref element type: expected integer or float");
72
73 return success();
74}
75
76static LogicalResult transferPreconditions(PatternRewriter &rewriter,
77 VectorTransferOpInterface xferOp) {
78 if (xferOp.getMask())
79 return rewriter.notifyMatchFailure(xferOp,
80 "Masked transfer is not supported");
81
82 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
83 if (!srcTy)
84 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
85
86 // Validate further transfer op semantics.
88 int64_t offset;
89 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
90 return rewriter.notifyMatchFailure(
91 xferOp, "Buffer must be contiguous in the innermost dimension");
92
93 VectorType vecTy = xferOp.getVectorType();
94 unsigned vecRank = vecTy.getRank();
95 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
96 return rewriter.notifyMatchFailure(
97 xferOp, "Boundary check is available only for block instructions.");
98
99 AffineMap map = xferOp.getPermutationMap();
100 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
101 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
102 unsigned numInputDims = map.getNumInputs();
103 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
104 auto dim = dyn_cast<AffineDimExpr>(expr);
105 if (dim.getPosition() < (numInputDims - vecRank))
106 return rewriter.notifyMatchFailure(
107 xferOp, "Only the innermost dimensions can be accessed");
108 }
109
110 return success();
111}
112
113static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
114 Location loc,
115 xegpu::TensorDescType descType,
117 MemRefType srcTy = src.getType();
118 assert(srcTy.isStrided() && "Expected strided memref type");
119 auto [strides, offset] = srcTy.getStridesAndOffset();
120 bool isStatic = true;
121
122 // Memref is dynamic if any of its shape, offset or strides is dynamic.
123 if (!srcTy.hasStaticShape())
124 isStatic = false;
125
126 if (!ShapedType::isStatic(offset))
127 isStatic = false;
128
129 for (auto stride : strides) {
130 if (!ShapedType::isStatic(stride)) {
131 isStatic = false;
132 break;
133 }
134 }
135
136 xegpu::CreateNdDescOp ndDesc;
137 if (isStatic) {
138 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
139 } else {
140 // In case of ranked dynamic memref, instead of passing on the memref,
141 // i64 base address, source's offset, shape and strides have to be
142 // explicitly provided.
143 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
144 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
145 rewriter, loc, meta.getBaseBuffer());
146 auto offset = meta.getOffset();
147 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
148 auto offsetInBytes = arith::MulIOp::create(
149 rewriter, loc, offset,
150 arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
151 auto adjustedBaseAddr = arith::AddIOp::create(
152 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
153 auto adjustedAddrI64 = arith::IndexCastOp::create(
154 rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
155 ndDesc = xegpu::CreateNdDescOp::create(
156 rewriter, loc, descType, adjustedAddrI64,
157 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
158 }
159
160 return ndDesc;
161}
162
163// Adjusts the strides of a memref according to a given permutation map for
164// vector operations.
165//
166// This function updates the innermost strides in the `strides` array to
167// reflect the permutation specified by `permMap`. The permutation is computed
168// using the inverse and broadcasting-aware version of the permutation map,
169// and is applied to the relevant strides. This ensures that memory accesses
170// are consistent with the logical permutation of vector elements.
171//
172// Example:
173// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
174// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
175// 0]), then after calling this function, the last two strides will be
176// swapped:
177// Original strides: [s0, s1, s2, s3]
178// After permutation: [s0, s1, s3, s2]
179//
180static void adjustStridesForPermutation(AffineMap permMap,
181 SmallVectorImpl<Value> &strides) {
182
186 SmallVector<int64_t> perms64(perms.begin(), perms.end());
187 strides = applyPermutation(strides, perms64);
188}
189
190// Computes memory strides and a memref offset for vector transfer operations,
191// handling both static and dynamic memrefs while applying permutation
192// transformations for XeGPU lowering.
193template <
194 typename OpType,
195 typename = std::enable_if_t<llvm::is_one_of<
196 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
197 vector::GatherOp, vector::ScatterOp>::value>>
198static std::pair<SmallVector<Value>, Value>
199computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
200 SmallVector<Value> strides;
201 Value baseMemref = xferOp.getBase();
202 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
203
204 Location loc = xferOp.getLoc();
205 Value offsetVal = nullptr;
206 if (memrefType.hasStaticShape()) {
207 int64_t offset;
208 SmallVector<int64_t> intStrides;
209 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
210 return {{}, offsetVal};
211 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
212 return ShapedType::isDynamic(strideVal);
213 });
214
215 if (!hasDynamicStrides)
216 for (int64_t s : intStrides)
217 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
218
219 if (!ShapedType::isDynamic(offset))
220 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
221 }
222
223 if (strides.empty() || !offsetVal) {
224 // For dynamic shape memref, use memref.extract_strided_metadata to get
225 // stride values
226 unsigned rank = memrefType.getRank();
227 Type indexType = rewriter.getIndexType();
228
229 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
230 // size0, size1, ..., sizeN-1]
231 SmallVector<Type> resultTypes;
232 resultTypes.push_back(MemRefType::get(
233 {}, memrefType.getElementType())); // base memref (unranked)
234 resultTypes.push_back(indexType); // offset
235
236 for (unsigned i = 0; i < rank; ++i)
237 resultTypes.push_back(indexType); // strides
238
239 for (unsigned i = 0; i < rank; ++i)
240 resultTypes.push_back(indexType); // sizes
241
242 auto meta = memref::ExtractStridedMetadataOp::create(
243 rewriter, loc, resultTypes, baseMemref);
244
245 if (strides.empty())
246 strides.append(meta.getStrides().begin(), meta.getStrides().end());
247
248 if (!offsetVal)
249 offsetVal = meta.getOffset();
250 }
251
252 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
253 vector::TransferWriteOp>::value) {
254 AffineMap permMap = xferOp.getPermutationMap();
255 // Adjust strides according to the permutation map (e.g., for transpose)
256 adjustStridesForPermutation(permMap, strides);
257 }
258
259 return {strides, offsetVal};
260}
261
262// This function compute the vectors of localOffsets for scattered load/stores.
263// It is used in the lowering of vector.transfer_read/write to
264// load_gather/store_scatter Example:
265// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
266// %cst {in_bounds = [true, true, true, true]}>} :
267// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
268//
269// %6 = vector.step: vector<4xindex>
270// %7 = vector.step: vector<2xindex>
271// %8 = vector.step: vector<6xindex>
272// %9 = vector.step: vector<32xindex>
273// %10 = arith.mul %6, 384
274// %11 = arith.mul %7, 192
275// %12 = arith.mul %8, 32
276// %13 = arith.mul %9, 1
277// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
278// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
279// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
280// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
281// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
282// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
283// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
284// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
285// %22 = arith.add %18, %19
286// %23 = arith.add %20, %21
287// %local_offsets = arith.add %22, %23
288// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
289// %offsets = memref_offset + orig_offset + local_offsets
290static Value computeOffsets(VectorTransferOpInterface xferOp,
291 PatternRewriter &rewriter, ArrayRef<Value> strides,
292 Value baseOffset) {
293 Location loc = xferOp.getLoc();
294 VectorType vectorType = xferOp.getVectorType();
295 SmallVector<Value> indices(xferOp.getIndices().begin(),
296 xferOp.getIndices().end());
297 ArrayRef<int64_t> vectorShape = vectorType.getShape();
298
299 // Create vector.step operations for each dimension
300 SmallVector<Value> stepVectors;
301 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
302 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
303 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
304 stepVectors.push_back(stepOp);
305 return stepOp;
306 });
307
308 // Multiply step vectors by corresponding strides
309 size_t memrefRank = strides.size();
310 size_t vectorRank = vectorShape.size();
311 SmallVector<Value> strideMultiplied;
312 for (size_t i = 0; i < vectorRank; ++i) {
313 size_t memrefDim = memrefRank - vectorRank + i;
314 Value strideValue = strides[memrefDim];
315 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
316 auto bcastOp =
317 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
318 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
319 strideMultiplied.push_back(mulOp);
320 }
321
322 // Shape cast each multiplied vector to add singleton dimensions
323 SmallVector<Value> shapeCasted;
324 for (size_t i = 0; i < vectorRank; ++i) {
325 SmallVector<int64_t> newShape(vectorRank, 1);
326 newShape[i] = vectorShape[i];
327 auto newType = VectorType::get(newShape, rewriter.getIndexType());
328 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
329 strideMultiplied[i]);
330 shapeCasted.push_back(castOp);
331 }
332
333 // Broadcast each shape-casted vector to full vector shape
334 SmallVector<Value> broadcasted;
335 auto fullIndexVectorType =
336 VectorType::get(vectorShape, rewriter.getIndexType());
337 for (Value shapeCastVal : shapeCasted) {
338 auto broadcastOp = vector::BroadcastOp::create(
339 rewriter, loc, fullIndexVectorType, shapeCastVal);
340 broadcasted.push_back(broadcastOp);
341 }
342
343 // Add all broadcasted vectors together to compute local offsets
344 Value localOffsets = broadcasted[0];
345 for (size_t i = 1; i < broadcasted.size(); ++i)
346 localOffsets =
347 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
348
349 // Compute base offset from transfer read indices
350 for (size_t i = 0; i < indices.size(); ++i) {
351 Value strideVal = strides[i];
352 Value offsetContrib =
353 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
354 baseOffset =
355 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
356 }
357 // Broadcast base offset to match vector shape
358 Value bcastBase = vector::BroadcastOp::create(
359 rewriter, loc, fullIndexVectorType, baseOffset);
360 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
361 return localOffsets;
362}
363
364// Compute the element-wise offsets for vector.gather or vector.scatter ops.
365//
366// This function linearizes the base offsets of the gather/scatter operation
367// and combines them with the per-element indices to produce a final vector of
368// memory offsets.
369template <
370 typename OpType,
371 typename = std::enable_if_t<llvm::is_one_of<
372 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
373static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
374 ArrayRef<Value> strides, Value baseOffset) {
375 Location loc = gatScatOp.getLoc();
376 SmallVector<Value> offsets = gatScatOp.getOffsets();
377 for (size_t i = 0; i < offsets.size(); ++i) {
378 Value offsetContrib =
379 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
380 baseOffset =
381 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
382 }
383 Value indices = gatScatOp.getIndices();
384 VectorType vecType = cast<VectorType>(indices.getType());
385
386 Value strideVector =
387 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
388 .getResult();
389 Value stridedIndices =
390 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
391
392 Value baseVector =
393 vector::BroadcastOp::create(
394 rewriter, loc,
395 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
396 baseOffset)
397 .getResult();
398 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
399 .getResult();
400}
401
402// Collapses shapes of a nD memref to the target rank while applying offsets for
403// the collapsed dimensions. Returns the new memref value and the remaining
404// offsets for the last targetRank dimensions. For example:
405// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
406// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
407static std::pair<Value, SmallVector<OpFoldResult>>
408convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
411 int64_t targetRank) {
412 auto memrefType = cast<MemRefType>(memref.getType());
413 unsigned rank = memrefType.getRank();
414
415 if (rank <= targetRank)
416 return {memref, offsets};
417
418 int64_t numCombinedDims = rank - targetRank;
419 SmallVector<OpFoldResult> subviewOffsets;
420 SmallVector<OpFoldResult> subviewSizes;
421 SmallVector<OpFoldResult> subviewStrides;
422
423 // For the combined dimensions: use the provided offsets, size=1, stride=1
424 for (unsigned i = 0; i < numCombinedDims; ++i) {
425 subviewOffsets.push_back(offsets[i]);
426 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
427 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
428 }
429
430 // For the last targetRank dimensions: offset=0, use full size, stride=1
431 SmallVector<int64_t> resultShape;
432 auto originalShape = memrefType.getShape();
433 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
434 for (unsigned i = numCombinedDims; i < rank; ++i) {
435 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
436 if (ShapedType::isDynamic(originalShape[i])) {
437 subviewSizes.push_back(meta.getSizes()[i]);
438 resultShape.push_back(ShapedType::kDynamic);
439 } else {
440 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
441 resultShape.push_back(originalShape[i]);
442 }
443 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
444 }
445
446 auto resultType = memref::SubViewOp::inferRankReducedResultType(
447 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
448 auto subviewOp =
449 memref::SubViewOp::create(rewriter, loc, resultType, memref,
450 subviewOffsets, subviewSizes, subviewStrides);
451
452 // Return the remaining offsets for the last targetRank dimensions
453 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
454 offsets.end());
455 return {subviewOp.getResult(), newOffsets};
456}
457
458template <
459 typename OpType,
460 typename = std::enable_if_t<llvm::is_one_of<
461 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
462 vector::GatherOp, vector::ScatterOp>::value>>
463// Convert memref to i64 base pointer
464static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
465 Location loc = xferOp.getLoc();
466 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
467 rewriter, loc, xferOp.getBase())
468 .getResult();
469 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
470 indexPtr)
471 .getResult();
472}
473
474static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
475 PatternRewriter &rewriter) {
476
477 Location loc = readOp.getLoc();
478 VectorType vectorType = readOp.getVectorType();
479 ArrayRef<int64_t> vectorShape = vectorType.getShape();
480 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
481 if (!memrefType)
482 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
483
484 auto meta = computeMemrefMeta(readOp, rewriter);
485 if (meta.first.empty())
486 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
487
488 Value localOffsets =
489 computeOffsets(readOp, rewriter, meta.first, meta.second);
490
491 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
492
493 Value mask = vector::ConstantMaskOp::create(
494 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
496 auto gatherOp = xegpu::LoadGatherOp::create(
497 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
498 /*chunk_size=*/IntegerAttr{},
499 /*l1_hint=*/xegpu::CachePolicyAttr{},
500 /*l2_hint=*/xegpu::CachePolicyAttr{},
501 /*l3_hint=*/xegpu::CachePolicyAttr{},
502 /*layout=*/nullptr);
503
504 rewriter.replaceOp(readOp, gatherOp.getResult());
505 return success();
506}
507
508static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
509 PatternRewriter &rewriter) {
510
511 Location loc = writeOp.getLoc();
512 VectorType vectorType = writeOp.getVectorType();
513 ArrayRef<int64_t> vectorShape = vectorType.getShape();
514
515 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
516 if (!memrefType)
517 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
518
519 auto meta = computeMemrefMeta(writeOp, rewriter);
520 if (meta.first.empty())
521 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
522
523 Value localOffsets =
524 computeOffsets(writeOp, rewriter, meta.first, meta.second);
525
526 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
527
528 Value mask = vector::ConstantMaskOp::create(
529 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
531 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
532 localOffsets, mask,
533 /*chunk_size=*/IntegerAttr{},
534 /*l1_hint=*/xegpu::CachePolicyAttr{},
535 /*l2_hint=*/xegpu::CachePolicyAttr{},
536 /*l3_hint=*/xegpu::CachePolicyAttr{},
537 /*layout=*/nullptr);
538 rewriter.eraseOp(writeOp);
539 return success();
540}
541
542struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
543 using Base::Base;
544
545 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
546 PatternRewriter &rewriter) const override {
547 Location loc = readOp.getLoc();
548
549 if (failed(transferPreconditions(rewriter, readOp)))
550 return failure();
551 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
552 VectorType loadedVecTy = readOp.getVectorType();
553 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
554 // Check if the memref has address space 3 (shared local memory)
555 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
556 // Handle the SLM case.
557 if (isSharedMemory) {
558 // If the memref is SLM only support 2D case for now.
559 if (loadedVecTy.getRank() != 2)
560 return rewriter.notifyMatchFailure(
561 readOp, "Only 2D vector loads are supported for SLM");
562 AffineMap readMap = readOp.getPermutationMap();
563 if (!readMap.isMinorIdentity())
564 return rewriter.notifyMatchFailure(
565 readOp,
566 "Non identity transposition is not supported for SLM loads.");
567 // Out of bounds case is not supported for SLM loads.
568 if (isOutOfBounds)
569 return rewriter.notifyMatchFailure(
570 readOp, "Out-of-bounds access is not supported for SLM loads");
571
572 // Create mem_desc for SLM
573 auto memDescType =
574 xegpu::MemDescType::get(rewriter.getContext(), readMemTy.getShape(),
575 readMemTy.getElementType(),
576 /*mem_layout=*/nullptr);
577 auto createMemDescOp = xegpu::CreateMemDescOp::create(
578 rewriter, loc, memDescType, readOp.getBase());
579 // Convert indices to OpFoldResult for LoadMatrixOp
580 SmallVector<OpFoldResult> indices =
581 getAsOpFoldResult(readOp.getIndices());
582 auto loadMatrixOp = xegpu::LoadMatrixOp::create(
583 rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices,
584 /*layout=*/nullptr);
585
586 rewriter.replaceOp(readOp, loadMatrixOp.getResult());
587 return success();
588 }
589
590 // TODO:This check needs to be replaced with proper uArch capability check
591 auto chip = xegpu::getChipStr(readOp);
592 // Lower to scattered load Op if the target HW doesn't have 2d block load
593 // support and the load is not from shared memory.
594 if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
595 readOp.getVectorType().getRank() > 2) {
596
597 // TODO: add support for OutOfBound access
598 if (isOutOfBounds)
599 return failure();
600 return lowerToScatteredLoadOp(readOp, rewriter);
601 }
602
603 // Handle the 1D non-SLM case using load.gather.
604 if (loadedVecTy.getRank() == 1 && !isOutOfBounds)
605 return lowerToScatteredLoadOp(readOp, rewriter);
606
607 // Perform common data transfer checks.
608 // TODO: Maybe too strict for SLM case.
609 if (failed(
610 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
611 return failure();
612
613 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
614 return rewriter.notifyMatchFailure(
615 readOp, "Unsupported non-zero padded out-of-bounds read");
616
617 AffineMap readMap = readOp.getPermutationMap();
618 // Check if this is a transpose: the map must have exactly 2 results,
619 // and those 2 results must be the last 2 input dimensions interchanged.
620 // Examples:
621 // (d0, d1) -> (d1, d0) // transpose
622 // (d0, d1) -> (d0, d1) // not a transpose
623 // (d0, d1, d2) -> (d2, d1) // transpose (last 2 dims swapped)
624 bool isTransposeLoad = false;
625 if (readMap.getNumResults() == 2) {
626 auto results = readMap.getResults();
627 unsigned numInputs = readMap.getNumInputs();
628 if (numInputs >= 2) {
629 auto lastDim = getAffineDimExpr(numInputs - 1, readMap.getContext());
630 auto secondLastDim =
631 getAffineDimExpr(numInputs - 2, readMap.getContext());
632 isTransposeLoad =
633 (results[0] == lastDim && results[1] == secondLastDim);
634 }
635 }
636 auto elementType = loadedVecTy.getElementType();
637
638 SmallVector<int64_t> descShape(loadedVecTy.getShape());
639 if (isTransposeLoad) {
640 // If load is transposed, simply swap the last two dimensions of the
641 // loaded vector type to get the descriptor shape.
642 size_t rank = descShape.size();
643 assert(rank >= 2 && "Transpose requires at least 2 dimensions");
644 std::swap(descShape[rank - 1], descShape[rank - 2]);
645 loadedVecTy = VectorType::get(descShape, elementType);
646 }
647 auto descType = xegpu::TensorDescType::get(
648 descShape, elementType, /*array_length=*/1,
649 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
650 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
651 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
652 loadedVecTy.getRank());
653 // By default, no specific caching policy is assigned.
654 xegpu::CachePolicyAttr hint = nullptr;
655 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
656 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
657
658 Operation *loadedOp =
659 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices,
660 /*packed=*/nullptr, /*transpose=*/nullptr,
661 /*l1_hint=*/hint,
662 /*l2_hint=*/hint, /*l3_hint=*/hint,
663 /*layout=*/nullptr);
664 if (isTransposeLoad) {
665 // Transposing the loaded vector with a separate vector.transpose
666 // operation
667 auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
668 SmallVector<int64_t> perm(
669 range.rbegin(), range.rend()); // reverse the range for transpose
670 loadedOp = vector::TransposeOp::create(rewriter, loc,
671 loadedOp->getResult(0), perm);
672 }
673 rewriter.replaceOp(readOp, loadedOp);
674
675 return success();
676 }
677};
678
679struct TransferWriteLowering
680 : public OpRewritePattern<vector::TransferWriteOp> {
681 using Base::Base;
682
683 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
684 PatternRewriter &rewriter) const override {
685 Location loc = writeOp.getLoc();
686
687 if (failed(transferPreconditions(rewriter, writeOp)))
688 return failure();
689 // Perform common data transfer checks.
690 VectorType vecTy = writeOp.getVectorType();
691 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
692 // Check if the memref has address space 3 (shared local memory)
693 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
694
695 // For shared local memory (address space 3), use create_mem_desc +
696 // store_matrix
697 if (isSharedMemory) {
698 // Only support 2D case for now.
699 if (vecTy.getRank() != 2)
700 return rewriter.notifyMatchFailure(
701 writeOp, "Only 2D vector stores are supported for SLM");
702 // Create mem_desc for SLM
703 auto memDescType =
704 xegpu::MemDescType::get(rewriter.getContext(), writeMemTy.getShape(),
705 writeMemTy.getElementType(),
706 /*mem_layout=*/nullptr);
707
708 auto createMemDescOp = xegpu::CreateMemDescOp::create(
709 rewriter, loc, memDescType, writeOp.getBase());
710
711 // Convert indices to OpFoldResult for StoreMatrixOp
712 SmallVector<OpFoldResult> indices =
713 getAsOpFoldResult(writeOp.getIndices());
714
715 xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
716 createMemDescOp.getResult(), indices,
717 /*layout=*/nullptr);
718
719 rewriter.eraseOp(writeOp);
720 return success();
721 }
722
723 // TODO:This check needs to be replaced with proper uArch capability check
724 auto chip = xegpu::getChipStr(writeOp);
725 // Lower to scattered store Op if the target HW doesn't have 2d block
726 // store support and the memref is not SLM.
727 if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
728 writeOp.getVectorType().getRank() > 2) {
729
730 // TODO: add support for OutOfBound access
731 if (writeOp.hasOutOfBoundsDim())
732 return failure();
733 return lowerToScatteredStoreOp(writeOp, rewriter);
734 }
735
736 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
737 return failure();
738
739 AffineMap map = writeOp.getPermutationMap();
740 if (!map.isMinorIdentity())
741 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
742
743 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
744 rewriter, loc, writeOp.getBase(),
745 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
746
747 auto descType = xegpu::TensorDescType::get(
748 vecTy.getShape(), vecTy.getElementType(),
749 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
750 xegpu::MemorySpace::Global);
751 // By default, no specific caching policy is assigned.
752 xegpu::CachePolicyAttr hint = nullptr;
753 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
754 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
755
756 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
757 ndDesc, indices,
758 /*l1_hint=*/hint,
759 /*l2_hint=*/hint, /*l3_hint=*/hint,
760 /*layout=*/nullptr);
761 rewriter.replaceOp(writeOp, storeOp);
762
763 return success();
764 }
765};
766
767struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
768 using Base::Base;
769
770 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
771 PatternRewriter &rewriter) const override {
772 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
773 if (!srcTy)
774 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
775
776 Location loc = gatherOp.getLoc();
777 VectorType vectorType = gatherOp.getVectorType();
778
779 auto meta = computeMemrefMeta(gatherOp, rewriter);
780 if (meta.first.empty())
781 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
782
783 Value localOffsets =
784 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
785 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
786
787 auto xeGatherOp = xegpu::LoadGatherOp::create(
788 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
789 /*chunk_size=*/IntegerAttr{},
790 /*l1_hint=*/xegpu::CachePolicyAttr{},
791 /*l2_hint=*/xegpu::CachePolicyAttr{},
792 /*l3_hint=*/xegpu::CachePolicyAttr{},
793 /*layout=*/nullptr);
794
795 auto selectOp =
796 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
797 xeGatherOp.getResult(), gatherOp.getPassThru());
798 rewriter.replaceOp(gatherOp, selectOp.getResult());
799 return success();
800 }
801};
802
803struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
804 using Base::Base;
805
806 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
807 PatternRewriter &rewriter) const override {
808 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
809 if (!srcTy)
810 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
811
812 Location loc = scatterOp.getLoc();
813 auto meta = computeMemrefMeta(scatterOp, rewriter);
814 if (meta.first.empty())
815 return rewriter.notifyMatchFailure(scatterOp,
816 "Failed to compute strides");
817
818 Value localOffsets =
819 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
820 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
821
822 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
823 flatMemref, localOffsets, scatterOp.getMask(),
824 /*chunk_size=*/IntegerAttr{},
825 /*l1_hint=*/xegpu::CachePolicyAttr{},
826 /*l2_hint=*/xegpu::CachePolicyAttr{},
827 /*l3_hint=*/xegpu::CachePolicyAttr{},
828 /*layout=*/nullptr);
829 rewriter.eraseOp(scatterOp);
830 return success();
831 }
832};
833
834struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
835 using Base::Base;
836
837 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
838 PatternRewriter &rewriter) const override {
839 Location loc = loadOp.getLoc();
840
841 VectorType vecTy = loadOp.getResult().getType();
842 MemRefType memTy = loadOp.getBase().getType();
843 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
844 return failure();
845
846 // Boundary check is available only for block instructions.
847 bool boundaryCheck = vecTy.getRank() > 1;
848 // By default, no specific caching policy is assigned.
849 xegpu::CachePolicyAttr hint = nullptr;
850
851 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
852 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
853 vecTy.getRank());
854
855 auto descType = xegpu::TensorDescType::get(
856 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
857 boundaryCheck, xegpu::MemorySpace::Global);
858
859 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
860 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
861 auto loadNdOp =
862 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
863 /*packed=*/nullptr, /*transpose=*/nullptr,
864 /*l1_hint=*/hint,
865 /*l2_hint=*/hint, /*l3_hint=*/hint,
866 /*layout=*/nullptr);
867 rewriter.replaceOp(loadOp, loadNdOp);
868
869 return success();
870 }
871};
872
873struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
874 using Base::Base;
875
876 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
877 PatternRewriter &rewriter) const override {
878 Location loc = storeOp.getLoc();
879
880 TypedValue<VectorType> vector = storeOp.getValueToStore();
881 VectorType vecTy = vector.getType();
882 MemRefType memTy = storeOp.getBase().getType();
883 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
884 return failure();
885
886 // Boundary check is available only for block instructions.
887 bool boundaryCheck = vecTy.getRank() > 1;
888
889 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
890 rewriter, loc, storeOp.getBase(),
891 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
892
893 auto descType = xegpu::TensorDescType::get(
894 vecTy.getShape(), vecTy.getElementType(),
895 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
896
897 // By default, no specific caching policy is assigned.
898 xegpu::CachePolicyAttr hint = nullptr;
899 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
900 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
901
902 auto storeNdOp =
903 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
904 /*l1_hint=*/hint,
905 /*l2_hint=*/hint, /*l3_hint=*/hint,
906 /*layout=*/nullptr);
907
908 rewriter.replaceOp(storeOp, storeNdOp);
909
910 return success();
911 }
912};
913
914struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
915 using Base::Base;
916
917 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
918 PatternRewriter &rewriter) const override {
919 Location loc = contractOp.getLoc();
920
921 if (contractOp.getKind() != vector::CombiningKind::ADD)
922 return rewriter.notifyMatchFailure(contractOp,
923 "Expects add combining kind");
924
925 TypedValue<Type> acc = contractOp.getAcc();
926 VectorType accType = dyn_cast<VectorType>(acc.getType());
927 if (!accType || accType.getRank() != 2)
928 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
929
930 // Accept only plain 2D data layout.
931 // VNNI packing is applied to DPAS as a separate lowering step.
932 TypedValue<VectorType> lhs = contractOp.getLhs();
933 TypedValue<VectorType> rhs = contractOp.getRhs();
934 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
935 return rewriter.notifyMatchFailure(contractOp,
936 "Expects lhs and rhs 2D vectors");
937
938 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
939 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
940
941 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
942 TypeRange{contractOp.getResultType()},
943 ValueRange{lhs, rhs, acc});
944 rewriter.replaceOp(contractOp, dpasOp);
945
946 return success();
947 }
948};
949
950// Returns `memrefTy` with its memory space replaced by `newMemSpace`.
951static MemRefType withMemorySpace(MemRefType memrefTy, Attribute newMemSpace) {
952 return MemRefType::get(memrefTy.getShape(), memrefTy.getElementType(),
953 memrefTy.getLayout(), newMemSpace);
954}
955
956// Rewrite every `memref.alloca` not already in shared local memory (SLM) to
957// be in SLM (address space 3), and propagate the new memory space through
958// memref-producing aliasing users (e.g. memref.cast, memref.subview,
959// memref.expand_shape, ...). Consumers that take a memref operand but
960// produce a non-memref result (e.g. vector.transfer_read, vector.load) are
961// left untouched: their operand type simply reflects the new memory space.
962//
963// This makes `xegpu.load_matrix`/`xegpu.store_matrix` lowering work end-to-end
964// for IR coming from bufferization, which by default assigns memory space 0/1
965// to allocations.
966static void promoteAllocasToSLM(Operation *root) {
967 MLIRContext *ctx = root->getContext();
968 Attribute slmAttr = IntegerAttr::get(IntegerType::get(ctx, 64), 3);
969
970 // A user is treated as a memref-producing alias (e.g. memref.cast,
971 // memref.subview, memref.expand_shape, ...) if it is side-effect free and
972 // produces at least one memref result. This excludes ops like memref.copy
973 // that have memory effects.
974 auto isMemrefResultOp = [](Operation *op) {
975 if (!isMemoryEffectFree(op))
976 return false;
977 return llvm::any_of(op->getResultTypes(),
978 [](Type t) { return isa<MemRefType>(t); });
979 };
980
981 // Update `v`'s type to have SLM memory space, then walk forward through
982 // memref-producing users and update their result types accordingly.
983 std::function<void(Value)> propagate = [&](Value v) {
984 auto memrefTy = dyn_cast<MemRefType>(v.getType());
985 if (!memrefTy || xegpu::XeGPUDialect::isSharedMemory(memrefTy))
986 return;
987 v.setType(withMemorySpace(memrefTy, slmAttr));
988 for (Operation *user : v.getUsers()) {
989 if (!isMemrefResultOp(user))
990 continue;
991 for (Value result : user->getResults())
992 propagate(result);
993 }
994 };
995
997 root->walk([&](memref::AllocaOp op) {
998 auto memrefTy = dyn_cast<MemRefType>(op.getResult().getType());
999 if (!memrefTy || xegpu::XeGPUDialect::isSharedMemory(memrefTy))
1000 return;
1001 allocas.push_back(op);
1002 });
1003
1004 for (memref::AllocaOp alloca : allocas) {
1005 OpBuilder builder(alloca);
1006 auto memrefTy = cast<MemRefType>(alloca.getResult().getType());
1007 auto newTy = withMemorySpace(memrefTy, slmAttr);
1008 auto newOp = memref::AllocaOp::create(
1009 builder, alloca.getLoc(), newTy, alloca.getDynamicSizes(),
1010 alloca.getSymbolOperands(), alloca.getAlignmentAttr());
1011 alloca.getResult().replaceAllUsesWith(newOp.getResult());
1012 alloca.erase();
1013 // Propagate the new memory space through memref-producing consumers.
1014 for (Operation *user : newOp.getResult().getUsers()) {
1015 if (!isMemrefResultOp(user))
1016 continue;
1017 for (Value result : user->getResults())
1018 propagate(result);
1019 }
1020 }
1021}
1022
1023struct ConvertVectorToXeGPUPass
1024 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
1025 void runOnOperation() override {
1026 // Promote local allocations to SLM (address space 3) so that
1027 // load_matrix/store_matrix lowerings have well-typed memref operands.
1028 promoteAllocasToSLM(getOperation());
1029
1030 RewritePatternSet patterns(&getContext());
1033 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1034 return signalPassFailure();
1035 }
1036};
1037
1038} // namespace
1039
1041 RewritePatternSet &patterns) {
1042 patterns
1043 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
1044 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
1045 patterns.getContext());
1046}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
result_type_range getResultTypes()
Definition Operation.h:453
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...