MLIR 23.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26#include <algorithm>
27#include <optional>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Return true if value represents a zero constant.
39static bool isZeroConstant(Value val) {
40 auto constant = val.getDefiningOp<arith::ConstantOp>();
41 if (!constant)
42 return false;
43
44 return TypeSwitch<Attribute, bool>(constant.getValue())
45 .Case([](FloatAttr floatAttr) { return floatAttr.getValue().isZero(); })
46 .Case([](IntegerAttr intAttr) { return intAttr.getValue().isZero(); })
47 .Default(false);
48}
49
50static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51 Operation *op, VectorType vecTy) {
52 // Validate only vector as the basic vector store and load ops guarantee
53 // XeGPU-compatible memref source.
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
56 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
57
58 if (!vecTy.getElementType().isIntOrFloat())
59 return rewriter.notifyMatchFailure(
60 op, "Expected scalar type with known bitwidth");
61
62 return success();
63}
64
65static LogicalResult transferPreconditions(PatternRewriter &rewriter,
66 VectorTransferOpInterface xferOp) {
67 if (xferOp.getMask())
68 return rewriter.notifyMatchFailure(xferOp,
69 "Masked transfer is not supported");
70
71 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
72 if (!srcTy)
73 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
74
75 // Validate further transfer op semantics.
77 int64_t offset;
78 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
79 return rewriter.notifyMatchFailure(
80 xferOp, "Buffer must be contiguous in the innermost dimension");
81
82 VectorType vecTy = xferOp.getVectorType();
83 unsigned vecRank = vecTy.getRank();
84 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
85 return rewriter.notifyMatchFailure(
86 xferOp, "Boundary check is available only for block instructions.");
87
88 AffineMap map = xferOp.getPermutationMap();
89 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
90 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
91 unsigned numInputDims = map.getNumInputs();
92 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
93 auto dim = dyn_cast<AffineDimExpr>(expr);
94 if (dim.getPosition() < (numInputDims - vecRank))
95 return rewriter.notifyMatchFailure(
96 xferOp, "Only the innermost dimensions can be accessed");
97 }
98
99 return success();
100}
101
102static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
103 Location loc,
104 xegpu::TensorDescType descType,
106 MemRefType srcTy = src.getType();
107 assert(srcTy.isStrided() && "Expected strided memref type");
108 auto [strides, offset] = srcTy.getStridesAndOffset();
109 bool isStatic = true;
110
111 // Memref is dynamic if any of its shape, offset or strides is dynamic.
112 if (!srcTy.hasStaticShape())
113 isStatic = false;
114
115 if (!ShapedType::isStatic(offset))
116 isStatic = false;
117
118 for (auto stride : strides) {
119 if (!ShapedType::isStatic(stride)) {
120 isStatic = false;
121 break;
122 }
123 }
124
125 xegpu::CreateNdDescOp ndDesc;
126 if (isStatic) {
127 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
128 } else {
129 // In case of ranked dynamic memref, instead of passing on the memref,
130 // i64 base address, source's offset, shape and strides have to be
131 // explicitly provided.
132 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
133 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
134 rewriter, loc, meta.getBaseBuffer());
135 auto offset = meta.getOffset();
136 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
137 auto offsetInBytes = arith::MulIOp::create(
138 rewriter, loc, offset,
139 arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
140 auto adjustedBaseAddr = arith::AddIOp::create(
141 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
142 auto adjustedAddrI64 = arith::IndexCastOp::create(
143 rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
144 ndDesc = xegpu::CreateNdDescOp::create(
145 rewriter, loc, descType, adjustedAddrI64,
146 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
147 }
148
149 return ndDesc;
150}
151
152// Adjusts the strides of a memref according to a given permutation map for
153// vector operations.
154//
155// This function updates the innermost strides in the `strides` array to
156// reflect the permutation specified by `permMap`. The permutation is computed
157// using the inverse and broadcasting-aware version of the permutation map,
158// and is applied to the relevant strides. This ensures that memory accesses
159// are consistent with the logical permutation of vector elements.
160//
161// Example:
162// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
163// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
164// 0]), then after calling this function, the last two strides will be
165// swapped:
166// Original strides: [s0, s1, s2, s3]
167// After permutation: [s0, s1, s3, s2]
168//
169static void adjustStridesForPermutation(AffineMap permMap,
170 SmallVectorImpl<Value> &strides) {
171
175 SmallVector<int64_t> perms64(perms.begin(), perms.end());
176 strides = applyPermutation(strides, perms64);
177}
178
179// Computes memory strides and a memref offset for vector transfer operations,
180// handling both static and dynamic memrefs while applying permutation
181// transformations for XeGPU lowering.
182template <
183 typename OpType,
184 typename = std::enable_if_t<llvm::is_one_of<
185 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
186 vector::GatherOp, vector::ScatterOp>::value>>
187static std::pair<SmallVector<Value>, Value>
188computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
189 SmallVector<Value> strides;
190 Value baseMemref = xferOp.getBase();
191 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192
193 Location loc = xferOp.getLoc();
194 Value offsetVal = nullptr;
195 if (memrefType.hasStaticShape()) {
196 int64_t offset;
197 SmallVector<int64_t> intStrides;
198 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
199 return {{}, offsetVal};
200 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
201 return ShapedType::isDynamic(strideVal);
202 });
203
204 if (!hasDynamicStrides)
205 for (int64_t s : intStrides)
206 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
207
208 if (!ShapedType::isDynamic(offset))
209 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
210 }
211
212 if (strides.empty() || !offsetVal) {
213 // For dynamic shape memref, use memref.extract_strided_metadata to get
214 // stride values
215 unsigned rank = memrefType.getRank();
216 Type indexType = rewriter.getIndexType();
217
218 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
219 // size0, size1, ..., sizeN-1]
220 SmallVector<Type> resultTypes;
221 resultTypes.push_back(MemRefType::get(
222 {}, memrefType.getElementType())); // base memref (unranked)
223 resultTypes.push_back(indexType); // offset
224
225 for (unsigned i = 0; i < rank; ++i)
226 resultTypes.push_back(indexType); // strides
227
228 for (unsigned i = 0; i < rank; ++i)
229 resultTypes.push_back(indexType); // sizes
230
231 auto meta = memref::ExtractStridedMetadataOp::create(
232 rewriter, loc, resultTypes, baseMemref);
233
234 if (strides.empty())
235 strides.append(meta.getStrides().begin(), meta.getStrides().end());
236
237 if (!offsetVal)
238 offsetVal = meta.getOffset();
239 }
240
241 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
242 vector::TransferWriteOp>::value) {
243 AffineMap permMap = xferOp.getPermutationMap();
244 // Adjust strides according to the permutation map (e.g., for transpose)
245 adjustStridesForPermutation(permMap, strides);
246 }
247
248 return {strides, offsetVal};
249}
250
251// This function compute the vectors of localOffsets for scattered load/stores.
252// It is used in the lowering of vector.transfer_read/write to
253// load_gather/store_scatter Example:
254// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
255// %cst {in_bounds = [true, true, true, true]}>} :
256// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
257//
258// %6 = vector.step: vector<4xindex>
259// %7 = vector.step: vector<2xindex>
260// %8 = vector.step: vector<6xindex>
261// %9 = vector.step: vector<32xindex>
262// %10 = arith.mul %6, 384
263// %11 = arith.mul %7, 192
264// %12 = arith.mul %8, 32
265// %13 = arith.mul %9, 1
266// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
267// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
268// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
269// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
270// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
271// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
272// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
273// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
274// %22 = arith.add %18, %19
275// %23 = arith.add %20, %21
276// %local_offsets = arith.add %22, %23
277// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
278// %offsets = memref_offset + orig_offset + local_offsets
279static Value computeOffsets(VectorTransferOpInterface xferOp,
280 PatternRewriter &rewriter, ArrayRef<Value> strides,
281 Value baseOffset) {
282 Location loc = xferOp.getLoc();
283 VectorType vectorType = xferOp.getVectorType();
284 SmallVector<Value> indices(xferOp.getIndices().begin(),
285 xferOp.getIndices().end());
286 ArrayRef<int64_t> vectorShape = vectorType.getShape();
287
288 // Create vector.step operations for each dimension
289 SmallVector<Value> stepVectors;
290 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
291 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
292 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
293 stepVectors.push_back(stepOp);
294 return stepOp;
295 });
296
297 // Multiply step vectors by corresponding strides
298 size_t memrefRank = strides.size();
299 size_t vectorRank = vectorShape.size();
300 SmallVector<Value> strideMultiplied;
301 for (size_t i = 0; i < vectorRank; ++i) {
302 size_t memrefDim = memrefRank - vectorRank + i;
303 Value strideValue = strides[memrefDim];
304 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
305 auto bcastOp =
306 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
307 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
308 strideMultiplied.push_back(mulOp);
309 }
310
311 // Shape cast each multiplied vector to add singleton dimensions
312 SmallVector<Value> shapeCasted;
313 for (size_t i = 0; i < vectorRank; ++i) {
314 SmallVector<int64_t> newShape(vectorRank, 1);
315 newShape[i] = vectorShape[i];
316 auto newType = VectorType::get(newShape, rewriter.getIndexType());
317 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
318 strideMultiplied[i]);
319 shapeCasted.push_back(castOp);
320 }
321
322 // Broadcast each shape-casted vector to full vector shape
323 SmallVector<Value> broadcasted;
324 auto fullIndexVectorType =
325 VectorType::get(vectorShape, rewriter.getIndexType());
326 for (Value shapeCastVal : shapeCasted) {
327 auto broadcastOp = vector::BroadcastOp::create(
328 rewriter, loc, fullIndexVectorType, shapeCastVal);
329 broadcasted.push_back(broadcastOp);
330 }
331
332 // Add all broadcasted vectors together to compute local offsets
333 Value localOffsets = broadcasted[0];
334 for (size_t i = 1; i < broadcasted.size(); ++i)
335 localOffsets =
336 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
337
338 // Compute base offset from transfer read indices
339 for (size_t i = 0; i < indices.size(); ++i) {
340 Value strideVal = strides[i];
341 Value offsetContrib =
342 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
343 baseOffset =
344 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
345 }
346 // Broadcast base offset to match vector shape
347 Value bcastBase = vector::BroadcastOp::create(
348 rewriter, loc, fullIndexVectorType, baseOffset);
349 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
350 return localOffsets;
351}
352
353// Compute the element-wise offsets for vector.gather or vector.scatter ops.
354//
355// This function linearizes the base offsets of the gather/scatter operation
356// and combines them with the per-element indices to produce a final vector of
357// memory offsets.
358template <
359 typename OpType,
360 typename = std::enable_if_t<llvm::is_one_of<
361 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
362static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
363 ArrayRef<Value> strides, Value baseOffset) {
364 Location loc = gatScatOp.getLoc();
365 SmallVector<Value> offsets = gatScatOp.getOffsets();
366 for (size_t i = 0; i < offsets.size(); ++i) {
367 Value offsetContrib =
368 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
369 baseOffset =
370 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
371 }
372 Value indices = gatScatOp.getIndices();
373 VectorType vecType = cast<VectorType>(indices.getType());
374
375 Value strideVector =
376 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
377 .getResult();
378 Value stridedIndices =
379 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
380
381 Value baseVector =
382 vector::BroadcastOp::create(
383 rewriter, loc,
384 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
385 baseOffset)
386 .getResult();
387 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
388 .getResult();
389}
390
391// Collapses shapes of a nD memref to the target rank while applying offsets for
392// the collapsed dimensions. Returns the new memref value and the remaining
393// offsets for the last targetRank dimensions. For example:
394// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
395// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
396static std::pair<Value, SmallVector<OpFoldResult>>
397convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
400 int64_t targetRank) {
401 auto memrefType = cast<MemRefType>(memref.getType());
402 unsigned rank = memrefType.getRank();
403
404 if (rank <= targetRank)
405 return {memref, offsets};
406
407 int64_t numCombinedDims = rank - targetRank;
408 SmallVector<OpFoldResult> subviewOffsets;
409 SmallVector<OpFoldResult> subviewSizes;
410 SmallVector<OpFoldResult> subviewStrides;
411
412 // For the combined dimensions: use the provided offsets, size=1, stride=1
413 for (unsigned i = 0; i < numCombinedDims; ++i) {
414 subviewOffsets.push_back(offsets[i]);
415 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
416 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
417 }
418
419 // For the last targetRank dimensions: offset=0, use full size, stride=1
420 SmallVector<int64_t> resultShape;
421 auto originalShape = memrefType.getShape();
422 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
423 for (unsigned i = numCombinedDims; i < rank; ++i) {
424 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
425 if (ShapedType::isDynamic(originalShape[i])) {
426 subviewSizes.push_back(meta.getSizes()[i]);
427 resultShape.push_back(ShapedType::kDynamic);
428 } else {
429 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
430 resultShape.push_back(originalShape[i]);
431 }
432 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
433 }
434
435 auto resultType = memref::SubViewOp::inferRankReducedResultType(
436 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
437 auto subviewOp =
438 memref::SubViewOp::create(rewriter, loc, resultType, memref,
439 subviewOffsets, subviewSizes, subviewStrides);
440
441 // Return the remaining offsets for the last targetRank dimensions
442 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
443 offsets.end());
444 return {subviewOp.getResult(), newOffsets};
445}
446
447template <
448 typename OpType,
449 typename = std::enable_if_t<llvm::is_one_of<
450 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
451 vector::GatherOp, vector::ScatterOp>::value>>
452// Convert memref to i64 base pointer
453static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
454 Location loc = xferOp.getLoc();
455 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
456 rewriter, loc, xferOp.getBase())
457 .getResult();
458 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
459 indexPtr)
460 .getResult();
461}
462
463static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
464 PatternRewriter &rewriter) {
465
466 Location loc = readOp.getLoc();
467 VectorType vectorType = readOp.getVectorType();
468 ArrayRef<int64_t> vectorShape = vectorType.getShape();
469 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
470 if (!memrefType)
471 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
472
473 auto meta = computeMemrefMeta(readOp, rewriter);
474 if (meta.first.empty())
475 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
476
477 Value localOffsets =
478 computeOffsets(readOp, rewriter, meta.first, meta.second);
479
480 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
481
482 Value mask = vector::ConstantMaskOp::create(
483 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
485 auto gatherOp = xegpu::LoadGatherOp::create(
486 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
487 /*chunk_size=*/IntegerAttr{},
488 /*l1_hint=*/xegpu::CachePolicyAttr{},
489 /*l2_hint=*/xegpu::CachePolicyAttr{},
490 /*l3_hint=*/xegpu::CachePolicyAttr{},
491 /*layout=*/nullptr);
492
493 rewriter.replaceOp(readOp, gatherOp.getResult());
494 return success();
495}
496
497static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
498 PatternRewriter &rewriter) {
499
500 Location loc = writeOp.getLoc();
501 VectorType vectorType = writeOp.getVectorType();
502 ArrayRef<int64_t> vectorShape = vectorType.getShape();
503
504 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
505 if (!memrefType)
506 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
507
508 auto meta = computeMemrefMeta(writeOp, rewriter);
509 if (meta.first.empty())
510 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
511
512 Value localOffsets =
513 computeOffsets(writeOp, rewriter, meta.first, meta.second);
514
515 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
516
517 Value mask = vector::ConstantMaskOp::create(
518 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
520 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
521 localOffsets, mask,
522 /*chunk_size=*/IntegerAttr{},
523 /*l1_hint=*/xegpu::CachePolicyAttr{},
524 /*l2_hint=*/xegpu::CachePolicyAttr{},
525 /*l3_hint=*/xegpu::CachePolicyAttr{},
526 /*layout=*/nullptr);
527 rewriter.eraseOp(writeOp);
528 return success();
529}
530
531struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
532 using Base::Base;
533
534 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
535 PatternRewriter &rewriter) const override {
536 Location loc = readOp.getLoc();
537
538 if (failed(transferPreconditions(rewriter, readOp)))
539 return failure();
540
541 // TODO:This check needs to be replaced with proper uArch capability check
542 auto chip = xegpu::getChipStr(readOp);
543 if (chip != "pvc" && chip != "bmg") {
544 // lower to scattered load Op if the target HW doesn't have 2d block load
545 // support
546 // TODO: add support for OutOfBound access
547 if (readOp.hasOutOfBoundsDim())
548 return failure();
549 return lowerToScatteredLoadOp(readOp, rewriter);
550 }
551
552 VectorType vecTy = readOp.getVectorType();
553
554 // Lower using load.gather in 1D case
555 if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
556 return lowerToScatteredLoadOp(readOp, rewriter);
557
558 // Perform common data transfer checks.
559 if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
560 return failure();
561
562 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
563 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
564 return rewriter.notifyMatchFailure(
565 readOp, "Unsupported non-zero padded out-of-bounds read");
566
567 AffineMap readMap = readOp.getPermutationMap();
568 bool isTransposeLoad = !readMap.isMinorIdentity();
569
570 Type elementType = vecTy.getElementType();
571 unsigned minTransposeBitWidth = 32;
572 if (isTransposeLoad &&
573 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
574 return rewriter.notifyMatchFailure(
575 readOp, "Unsupported data type for transposition");
576
577 // If load is transposed, get the base shape for the tensor descriptor.
578 SmallVector<int64_t> descShape(vecTy.getShape());
579 if (isTransposeLoad)
580 std::reverse(descShape.begin(), descShape.end());
581 auto descType = xegpu::TensorDescType::get(
582 descShape, elementType, /*array_length=*/1,
583 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
584
585 DenseI64ArrayAttr transposeAttr =
586 !isTransposeLoad ? nullptr
588 ArrayRef<int64_t>{1, 0});
589 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
590 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
591 vecTy.getRank());
592 // By default, no specific caching policy is assigned.
593 xegpu::CachePolicyAttr hint = nullptr;
594 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
595 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
596
597 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
598 /*packed=*/nullptr, transposeAttr,
599 /*l1_hint=*/hint,
600 /*l2_hint=*/hint, /*l3_hint=*/hint,
601 /*layout=*/nullptr);
602 rewriter.replaceOp(readOp, loadOp);
603
604 return success();
605 }
606};
607
608struct TransferWriteLowering
609 : public OpRewritePattern<vector::TransferWriteOp> {
610 using Base::Base;
611
612 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
613 PatternRewriter &rewriter) const override {
614 Location loc = writeOp.getLoc();
615
616 if (failed(transferPreconditions(rewriter, writeOp)))
617 return failure();
618
619 // TODO:This check needs to be replaced with proper uArch capability check
620 auto chip = xegpu::getChipStr(writeOp);
621 if (chip != "pvc" && chip != "bmg") {
622 // lower to scattered store Op if the target HW doesn't have 2d block
623 // store support
624 // TODO: add support for OutOfBound access
625 if (writeOp.hasOutOfBoundsDim())
626 return failure();
627 return lowerToScatteredStoreOp(writeOp, rewriter);
628 }
629
630 // Perform common data transfer checks.
631 VectorType vecTy = writeOp.getVectorType();
632 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
633 return failure();
634
635 AffineMap map = writeOp.getPermutationMap();
636 if (!map.isMinorIdentity())
637 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
638
639 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
640 rewriter, loc, writeOp.getBase(),
641 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
642
643 auto descType = xegpu::TensorDescType::get(
644 vecTy.getShape(), vecTy.getElementType(),
645 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
646 xegpu::MemorySpace::Global);
647 // By default, no specific caching policy is assigned.
648 xegpu::CachePolicyAttr hint = nullptr;
649 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
650 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
651
652 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
653 ndDesc, indices,
654 /*l1_hint=*/hint,
655 /*l2_hint=*/hint, /*l3_hint=*/hint,
656 /*layout=*/nullptr);
657 rewriter.replaceOp(writeOp, storeOp);
658
659 return success();
660 }
661};
662
663struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
664 using Base::Base;
665
666 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
667 PatternRewriter &rewriter) const override {
668 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
669 if (!srcTy)
670 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
671
672 Location loc = gatherOp.getLoc();
673 VectorType vectorType = gatherOp.getVectorType();
674
675 auto meta = computeMemrefMeta(gatherOp, rewriter);
676 if (meta.first.empty())
677 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
678
679 Value localOffsets =
680 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
681 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
682
683 auto xeGatherOp = xegpu::LoadGatherOp::create(
684 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
685 /*chunk_size=*/IntegerAttr{},
686 /*l1_hint=*/xegpu::CachePolicyAttr{},
687 /*l2_hint=*/xegpu::CachePolicyAttr{},
688 /*l3_hint=*/xegpu::CachePolicyAttr{},
689 /*layout=*/nullptr);
690
691 auto selectOp =
692 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
693 xeGatherOp.getResult(), gatherOp.getPassThru());
694 rewriter.replaceOp(gatherOp, selectOp.getResult());
695 return success();
696 }
697};
698
699struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
700 using Base::Base;
701
702 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
703 PatternRewriter &rewriter) const override {
704 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
705 if (!srcTy)
706 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
707
708 Location loc = scatterOp.getLoc();
709 auto meta = computeMemrefMeta(scatterOp, rewriter);
710 if (meta.first.empty())
711 return rewriter.notifyMatchFailure(scatterOp,
712 "Failed to compute strides");
713
714 Value localOffsets =
715 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
716 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
717
718 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
719 flatMemref, localOffsets, scatterOp.getMask(),
720 /*chunk_size=*/IntegerAttr{},
721 /*l1_hint=*/xegpu::CachePolicyAttr{},
722 /*l2_hint=*/xegpu::CachePolicyAttr{},
723 /*l3_hint=*/xegpu::CachePolicyAttr{},
724 /*layout=*/nullptr);
725 rewriter.eraseOp(scatterOp);
726 return success();
727 }
728};
729
730struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
731 using Base::Base;
732
733 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
734 PatternRewriter &rewriter) const override {
735 Location loc = loadOp.getLoc();
736
737 VectorType vecTy = loadOp.getResult().getType();
738 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
739 return failure();
740
741 // Boundary check is available only for block instructions.
742 bool boundaryCheck = vecTy.getRank() > 1;
743 // By default, no specific caching policy is assigned.
744 xegpu::CachePolicyAttr hint = nullptr;
745
746 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
747 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
748 vecTy.getRank());
749
750 auto descType = xegpu::TensorDescType::get(
751 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
752 boundaryCheck, xegpu::MemorySpace::Global);
753
754 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
755 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
756 auto loadNdOp =
757 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
758 /*packed=*/nullptr, /*transpose=*/nullptr,
759 /*l1_hint=*/hint,
760 /*l2_hint=*/hint, /*l3_hint=*/hint,
761 /*layout=*/nullptr);
762 rewriter.replaceOp(loadOp, loadNdOp);
763
764 return success();
765 }
766};
767
768struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
769 using Base::Base;
770
771 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
772 PatternRewriter &rewriter) const override {
773 Location loc = storeOp.getLoc();
774
775 TypedValue<VectorType> vector = storeOp.getValueToStore();
776 VectorType vecTy = vector.getType();
777 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
778 return failure();
779
780 // Boundary check is available only for block instructions.
781 bool boundaryCheck = vecTy.getRank() > 1;
782
783 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
784 rewriter, loc, storeOp.getBase(),
785 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
786
787 auto descType = xegpu::TensorDescType::get(
788 vecTy.getShape(), vecTy.getElementType(),
789 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
790
791 // By default, no specific caching policy is assigned.
792 xegpu::CachePolicyAttr hint = nullptr;
793 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
794 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
795
796 auto storeNdOp =
797 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
798 /*l1_hint=*/hint,
799 /*l2_hint=*/hint, /*l3_hint=*/hint,
800 /*layout=*/nullptr);
801
802 rewriter.replaceOp(storeOp, storeNdOp);
803
804 return success();
805 }
806};
807
808struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
809 using Base::Base;
810
811 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
812 PatternRewriter &rewriter) const override {
813 Location loc = contractOp.getLoc();
814
815 if (contractOp.getKind() != vector::CombiningKind::ADD)
816 return rewriter.notifyMatchFailure(contractOp,
817 "Expects add combining kind");
818
819 TypedValue<Type> acc = contractOp.getAcc();
820 VectorType accType = dyn_cast<VectorType>(acc.getType());
821 if (!accType || accType.getRank() != 2)
822 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
823
824 // Accept only plain 2D data layout.
825 // VNNI packing is applied to DPAS as a separate lowering step.
826 TypedValue<VectorType> lhs = contractOp.getLhs();
827 TypedValue<VectorType> rhs = contractOp.getRhs();
828 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
829 return rewriter.notifyMatchFailure(contractOp,
830 "Expects lhs and rhs 2D vectors");
831
832 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
833 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
834
835 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
836 TypeRange{contractOp.getResultType()},
837 ValueRange{lhs, rhs, acc});
838 rewriter.replaceOp(contractOp, dpasOp);
839
840 return success();
841 }
842};
843
844struct ConvertVectorToXeGPUPass
845 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
846 void runOnOperation() override {
847 RewritePatternSet patterns(&getContext());
849 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
850 return signalPassFailure();
851 }
852};
853
854} // namespace
855
859 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
860 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
861 patterns.getContext());
862}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...