MLIR 22.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26#include <algorithm>
27#include <optional>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Return true if value represents a zero constant.
39static bool isZeroConstant(Value val) {
40 auto constant = val.getDefiningOp<arith::ConstantOp>();
41 if (!constant)
42 return false;
43
44 return TypeSwitch<Attribute, bool>(constant.getValue())
45 .Case<FloatAttr>(
46 [](auto floatAttr) { return floatAttr.getValue().isZero(); })
47 .Case<IntegerAttr>(
48 [](auto intAttr) { return intAttr.getValue().isZero(); })
49 .Default(false);
50}
51
52static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53 Operation *op, VectorType vecTy) {
54 // Validate only vector as the basic vector store and load ops guarantee
55 // XeGPU-compatible memref source.
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
58 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59
60 return success();
61}
62
63static LogicalResult transferPreconditions(PatternRewriter &rewriter,
64 VectorTransferOpInterface xferOp) {
65 if (xferOp.getMask())
66 return rewriter.notifyMatchFailure(xferOp,
67 "Masked transfer is not supported");
68
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
70 if (!srcTy)
71 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
72
73 // Validate further transfer op semantics.
75 int64_t offset;
76 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
77 return rewriter.notifyMatchFailure(
78 xferOp, "Buffer must be contiguous in the innermost dimension");
79
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
83 return rewriter.notifyMatchFailure(
84 xferOp, "Boundary check is available only for block instructions.");
85
86 AffineMap map = xferOp.getPermutationMap();
87 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
88 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
89 unsigned numInputDims = map.getNumInputs();
90 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
93 return rewriter.notifyMatchFailure(
94 xferOp, "Only the innermost dimensions can be accessed");
95 }
96
97 return success();
98}
99
100static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
101 Location loc,
102 xegpu::TensorDescType descType,
104 MemRefType srcTy = src.getType();
105 auto [strides, offset] = srcTy.getStridesAndOffset();
106
107 xegpu::CreateNdDescOp ndDesc;
108 if (srcTy.hasStaticShape()) {
109 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
110 } else {
111 // In case of any dynamic shapes, source's shape and strides have to be
112 // explicitly provided.
113 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
114 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115 meta.getConstifiedMixedSizes(),
116 meta.getConstifiedMixedStrides());
117 }
118
119 return ndDesc;
120}
121
122// Adjusts the strides of a memref according to a given permutation map for
123// vector operations.
124//
125// This function updates the innermost strides in the `strides` array to
126// reflect the permutation specified by `permMap`. The permutation is computed
127// using the inverse and broadcasting-aware version of the permutation map,
128// and is applied to the relevant strides. This ensures that memory accesses
129// are consistent with the logical permutation of vector elements.
130//
131// Example:
132// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
133// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
134// 0]), then after calling this function, the last two strides will be
135// swapped:
136// Original strides: [s0, s1, s2, s3]
137// After permutation: [s0, s1, s3, s2]
138//
139static void adjustStridesForPermutation(AffineMap permMap,
140 SmallVectorImpl<Value> &strides) {
141
145 SmallVector<int64_t> perms64(perms.begin(), perms.end());
146 strides = applyPermutation(strides, perms64);
147}
148
149// Computes memory strides and a memref offset for vector transfer operations,
150// handling both static and dynamic memrefs while applying permutation
151// transformations for XeGPU lowering.
152template <
153 typename OpType,
154 typename = std::enable_if_t<llvm::is_one_of<
155 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
156 vector::GatherOp, vector::ScatterOp>::value>>
157static std::pair<SmallVector<Value>, Value>
158computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
159 SmallVector<Value> strides;
160 Value baseMemref = xferOp.getBase();
161 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
162
163 Location loc = xferOp.getLoc();
164 Value offsetVal = nullptr;
165 if (memrefType.hasStaticShape()) {
166 int64_t offset;
167 SmallVector<int64_t> intStrides;
168 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
169 return {{}, offsetVal};
170 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
171 return ShapedType::isDynamic(strideVal);
172 });
173
174 if (!hasDynamicStrides)
175 for (int64_t s : intStrides)
176 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
177
178 if (!ShapedType::isDynamic(offset))
179 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
180 }
181
182 if (strides.empty() || !offsetVal) {
183 // For dynamic shape memref, use memref.extract_strided_metadata to get
184 // stride values
185 unsigned rank = memrefType.getRank();
186 Type indexType = rewriter.getIndexType();
187
188 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
189 // size0, size1, ..., sizeN-1]
190 SmallVector<Type> resultTypes;
191 resultTypes.push_back(MemRefType::get(
192 {}, memrefType.getElementType())); // base memref (unranked)
193 resultTypes.push_back(indexType); // offset
194
195 for (unsigned i = 0; i < rank; ++i)
196 resultTypes.push_back(indexType); // strides
197
198 for (unsigned i = 0; i < rank; ++i)
199 resultTypes.push_back(indexType); // sizes
200
201 auto meta = memref::ExtractStridedMetadataOp::create(
202 rewriter, loc, resultTypes, baseMemref);
203
204 if (strides.empty())
205 strides.append(meta.getStrides().begin(), meta.getStrides().end());
206
207 if (!offsetVal)
208 offsetVal = meta.getOffset();
209 }
210
211 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
212 vector::TransferWriteOp>::value) {
213 AffineMap permMap = xferOp.getPermutationMap();
214 // Adjust strides according to the permutation map (e.g., for transpose)
215 adjustStridesForPermutation(permMap, strides);
216 }
217
218 return {strides, offsetVal};
219}
220
221// This function compute the vectors of localOffsets for scattered load/stores.
222// It is used in the lowering of vector.transfer_read/write to
223// load_gather/store_scatter Example:
224// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
225// %cst {in_bounds = [true, true, true, true]}>} :
226// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
227//
228// %6 = vector.step: vector<4xindex>
229// %7 = vector.step: vector<2xindex>
230// %8 = vector.step: vector<6xindex>
231// %9 = vector.step: vector<32xindex>
232// %10 = arith.mul %6, 384
233// %11 = arith.mul %7, 192
234// %12 = arith.mul %8, 32
235// %13 = arith.mul %9, 1
236// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
237// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
238// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
239// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
240// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
241// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
242// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
243// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
244// %22 = arith.add %18, %19
245// %23 = arith.add %20, %21
246// %local_offsets = arith.add %22, %23
247// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
248// %offsets = memref_offset + orig_offset + local_offsets
249static Value computeOffsets(VectorTransferOpInterface xferOp,
250 PatternRewriter &rewriter, ArrayRef<Value> strides,
251 Value baseOffset) {
252 Location loc = xferOp.getLoc();
253 VectorType vectorType = xferOp.getVectorType();
254 SmallVector<Value> indices(xferOp.getIndices().begin(),
255 xferOp.getIndices().end());
256 ArrayRef<int64_t> vectorShape = vectorType.getShape();
257
258 // Create vector.step operations for each dimension
259 SmallVector<Value> stepVectors;
260 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
261 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
262 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
263 stepVectors.push_back(stepOp);
264 return stepOp;
265 });
266
267 // Multiply step vectors by corresponding strides
268 size_t memrefRank = strides.size();
269 size_t vectorRank = vectorShape.size();
270 SmallVector<Value> strideMultiplied;
271 for (size_t i = 0; i < vectorRank; ++i) {
272 size_t memrefDim = memrefRank - vectorRank + i;
273 Value strideValue = strides[memrefDim];
274 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
275 auto bcastOp =
276 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
277 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
278 strideMultiplied.push_back(mulOp);
279 }
280
281 // Shape cast each multiplied vector to add singleton dimensions
282 SmallVector<Value> shapeCasted;
283 for (size_t i = 0; i < vectorRank; ++i) {
284 SmallVector<int64_t> newShape(vectorRank, 1);
285 newShape[i] = vectorShape[i];
286 auto newType = VectorType::get(newShape, rewriter.getIndexType());
287 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
288 strideMultiplied[i]);
289 shapeCasted.push_back(castOp);
290 }
291
292 // Broadcast each shape-casted vector to full vector shape
293 SmallVector<Value> broadcasted;
294 auto fullIndexVectorType =
295 VectorType::get(vectorShape, rewriter.getIndexType());
296 for (Value shapeCastVal : shapeCasted) {
297 auto broadcastOp = vector::BroadcastOp::create(
298 rewriter, loc, fullIndexVectorType, shapeCastVal);
299 broadcasted.push_back(broadcastOp);
300 }
301
302 // Add all broadcasted vectors together to compute local offsets
303 Value localOffsets = broadcasted[0];
304 for (size_t i = 1; i < broadcasted.size(); ++i)
305 localOffsets =
306 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
307
308 // Compute base offset from transfer read indices
309 for (size_t i = 0; i < indices.size(); ++i) {
310 Value strideVal = strides[i];
311 Value offsetContrib =
312 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
313 baseOffset =
314 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
315 }
316 // Broadcast base offset to match vector shape
317 Value bcastBase = vector::BroadcastOp::create(
318 rewriter, loc, fullIndexVectorType, baseOffset);
319 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
320 return localOffsets;
321}
322
323// Compute the element-wise offsets for vector.gather or vector.scatter ops.
324//
325// This function linearizes the base offsets of the gather/scatter operation
326// and combines them with the per-element indices to produce a final vector of
327// memory offsets.
328template <
329 typename OpType,
330 typename = std::enable_if_t<llvm::is_one_of<
331 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
332static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
333 ArrayRef<Value> strides, Value baseOffset) {
334 Location loc = gatScatOp.getLoc();
335 SmallVector<Value> offsets = gatScatOp.getOffsets();
336 for (size_t i = 0; i < offsets.size(); ++i) {
337 Value offsetContrib =
338 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
339 baseOffset =
340 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
341 }
342 Value indices = gatScatOp.getIndices();
343 VectorType vecType = cast<VectorType>(indices.getType());
344
345 Value strideVector =
346 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
347 .getResult();
348 Value stridedIndices =
349 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
350
351 Value baseVector =
352 vector::BroadcastOp::create(
353 rewriter, loc,
354 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
355 baseOffset)
356 .getResult();
357 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
358 .getResult();
359}
360
361// Collapses shapes of a nD memref to the target rank while applying offsets for
362// the collapsed dimensions. Returns the new memref value and the remaining
363// offsets for the last targetRank dimensions. For example:
364// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
366static std::pair<Value, SmallVector<OpFoldResult>>
367convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
370 int64_t targetRank) {
371 auto memrefType = cast<MemRefType>(memref.getType());
372 unsigned rank = memrefType.getRank();
373
374 if (rank <= targetRank)
375 return {memref, offsets};
376
377 int64_t numCombinedDims = rank - targetRank;
378 SmallVector<OpFoldResult> subviewOffsets;
379 SmallVector<OpFoldResult> subviewSizes;
380 SmallVector<OpFoldResult> subviewStrides;
381
382 // For the combined dimensions: use the provided offsets, size=1, stride=1
383 for (unsigned i = 0; i < numCombinedDims; ++i) {
384 subviewOffsets.push_back(offsets[i]);
385 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
386 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
387 }
388
389 // For the last targetRank dimensions: offset=0, use full size, stride=1
390 SmallVector<int64_t> resultShape;
391 auto originalShape = memrefType.getShape();
392 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
393 for (unsigned i = numCombinedDims; i < rank; ++i) {
394 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
395 if (ShapedType::isDynamic(originalShape[i])) {
396 subviewSizes.push_back(meta.getSizes()[i]);
397 resultShape.push_back(ShapedType::kDynamic);
398 } else {
399 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
400 resultShape.push_back(originalShape[i]);
401 }
402 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
403 }
404
405 auto resultType = memref::SubViewOp::inferRankReducedResultType(
406 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
407 auto subviewOp =
408 memref::SubViewOp::create(rewriter, loc, resultType, memref,
409 subviewOffsets, subviewSizes, subviewStrides);
410
411 // Return the remaining offsets for the last targetRank dimensions
412 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
413 offsets.end());
414 return {subviewOp.getResult(), newOffsets};
415}
416
417template <
418 typename OpType,
419 typename = std::enable_if_t<llvm::is_one_of<
420 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
421 vector::GatherOp, vector::ScatterOp>::value>>
422// Convert memref to i64 base pointer
423static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
424 Location loc = xferOp.getLoc();
425 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
426 rewriter, loc, xferOp.getBase())
427 .getResult();
428 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
429 indexPtr)
430 .getResult();
431}
432
433static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
434 PatternRewriter &rewriter) {
435
436 Location loc = readOp.getLoc();
437 VectorType vectorType = readOp.getVectorType();
438 ArrayRef<int64_t> vectorShape = vectorType.getShape();
439 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
440 if (!memrefType)
441 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
442
443 auto meta = computeMemrefMeta(readOp, rewriter);
444 if (meta.first.empty())
445 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
446
447 Value localOffsets =
448 computeOffsets(readOp, rewriter, meta.first, meta.second);
449
450 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
451
452 Value mask = vector::ConstantMaskOp::create(
453 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
455 auto gatherOp = xegpu::LoadGatherOp::create(
456 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
457 /*chunk_size=*/IntegerAttr{},
458 /*l1_hint=*/xegpu::CachePolicyAttr{},
459 /*l2_hint=*/xegpu::CachePolicyAttr{},
460 /*l3_hint=*/xegpu::CachePolicyAttr{},
461 /*layout=*/nullptr);
462
463 rewriter.replaceOp(readOp, gatherOp.getResult());
464 return success();
465}
466
467static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
468 PatternRewriter &rewriter) {
469
470 Location loc = writeOp.getLoc();
471 VectorType vectorType = writeOp.getVectorType();
472 ArrayRef<int64_t> vectorShape = vectorType.getShape();
473
474 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
475 if (!memrefType)
476 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
477
478 auto meta = computeMemrefMeta(writeOp, rewriter);
479 if (meta.first.empty())
480 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
481
482 Value localOffsets =
483 computeOffsets(writeOp, rewriter, meta.first, meta.second);
484
485 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
486
487 Value mask = vector::ConstantMaskOp::create(
488 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
490 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
491 localOffsets, mask,
492 /*chunk_size=*/IntegerAttr{},
493 /*l1_hint=*/xegpu::CachePolicyAttr{},
494 /*l2_hint=*/xegpu::CachePolicyAttr{},
495 /*l3_hint=*/xegpu::CachePolicyAttr{},
496 /*layout=*/nullptr);
497 rewriter.eraseOp(writeOp);
498 return success();
499}
500
501struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
502 using Base::Base;
503
504 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
505 PatternRewriter &rewriter) const override {
506 Location loc = readOp.getLoc();
507
508 if (failed(transferPreconditions(rewriter, readOp)))
509 return failure();
510
511 // TODO:This check needs to be replaced with proper uArch capability check
512 auto chip = xegpu::getChipStr(readOp);
513 if (chip != "pvc" && chip != "bmg") {
514 // lower to scattered load Op if the target HW doesn't have 2d block load
515 // support
516 // TODO: add support for OutOfBound access
517 if (readOp.hasOutOfBoundsDim())
518 return failure();
519 return lowerToScatteredLoadOp(readOp, rewriter);
520 }
521
522 // Perform common data transfer checks.
523 VectorType vecTy = readOp.getVectorType();
524 if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
525 return failure();
526
527 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
528 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
529 return rewriter.notifyMatchFailure(
530 readOp, "Unsupported non-zero padded out-of-bounds read");
531
532 AffineMap readMap = readOp.getPermutationMap();
533 bool isTransposeLoad = !readMap.isMinorIdentity();
534
535 Type elementType = vecTy.getElementType();
536 unsigned minTransposeBitWidth = 32;
537 if (isTransposeLoad &&
538 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
539 return rewriter.notifyMatchFailure(
540 readOp, "Unsupported data type for transposition");
541
542 // If load is transposed, get the base shape for the tensor descriptor.
543 SmallVector<int64_t> descShape(vecTy.getShape());
544 if (isTransposeLoad)
545 std::reverse(descShape.begin(), descShape.end());
546 auto descType = xegpu::TensorDescType::get(
547 descShape, elementType, /*array_length=*/1,
548 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
549
550 DenseI64ArrayAttr transposeAttr =
551 !isTransposeLoad ? nullptr
553 ArrayRef<int64_t>{1, 0});
554 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
555 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
556 vecTy.getRank());
557 // By default, no specific caching policy is assigned.
558 xegpu::CachePolicyAttr hint = nullptr;
559 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
560 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
561
562 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
563 /*packed=*/nullptr, transposeAttr,
564 /*l1_hint=*/hint,
565 /*l2_hint=*/hint, /*l3_hint=*/hint);
566 rewriter.replaceOp(readOp, loadOp);
567
568 return success();
569 }
570};
571
572struct TransferWriteLowering
573 : public OpRewritePattern<vector::TransferWriteOp> {
574 using Base::Base;
575
576 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
577 PatternRewriter &rewriter) const override {
578 Location loc = writeOp.getLoc();
579
580 if (failed(transferPreconditions(rewriter, writeOp)))
581 return failure();
582
583 // TODO:This check needs to be replaced with proper uArch capability check
584 auto chip = xegpu::getChipStr(writeOp);
585 if (chip != "pvc" && chip != "bmg") {
586 // lower to scattered store Op if the target HW doesn't have 2d block
587 // store support
588 // TODO: add support for OutOfBound access
589 if (writeOp.hasOutOfBoundsDim())
590 return failure();
591 return lowerToScatteredStoreOp(writeOp, rewriter);
592 }
593
594 // Perform common data transfer checks.
595 VectorType vecTy = writeOp.getVectorType();
596 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
597 return failure();
598
599 AffineMap map = writeOp.getPermutationMap();
600 if (!map.isMinorIdentity())
601 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
602
603 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
604 rewriter, loc, writeOp.getBase(),
605 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
606
607 auto descType = xegpu::TensorDescType::get(
608 vecTy.getShape(), vecTy.getElementType(),
609 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
610 xegpu::MemorySpace::Global);
611 // By default, no specific caching policy is assigned.
612 xegpu::CachePolicyAttr hint = nullptr;
613 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
614 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
615
616 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
617 ndDesc, indices,
618 /*l1_hint=*/hint,
619 /*l2_hint=*/hint, /*l3_hint=*/hint);
620 rewriter.replaceOp(writeOp, storeOp);
621
622 return success();
623 }
624};
625
626struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
627 using Base::Base;
628
629 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
630 PatternRewriter &rewriter) const override {
631 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
632 if (!srcTy)
633 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
634
635 Location loc = gatherOp.getLoc();
636 VectorType vectorType = gatherOp.getVectorType();
637
638 auto meta = computeMemrefMeta(gatherOp, rewriter);
639 if (meta.first.empty())
640 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
641
642 Value localOffsets =
643 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
644 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
645
646 auto xeGatherOp = xegpu::LoadGatherOp::create(
647 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
648 /*chunk_size=*/IntegerAttr{},
649 /*l1_hint=*/xegpu::CachePolicyAttr{},
650 /*l2_hint=*/xegpu::CachePolicyAttr{},
651 /*l3_hint=*/xegpu::CachePolicyAttr{},
652 /*layout=*/nullptr);
653
654 auto selectOp =
655 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
656 xeGatherOp.getResult(), gatherOp.getPassThru());
657 rewriter.replaceOp(gatherOp, selectOp.getResult());
658 return success();
659 }
660};
661
662struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
663 using Base::Base;
664
665 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
666 PatternRewriter &rewriter) const override {
667 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
668 if (!srcTy)
669 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
670
671 Location loc = scatterOp.getLoc();
672 auto meta = computeMemrefMeta(scatterOp, rewriter);
673 if (meta.first.empty())
674 return rewriter.notifyMatchFailure(scatterOp,
675 "Failed to compute strides");
676
677 Value localOffsets =
678 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
679 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
680
681 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
682 flatMemref, localOffsets, scatterOp.getMask(),
683 /*chunk_size=*/IntegerAttr{},
684 /*l1_hint=*/xegpu::CachePolicyAttr{},
685 /*l2_hint=*/xegpu::CachePolicyAttr{},
686 /*l3_hint=*/xegpu::CachePolicyAttr{},
687 /*layout=*/nullptr);
688 rewriter.eraseOp(scatterOp);
689 return success();
690 }
691};
692
693struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
694 using Base::Base;
695
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter) const override {
698 Location loc = loadOp.getLoc();
699
700 VectorType vecTy = loadOp.getResult().getType();
701 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
702 return failure();
703
704 // Boundary check is available only for block instructions.
705 bool boundaryCheck = vecTy.getRank() > 1;
706 // By default, no specific caching policy is assigned.
707 xegpu::CachePolicyAttr hint = nullptr;
708
709 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
710 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
711 vecTy.getRank());
712
713 auto descType = xegpu::TensorDescType::get(
714 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
715 boundaryCheck, xegpu::MemorySpace::Global);
716
717 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
718 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
719 auto loadNdOp =
720 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
721 /*packed=*/nullptr, /*transpose=*/nullptr,
722 /*l1_hint=*/hint,
723 /*l2_hint=*/hint, /*l3_hint=*/hint);
724 rewriter.replaceOp(loadOp, loadNdOp);
725
726 return success();
727 }
728};
729
730struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
731 using Base::Base;
732
733 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
734 PatternRewriter &rewriter) const override {
735 Location loc = storeOp.getLoc();
736
737 TypedValue<VectorType> vector = storeOp.getValueToStore();
738 VectorType vecTy = vector.getType();
739 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
740 return failure();
741
742 // Boundary check is available only for block instructions.
743 bool boundaryCheck = vecTy.getRank() > 1;
744
745 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
746 rewriter, loc, storeOp.getBase(),
747 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
748
749 auto descType = xegpu::TensorDescType::get(
750 vecTy.getShape(), vecTy.getElementType(),
751 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
752
753 // By default, no specific caching policy is assigned.
754 xegpu::CachePolicyAttr hint = nullptr;
755 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
756 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
757
758 auto storeNdOp =
759 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
760 /*l1_hint=*/hint,
761 /*l2_hint=*/hint, /*l3_hint=*/hint);
762
763 rewriter.replaceOp(storeOp, storeNdOp);
764
765 return success();
766 }
767};
768
769struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
770 using Base::Base;
771
772 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
773 PatternRewriter &rewriter) const override {
774 Location loc = contractOp.getLoc();
775
776 if (contractOp.getKind() != vector::CombiningKind::ADD)
777 return rewriter.notifyMatchFailure(contractOp,
778 "Expects add combining kind");
779
780 TypedValue<Type> acc = contractOp.getAcc();
781 VectorType accType = dyn_cast<VectorType>(acc.getType());
782 if (!accType || accType.getRank() != 2)
783 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
784
785 // Accept only plain 2D data layout.
786 // VNNI packing is applied to DPAS as a separate lowering step.
787 TypedValue<VectorType> lhs = contractOp.getLhs();
788 TypedValue<VectorType> rhs = contractOp.getRhs();
789 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
790 return rewriter.notifyMatchFailure(contractOp,
791 "Expects lhs and rhs 2D vectors");
792
793 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
794 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
795
796 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
797 TypeRange{contractOp.getResultType()},
798 ValueRange{lhs, rhs, acc});
799 rewriter.replaceOp(contractOp, dpasOp);
800
801 return success();
802 }
803};
804
805struct ConvertVectorToXeGPUPass
806 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
807 void runOnOperation() override {
808 RewritePatternSet patterns(&getContext());
810 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
811 return signalPassFailure();
812 }
813};
814
815} // namespace
816
820 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
821 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
822 patterns.getContext());
823}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...