MLIR 22.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/Pass/Pass.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26#include <algorithm>
27#include <optional>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35
36namespace {
37
38// Return true if value represents a zero constant.
39static bool isZeroConstant(Value val) {
40 auto constant = val.getDefiningOp<arith::ConstantOp>();
41 if (!constant)
42 return false;
43
44 return TypeSwitch<Attribute, bool>(constant.getValue())
45 .Case<FloatAttr>(
46 [](auto floatAttr) { return floatAttr.getValue().isZero(); })
47 .Case<IntegerAttr>(
48 [](auto intAttr) { return intAttr.getValue().isZero(); })
49 .Default(false);
50}
51
52static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53 Operation *op, VectorType vecTy) {
54 // Validate only vector as the basic vector store and load ops guarantee
55 // XeGPU-compatible memref source.
56 unsigned vecRank = vecTy.getRank();
57 if (!(vecRank == 1 || vecRank == 2))
58 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59
60 return success();
61}
62
63static LogicalResult transferPreconditions(PatternRewriter &rewriter,
64 VectorTransferOpInterface xferOp) {
65 if (xferOp.getMask())
66 return rewriter.notifyMatchFailure(xferOp,
67 "Masked transfer is not supported");
68
69 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
70 if (!srcTy)
71 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
72
73 // Validate further transfer op semantics.
75 int64_t offset;
76 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
77 return rewriter.notifyMatchFailure(
78 xferOp, "Buffer must be contiguous in the innermost dimension");
79
80 VectorType vecTy = xferOp.getVectorType();
81 unsigned vecRank = vecTy.getRank();
82 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
83 return rewriter.notifyMatchFailure(
84 xferOp, "Boundary check is available only for block instructions.");
85
86 AffineMap map = xferOp.getPermutationMap();
87 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
88 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
89 unsigned numInputDims = map.getNumInputs();
90 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
91 auto dim = dyn_cast<AffineDimExpr>(expr);
92 if (dim.getPosition() < (numInputDims - vecRank))
93 return rewriter.notifyMatchFailure(
94 xferOp, "Only the innermost dimensions can be accessed");
95 }
96
97 return success();
98}
99
100static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
101 Location loc,
102 xegpu::TensorDescType descType,
104 MemRefType srcTy = src.getType();
105 assert(srcTy.isStrided() && "Expected strided memref type");
106 auto [strides, offset] = srcTy.getStridesAndOffset();
107 bool isStatic = true;
108
109 // Memref is dynamic if any of its shape, offset or strides is dynamic.
110 if (!srcTy.hasStaticShape())
111 isStatic = false;
112
113 if (!ShapedType::isStatic(offset))
114 isStatic = false;
115
116 for (auto stride : strides) {
117 if (!ShapedType::isStatic(stride)) {
118 isStatic = false;
119 break;
120 }
121 }
122
123 xegpu::CreateNdDescOp ndDesc;
124 if (isStatic) {
125 ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
126 } else {
127 // In case of ranked dynamic memref, instead of passing on the memref,
128 // i64 base address, source's offset, shape and strides have to be
129 // explicitly provided.
130 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
131 auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
132 rewriter, loc, meta.getBaseBuffer());
133 auto offset = meta.getOffset();
134 auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
135 auto offsetInBytes = arith::MulIOp::create(
136 rewriter, loc, offset,
137 arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
138 auto adjustedBaseAddr = arith::AddIOp::create(
139 rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
140 auto adjustedAddrI64 = arith::IndexCastOp::create(
141 rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
142 ndDesc = xegpu::CreateNdDescOp::create(
143 rewriter, loc, descType, adjustedAddrI64,
144 meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
145 }
146
147 return ndDesc;
148}
149
150// Adjusts the strides of a memref according to a given permutation map for
151// vector operations.
152//
153// This function updates the innermost strides in the `strides` array to
154// reflect the permutation specified by `permMap`. The permutation is computed
155// using the inverse and broadcasting-aware version of the permutation map,
156// and is applied to the relevant strides. This ensures that memory accesses
157// are consistent with the logical permutation of vector elements.
158//
159// Example:
160// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
161// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
162// 0]), then after calling this function, the last two strides will be
163// swapped:
164// Original strides: [s0, s1, s2, s3]
165// After permutation: [s0, s1, s3, s2]
166//
167static void adjustStridesForPermutation(AffineMap permMap,
168 SmallVectorImpl<Value> &strides) {
169
173 SmallVector<int64_t> perms64(perms.begin(), perms.end());
174 strides = applyPermutation(strides, perms64);
175}
176
177// Computes memory strides and a memref offset for vector transfer operations,
178// handling both static and dynamic memrefs while applying permutation
179// transformations for XeGPU lowering.
180template <
181 typename OpType,
182 typename = std::enable_if_t<llvm::is_one_of<
183 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
184 vector::GatherOp, vector::ScatterOp>::value>>
185static std::pair<SmallVector<Value>, Value>
186computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
187 SmallVector<Value> strides;
188 Value baseMemref = xferOp.getBase();
189 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
190
191 Location loc = xferOp.getLoc();
192 Value offsetVal = nullptr;
193 if (memrefType.hasStaticShape()) {
194 int64_t offset;
195 SmallVector<int64_t> intStrides;
196 if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
197 return {{}, offsetVal};
198 bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
199 return ShapedType::isDynamic(strideVal);
200 });
201
202 if (!hasDynamicStrides)
203 for (int64_t s : intStrides)
204 strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
205
206 if (!ShapedType::isDynamic(offset))
207 offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
208 }
209
210 if (strides.empty() || !offsetVal) {
211 // For dynamic shape memref, use memref.extract_strided_metadata to get
212 // stride values
213 unsigned rank = memrefType.getRank();
214 Type indexType = rewriter.getIndexType();
215
216 // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
217 // size0, size1, ..., sizeN-1]
218 SmallVector<Type> resultTypes;
219 resultTypes.push_back(MemRefType::get(
220 {}, memrefType.getElementType())); // base memref (unranked)
221 resultTypes.push_back(indexType); // offset
222
223 for (unsigned i = 0; i < rank; ++i)
224 resultTypes.push_back(indexType); // strides
225
226 for (unsigned i = 0; i < rank; ++i)
227 resultTypes.push_back(indexType); // sizes
228
229 auto meta = memref::ExtractStridedMetadataOp::create(
230 rewriter, loc, resultTypes, baseMemref);
231
232 if (strides.empty())
233 strides.append(meta.getStrides().begin(), meta.getStrides().end());
234
235 if (!offsetVal)
236 offsetVal = meta.getOffset();
237 }
238
239 if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
240 vector::TransferWriteOp>::value) {
241 AffineMap permMap = xferOp.getPermutationMap();
242 // Adjust strides according to the permutation map (e.g., for transpose)
243 adjustStridesForPermutation(permMap, strides);
244 }
245
246 return {strides, offsetVal};
247}
248
249// This function compute the vectors of localOffsets for scattered load/stores.
250// It is used in the lowering of vector.transfer_read/write to
251// load_gather/store_scatter Example:
252// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
253// %cst {in_bounds = [true, true, true, true]}>} :
254// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
255//
256// %6 = vector.step: vector<4xindex>
257// %7 = vector.step: vector<2xindex>
258// %8 = vector.step: vector<6xindex>
259// %9 = vector.step: vector<32xindex>
260// %10 = arith.mul %6, 384
261// %11 = arith.mul %7, 192
262// %12 = arith.mul %8, 32
263// %13 = arith.mul %9, 1
264// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
265// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
266// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
267// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
268// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
269// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
270// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
271// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
272// %22 = arith.add %18, %19
273// %23 = arith.add %20, %21
274// %local_offsets = arith.add %22, %23
275// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
276// %offsets = memref_offset + orig_offset + local_offsets
277static Value computeOffsets(VectorTransferOpInterface xferOp,
278 PatternRewriter &rewriter, ArrayRef<Value> strides,
279 Value baseOffset) {
280 Location loc = xferOp.getLoc();
281 VectorType vectorType = xferOp.getVectorType();
282 SmallVector<Value> indices(xferOp.getIndices().begin(),
283 xferOp.getIndices().end());
284 ArrayRef<int64_t> vectorShape = vectorType.getShape();
285
286 // Create vector.step operations for each dimension
287 SmallVector<Value> stepVectors;
288 llvm::map_to_vector(vectorShape, [&](int64_t dim) {
289 auto stepType = VectorType::get({dim}, rewriter.getIndexType());
290 auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
291 stepVectors.push_back(stepOp);
292 return stepOp;
293 });
294
295 // Multiply step vectors by corresponding strides
296 size_t memrefRank = strides.size();
297 size_t vectorRank = vectorShape.size();
298 SmallVector<Value> strideMultiplied;
299 for (size_t i = 0; i < vectorRank; ++i) {
300 size_t memrefDim = memrefRank - vectorRank + i;
301 Value strideValue = strides[memrefDim];
302 auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
303 auto bcastOp =
304 vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
305 auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
306 strideMultiplied.push_back(mulOp);
307 }
308
309 // Shape cast each multiplied vector to add singleton dimensions
310 SmallVector<Value> shapeCasted;
311 for (size_t i = 0; i < vectorRank; ++i) {
312 SmallVector<int64_t> newShape(vectorRank, 1);
313 newShape[i] = vectorShape[i];
314 auto newType = VectorType::get(newShape, rewriter.getIndexType());
315 auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
316 strideMultiplied[i]);
317 shapeCasted.push_back(castOp);
318 }
319
320 // Broadcast each shape-casted vector to full vector shape
321 SmallVector<Value> broadcasted;
322 auto fullIndexVectorType =
323 VectorType::get(vectorShape, rewriter.getIndexType());
324 for (Value shapeCastVal : shapeCasted) {
325 auto broadcastOp = vector::BroadcastOp::create(
326 rewriter, loc, fullIndexVectorType, shapeCastVal);
327 broadcasted.push_back(broadcastOp);
328 }
329
330 // Add all broadcasted vectors together to compute local offsets
331 Value localOffsets = broadcasted[0];
332 for (size_t i = 1; i < broadcasted.size(); ++i)
333 localOffsets =
334 arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
335
336 // Compute base offset from transfer read indices
337 for (size_t i = 0; i < indices.size(); ++i) {
338 Value strideVal = strides[i];
339 Value offsetContrib =
340 arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
341 baseOffset =
342 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
343 }
344 // Broadcast base offset to match vector shape
345 Value bcastBase = vector::BroadcastOp::create(
346 rewriter, loc, fullIndexVectorType, baseOffset);
347 localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
348 return localOffsets;
349}
350
351// Compute the element-wise offsets for vector.gather or vector.scatter ops.
352//
353// This function linearizes the base offsets of the gather/scatter operation
354// and combines them with the per-element indices to produce a final vector of
355// memory offsets.
356template <
357 typename OpType,
358 typename = std::enable_if_t<llvm::is_one_of<
359 std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
360static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
361 ArrayRef<Value> strides, Value baseOffset) {
362 Location loc = gatScatOp.getLoc();
363 SmallVector<Value> offsets = gatScatOp.getOffsets();
364 for (size_t i = 0; i < offsets.size(); ++i) {
365 Value offsetContrib =
366 arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
367 baseOffset =
368 arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
369 }
370 Value indices = gatScatOp.getIndices();
371 VectorType vecType = cast<VectorType>(indices.getType());
372
373 Value strideVector =
374 vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
375 .getResult();
376 Value stridedIndices =
377 arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
378
379 Value baseVector =
380 vector::BroadcastOp::create(
381 rewriter, loc,
382 VectorType::get(vecType.getShape(), rewriter.getIndexType()),
383 baseOffset)
384 .getResult();
385 return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
386 .getResult();
387}
388
389// Collapses shapes of a nD memref to the target rank while applying offsets for
390// the collapsed dimensions. Returns the new memref value and the remaining
391// offsets for the last targetRank dimensions. For example:
392// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
393// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
394static std::pair<Value, SmallVector<OpFoldResult>>
395convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
398 int64_t targetRank) {
399 auto memrefType = cast<MemRefType>(memref.getType());
400 unsigned rank = memrefType.getRank();
401
402 if (rank <= targetRank)
403 return {memref, offsets};
404
405 int64_t numCombinedDims = rank - targetRank;
406 SmallVector<OpFoldResult> subviewOffsets;
407 SmallVector<OpFoldResult> subviewSizes;
408 SmallVector<OpFoldResult> subviewStrides;
409
410 // For the combined dimensions: use the provided offsets, size=1, stride=1
411 for (unsigned i = 0; i < numCombinedDims; ++i) {
412 subviewOffsets.push_back(offsets[i]);
413 subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
414 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
415 }
416
417 // For the last targetRank dimensions: offset=0, use full size, stride=1
418 SmallVector<int64_t> resultShape;
419 auto originalShape = memrefType.getShape();
420 auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
421 for (unsigned i = numCombinedDims; i < rank; ++i) {
422 subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
423 if (ShapedType::isDynamic(originalShape[i])) {
424 subviewSizes.push_back(meta.getSizes()[i]);
425 resultShape.push_back(ShapedType::kDynamic);
426 } else {
427 subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
428 resultShape.push_back(originalShape[i]);
429 }
430 subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
431 }
432
433 auto resultType = memref::SubViewOp::inferRankReducedResultType(
434 resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
435 auto subviewOp =
436 memref::SubViewOp::create(rewriter, loc, resultType, memref,
437 subviewOffsets, subviewSizes, subviewStrides);
438
439 // Return the remaining offsets for the last targetRank dimensions
440 SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
441 offsets.end());
442 return {subviewOp.getResult(), newOffsets};
443}
444
445template <
446 typename OpType,
447 typename = std::enable_if_t<llvm::is_one_of<
448 std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
449 vector::GatherOp, vector::ScatterOp>::value>>
450// Convert memref to i64 base pointer
451static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
452 Location loc = xferOp.getLoc();
453 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
454 rewriter, loc, xferOp.getBase())
455 .getResult();
456 return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
457 indexPtr)
458 .getResult();
459}
460
461static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
462 PatternRewriter &rewriter) {
463
464 Location loc = readOp.getLoc();
465 VectorType vectorType = readOp.getVectorType();
466 ArrayRef<int64_t> vectorShape = vectorType.getShape();
467 auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
468 if (!memrefType)
469 return rewriter.notifyMatchFailure(readOp, "Expected memref source");
470
471 auto meta = computeMemrefMeta(readOp, rewriter);
472 if (meta.first.empty())
473 return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
474
475 Value localOffsets =
476 computeOffsets(readOp, rewriter, meta.first, meta.second);
477
478 Value flatMemref = memrefToIndexPtr(readOp, rewriter);
479
480 Value mask = vector::ConstantMaskOp::create(
481 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
483 auto gatherOp = xegpu::LoadGatherOp::create(
484 rewriter, loc, vectorType, flatMemref, localOffsets, mask,
485 /*chunk_size=*/IntegerAttr{},
486 /*l1_hint=*/xegpu::CachePolicyAttr{},
487 /*l2_hint=*/xegpu::CachePolicyAttr{},
488 /*l3_hint=*/xegpu::CachePolicyAttr{},
489 /*layout=*/nullptr);
490
491 rewriter.replaceOp(readOp, gatherOp.getResult());
492 return success();
493}
494
495static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
496 PatternRewriter &rewriter) {
497
498 Location loc = writeOp.getLoc();
499 VectorType vectorType = writeOp.getVectorType();
500 ArrayRef<int64_t> vectorShape = vectorType.getShape();
501
502 auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
503 if (!memrefType)
504 return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
505
506 auto meta = computeMemrefMeta(writeOp, rewriter);
507 if (meta.first.empty())
508 return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
509
510 Value localOffsets =
511 computeOffsets(writeOp, rewriter, meta.first, meta.second);
512
513 Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
514
515 Value mask = vector::ConstantMaskOp::create(
516 rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
518 xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
519 localOffsets, mask,
520 /*chunk_size=*/IntegerAttr{},
521 /*l1_hint=*/xegpu::CachePolicyAttr{},
522 /*l2_hint=*/xegpu::CachePolicyAttr{},
523 /*l3_hint=*/xegpu::CachePolicyAttr{},
524 /*layout=*/nullptr);
525 rewriter.eraseOp(writeOp);
526 return success();
527}
528
529struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
530 using Base::Base;
531
532 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
533 PatternRewriter &rewriter) const override {
534 Location loc = readOp.getLoc();
535
536 if (failed(transferPreconditions(rewriter, readOp)))
537 return failure();
538
539 // TODO:This check needs to be replaced with proper uArch capability check
540 auto chip = xegpu::getChipStr(readOp);
541 if (chip != "pvc" && chip != "bmg") {
542 // lower to scattered load Op if the target HW doesn't have 2d block load
543 // support
544 // TODO: add support for OutOfBound access
545 if (readOp.hasOutOfBoundsDim())
546 return failure();
547 return lowerToScatteredLoadOp(readOp, rewriter);
548 }
549
550 VectorType vecTy = readOp.getVectorType();
551
552 // Lower using load.gather in 1D case
553 if (vecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
554 return lowerToScatteredLoadOp(readOp, rewriter);
555
556 // Perform common data transfer checks.
557 if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
558 return failure();
559
560 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
561 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
562 return rewriter.notifyMatchFailure(
563 readOp, "Unsupported non-zero padded out-of-bounds read");
564
565 AffineMap readMap = readOp.getPermutationMap();
566 bool isTransposeLoad = !readMap.isMinorIdentity();
567
568 Type elementType = vecTy.getElementType();
569 unsigned minTransposeBitWidth = 32;
570 if (isTransposeLoad &&
571 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
572 return rewriter.notifyMatchFailure(
573 readOp, "Unsupported data type for transposition");
574
575 // If load is transposed, get the base shape for the tensor descriptor.
576 SmallVector<int64_t> descShape(vecTy.getShape());
577 if (isTransposeLoad)
578 std::reverse(descShape.begin(), descShape.end());
579 auto descType = xegpu::TensorDescType::get(
580 descShape, elementType, /*array_length=*/1,
581 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
582
583 DenseI64ArrayAttr transposeAttr =
584 !isTransposeLoad ? nullptr
586 ArrayRef<int64_t>{1, 0});
587 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
588 rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
589 vecTy.getRank());
590 // By default, no specific caching policy is assigned.
591 xegpu::CachePolicyAttr hint = nullptr;
592 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
593 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
594
595 auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
596 /*packed=*/nullptr, transposeAttr,
597 /*l1_hint=*/hint,
598 /*l2_hint=*/hint, /*l3_hint=*/hint,
599 /*layout=*/nullptr);
600 rewriter.replaceOp(readOp, loadOp);
601
602 return success();
603 }
604};
605
606struct TransferWriteLowering
607 : public OpRewritePattern<vector::TransferWriteOp> {
608 using Base::Base;
609
610 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
611 PatternRewriter &rewriter) const override {
612 Location loc = writeOp.getLoc();
613
614 if (failed(transferPreconditions(rewriter, writeOp)))
615 return failure();
616
617 // TODO:This check needs to be replaced with proper uArch capability check
618 auto chip = xegpu::getChipStr(writeOp);
619 if (chip != "pvc" && chip != "bmg") {
620 // lower to scattered store Op if the target HW doesn't have 2d block
621 // store support
622 // TODO: add support for OutOfBound access
623 if (writeOp.hasOutOfBoundsDim())
624 return failure();
625 return lowerToScatteredStoreOp(writeOp, rewriter);
626 }
627
628 // Perform common data transfer checks.
629 VectorType vecTy = writeOp.getVectorType();
630 if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
631 return failure();
632
633 AffineMap map = writeOp.getPermutationMap();
634 if (!map.isMinorIdentity())
635 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
636
637 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
638 rewriter, loc, writeOp.getBase(),
639 getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
640
641 auto descType = xegpu::TensorDescType::get(
642 vecTy.getShape(), vecTy.getElementType(),
643 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
644 xegpu::MemorySpace::Global);
645 // By default, no specific caching policy is assigned.
646 xegpu::CachePolicyAttr hint = nullptr;
647 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
648 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
649
650 auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
651 ndDesc, indices,
652 /*l1_hint=*/hint,
653 /*l2_hint=*/hint, /*l3_hint=*/hint,
654 /*layout=*/nullptr);
655 rewriter.replaceOp(writeOp, storeOp);
656
657 return success();
658 }
659};
660
661struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
662 using Base::Base;
663
664 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
665 PatternRewriter &rewriter) const override {
666 auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
667 if (!srcTy)
668 return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
669
670 Location loc = gatherOp.getLoc();
671 VectorType vectorType = gatherOp.getVectorType();
672
673 auto meta = computeMemrefMeta(gatherOp, rewriter);
674 if (meta.first.empty())
675 return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
676
677 Value localOffsets =
678 computeOffsets(rewriter, gatherOp, meta.first, meta.second);
679 Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
680
681 auto xeGatherOp = xegpu::LoadGatherOp::create(
682 rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
683 /*chunk_size=*/IntegerAttr{},
684 /*l1_hint=*/xegpu::CachePolicyAttr{},
685 /*l2_hint=*/xegpu::CachePolicyAttr{},
686 /*l3_hint=*/xegpu::CachePolicyAttr{},
687 /*layout=*/nullptr);
688
689 auto selectOp =
690 arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
691 xeGatherOp.getResult(), gatherOp.getPassThru());
692 rewriter.replaceOp(gatherOp, selectOp.getResult());
693 return success();
694 }
695};
696
697struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
698 using Base::Base;
699
700 LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
701 PatternRewriter &rewriter) const override {
702 auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
703 if (!srcTy)
704 return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
705
706 Location loc = scatterOp.getLoc();
707 auto meta = computeMemrefMeta(scatterOp, rewriter);
708 if (meta.first.empty())
709 return rewriter.notifyMatchFailure(scatterOp,
710 "Failed to compute strides");
711
712 Value localOffsets =
713 computeOffsets(rewriter, scatterOp, meta.first, meta.second);
714 Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
715
716 xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
717 flatMemref, localOffsets, scatterOp.getMask(),
718 /*chunk_size=*/IntegerAttr{},
719 /*l1_hint=*/xegpu::CachePolicyAttr{},
720 /*l2_hint=*/xegpu::CachePolicyAttr{},
721 /*l3_hint=*/xegpu::CachePolicyAttr{},
722 /*layout=*/nullptr);
723 rewriter.eraseOp(scatterOp);
724 return success();
725 }
726};
727
728struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
729 using Base::Base;
730
731 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
732 PatternRewriter &rewriter) const override {
733 Location loc = loadOp.getLoc();
734
735 VectorType vecTy = loadOp.getResult().getType();
736 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
737 return failure();
738
739 // Boundary check is available only for block instructions.
740 bool boundaryCheck = vecTy.getRank() > 1;
741 // By default, no specific caching policy is assigned.
742 xegpu::CachePolicyAttr hint = nullptr;
743
744 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
745 rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
746 vecTy.getRank());
747
748 auto descType = xegpu::TensorDescType::get(
749 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
750 boundaryCheck, xegpu::MemorySpace::Global);
751
752 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
753 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
754 auto loadNdOp =
755 xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
756 /*packed=*/nullptr, /*transpose=*/nullptr,
757 /*l1_hint=*/hint,
758 /*l2_hint=*/hint, /*l3_hint=*/hint,
759 /*layout=*/nullptr);
760 rewriter.replaceOp(loadOp, loadNdOp);
761
762 return success();
763 }
764};
765
766struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
767 using Base::Base;
768
769 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
770 PatternRewriter &rewriter) const override {
771 Location loc = storeOp.getLoc();
772
773 TypedValue<VectorType> vector = storeOp.getValueToStore();
774 VectorType vecTy = vector.getType();
775 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
776 return failure();
777
778 // Boundary check is available only for block instructions.
779 bool boundaryCheck = vecTy.getRank() > 1;
780
781 auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
782 rewriter, loc, storeOp.getBase(),
783 getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
784
785 auto descType = xegpu::TensorDescType::get(
786 vecTy.getShape(), vecTy.getElementType(),
787 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
788
789 // By default, no specific caching policy is assigned.
790 xegpu::CachePolicyAttr hint = nullptr;
791 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
792 rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
793
794 auto storeNdOp =
795 xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
796 /*l1_hint=*/hint,
797 /*l2_hint=*/hint, /*l3_hint=*/hint,
798 /*layout=*/nullptr);
799
800 rewriter.replaceOp(storeOp, storeNdOp);
801
802 return success();
803 }
804};
805
806struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
807 using Base::Base;
808
809 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
810 PatternRewriter &rewriter) const override {
811 Location loc = contractOp.getLoc();
812
813 if (contractOp.getKind() != vector::CombiningKind::ADD)
814 return rewriter.notifyMatchFailure(contractOp,
815 "Expects add combining kind");
816
817 TypedValue<Type> acc = contractOp.getAcc();
818 VectorType accType = dyn_cast<VectorType>(acc.getType());
819 if (!accType || accType.getRank() != 2)
820 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
821
822 // Accept only plain 2D data layout.
823 // VNNI packing is applied to DPAS as a separate lowering step.
824 TypedValue<VectorType> lhs = contractOp.getLhs();
825 TypedValue<VectorType> rhs = contractOp.getRhs();
826 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
827 return rewriter.notifyMatchFailure(contractOp,
828 "Expects lhs and rhs 2D vectors");
829
830 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
831 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
832
833 auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
834 TypeRange{contractOp.getResultType()},
835 ValueRange{lhs, rhs, acc});
836 rewriter.replaceOp(contractOp, dpasOp);
837
838 return success();
839 }
840};
841
842struct ConvertVectorToXeGPUPass
843 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
844 void runOnOperation() override {
845 RewritePatternSet patterns(&getContext());
847 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
848 return signalPassFailure();
849 }
850};
851
852} // namespace
853
857 .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
858 ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
859 patterns.getContext());
860}
return success()
lhs
b getContext())
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumInputs() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int64_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...