MLIR 23.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
15
23#include "mlir/Pass/Pass.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27#include <algorithm>
28#include <optional>
29
30namespace mlir {
31#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36
37namespace {
38
39// Return true if value represents a zero constant.
40static bool isZeroConstant(Value val) {
41 auto constant = val.getDefiningOp<arith::ConstantOp>();
42 if (!constant)
43 return false;
44
45 return TypeSwitch<Attribute, bool>(constant.getValue())
46 .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
47 .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
48 .Default(false);
49}
50
51static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
52 Operation *op, VectorType vecTy,
53 MemRefType memTy) {
54 // Validate only vector as the basic vector store and load ops guarantee
55 // XeGPU-compatible memref source.
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
58 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59
60 if (!vecTy.getElementType().isIntOrFloat())
61 return rewriter.notifyMatchFailure(
62 op, "Expected scalar type with known bitwidth");
63
64 // XeGPU requires the memref to have a scalar integer or float element type.
65 // Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
66 // supported because createNdDescriptor computes byte offsets using
67 // getElementTypeBitWidth(), which asserts on non-integer/float types.
68 if (!memTy.getElementType().isIntOrFloat())
69 return rewriter.notifyMatchFailure(
70 op, "Unsupported memref element type: expected integer or float");
71
72 return success();
73}
74
75static LogicalResult transferPreconditions(PatternRewriter &rewriter,
76 VectorTransferOpInterface xferOp) {
77 if (xferOp.getMask())
78 return rewriter.notifyMatchFailure(xferOp,
79 "Masked transfer is not supported");
80
81 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
82 if (!srcTy)
83 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
84
85 // Validate further transfer op semantics.
87 int64_t offset;
88 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
89 return rewriter.notifyMatchFailure(
90 xferOp, "Buffer must be contiguous in the innermost dimension");
91
92 VectorType vecTy = xferOp.getVectorType();
93 unsigned vecRank = vecTy.getRank();
94 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
95 return rewriter.notifyMatchFailure(
96 xferOp, "Boundary check is available only for block instructions.");
97
98 AffineMap map = xferOp.getPermutationMap();
99 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
100 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
101 unsigned numInputDims = map.getNumInputs();
102 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
103 auto dim = dyn_cast<AffineDimExpr>(expr);
104 if (dim.getPosition() < (numInputDims - vecRank))
105 return rewriter.notifyMatchFailure(
106 xferOp, "Only the innermost dimensions can be accessed");
107 }
108
109 return success();
110}
111
112static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
113 Location loc,
114 xegpu::TensorDescType descType,
116 MemRefType srcTy = src.getType();
117 assert(srcTy.isStrided() && "Expected strided memref type");
118 auto [strides, offset] = srcTy.getStridesAndOffset();
119 bool isStatic = true;
120
121 // Memref is dynamic if any of its shape, offset or strides is dynamic.
122 if (!srcTy.hasStaticShape())
123 isStatic = false;
124
125 if (!ShapedType::isStatic(offset))
126 isStatic = false;
127
128 for (auto stride : strides) {
129 if (!ShapedType::isStatic(stride)) {
130 isStatic = false;
131 break;
132 }
133 }
134
135 xegpu::CreateNdDescOp ndDesc;
136 if (isStatic) {
137 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
138 } else {
139 // In case of ranked dynamic memref, instead of passing on the memref,
140 // i64 base address, source's offset, shape and strides have to be
141 // explicitly provided.
142 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
143 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
144 rewriter, loc, meta.getBaseBuffer());
145 auto offset = meta.getOffset();
146 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
147 auto offsetInBytes = arith::MulIOp::create(
148 rewriter, loc, offset,
149 arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
150 auto adjustedBaseAddr = arith::AddIOp::create(
151 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
152 auto adjustedAddrI64 = arith::IndexCastOp::create(
153 rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
154 ndDesc = xegpu::CreateNdDescOp::create(
155 rewriter, loc, descType, adjustedAddrI64,
156 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
157 }
158
159 return ndDesc;
160}
161
162// Adjusts the strides of a memref according to a given permutation map for
163// vector operations.
164//
165// This function updates the innermost strides in the `strides` array to
166// reflect the permutation specified by `permMap`. The permutation is computed
167// using the inverse and broadcasting-aware version of the permutation map,
168// and is applied to the relevant strides. This ensures that memory accesses
169// are consistent with the logical permutation of vector elements.
170//
171// Example:
172// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
173// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
174// 0]), then after calling this function, the last two strides will be
175// swapped:
176// Original strides: [s0, s1, s2, s3]
177// After permutation: [s0, s1, s3, s2]
178//
179static void adjustStridesForPermutation(AffineMap permMap,
180 SmallVectorImpl<Value> &strides) {
181
185 SmallVector<int64_t> perms64(perms.begin(), perms.end());
186 strides = applyPermutation(strides, perms64);
187}
188
189// Computes memory strides and a memref offset for vector transfer operations,
190// handling both static and dynamic memrefs while applying permutation
191// transformations for XeGPU lowering.
192template <
193 typename OpType,
194 typename = std::enable_if_t<llvm::is_one_of<
195 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
196 vector::GatherOp, vector::ScatterOp>::value>>
197static std::pair<SmallVector<Value>, Value>
198computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
199 SmallVector<Value> strides;
200 Value baseMemref = xferOp.getBase();
201 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
202
203 Location loc = xferOp.getLoc();
204 Value offsetVal = nullptr;
205 if (memrefType.hasStaticShape()) {
206 int64_t offset;
207 SmallVector<int64_t> intStrides;
208 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
209 return {{}, offsetVal};
210 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
211 return ShapedType::isDynamic(strideVal);
212 });
213
214 if (!hasDynamicStrides)
215 for (int64_t s : intStrides)
216 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
217
218 if (!ShapedType::isDynamic(offset))
219 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
220 }
221
222 if (strides.empty() || !offsetVal) {
223 // For dynamic shape memref, use memref.extract_strided_metadata to get
224 // stride values
225 unsigned rank = memrefType.getRank();
226 Type indexType = rewriter.getIndexType();
227
228 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
229 // size0, size1, ..., sizeN-1]
230 SmallVector<Type> resultTypes;
231 resultTypes.push_back(MemRefType::get(
232 {}, memrefType.getElementType())); // base memref (unranked)
233 resultTypes.push_back(indexType); // offset
234
235 for (unsigned i = 0; i < rank; ++i)
236 resultTypes.push_back(indexType); // strides
237
238 for (unsigned i = 0; i < rank; ++i)
239 resultTypes.push_back(indexType); // sizes
240
241 auto meta = memref::ExtractStridedMetadataOp::create(
242 rewriter, loc, resultTypes, baseMemref);
243
244 if (strides.empty())
245 strides.append(meta.getStrides().begin(), meta.getStrides().end());
246
247 if (!offsetVal)
248 offsetVal = meta.getOffset();
249 }
250
251 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
252 vector::TransferWriteOp>::value) {
253 AffineMap permMap = xferOp.getPermutationMap();
254 // Adjust strides according to the permutation map (e.g., for transpose)
255 adjustStridesForPermutation(permMap, strides);
256 }
257
258 return {strides, offsetVal};
259}
260
261// This function compute the vectors of localOffsets for scattered load/stores.
262// It is used in the lowering of vector.transfer_read/write to
263// load_gather/store_scatter Example:
264// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
265// %cst {in_bounds = [true, true, true, true]}>} :
266// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
267//
268// %6 = vector.step: vector<4xindex>
269// %7 = vector.step: vector<2xindex>
270// %8 = vector.step: vector<6xindex>
271// %9 = vector.step: vector<32xindex>
272// %10 = arith.mul %6, 384
273// %11 = arith.mul %7, 192
274// %12 = arith.mul %8, 32
275// %13 = arith.mul %9, 1
276// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
277// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
278// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
279// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
280// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
281// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
282// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
283// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
284// %22 = arith.add %18, %19
285// %23 = arith.add %20, %21
286// %local_offsets = arith.add %22, %23
287// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
288// %offsets = memref_offset + orig_offset + local_offsets
289static Value computeOffsets(VectorTransferOpInterface xferOp,
290 PatternRewriter &rewriter, ArrayRef<Value> strides,
291 Value baseOffset) {
292 Location loc = xferOp.getLoc();
293 VectorType vectorType = xferOp.getVectorType();
294 SmallVector<Value> indices(xferOp.getIndices().begin(),
295 xferOp.getIndices().end());
296 ArrayRef<int64_t> vectorShape = vectorType.getShape();
297
298 // Create vector.step operations for each dimension
299 SmallVector<Value> stepVectors;
300 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
301 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
302 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
303 stepVectors.push_back(stepOp);
304 return stepOp;
305 });
306
307 // Multiply step vectors by corresponding strides
308 size_t memrefRank = strides.size();
309 size_t vectorRank = vectorShape.size();
310 SmallVector<Value> strideMultiplied;
311 for (size_t i = 0; i < vectorRank; ++i) {
312 size_t memrefDim = memrefRank - vectorRank + i;
313 Value strideValue = strides[memrefDim];
314 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
315 auto bcastOp =
316 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
317 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
318 strideMultiplied.push_back(mulOp);
319 }
320
321 // Shape cast each multiplied vector to add singleton dimensions
322 SmallVector<Value> shapeCasted;
323 for (size_t i = 0; i < vectorRank; ++i) {
324 SmallVector<int64_t> newShape(vectorRank, 1);
325 newShape[i] = vectorShape[i];
326 auto newType = VectorType::get(newShape, rewriter.getIndexType());
327 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
328 strideMultiplied[i]);
329 shapeCasted.push_back(castOp);
330 }
331
332 // Broadcast each shape-casted vector to full vector shape
333 SmallVector<Value> broadcasted;
334 auto fullIndexVectorType =
335 VectorType::get(vectorShape, rewriter.getIndexType());
336 for (Value shapeCastVal : shapeCasted) {
337 auto broadcastOp = vector::BroadcastOp::create(
338 rewriter, loc, fullIndexVectorType, shapeCastVal);
339 broadcasted.push_back(broadcastOp);
340 }
341
342 // Add all broadcasted vectors together to compute local offsets
343 Value localOffsets = broadcasted[0];
344 for (size_t i = 1; i < broadcasted.size(); ++i)
345 localOffsets =
346 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
347
348 // Compute base offset from transfer read indices
349 for (size_t i = 0; i < indices.size(); ++i) {
350 Value strideVal = strides[i];
351 Value offsetContrib =
352 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
353 baseOffset =
354 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
355 }
356 // Broadcast base offset to match vector shape
357 Value bcastBase = vector::BroadcastOp::create(
358 rewriter, loc, fullIndexVectorType, baseOffset);
359 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
360 return localOffsets;
361}
362
363// Compute the element-wise offsets for vector.gather or vector.scatter ops.
364//
365// This function linearizes the base offsets of the gather/scatter operation
366// and combines them with the per-element indices to produce a final vector of
367// memory offsets.
368template <
369 typename OpType,
370 typename = std::enable_if_t<llvm::is_one_of<
371 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
372static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
373 ArrayRef<Value> strides, Value baseOffset) {
374 Location loc = gatScatOp.getLoc();
375 SmallVector<Value> offsets = gatScatOp.getOffsets();
376 for (size_t i = 0; i < offsets.size(); ++i) {
377 Value offsetContrib =
378 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
379 baseOffset =
380 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
381 }
382 Value indices = gatScatOp.getIndices();
383 VectorType vecType = cast<VectorType>(indices.getType());
384
385 Value strideVector =
386 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
387 .getResult();
388 Value stridedIndices =
389 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
390
391 Value baseVector =
392 vector::BroadcastOp::create(
393 rewriter, loc,
394 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
395 baseOffset)
396 .getResult();
397 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
398 .getResult();
399}
400
401// Collapses shapes of a nD memref to the target rank while applying offsets for
402// the collapsed dimensions. Returns the new memref value and the remaining
403// offsets for the last targetRank dimensions. For example:
404// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
405// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
406static std::pair<Value, SmallVector<OpFoldResult>>
407convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
410 int64_t targetRank) {
411 auto memrefType = cast<MemRefType>(memref.getType());
412 unsigned rank = memrefType.getRank();
413
414 if (rank <= targetRank)
415 return {memref, offsets};
416
417 int64_t numCombinedDims = rank - targetRank;
418 SmallVector<OpFoldResult> subviewOffsets;
419 SmallVector<OpFoldResult> subviewSizes;
420 SmallVector<OpFoldResult> subviewStrides;
421
422 // For the combined dimensions: use the provided offsets, size=1, stride=1
423 for (unsigned i = 0; i < numCombinedDims; ++i) {
424 subviewOffsets.push_back(offsets[i]);
425 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
426 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
427 }
428
429 // For the last targetRank dimensions: offset=0, use full size, stride=1
430 SmallVector<int64_t> resultShape;
431 auto originalShape = memrefType.getShape();
432 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
433 for (unsigned i = numCombinedDims; i < rank; ++i) {
434 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
435 if (ShapedType::isDynamic(originalShape[i])) {
436 subviewSizes.push_back(meta.getSizes()[i]);
437 resultShape.push_back(ShapedType::kDynamic);
438 } else {
439 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
440 resultShape.push_back(originalShape[i]);
441 }
442 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
443 }
444
445 auto resultType = memref::SubViewOp::inferRankReducedResultType(
446 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
447 auto subviewOp =
448 memref::SubViewOp::create(rewriter, loc, resultType, memref,
449 subviewOffsets, subviewSizes, subviewStrides);
450
451 // Return the remaining offsets for the last targetRank dimensions
452 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
453 offsets.end());
454 return {subviewOp.getResult(), newOffsets};
455}
456
457template <
458 typename OpType,
459 typename = std::enable_if_t<llvm::is_one_of<
460 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
461 vector::GatherOp, vector::ScatterOp>::value>>
462// Convert memref to i64 base pointer
463static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
464 Location loc = xferOp.getLoc();
465 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
466 rewriter, loc, xferOp.getBase())
467 .getResult();
468 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
469 indexPtr)
470 .getResult();
471}
472
473static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
474 PatternRewriter &rewriter) {
475
476 Location loc = readOp.getLoc();
477 VectorType vectorType = readOp.getVectorType();
478 ArrayRef<int64_t> vectorShape = vectorType.getShape();
479 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
480 if (!memrefType)
481 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
482
483 auto meta = computeMemrefMeta(readOp, rewriter);
484 if (meta.first.empty())
485 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
486
487 Value localOffsets =
488 computeOffsets(readOp, rewriter, meta.first, meta.second);
489
490 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
491
492 Value mask = vector::ConstantMaskOp::create(
493 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
495 auto gatherOp = xegpu::LoadGatherOp::create(
496 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
497 /*chunk_size=*/IntegerAttr{},
498 /*l1_hint=*/xegpu::CachePolicyAttr{},
499 /*l2_hint=*/xegpu::CachePolicyAttr{},
500 /*l3_hint=*/xegpu::CachePolicyAttr{},
501 /*layout=*/nullptr);
502
503 rewriter.replaceOp(readOp, gatherOp.getResult());
504 return success();
505}
506
507static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
508 PatternRewriter &rewriter) {
509
510 Location loc = writeOp.getLoc();
511 VectorType vectorType = writeOp.getVectorType();
512 ArrayRef<int64_t> vectorShape = vectorType.getShape();
513
514 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
515 if (!memrefType)
516 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
517
518 auto meta = computeMemrefMeta(writeOp, rewriter);
519 if (meta.first.empty())
520 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
521
522 Value localOffsets =
523 computeOffsets(writeOp, rewriter, meta.first, meta.second);
524
525 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
526
527 Value mask = vector::ConstantMaskOp::create(
528 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
530 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
531 localOffsets, mask,
532 /*chunk_size=*/IntegerAttr{},
533 /*l1_hint=*/xegpu::CachePolicyAttr{},
534 /*l2_hint=*/xegpu::CachePolicyAttr{},
535 /*l3_hint=*/xegpu::CachePolicyAttr{},
536 /*layout=*/nullptr);
537 rewriter.eraseOp(writeOp);
538 return success();
539}
540
541struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
542 using Base::Base;
543
544 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
545 PatternRewriter &rewriter) const override {
546 Location loc = readOp.getLoc();
547
548 if (failed(transferPreconditions(rewriter, readOp)))
549 return failure();
550 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
551 VectorType loadedVecTy = readOp.getVectorType();
552 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
553 // Check if the memref has address space 3 (shared local memory)
554 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
555 // Handle the SLM case.
556 if (isSharedMemory) {
557 // If the memref is SLM only support 2D case for now.
558 if (loadedVecTy.getRank() != 2)
559 return rewriter.notifyMatchFailure(
560 readOp, "Only 2D vector loads are supported for SLM");
561 AffineMap readMap = readOp.getPermutationMap();
562 if (!readMap.isMinorIdentity())
563 return rewriter.notifyMatchFailure(
564 readOp,
565 "Non identity transposition is not supported for SLM loads.");
566 // Out of bounds case is not supported for SLM loads.
567 if (isOutOfBounds)
568 return rewriter.notifyMatchFailure(
569 readOp, "Out-of-bounds access is not supported for SLM loads");
570
571 // Create mem_desc for SLM
572 auto memDescType =
573 xegpu::MemDescType::get(rewriter.getContext(), readMemTy.getShape(),
574 readMemTy.getElementType(),
575 /*mem_layout=*/nullptr);
576 auto createMemDescOp = xegpu::CreateMemDescOp::create(
577 rewriter, loc, memDescType, readOp.getBase());
578 // Convert indices to OpFoldResult for LoadMatrixOp
579 SmallVector<OpFoldResult> indices =
580 getAsOpFoldResult(readOp.getIndices());
581 auto loadMatrixOp = xegpu::LoadMatrixOp::create(
582 rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices,
583 /*layout=*/nullptr);
584
585 rewriter.replaceOp(readOp, loadMatrixOp.getResult());
586 return success();
587 }
588
589 // TODO:This check needs to be replaced with proper uArch capability check
590 auto chip = xegpu::getChipStr(readOp);
591 // Lower to scattered load Op if the target HW doesn't have 2d block load
592 // support and the load is not from shared memory.
593 if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
594 readOp.getVectorType().getRank() > 2) {
595
596 // TODO: add support for OutOfBound access
597 if (isOutOfBounds)
598 return failure();
599 return lowerToScatteredLoadOp(readOp, rewriter);
600 }
601
602 // Handle the 1D non-SLM case using load.gather.
603 if (loadedVecTy.getRank() == 1 && !isOutOfBounds)
604 return lowerToScatteredLoadOp(readOp, rewriter);
605
606 // Perform common data transfer checks.
607 // TODO: Maybe too strict for SLM case.
608 if (failed(
609 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
610 return failure();
611
612 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
613 return rewriter.notifyMatchFailure(
614 readOp, "Unsupported non-zero padded out-of-bounds read");
615
616 AffineMap readMap = readOp.getPermutationMap();
617 // Check if this is a transpose: the map must have exactly 2 results,
618 // and those 2 results must be the last 2 input dimensions interchanged.
619 // Examples:
620 // (d0, d1) -> (d1, d0) // transpose
621 // (d0, d1) -> (d0, d1) // not a transpose
622 // (d0, d1, d2) -> (d2, d1) // transpose (last 2 dims swapped)
623 bool isTransposeLoad = false;
624 if (readMap.getNumResults() == 2) {
625 auto results = readMap.getResults();
626 unsigned numInputs = readMap.getNumInputs();
627 if (numInputs >= 2) {
628 auto lastDim = getAffineDimExpr(numInputs - 1, readMap.getContext());
629 auto secondLastDim =
630 getAffineDimExpr(numInputs - 2, readMap.getContext());
631 isTransposeLoad =
632 (results[0] == lastDim && results[1] == secondLastDim);
633 }
634 }
635 auto elementType = loadedVecTy.getElementType();
636
637 SmallVector<int64_t> descShape(loadedVecTy.getShape());
638 if (isTransposeLoad) {
639 // If load is transposed, simply swap the last two dimensions of the
640 // loaded vector type to get the descriptor shape.
641 size_t rank = descShape.size();
642 assert(rank >= 2 && "Transpose requires at least 2 dimensions");
643 std::swap(descShape[rank - 1], descShape[rank - 2]);
644 loadedVecTy = VectorType::get(descShape, elementType);
645 }
646 auto descType = xegpu::TensorDescType::get(
647 descShape, elementType, /*array_length=*/1,
648 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
649 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
650 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
651 loadedVecTy.getRank());
652 // By default, no specific caching policy is assigned.
653 xegpu::CachePolicyAttr hint = nullptr;
654 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
655 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
656
657 Operation *loadedOp =
658 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices,
659 /*packed=*/nullptr, /*transpose=*/nullptr,
660 /*l1_hint=*/hint,
661 /*l2_hint=*/hint, /*l3_hint=*/hint,
662 /*layout=*/nullptr);
663 if (isTransposeLoad) {
664 // Transposing the loaded vector with a separate vector.transpose
665 // operation
666 auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
667 SmallVector<int64_t> perm(
668 range.rbegin(), range.rend()); // reverse the range for transpose
669 loadedOp = vector::TransposeOp::create(rewriter, loc,
670 loadedOp->getResult(0), perm);
671 }
672 rewriter.replaceOp(readOp, loadedOp);
673
674 return success();
675 }
676};
677
678struct TransferWriteLowering
679 : public OpRewritePattern<vector::TransferWriteOp> {
680 using Base::Base;
681
682 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
683 PatternRewriter &rewriter) const override {
684 Location loc = writeOp.getLoc();
685
686 if (failed(transferPreconditions(rewriter, writeOp)))
687 return failure();
688 // Perform common data transfer checks.
689 VectorType vecTy = writeOp.getVectorType();
690 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
691 // Check if the memref has address space 3 (shared local memory)
692 bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
693
694 // For shared local memory (address space 3), use create_mem_desc +
695 // store_matrix
696 if (isSharedMemory) {
697 // Only support 2D case for now.
698 if (vecTy.getRank() != 2)
699 return rewriter.notifyMatchFailure(
700 writeOp, "Only 2D vector stores are supported for SLM");
701 // Create mem_desc for SLM
702 auto memDescType =
703 xegpu::MemDescType::get(rewriter.getContext(), writeMemTy.getShape(),
704 writeMemTy.getElementType(),
705 /*mem_layout=*/nullptr);
706
707 auto createMemDescOp = xegpu::CreateMemDescOp::create(
708 rewriter, loc, memDescType, writeOp.getBase());
709
710 // Convert indices to OpFoldResult for StoreMatrixOp
711 SmallVector<OpFoldResult> indices =
712 getAsOpFoldResult(writeOp.getIndices());
713
714 xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
715 createMemDescOp.getResult(), indices,
716 /*layout=*/nullptr);
717
718 rewriter.eraseOp(writeOp);
719 return success();
720 }
721
722 // TODO:This check needs to be replaced with proper uArch capability check
723 auto chip = xegpu::getChipStr(writeOp);
724 // Lower to scattered store Op if the target HW doesn't have 2d block
725 // store support and the memref is not SLM.
726 if ((chip != "pvc" && chip != "bmg" && chip != "cri") ||
727 writeOp.getVectorType().getRank() > 2) {
728
729 // TODO: add support for OutOfBound access
730 if (writeOp.hasOutOfBoundsDim())
731 return failure();
732 return lowerToScatteredStoreOp(writeOp, rewriter);
733 }
734
735 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
736 return failure();
737
738 AffineMap map = writeOp.getPermutationMap();
739 if (!map.isMinorIdentity())
740 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
741
742 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
743 rewriter, loc, writeOp.getBase(),
744 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
745
746 auto descType = xegpu::TensorDescType::get(
747 vecTy.getShape(), vecTy.getElementType(),
748 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
749 xegpu::MemorySpace::Global);
750 // By default, no specific caching policy is assigned.
751 xegpu::CachePolicyAttr hint = nullptr;
752 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
753 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
754
755 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
756 ndDesc, indices,
757 /*l1_hint=*/hint,
758 /*l2_hint=*/hint, /*l3_hint=*/hint,
759 /*layout=*/nullptr);
760 rewriter.replaceOp(writeOp, storeOp);
761
762 return success();
763 }
764};
765
766struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
767 using Base::Base;
768
769 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
770 PatternRewriter &rewriter) const override {
771 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
772 if (!srcTy)
773 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
774
775 Location loc = gatherOp.getLoc();
776 VectorType vectorType = gatherOp.getVectorType();
777
778 auto meta = computeMemrefMeta(gatherOp, rewriter);
779 if (meta.first.empty())
780 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
781
782 Value localOffsets =
783 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
784 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
785
786 auto xeGatherOp = xegpu::LoadGatherOp::create(
787 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
788 /*chunk_size=*/IntegerAttr{},
789 /*l1_hint=*/xegpu::CachePolicyAttr{},
790 /*l2_hint=*/xegpu::CachePolicyAttr{},
791 /*l3_hint=*/xegpu::CachePolicyAttr{},
792 /*layout=*/nullptr);
793
794 auto selectOp =
795 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
796 xeGatherOp.getResult(), gatherOp.getPassThru());
797 rewriter.replaceOp(gatherOp, selectOp.getResult());
798 return success();
799 }
800};
801
802struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
803 using Base::Base;
804
805 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
806 PatternRewriter &rewriter) const override {
807 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
808 if (!srcTy)
809 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
810
811 Location loc = scatterOp.getLoc();
812 auto meta = computeMemrefMeta(scatterOp, rewriter);
813 if (meta.first.empty())
814 return rewriter.notifyMatchFailure(scatterOp,
815 "Failed to compute strides");
816
817 Value localOffsets =
818 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
819 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
820
821 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
822 flatMemref, localOffsets, scatterOp.getMask(),
823 /*chunk_size=*/IntegerAttr{},
824 /*l1_hint=*/xegpu::CachePolicyAttr{},
825 /*l2_hint=*/xegpu::CachePolicyAttr{},
826 /*l3_hint=*/xegpu::CachePolicyAttr{},
827 /*layout=*/nullptr);
828 rewriter.eraseOp(scatterOp);
829 return success();
830 }
831};
832
833struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
834 using Base::Base;
835
836 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
837 PatternRewriter &rewriter) const override {
838 Location loc = loadOp.getLoc();
839
840 VectorType vecTy = loadOp.getResult().getType();
841 MemRefType memTy = loadOp.getBase().getType();
842 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
843 return failure();
844
845 // Boundary check is available only for block instructions.
846 bool boundaryCheck = vecTy.getRank() > 1;
847 // By default, no specific caching policy is assigned.
848 xegpu::CachePolicyAttr hint = nullptr;
849
850 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
851 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
852 vecTy.getRank());
853
854 auto descType = xegpu::TensorDescType::get(
855 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
856 boundaryCheck, xegpu::MemorySpace::Global);
857
858 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
859 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
860 auto loadNdOp =
861 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
862 /*packed=*/nullptr, /*transpose=*/nullptr,
863 /*l1_hint=*/hint,
864 /*l2_hint=*/hint, /*l3_hint=*/hint,
865 /*layout=*/nullptr);
866 rewriter.replaceOp(loadOp, loadNdOp);
867
868 return success();
869 }
870};
871
872struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
873 using Base::Base;
874
875 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
876 PatternRewriter &rewriter) const override {
877 Location loc = storeOp.getLoc();
878
879 TypedValue<VectorType> vector = storeOp.getValueToStore();
880 VectorType vecTy = vector.getType();
881 MemRefType memTy = storeOp.getBase().getType();
882 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
883 return failure();
884
885 // Boundary check is available only for block instructions.
886 bool boundaryCheck = vecTy.getRank() > 1;
887
888 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
889 rewriter, loc, storeOp.getBase(),
890 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
891
892 auto descType = xegpu::TensorDescType::get(
893 vecTy.getShape(), vecTy.getElementType(),
894 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
895
896 // By default, no specific caching policy is assigned.
897 xegpu::CachePolicyAttr hint = nullptr;
898 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
899 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
900
901 auto storeNdOp =
902 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
903 /*l1_hint=*/hint,
904 /*l2_hint=*/hint, /*l3_hint=*/hint,
905 /*layout=*/nullptr);
906
907 rewriter.replaceOp(storeOp, storeNdOp);
908
909 return success();
910 }
911};
912
913struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
914 using Base::Base;
915
916 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
917 PatternRewriter &rewriter) const override {
918 Location loc = contractOp.getLoc();
919
920 if (contractOp.getKind() != vector::CombiningKind::ADD)
921 return rewriter.notifyMatchFailure(contractOp,
922 "Expects add combining kind");
923
924 TypedValue<Type> acc = contractOp.getAcc();
925 VectorType accType = dyn_cast<VectorType>(acc.getType());
926 if (!accType || accType.getRank() != 2)
927 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
928
929 // Accept only plain 2D data layout.
930 // VNNI packing is applied to DPAS as a separate lowering step.
931 TypedValue<VectorType> lhs = contractOp.getLhs();
932 TypedValue<VectorType> rhs = contractOp.getRhs();
933 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
934 return rewriter.notifyMatchFailure(contractOp,
935 "Expects lhs and rhs 2D vectors");
936
937 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
938 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
939
940 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
941 TypeRange{contractOp.getResultType()},
942 ValueRange{lhs, rhs, acc});
943 rewriter.replaceOp(contractOp, dpasOp);
944
945 return success();
946 }
947};
948
949struct ConvertVectorToXeGPUPass
950 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
951 void runOnOperation() override {
952 RewritePatternSet patterns(&getContext());
955 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
956 return signalPassFailure();
957 }
958};
959
960} // namespace
961
963 RewritePatternSet &patterns) {
964 patterns
965 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
966 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
967 patterns.getContext());
968}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...