MLIR 22.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26#include <algorithm>
27#include <optional>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Return true if value represents a zero constant.
39static bool isZeroConstant(Value val) {
40 auto constant = val.getDefiningOp<arith::ConstantOp>();
41 if (!constant)
42 return false;
43
44 return TypeSwitch<Attribute, bool>(constant.getValue())
45 .Case<FloatAttr>(
46 [](auto floatAttr) { return floatAttr.getValue().isZero(); })
47 .Case<IntegerAttr>(
48 [](auto intAttr) { return intAttr.getValue().isZero(); })
49 .Default(false);
50}
51
52static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53 Operation *op, VectorType vecTy) {
54 // Validate only vector as the basic vector store and load ops guarantee
55 // XeGPU-compatible memref source.
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
58 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59
60 return success();
61}
62
63static LogicalResult transferPreconditions(PatternRewriter &rewriter,
64 VectorTransferOpInterface xferOp) {
65 if (xferOp.getMask())
66 return rewriter.notifyMatchFailure(xferOp,
67 "Masked transfer is not supported");
68
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
70 if (!srcTy)
71 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
72
73 // Validate further transfer op semantics.
75 int64_t offset;
76 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
77 return rewriter.notifyMatchFailure(
78 xferOp, "Buffer must be contiguous in the innermost dimension");
79
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
83 return rewriter.notifyMatchFailure(
84 xferOp, "Boundary check is available only for block instructions.");
85
86 AffineMap map = xferOp.getPermutationMap();
87 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
88 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
89 unsigned numInputDims = map.getNumInputs();
90 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
93 return rewriter.notifyMatchFailure(
94 xferOp, "Only the innermost dimensions can be accessed");
95 }
96
97 return success();
98}
99
100static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
101 Location loc,
102 xegpu::TensorDescType descType,
104 MemRefType srcTy = src.getType();
105 auto [strides, offset] = srcTy.getStridesAndOffset();
106
107 xegpu::CreateNdDescOp ndDesc;
108 if (srcTy.hasStaticShape()) {
109 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
110 } else {
111 // In case of any dynamic shapes, source's shape and strides have to be
112 // explicitly provided.
113 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
114 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
115 meta.getConstifiedMixedSizes(),
116 meta.getConstifiedMixedStrides());
117 }
118
119 return ndDesc;
120}
121
122// Adjusts the strides of a memref according to a given permutation map for
123// vector operations.
124//
125// This function updates the innermost strides in the `strides` array to
126// reflect the permutation specified by `permMap`. The permutation is computed
127// using the inverse and broadcasting-aware version of the permutation map,
128// and is applied to the relevant strides. This ensures that memory accesses
129// are consistent with the logical permutation of vector elements.
130//
131// Example:
132// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
133// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
134// 0]), then after calling this function, the last two strides will be
135// swapped:
136// Original strides: [s0, s1, s2, s3]
137// After permutation: [s0, s1, s3, s2]
138//
139static void adjustStridesForPermutation(AffineMap permMap,
140 SmallVectorImpl<Value> &strides) {
141
145 SmallVector<int64_t> perms64(perms.begin(), perms.end());
146 strides = applyPermutation(strides, perms64);
147}
148
149// Computes memory strides and a memref offset for vector transfer operations,
150// handling both static and dynamic memrefs while applying permutation
151// transformations for XeGPU lowering.
152template <
153 typename OpType,
154 typename = std::enable_if_t<llvm::is_one_of<
155 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
156 vector::GatherOp, vector::ScatterOp>::value>>
157static std::pair<SmallVector<Value>, Value>
158computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
159 SmallVector<Value> strides;
160 Value baseMemref = xferOp.getBase();
161 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
162
163 Location loc = xferOp.getLoc();
164 Value offsetVal = nullptr;
165 if (memrefType.hasStaticShape()) {
166 int64_t offset;
167 SmallVector<int64_t> intStrides;
168 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
169 return {{}, offsetVal};
170 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
171 return ShapedType::isDynamic(strideVal);
172 });
173
174 if (!hasDynamicStrides)
175 for (int64_t s : intStrides)
176 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
177
178 if (!ShapedType::isDynamic(offset))
179 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
180 }
181
182 if (strides.empty() || !offsetVal) {
183 // For dynamic shape memref, use memref.extract_strided_metadata to get
184 // stride values
185 unsigned rank = memrefType.getRank();
186 Type indexType = rewriter.getIndexType();
187
188 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
189 // size0, size1, ..., sizeN-1]
190 SmallVector<Type> resultTypes;
191 resultTypes.push_back(MemRefType::get(
192 {}, memrefType.getElementType())); // base memref (unranked)
193 resultTypes.push_back(indexType); // offset
194
195 for (unsigned i = 0; i < rank; ++i)
196 resultTypes.push_back(indexType); // strides
197
198 for (unsigned i = 0; i < rank; ++i)
199 resultTypes.push_back(indexType); // sizes
200
201 auto meta = memref::ExtractStridedMetadataOp::create(
202 rewriter, loc, resultTypes, baseMemref);
203
204 if (strides.empty())
205 strides.append(meta.getStrides().begin(), meta.getStrides().end());
206
207 if (!offsetVal)
208 offsetVal = meta.getOffset();
209 }
210
211 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
212 vector::TransferWriteOp>::value) {
213 AffineMap permMap = xferOp.getPermutationMap();
214 // Adjust strides according to the permutation map (e.g., for transpose)
215 adjustStridesForPermutation(permMap, strides);
216 }
217
218 return {strides, offsetVal};
219}
220
221// This function compute the vectors of localOffsets for scattered load/stores.
222// It is used in the lowering of vector.transfer_read/write to
223// load_gather/store_scatter Example:
224// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
225// %cst {in_bounds = [true, true, true, true]}>} :
226// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
227//
228// %6 = vector.step: vector<4xindex>
229// %7 = vector.step: vector<2xindex>
230// %8 = vector.step: vector<6xindex>
231// %9 = vector.step: vector<32xindex>
232// %10 = arith.mul %6, 384
233// %11 = arith.mul %7, 192
234// %12 = arith.mul %8, 32
235// %13 = arith.mul %9, 1
236// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
237// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
238// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
239// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
240// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
241// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
242// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
243// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
244// %22 = arith.add %18, %19
245// %23 = arith.add %20, %21
246// %local_offsets = arith.add %22, %23
247// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
248// %offsets = memref_offset + orig_offset + local_offsets
249static Value computeOffsets(VectorTransferOpInterface xferOp,
250 PatternRewriter &rewriter, ArrayRef<Value> strides,
251 Value baseOffset) {
252 Location loc = xferOp.getLoc();
253 VectorType vectorType = xferOp.getVectorType();
254 SmallVector<Value> indices(xferOp.getIndices().begin(),
255 xferOp.getIndices().end());
256 ArrayRef<int64_t> vectorShape = vectorType.getShape();
257
258 // Create vector.step operations for each dimension
259 SmallVector<Value> stepVectors;
260 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
261 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
262 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
263 stepVectors.push_back(stepOp);
264 return stepOp;
265 });
266
267 // Multiply step vectors by corresponding strides
268 size_t memrefRank = strides.size();
269 size_t vectorRank = vectorShape.size();
270 SmallVector<Value> strideMultiplied;
271 for (size_t i = 0; i < vectorRank; ++i) {
272 size_t memrefDim = memrefRank - vectorRank + i;
273 Value strideValue = strides[memrefDim];
274 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
275 auto bcastOp =
276 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
277 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
278 strideMultiplied.push_back(mulOp);
279 }
280
281 // Shape cast each multiplied vector to add singleton dimensions
282 SmallVector<Value> shapeCasted;
283 for (size_t i = 0; i < vectorRank; ++i) {
284 SmallVector<int64_t> newShape(vectorRank, 1);
285 newShape[i] = vectorShape[i];
286 auto newType = VectorType::get(newShape, rewriter.getIndexType());
287 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
288 strideMultiplied[i]);
289 shapeCasted.push_back(castOp);
290 }
291
292 // Broadcast each shape-casted vector to full vector shape
293 SmallVector<Value> broadcasted;
294 auto fullIndexVectorType =
295 VectorType::get(vectorShape, rewriter.getIndexType());
296 for (Value shapeCastVal : shapeCasted) {
297 auto broadcastOp = vector::BroadcastOp::create(
298 rewriter, loc, fullIndexVectorType, shapeCastVal);
299 broadcasted.push_back(broadcastOp);
300 }
301
302 // Add all broadcasted vectors together to compute local offsets
303 Value localOffsets = broadcasted[0];
304 for (size_t i = 1; i < broadcasted.size(); ++i)
305 localOffsets =
306 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
307
308 // Compute base offset from transfer read indices
309 for (size_t i = 0; i < indices.size(); ++i) {
310 Value strideVal = strides[i];
311 Value offsetContrib =
312 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
313 baseOffset =
314 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
315 }
316 // Broadcast base offset to match vector shape
317 Value bcastBase = vector::BroadcastOp::create(
318 rewriter, loc, fullIndexVectorType, baseOffset);
319 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
320 return localOffsets;
321}
322
323// Compute the element-wise offsets for vector.gather or vector.scatter ops.
324//
325// This function linearizes the base offsets of the gather/scatter operation
326// and combines them with the per-element indices to produce a final vector of
327// memory offsets.
328template <
329 typename OpType,
330 typename = std::enable_if_t<llvm::is_one_of<
331 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
332static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
333 ArrayRef<Value> strides, Value baseOffset) {
334 Location loc = gatScatOp.getLoc();
335 SmallVector<Value> offsets = gatScatOp.getOffsets();
336 for (size_t i = 0; i < offsets.size(); ++i) {
337 Value offsetContrib =
338 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
339 baseOffset =
340 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
341 }
342 Value indices = gatScatOp.getIndices();
343 VectorType vecType = cast<VectorType>(indices.getType());
344
345 Value strideVector =
346 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
347 .getResult();
348 Value stridedIndices =
349 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
350
351 Value baseVector =
352 vector::BroadcastOp::create(
353 rewriter, loc,
354 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
355 baseOffset)
356 .getResult();
357 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
358 .getResult();
359}
360
361// Collapses shapes of a nD memref to the target rank while applying offsets for
362// the collapsed dimensions. Returns the new memref value and the remaining
363// offsets for the last targetRank dimensions. For example:
364// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
366static std::pair<Value, SmallVector<OpFoldResult>>
367convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
370 int64_t targetRank) {
371 auto memrefType = cast<MemRefType>(memref.getType());
372 unsigned rank = memrefType.getRank();
373
374 if (rank <= targetRank)
375 return {memref, offsets};
376
377 int64_t numCombinedDims = rank - targetRank;
378 SmallVector<OpFoldResult> subviewOffsets;
379 SmallVector<OpFoldResult> subviewSizes;
380 SmallVector<OpFoldResult> subviewStrides;
381
382 // For the combined dimensions: use the provided offsets, size=1, stride=1
383 for (unsigned i = 0; i < numCombinedDims; ++i) {
384 subviewOffsets.push_back(offsets[i]);
385 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
386 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
387 }
388
389 // For the last targetRank dimensions: offset=0, use full size, stride=1
390 SmallVector<int64_t> resultShape;
391 auto originalShape = memrefType.getShape();
392 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
393 for (unsigned i = numCombinedDims; i < rank; ++i) {
394 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
395 if (ShapedType::isDynamic(originalShape[i])) {
396 subviewSizes.push_back(meta.getSizes()[i]);
397 resultShape.push_back(ShapedType::kDynamic);
398 } else {
399 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
400 resultShape.push_back(originalShape[i]);
401 }
402 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
403 }
404
405 auto resultType = memref::SubViewOp::inferRankReducedResultType(
406 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
407 auto subviewOp =
408 memref::SubViewOp::create(rewriter, loc, resultType, memref,
409 subviewOffsets, subviewSizes, subviewStrides);
410
411 // Return the remaining offsets for the last targetRank dimensions
412 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
413 offsets.end());
414 return {subviewOp.getResult(), newOffsets};
415}
416
417template <
418 typename OpType,
419 typename = std::enable_if_t<llvm::is_one_of<
420 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
421 vector::GatherOp, vector::ScatterOp>::value>>
422// Convert memref to i64 base pointer
423static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
424 Location loc = xferOp.getLoc();
425 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
426 rewriter, loc, xferOp.getBase())
427 .getResult();
428 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
429 indexPtr)
430 .getResult();
431}
432
433static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
434 PatternRewriter &rewriter) {
435
436 Location loc = readOp.getLoc();
437 VectorType vectorType = readOp.getVectorType();
438 ArrayRef<int64_t> vectorShape = vectorType.getShape();
439 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
440 if (!memrefType)
441 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
442
443 auto meta = computeMemrefMeta(readOp, rewriter);
444 if (meta.first.empty())
445 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
446
447 Value localOffsets =
448 computeOffsets(readOp, rewriter, meta.first, meta.second);
449
450 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
451
452 Value mask = vector::ConstantMaskOp::create(
453 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
455 auto gatherOp = xegpu::LoadGatherOp::create(
456 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
457 /*chunk_size=*/IntegerAttr{},
458 /*l1_hint=*/xegpu::CachePolicyAttr{},
459 /*l2_hint=*/xegpu::CachePolicyAttr{},
460 /*l3_hint=*/xegpu::CachePolicyAttr{},
461 /*layout=*/nullptr);
462
463 rewriter.replaceOp(readOp, gatherOp.getResult());
464 return success();
465}
466
467static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
468 PatternRewriter &rewriter) {
469
470 Location loc = writeOp.getLoc();
471 VectorType vectorType = writeOp.getVectorType();
472 ArrayRef<int64_t> vectorShape = vectorType.getShape();
473
474 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
475 if (!memrefType)
476 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
477
478 auto meta = computeMemrefMeta(writeOp, rewriter);
479 if (meta.first.empty())
480 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
481
482 Value localOffsets =
483 computeOffsets(writeOp, rewriter, meta.first, meta.second);
484
485 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
486
487 Value mask = vector::ConstantMaskOp::create(
488 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
490 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
491 localOffsets, mask,
492 /*chunk_size=*/IntegerAttr{},
493 /*l1_hint=*/xegpu::CachePolicyAttr{},
494 /*l2_hint=*/xegpu::CachePolicyAttr{},
495 /*l3_hint=*/xegpu::CachePolicyAttr{},
496 /*layout=*/nullptr);
497 rewriter.eraseOp(writeOp);
498 return success();
499}
500
501struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
502 using Base::Base;
503
504 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
505 PatternRewriter &rewriter) const override {
506 Location loc = readOp.getLoc();
507
508 if (failed(transferPreconditions(rewriter, readOp)))
509 return failure();
510
511 // TODO:This check needs to be replaced with proper uArch capability check
512 auto chip = xegpu::getChipStr(readOp);
513 if (chip != "pvc" && chip != "bmg") {
514 // lower to scattered load Op if the target HW doesn't have 2d block load
515 // support
516 // TODO: add support for OutOfBound access
517 if (readOp.hasOutOfBoundsDim())
518 return failure();
519 return lowerToScatteredLoadOp(readOp, rewriter);
520 }
521
522 VectorType vecTy = readOp.getVectorType();
523
524 // Lower using load.gather in 1D case
525 if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
526 return lowerToScatteredLoadOp(readOp, rewriter);
527
528 // Perform common data transfer checks.
529 if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
530 return failure();
531
532 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
533 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
534 return rewriter.notifyMatchFailure(
535 readOp, "Unsupported non-zero padded out-of-bounds read");
536
537 AffineMap readMap = readOp.getPermutationMap();
538 bool isTransposeLoad = !readMap.isMinorIdentity();
539
540 Type elementType = vecTy.getElementType();
541 unsigned minTransposeBitWidth = 32;
542 if (isTransposeLoad &&
543 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
544 return rewriter.notifyMatchFailure(
545 readOp, "Unsupported data type for transposition");
546
547 // If load is transposed, get the base shape for the tensor descriptor.
548 SmallVector<int64_t> descShape(vecTy.getShape());
549 if (isTransposeLoad)
550 std::reverse(descShape.begin(), descShape.end());
551 auto descType = xegpu::TensorDescType::get(
552 descShape, elementType, /*array_length=*/1,
553 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
554
555 DenseI64ArrayAttr transposeAttr =
556 !isTransposeLoad ? nullptr
558 ArrayRef<int64_t>{1, 0});
559 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
560 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
561 vecTy.getRank());
562 // By default, no specific caching policy is assigned.
563 xegpu::CachePolicyAttr hint = nullptr;
564 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
565 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
566
567 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
568 /*packed=*/nullptr, transposeAttr,
569 /*l1_hint=*/hint,
570 /*l2_hint=*/hint, /*l3_hint=*/hint,
571 /*layout=*/nullptr);
572 rewriter.replaceOp(readOp, loadOp);
573
574 return success();
575 }
576};
577
578struct TransferWriteLowering
579 : public OpRewritePattern<vector::TransferWriteOp> {
580 using Base::Base;
581
582 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
583 PatternRewriter &rewriter) const override {
584 Location loc = writeOp.getLoc();
585
586 if (failed(transferPreconditions(rewriter, writeOp)))
587 return failure();
588
589 // TODO:This check needs to be replaced with proper uArch capability check
590 auto chip = xegpu::getChipStr(writeOp);
591 if (chip != "pvc" && chip != "bmg") {
592 // lower to scattered store Op if the target HW doesn't have 2d block
593 // store support
594 // TODO: add support for OutOfBound access
595 if (writeOp.hasOutOfBoundsDim())
596 return failure();
597 return lowerToScatteredStoreOp(writeOp, rewriter);
598 }
599
600 // Perform common data transfer checks.
601 VectorType vecTy = writeOp.getVectorType();
602 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
603 return failure();
604
605 AffineMap map = writeOp.getPermutationMap();
606 if (!map.isMinorIdentity())
607 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
608
609 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
610 rewriter, loc, writeOp.getBase(),
611 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
612
613 auto descType = xegpu::TensorDescType::get(
614 vecTy.getShape(), vecTy.getElementType(),
615 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
616 xegpu::MemorySpace::Global);
617 // By default, no specific caching policy is assigned.
618 xegpu::CachePolicyAttr hint = nullptr;
619 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
620 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
621
622 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
623 ndDesc, indices,
624 /*l1_hint=*/hint,
625 /*l2_hint=*/hint, /*l3_hint=*/hint,
626 /*layout=*/nullptr);
627 rewriter.replaceOp(writeOp, storeOp);
628
629 return success();
630 }
631};
632
633struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
634 using Base::Base;
635
636 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
637 PatternRewriter &rewriter) const override {
638 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
639 if (!srcTy)
640 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
641
642 Location loc = gatherOp.getLoc();
643 VectorType vectorType = gatherOp.getVectorType();
644
645 auto meta = computeMemrefMeta(gatherOp, rewriter);
646 if (meta.first.empty())
647 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
648
649 Value localOffsets =
650 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
651 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
652
653 auto xeGatherOp = xegpu::LoadGatherOp::create(
654 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
655 /*chunk_size=*/IntegerAttr{},
656 /*l1_hint=*/xegpu::CachePolicyAttr{},
657 /*l2_hint=*/xegpu::CachePolicyAttr{},
658 /*l3_hint=*/xegpu::CachePolicyAttr{},
659 /*layout=*/nullptr);
660
661 auto selectOp =
662 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
663 xeGatherOp.getResult(), gatherOp.getPassThru());
664 rewriter.replaceOp(gatherOp, selectOp.getResult());
665 return success();
666 }
667};
668
669struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
670 using Base::Base;
671
672 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
673 PatternRewriter &rewriter) const override {
674 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
675 if (!srcTy)
676 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
677
678 Location loc = scatterOp.getLoc();
679 auto meta = computeMemrefMeta(scatterOp, rewriter);
680 if (meta.first.empty())
681 return rewriter.notifyMatchFailure(scatterOp,
682 "Failed to compute strides");
683
684 Value localOffsets =
685 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
686 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
687
688 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
689 flatMemref, localOffsets, scatterOp.getMask(),
690 /*chunk_size=*/IntegerAttr{},
691 /*l1_hint=*/xegpu::CachePolicyAttr{},
692 /*l2_hint=*/xegpu::CachePolicyAttr{},
693 /*l3_hint=*/xegpu::CachePolicyAttr{},
694 /*layout=*/nullptr);
695 rewriter.eraseOp(scatterOp);
696 return success();
697 }
698};
699
700struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
701 using Base::Base;
702
703 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
704 PatternRewriter &rewriter) const override {
705 Location loc = loadOp.getLoc();
706
707 VectorType vecTy = loadOp.getResult().getType();
708 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
709 return failure();
710
711 // Boundary check is available only for block instructions.
712 bool boundaryCheck = vecTy.getRank() > 1;
713 // By default, no specific caching policy is assigned.
714 xegpu::CachePolicyAttr hint = nullptr;
715
716 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
717 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
718 vecTy.getRank());
719
720 auto descType = xegpu::TensorDescType::get(
721 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
722 boundaryCheck, xegpu::MemorySpace::Global);
723
724 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
725 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
726 auto loadNdOp =
727 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
728 /*packed=*/nullptr, /*transpose=*/nullptr,
729 /*l1_hint=*/hint,
730 /*l2_hint=*/hint, /*l3_hint=*/hint,
731 /*layout=*/nullptr);
732 rewriter.replaceOp(loadOp, loadNdOp);
733
734 return success();
735 }
736};
737
738struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
739 using Base::Base;
740
741 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
742 PatternRewriter &rewriter) const override {
743 Location loc = storeOp.getLoc();
744
745 TypedValue<VectorType> vector = storeOp.getValueToStore();
746 VectorType vecTy = vector.getType();
747 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
748 return failure();
749
750 // Boundary check is available only for block instructions.
751 bool boundaryCheck = vecTy.getRank() > 1;
752
753 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
754 rewriter, loc, storeOp.getBase(),
755 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
756
757 auto descType = xegpu::TensorDescType::get(
758 vecTy.getShape(), vecTy.getElementType(),
759 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
760
761 // By default, no specific caching policy is assigned.
762 xegpu::CachePolicyAttr hint = nullptr;
763 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
764 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
765
766 auto storeNdOp =
767 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
768 /*l1_hint=*/hint,
769 /*l2_hint=*/hint, /*l3_hint=*/hint,
770 /*layout=*/nullptr);
771
772 rewriter.replaceOp(storeOp, storeNdOp);
773
774 return success();
775 }
776};
777
778struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
779 using Base::Base;
780
781 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
782 PatternRewriter &rewriter) const override {
783 Location loc = contractOp.getLoc();
784
785 if (contractOp.getKind() != vector::CombiningKind::ADD)
786 return rewriter.notifyMatchFailure(contractOp,
787 "Expects add combining kind");
788
789 TypedValue<Type> acc = contractOp.getAcc();
790 VectorType accType = dyn_cast<VectorType>(acc.getType());
791 if (!accType || accType.getRank() != 2)
792 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
793
794 // Accept only plain 2D data layout.
795 // VNNI packing is applied to DPAS as a separate lowering step.
796 TypedValue<VectorType> lhs = contractOp.getLhs();
797 TypedValue<VectorType> rhs = contractOp.getRhs();
798 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
799 return rewriter.notifyMatchFailure(contractOp,
800 "Expects lhs and rhs 2D vectors");
801
802 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
803 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
804
805 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
806 TypeRange{contractOp.getResultType()},
807 ValueRange{lhs, rhs, acc});
808 rewriter.replaceOp(contractOp, dpasOp);
809
810 return success();
811 }
812};
813
814struct ConvertVectorToXeGPUPass
815 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
816 void runOnOperation() override {
817 RewritePatternSet patterns(&getContext());
819 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
820 return signalPassFailure();
821 }
822};
823
824} // namespace
825
829 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
830 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
831 patterns.getContext());
832}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...