MLIR 23.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
15
23#include "mlir/Pass/Pass.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27#include <algorithm>
28#include <optional>
29
30namespace mlir {
31#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36
37namespace {
38
39// Return true if value represents a zero constant.
40static bool isZeroConstant(Value val) {
41 auto constant = val.getDefiningOp<arith::ConstantOp>();
42 if (!constant)
43 return false;
44
45 return TypeSwitch<Attribute, bool>(constant.getValue())
46 .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
47 .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
48 .Default(false);
49}
50
51static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
52 Operation *op, VectorType vecTy,
53 MemRefType memTy) {
54 // Validate only vector as the basic vector store and load ops guarantee
55 // XeGPU-compatible memref source.
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
58 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59
60 if (!vecTy.getElementType().isIntOrFloat())
61 return rewriter.notifyMatchFailure(
62 op, "Expected scalar type with known bitwidth");
63
64 // XeGPU requires the memref to have a scalar integer or float element type.
65 // Memrefs with vector element types (e.g. memref<?xvector<4xf32>>) are not
66 // supported because createNdDescriptor computes byte offsets using
67 // getElementTypeBitWidth(), which asserts on non-integer/float types.
68 if (!memTy.getElementType().isIntOrFloat())
69 return rewriter.notifyMatchFailure(
70 op, "Unsupported memref element type: expected integer or float");
71
72 return success();
73}
74
75static LogicalResult transferPreconditions(PatternRewriter &rewriter,
76 VectorTransferOpInterface xferOp) {
77 if (xferOp.getMask())
78 return rewriter.notifyMatchFailure(xferOp,
79 "Masked transfer is not supported");
80
81 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
82 if (!srcTy)
83 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
84
85 // Validate further transfer op semantics.
87 int64_t offset;
88 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
89 return rewriter.notifyMatchFailure(
90 xferOp, "Buffer must be contiguous in the innermost dimension");
91
92 VectorType vecTy = xferOp.getVectorType();
93 unsigned vecRank = vecTy.getRank();
94 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
95 return rewriter.notifyMatchFailure(
96 xferOp, "Boundary check is available only for block instructions.");
97
98 AffineMap map = xferOp.getPermutationMap();
99 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
100 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
101 unsigned numInputDims = map.getNumInputs();
102 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
103 auto dim = dyn_cast<AffineDimExpr>(expr);
104 if (dim.getPosition() < (numInputDims - vecRank))
105 return rewriter.notifyMatchFailure(
106 xferOp, "Only the innermost dimensions can be accessed");
107 }
108
109 return success();
110}
111
112static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
113 Location loc,
114 xegpu::TensorDescType descType,
116 MemRefType srcTy = src.getType();
117 assert(srcTy.isStrided() && "Expected strided memref type");
118 auto [strides, offset] = srcTy.getStridesAndOffset();
119 bool isStatic = true;
120
121 // Memref is dynamic if any of its shape, offset or strides is dynamic.
122 if (!srcTy.hasStaticShape())
123 isStatic = false;
124
125 if (!ShapedType::isStatic(offset))
126 isStatic = false;
127
128 for (auto stride : strides) {
129 if (!ShapedType::isStatic(stride)) {
130 isStatic = false;
131 break;
132 }
133 }
134
135 xegpu::CreateNdDescOp ndDesc;
136 if (isStatic) {
137 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
138 } else {
139 // In case of ranked dynamic memref, instead of passing on the memref,
140 // i64 base address, source's offset, shape and strides have to be
141 // explicitly provided.
142 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
143 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
144 rewriter, loc, meta.getBaseBuffer());
145 auto offset = meta.getOffset();
146 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
147 auto offsetInBytes = arith::MulIOp::create(
148 rewriter, loc, offset,
149 arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
150 auto adjustedBaseAddr = arith::AddIOp::create(
151 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
152 auto adjustedAddrI64 = arith::IndexCastOp::create(
153 rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
154 ndDesc = xegpu::CreateNdDescOp::create(
155 rewriter, loc, descType, adjustedAddrI64,
156 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
157 }
158
159 return ndDesc;
160}
161
162// Adjusts the strides of a memref according to a given permutation map for
163// vector operations.
164//
165// This function updates the innermost strides in the `strides` array to
166// reflect the permutation specified by `permMap`. The permutation is computed
167// using the inverse and broadcasting-aware version of the permutation map,
168// and is applied to the relevant strides. This ensures that memory accesses
169// are consistent with the logical permutation of vector elements.
170//
171// Example:
172// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
173// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
174// 0]), then after calling this function, the last two strides will be
175// swapped:
176// Original strides: [s0, s1, s2, s3]
177// After permutation: [s0, s1, s3, s2]
178//
179static void adjustStridesForPermutation(AffineMap permMap,
180 SmallVectorImpl<Value> &strides) {
181
185 SmallVector<int64_t> perms64(perms.begin(), perms.end());
186 strides = applyPermutation(strides, perms64);
187}
188
189// Computes memory strides and a memref offset for vector transfer operations,
190// handling both static and dynamic memrefs while applying permutation
191// transformations for XeGPU lowering.
192template <
193 typename OpType,
194 typename = std::enable_if_t<llvm::is_one_of<
195 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
196 vector::GatherOp, vector::ScatterOp>::value>>
197static std::pair<SmallVector<Value>, Value>
198computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
199 SmallVector<Value> strides;
200 Value baseMemref = xferOp.getBase();
201 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
202
203 Location loc = xferOp.getLoc();
204 Value offsetVal = nullptr;
205 if (memrefType.hasStaticShape()) {
206 int64_t offset;
207 SmallVector<int64_t> intStrides;
208 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
209 return {{}, offsetVal};
210 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
211 return ShapedType::isDynamic(strideVal);
212 });
213
214 if (!hasDynamicStrides)
215 for (int64_t s : intStrides)
216 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
217
218 if (!ShapedType::isDynamic(offset))
219 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
220 }
221
222 if (strides.empty() || !offsetVal) {
223 // For dynamic shape memref, use memref.extract_strided_metadata to get
224 // stride values
225 unsigned rank = memrefType.getRank();
226 Type indexType = rewriter.getIndexType();
227
228 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
229 // size0, size1, ..., sizeN-1]
230 SmallVector<Type> resultTypes;
231 resultTypes.push_back(MemRefType::get(
232 {}, memrefType.getElementType())); // base memref (unranked)
233 resultTypes.push_back(indexType); // offset
234
235 for (unsigned i = 0; i < rank; ++i)
236 resultTypes.push_back(indexType); // strides
237
238 for (unsigned i = 0; i < rank; ++i)
239 resultTypes.push_back(indexType); // sizes
240
241 auto meta = memref::ExtractStridedMetadataOp::create(
242 rewriter, loc, resultTypes, baseMemref);
243
244 if (strides.empty())
245 strides.append(meta.getStrides().begin(), meta.getStrides().end());
246
247 if (!offsetVal)
248 offsetVal = meta.getOffset();
249 }
250
251 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
252 vector::TransferWriteOp>::value) {
253 AffineMap permMap = xferOp.getPermutationMap();
254 // Adjust strides according to the permutation map (e.g., for transpose)
255 adjustStridesForPermutation(permMap, strides);
256 }
257
258 return {strides, offsetVal};
259}
260
261// This function compute the vectors of localOffsets for scattered load/stores.
262// It is used in the lowering of vector.transfer_read/write to
263// load_gather/store_scatter Example:
264// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
265// %cst {in_bounds = [true, true, true, true]}>} :
266// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
267//
268// %6 = vector.step: vector<4xindex>
269// %7 = vector.step: vector<2xindex>
270// %8 = vector.step: vector<6xindex>
271// %9 = vector.step: vector<32xindex>
272// %10 = arith.mul %6, 384
273// %11 = arith.mul %7, 192
274// %12 = arith.mul %8, 32
275// %13 = arith.mul %9, 1
276// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
277// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
278// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
279// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
280// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
281// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
282// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
283// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
284// %22 = arith.add %18, %19
285// %23 = arith.add %20, %21
286// %local_offsets = arith.add %22, %23
287// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
288// %offsets = memref_offset + orig_offset + local_offsets
289static Value computeOffsets(VectorTransferOpInterface xferOp,
290 PatternRewriter &rewriter, ArrayRef<Value> strides,
291 Value baseOffset) {
292 Location loc = xferOp.getLoc();
293 VectorType vectorType = xferOp.getVectorType();
294 SmallVector<Value> indices(xferOp.getIndices().begin(),
295 xferOp.getIndices().end());
296 ArrayRef<int64_t> vectorShape = vectorType.getShape();
297
298 // Create vector.step operations for each dimension
299 SmallVector<Value> stepVectors;
300 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
301 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
302 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
303 stepVectors.push_back(stepOp);
304 return stepOp;
305 });
306
307 // Multiply step vectors by corresponding strides
308 size_t memrefRank = strides.size();
309 size_t vectorRank = vectorShape.size();
310 SmallVector<Value> strideMultiplied;
311 for (size_t i = 0; i < vectorRank; ++i) {
312 size_t memrefDim = memrefRank - vectorRank + i;
313 Value strideValue = strides[memrefDim];
314 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
315 auto bcastOp =
316 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
317 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
318 strideMultiplied.push_back(mulOp);
319 }
320
321 // Shape cast each multiplied vector to add singleton dimensions
322 SmallVector<Value> shapeCasted;
323 for (size_t i = 0; i < vectorRank; ++i) {
324 SmallVector<int64_t> newShape(vectorRank, 1);
325 newShape[i] = vectorShape[i];
326 auto newType = VectorType::get(newShape, rewriter.getIndexType());
327 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
328 strideMultiplied[i]);
329 shapeCasted.push_back(castOp);
330 }
331
332 // Broadcast each shape-casted vector to full vector shape
333 SmallVector<Value> broadcasted;
334 auto fullIndexVectorType =
335 VectorType::get(vectorShape, rewriter.getIndexType());
336 for (Value shapeCastVal : shapeCasted) {
337 auto broadcastOp = vector::BroadcastOp::create(
338 rewriter, loc, fullIndexVectorType, shapeCastVal);
339 broadcasted.push_back(broadcastOp);
340 }
341
342 // Add all broadcasted vectors together to compute local offsets
343 Value localOffsets = broadcasted[0];
344 for (size_t i = 1; i < broadcasted.size(); ++i)
345 localOffsets =
346 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
347
348 // Compute base offset from transfer read indices
349 for (size_t i = 0; i < indices.size(); ++i) {
350 Value strideVal = strides[i];
351 Value offsetContrib =
352 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
353 baseOffset =
354 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
355 }
356 // Broadcast base offset to match vector shape
357 Value bcastBase = vector::BroadcastOp::create(
358 rewriter, loc, fullIndexVectorType, baseOffset);
359 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
360 return localOffsets;
361}
362
363// Compute the element-wise offsets for vector.gather or vector.scatter ops.
364//
365// This function linearizes the base offsets of the gather/scatter operation
366// and combines them with the per-element indices to produce a final vector of
367// memory offsets.
368template <
369 typename OpType,
370 typename = std::enable_if_t<llvm::is_one_of<
371 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
372static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
373 ArrayRef<Value> strides, Value baseOffset) {
374 Location loc = gatScatOp.getLoc();
375 SmallVector<Value> offsets = gatScatOp.getOffsets();
376 for (size_t i = 0; i < offsets.size(); ++i) {
377 Value offsetContrib =
378 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
379 baseOffset =
380 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
381 }
382 Value indices = gatScatOp.getIndices();
383 VectorType vecType = cast<VectorType>(indices.getType());
384
385 Value strideVector =
386 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
387 .getResult();
388 Value stridedIndices =
389 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
390
391 Value baseVector =
392 vector::BroadcastOp::create(
393 rewriter, loc,
394 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
395 baseOffset)
396 .getResult();
397 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
398 .getResult();
399}
400
401// Collapses shapes of a nD memref to the target rank while applying offsets for
402// the collapsed dimensions. Returns the new memref value and the remaining
403// offsets for the last targetRank dimensions. For example:
404// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
405// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
406static std::pair<Value, SmallVector<OpFoldResult>>
407convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
410 int64_t targetRank) {
411 auto memrefType = cast<MemRefType>(memref.getType());
412 unsigned rank = memrefType.getRank();
413
414 if (rank <= targetRank)
415 return {memref, offsets};
416
417 int64_t numCombinedDims = rank - targetRank;
418 SmallVector<OpFoldResult> subviewOffsets;
419 SmallVector<OpFoldResult> subviewSizes;
420 SmallVector<OpFoldResult> subviewStrides;
421
422 // For the combined dimensions: use the provided offsets, size=1, stride=1
423 for (unsigned i = 0; i < numCombinedDims; ++i) {
424 subviewOffsets.push_back(offsets[i]);
425 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
426 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
427 }
428
429 // For the last targetRank dimensions: offset=0, use full size, stride=1
430 SmallVector<int64_t> resultShape;
431 auto originalShape = memrefType.getShape();
432 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
433 for (unsigned i = numCombinedDims; i < rank; ++i) {
434 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
435 if (ShapedType::isDynamic(originalShape[i])) {
436 subviewSizes.push_back(meta.getSizes()[i]);
437 resultShape.push_back(ShapedType::kDynamic);
438 } else {
439 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
440 resultShape.push_back(originalShape[i]);
441 }
442 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
443 }
444
445 auto resultType = memref::SubViewOp::inferRankReducedResultType(
446 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
447 auto subviewOp =
448 memref::SubViewOp::create(rewriter, loc, resultType, memref,
449 subviewOffsets, subviewSizes, subviewStrides);
450
451 // Return the remaining offsets for the last targetRank dimensions
452 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
453 offsets.end());
454 return {subviewOp.getResult(), newOffsets};
455}
456
457template <
458 typename OpType,
459 typename = std::enable_if_t<llvm::is_one_of<
460 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
461 vector::GatherOp, vector::ScatterOp>::value>>
462// Convert memref to i64 base pointer
463static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
464 Location loc = xferOp.getLoc();
465 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
466 rewriter, loc, xferOp.getBase())
467 .getResult();
468 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
469 indexPtr)
470 .getResult();
471}
472
473static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
474 PatternRewriter &rewriter) {
475
476 Location loc = readOp.getLoc();
477 VectorType vectorType = readOp.getVectorType();
478 ArrayRef<int64_t> vectorShape = vectorType.getShape();
479 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
480 if (!memrefType)
481 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
482
483 auto meta = computeMemrefMeta(readOp, rewriter);
484 if (meta.first.empty())
485 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
486
487 Value localOffsets =
488 computeOffsets(readOp, rewriter, meta.first, meta.second);
489
490 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
491
492 Value mask = vector::ConstantMaskOp::create(
493 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
495 auto gatherOp = xegpu::LoadGatherOp::create(
496 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
497 /*chunk_size=*/IntegerAttr{},
498 /*l1_hint=*/xegpu::CachePolicyAttr{},
499 /*l2_hint=*/xegpu::CachePolicyAttr{},
500 /*l3_hint=*/xegpu::CachePolicyAttr{},
501 /*layout=*/nullptr);
502
503 rewriter.replaceOp(readOp, gatherOp.getResult());
504 return success();
505}
506
507static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
508 PatternRewriter &rewriter) {
509
510 Location loc = writeOp.getLoc();
511 VectorType vectorType = writeOp.getVectorType();
512 ArrayRef<int64_t> vectorShape = vectorType.getShape();
513
514 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
515 if (!memrefType)
516 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
517
518 auto meta = computeMemrefMeta(writeOp, rewriter);
519 if (meta.first.empty())
520 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
521
522 Value localOffsets =
523 computeOffsets(writeOp, rewriter, meta.first, meta.second);
524
525 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
526
527 Value mask = vector::ConstantMaskOp::create(
528 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
530 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
531 localOffsets, mask,
532 /*chunk_size=*/IntegerAttr{},
533 /*l1_hint=*/xegpu::CachePolicyAttr{},
534 /*l2_hint=*/xegpu::CachePolicyAttr{},
535 /*l3_hint=*/xegpu::CachePolicyAttr{},
536 /*layout=*/nullptr);
537 rewriter.eraseOp(writeOp);
538 return success();
539}
540
541struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
542 using Base::Base;
543
544 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
545 PatternRewriter &rewriter) const override {
546 Location loc = readOp.getLoc();
547
548 if (failed(transferPreconditions(rewriter, readOp)))
549 return failure();
550
551 // TODO:This check needs to be replaced with proper uArch capability check
552 auto chip = xegpu::getChipStr(readOp);
553 if (chip != "pvc" && chip != "bmg") {
554 // lower to scattered load Op if the target HW doesn't have 2d block load
555 // support
556 // TODO: add support for OutOfBound access
557 if (readOp.hasOutOfBoundsDim())
558 return failure();
559 return lowerToScatteredLoadOp(readOp, rewriter);
560 }
561
562 VectorType loadedVecTy = readOp.getVectorType();
563
564 // Lower using load.gather in 1D case
565 if (loadedVecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
566 return lowerToScatteredLoadOp(readOp, rewriter);
567
568 // Perform common data transfer checks.
569 auto readMemTy = cast<MemRefType>(readOp.getShapedType());
570 if (failed(
571 storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
572 return failure();
573
574 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
575 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
576 return rewriter.notifyMatchFailure(
577 readOp, "Unsupported non-zero padded out-of-bounds read");
578
579 AffineMap readMap = readOp.getPermutationMap();
580 bool isTransposeLoad = !readMap.isMinorIdentity();
581 auto elementType = loadedVecTy.getElementType();
582
583 SmallVector<int64_t> descShape(loadedVecTy.getShape());
584 if (isTransposeLoad) {
585 // If load is transposed, then the shape of the source-descriptor
586 // is the opposite from the result-shape. Applying the permutation
587 // to get the reversive shape.
588 auto inversedMap = inversePermutation(readMap);
589 descShape = applyPermutationMap(inversedMap, loadedVecTy.getShape());
590 loadedVecTy = VectorType::get(descShape, elementType);
591 }
592 auto descType = xegpu::TensorDescType::get(
593 descShape, elementType, /*array_length=*/1,
594 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
595 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
596 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
597 loadedVecTy.getRank());
598 // By default, no specific caching policy is assigned.
599 xegpu::CachePolicyAttr hint = nullptr;
600 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
601 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
602
603 Operation *loadedOp =
604 xegpu::LoadNdOp::create(rewriter, loc, loadedVecTy, ndDesc, indices,
605 /*packed=*/nullptr, /*transpose=*/nullptr,
606 /*l1_hint=*/hint,
607 /*l2_hint=*/hint, /*l3_hint=*/hint,
608 /*layout=*/nullptr);
609 if (isTransposeLoad) {
610 // Transposing the loaded vector with a separate vector.transpose
611 // operation
612 auto range = llvm::seq<int64_t>(0, readMap.getResults().size());
613 SmallVector<int64_t> perm(range.begin(), range.end());
614 auto permApplied = applyPermutationMap<int64_t>(readMap, perm);
615 loadedOp = vector::TransposeOp::create(
616 rewriter, loc, loadedOp->getResult(0), permApplied);
617 }
618 rewriter.replaceOp(readOp, loadedOp);
619
620 return success();
621 }
622};
623
624struct TransferWriteLowering
625 : public OpRewritePattern<vector::TransferWriteOp> {
626 using Base::Base;
627
628 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
629 PatternRewriter &rewriter) const override {
630 Location loc = writeOp.getLoc();
631
632 if (failed(transferPreconditions(rewriter, writeOp)))
633 return failure();
634
635 // TODO:This check needs to be replaced with proper uArch capability check
636 auto chip = xegpu::getChipStr(writeOp);
637 if (chip != "pvc" && chip != "bmg") {
638 // lower to scattered store Op if the target HW doesn't have 2d block
639 // store support
640 // TODO: add support for OutOfBound access
641 if (writeOp.hasOutOfBoundsDim())
642 return failure();
643 return lowerToScatteredStoreOp(writeOp, rewriter);
644 }
645
646 // Perform common data transfer checks.
647 VectorType vecTy = writeOp.getVectorType();
648 auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
649 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
650 return failure();
651
652 AffineMap map = writeOp.getPermutationMap();
653 if (!map.isMinorIdentity())
654 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
655
656 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
657 rewriter, loc, writeOp.getBase(),
658 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
659
660 auto descType = xegpu::TensorDescType::get(
661 vecTy.getShape(), vecTy.getElementType(),
662 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
663 xegpu::MemorySpace::Global);
664 // By default, no specific caching policy is assigned.
665 xegpu::CachePolicyAttr hint = nullptr;
666 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
667 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
668
669 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
670 ndDesc, indices,
671 /*l1_hint=*/hint,
672 /*l2_hint=*/hint, /*l3_hint=*/hint,
673 /*layout=*/nullptr);
674 rewriter.replaceOp(writeOp, storeOp);
675
676 return success();
677 }
678};
679
680struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
681 using Base::Base;
682
683 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
684 PatternRewriter &rewriter) const override {
685 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
686 if (!srcTy)
687 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
688
689 Location loc = gatherOp.getLoc();
690 VectorType vectorType = gatherOp.getVectorType();
691
692 auto meta = computeMemrefMeta(gatherOp, rewriter);
693 if (meta.first.empty())
694 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
695
696 Value localOffsets =
697 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
698 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
699
700 auto xeGatherOp = xegpu::LoadGatherOp::create(
701 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
702 /*chunk_size=*/IntegerAttr{},
703 /*l1_hint=*/xegpu::CachePolicyAttr{},
704 /*l2_hint=*/xegpu::CachePolicyAttr{},
705 /*l3_hint=*/xegpu::CachePolicyAttr{},
706 /*layout=*/nullptr);
707
708 auto selectOp =
709 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
710 xeGatherOp.getResult(), gatherOp.getPassThru());
711 rewriter.replaceOp(gatherOp, selectOp.getResult());
712 return success();
713 }
714};
715
716struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
717 using Base::Base;
718
719 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
720 PatternRewriter &rewriter) const override {
721 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
722 if (!srcTy)
723 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
724
725 Location loc = scatterOp.getLoc();
726 auto meta = computeMemrefMeta(scatterOp, rewriter);
727 if (meta.first.empty())
728 return rewriter.notifyMatchFailure(scatterOp,
729 "Failed to compute strides");
730
731 Value localOffsets =
732 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
733 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
734
735 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
736 flatMemref, localOffsets, scatterOp.getMask(),
737 /*chunk_size=*/IntegerAttr{},
738 /*l1_hint=*/xegpu::CachePolicyAttr{},
739 /*l2_hint=*/xegpu::CachePolicyAttr{},
740 /*l3_hint=*/xegpu::CachePolicyAttr{},
741 /*layout=*/nullptr);
742 rewriter.eraseOp(scatterOp);
743 return success();
744 }
745};
746
747struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
748 using Base::Base;
749
750 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
751 PatternRewriter &rewriter) const override {
752 Location loc = loadOp.getLoc();
753
754 VectorType vecTy = loadOp.getResult().getType();
755 MemRefType memTy = loadOp.getBase().getType();
756 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy, memTy)))
757 return failure();
758
759 // Boundary check is available only for block instructions.
760 bool boundaryCheck = vecTy.getRank() > 1;
761 // By default, no specific caching policy is assigned.
762 xegpu::CachePolicyAttr hint = nullptr;
763
764 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
765 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
766 vecTy.getRank());
767
768 auto descType = xegpu::TensorDescType::get(
769 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
770 boundaryCheck, xegpu::MemorySpace::Global);
771
772 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
773 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
774 auto loadNdOp =
775 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
776 /*packed=*/nullptr, /*transpose=*/nullptr,
777 /*l1_hint=*/hint,
778 /*l2_hint=*/hint, /*l3_hint=*/hint,
779 /*layout=*/nullptr);
780 rewriter.replaceOp(loadOp, loadNdOp);
781
782 return success();
783 }
784};
785
786struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
787 using Base::Base;
788
789 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
790 PatternRewriter &rewriter) const override {
791 Location loc = storeOp.getLoc();
792
793 TypedValue<VectorType> vector = storeOp.getValueToStore();
794 VectorType vecTy = vector.getType();
795 MemRefType memTy = storeOp.getBase().getType();
796 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy, memTy)))
797 return failure();
798
799 // Boundary check is available only for block instructions.
800 bool boundaryCheck = vecTy.getRank() > 1;
801
802 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
803 rewriter, loc, storeOp.getBase(),
804 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
805
806 auto descType = xegpu::TensorDescType::get(
807 vecTy.getShape(), vecTy.getElementType(),
808 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
809
810 // By default, no specific caching policy is assigned.
811 xegpu::CachePolicyAttr hint = nullptr;
812 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
813 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
814
815 auto storeNdOp =
816 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
817 /*l1_hint=*/hint,
818 /*l2_hint=*/hint, /*l3_hint=*/hint,
819 /*layout=*/nullptr);
820
821 rewriter.replaceOp(storeOp, storeNdOp);
822
823 return success();
824 }
825};
826
827struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
828 using Base::Base;
829
830 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
831 PatternRewriter &rewriter) const override {
832 Location loc = contractOp.getLoc();
833
834 if (contractOp.getKind() != vector::CombiningKind::ADD)
835 return rewriter.notifyMatchFailure(contractOp,
836 "Expects add combining kind");
837
838 TypedValue<Type> acc = contractOp.getAcc();
839 VectorType accType = dyn_cast<VectorType>(acc.getType());
840 if (!accType || accType.getRank() != 2)
841 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
842
843 // Accept only plain 2D data layout.
844 // VNNI packing is applied to DPAS as a separate lowering step.
845 TypedValue<VectorType> lhs = contractOp.getLhs();
846 TypedValue<VectorType> rhs = contractOp.getRhs();
847 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
848 return rewriter.notifyMatchFailure(contractOp,
849 "Expects lhs and rhs 2D vectors");
850
851 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
852 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
853
854 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
855 TypeRange{contractOp.getResultType()},
856 ValueRange{lhs, rhs, acc});
857 rewriter.replaceOp(contractOp, dpasOp);
858
859 return success();
860 }
861};
862
863struct ConvertVectorToXeGPUPass
864 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
865 void runOnOperation() override {
866 RewritePatternSet patterns(&getContext());
869 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
870 return signalPassFailure();
871 }
872};
873
874} // namespace
875
877 RewritePatternSet &patterns) {
878 patterns
879 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
880 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
881 patterns.getContext());
882}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getI1Type()
Definition Builders.cpp:57
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...