MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
33#include "mlir/IR/Types.h"
34
35#include "llvm/ADT/TypeSwitch.h"
36
37#include <numeric>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48// TODO: Below are uArch dependent values, should move away from hardcoding
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
51
52// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
53enum class NdTdescOffset : uint32_t {
54 BasePtr = 0, // Base pointer (i64)
55 BaseShapeW = 2, // Base shape width (i32)
56 BaseShapeH = 3, // Base shape height (i32)
57 BasePitch = 4, // Base pitch (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70/// Checks if the given MemRefType refers to shared memory.
71static bool isSharedMemRef(const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
73 if (!attr)
74 return false;
75 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
77 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
80}
81
82// Get same bitwidth flat vector type of new element type.
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 Type toElemType) {
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
87 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
88 const int size =
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
91}
92
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 // If no hints are provided, use the default cache control.
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
99 // If only one of the hints is provided, use the default for the other level.
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
102 switch (L1hintVal) {
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
108 else
109 llvm_unreachable("Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
115 else
116 llvm_unreachable("Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
122 else
123 llvm_unreachable("Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
126 default:
127 llvm_unreachable("Unsupported cache control.");
128 }
129}
130
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
134 // If no hints are provided, use the default cache control.
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
137 // If only one of the hints is provided, use the default for the other level.
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
140 switch (L1hintVal) {
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
146 else
147 llvm_unreachable("Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
153 else
154 llvm_unreachable("Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
160 else
161 llvm_unreachable("Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
167 else
168 llvm_unreachable("Unsupported cache control.");
169 default:
170 llvm_unreachable("Unsupported cache control.");
171 }
172}
173
174//
175// Note:
176// Block operations for tile of sub byte element types are handled by
177// emulating with larger element types.
178// Tensor descriptor are keep intact and only ops consuming them are
179// emulated
180//
181
182class CreateNdDescToXeVMPattern
183 : public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
185 LogicalResult
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
191 // Op is lowered to a code sequence that populates payload.
192 // Payload is a 8xi32 vector. Offset to individual fields are defined in
193 // NdTdescOffset enum.
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
197 // 4xi64 view is used for inserting the base pointer.
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
199 // Initialize payload to zero.
200 Value payload = arith::ConstantOp::create(
201 rewriter, loc,
202 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
203
204 Value baseAddr;
205 Value baseShapeW;
206 Value baseShapeH;
207
208 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
211 // Descriptor shape is expected to be 2D.
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
215 // If source is a memref, we need to extract the aligned pointer as index.
216 // Pointer type is passed as i32 or i64 by type converter.
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
220 }
221 // Access adaptor after failure check to avoid rolling back generated code
222 // for materialization cast.
223 baseAddr = adaptor.getSource();
224 } else {
225 baseAddr = adaptor.getSource();
226 if (baseAddr.getType() != i64Ty) {
227 // Pointer type may be i32. Cast to i64 if needed.
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
229 }
230 }
231 // 1D tensor descriptor is just the base address.
232 if (rank == 1) {
233 rewriter.replaceOp(op, baseAddr);
234 return success();
235 }
236 // Utility for creating offset values from op fold result.
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
239 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
240 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
241 return val;
242 };
243 // Get shape values from op fold results.
244 baseShapeW = createOffset(mixedSizes, 1);
245 baseShapeH = createOffset(mixedSizes, 0);
246 // Get pitch value from op fold results.
247 Value basePitch = createOffset(mixedStrides, 0);
248 // Populate payload.
249 Value payLoadAsI64 =
250 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
251 payLoadAsI64 =
252 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
253 static_cast<int>(NdTdescOffset::BasePtr));
254 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
255 payload =
256 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
257 static_cast<int>(NdTdescOffset::BaseShapeW));
258 payload =
259 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeH));
261 payload =
262 vector::InsertOp::create(rewriter, loc, basePitch, payload,
263 static_cast<int>(NdTdescOffset::BasePitch));
264 rewriter.replaceOp(op, payload);
265 return success();
266 }
267};
268
269template <
270 typename OpType,
271 typename = std::enable_if_t<llvm::is_one_of<
272 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
273class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
274 using OpConversionPattern<OpType>::OpConversionPattern;
275 LogicalResult
276 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 auto mixedOffsets = op.getMixedOffsets();
279 int64_t opOffsetsSize = mixedOffsets.size();
280 auto loc = op.getLoc();
281 auto ctxt = rewriter.getContext();
282
283 auto tdesc = adaptor.getTensorDesc();
284 auto tdescTy = op.getTensorDescType();
285 auto tileRank = tdescTy.getRank();
286 if (opOffsetsSize != tileRank)
287 return rewriter.notifyMatchFailure(
288 op, "Expected offset rank to match descriptor rank.");
289 auto elemType = tdescTy.getElementType();
290 auto elemBitSize = elemType.getIntOrFloatBitWidth();
291 bool isSubByte = elemBitSize < 8;
292 uint64_t wScaleFactor = 1;
293
294 if (!isSubByte && (elemBitSize % 8 != 0))
295 return rewriter.notifyMatchFailure(
296 op, "Expected element type bit width to be multiple of 8.");
297 auto tileW = tdescTy.getDimSize(tileRank - 1);
298 // For sub byte types, only 4bits are currently supported.
299 if (isSubByte) {
300 if (elemBitSize != 4)
301 return rewriter.notifyMatchFailure(
302 op, "Only sub byte types of 4bits are supported.");
303 if (tileRank != 2)
304 return rewriter.notifyMatchFailure(
305 op, "Sub byte types are only supported for 2D tensor descriptors.");
306 auto subByteFactor = 8 / elemBitSize;
307 auto tileH = tdescTy.getDimSize(0);
308 // Handle special case for packed load.
309 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
310 if (op.getPacked().value_or(false)) {
311 // packed load is implemented as packed loads of 8bit elements.
312 if (tileH == systolicDepth * 4 &&
313 tileW == executionSize * subByteFactor) {
314 // Usage case for loading as Matrix B with pack request.
315 // source is assumed to pre-packed into 8bit elements
316 // Emulate with 8bit loads with pack request.
317 // scaled_tileW = executionSize
318 elemType = rewriter.getIntegerType(8);
319 tileW = executionSize;
320 wScaleFactor = subByteFactor;
321 }
322 }
323 }
324 // If not handled by packed load case above, handle other cases.
325 if (wScaleFactor == 1) {
326 auto sub16BitFactor = subByteFactor * 2;
327 if (tileW == executionSize * sub16BitFactor) {
328 // Usage case for loading as Matrix A operand
329 // Emulate with 16bit loads/stores.
330 // scaled_tileW = executionSize
331 elemType = rewriter.getIntegerType(16);
332 tileW = executionSize;
333 wScaleFactor = sub16BitFactor;
334 } else {
335 return rewriter.notifyMatchFailure(
336 op, "Unsupported tile shape for sub byte types.");
337 }
338 }
339 // recompute element bit size for emulation.
340 elemBitSize = elemType.getIntOrFloatBitWidth();
341 }
342
343 // Get address space from tensor descriptor memory space.
344 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
345 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
346 if (tileRank == 2) {
347 // Compute element byte size.
348 Value elemByteSize = arith::ConstantIntOp::create(
349 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
350 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
351 Value payLoadAsI64 =
352 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
353 Value basePtr =
354 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
355 static_cast<int>(NdTdescOffset::BasePtr));
356 Value baseShapeW = vector::ExtractOp::create(
357 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
358 Value baseShapeH = vector::ExtractOp::create(
359 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
360 Value basePitch = vector::ExtractOp::create(
361 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
362 // Offsets are provided by the op.
363 // convert them to i32.
364 Value offsetW =
365 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
366 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
367 rewriter.getI32Type(), offsetW);
368 Value offsetH =
369 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
370 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
371 rewriter.getI32Type(), offsetH);
372 // Convert base pointer (i64) to LLVM pointer type.
373 Value basePtrLLVM =
374 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
375 // FIXME: width or pitch is not the same as baseShapeW it should be the
376 // stride of the second to last dimension in row major layout.
377 // Compute width in bytes.
378 Value baseShapeWInBytes =
379 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
380 // Compute pitch in bytes.
381 Value basePitchBytes =
382 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
383
384 if (wScaleFactor > 1) {
385 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
386 // Note: tileW is already scaled above.
387 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
388 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
389 baseShapeWInBytes = arith::ShRSIOp::create(
390 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
391 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
392 wScaleFactorValLog2);
393 offsetW =
394 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
395 }
396 // Get tile height from the tensor descriptor type.
397 auto tileH = tdescTy.getDimSize(0);
398 // Get vblocks from the tensor descriptor type.
399 int32_t vblocks = tdescTy.getArrayLength();
400 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
401 Value src = adaptor.getValue();
402 // If store value is a scalar, get value from op instead of adaptor.
403 // Adaptor might have optimized away single element vector
404 if (src.getType().isIntOrFloat()) {
405 src = op.getValue();
406 }
407 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
408 if (!srcVecTy)
409 return rewriter.notifyMatchFailure(
410 op, "Expected store value to be a vector type.");
411 // Get flat vector type of integer type with matching element bit size.
412 VectorType newSrcVecTy =
413 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
414 if (srcVecTy != newSrcVecTy)
415 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
416 auto storeCacheControl =
417 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
418 xevm::BlockStore2dOp::create(
419 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
420 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
421 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
422 rewriter.eraseOp(op);
423 } else {
424 auto loadCacheControl =
425 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
426 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
427 xevm::BlockPrefetch2dOp::create(
428 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
429 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
430 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
431 rewriter.eraseOp(op);
432 } else {
433 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
434 const bool vnni = op.getPacked().value_or(false);
435 auto transposeValue = op.getTranspose();
436 bool transpose =
437 transposeValue.has_value() && transposeValue.value()[0] == 1;
438 VectorType loadedTy = encodeVectorTypeTo(
439 dstVecTy, vnni ? rewriter.getI32Type()
440 : rewriter.getIntegerType(elemBitSize));
441
442 Value resultFlatVec = xevm::BlockLoad2dOp::create(
443 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
444 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
445 tileH, vblocks, transpose, vnni,
446 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
447 resultFlatVec = vector::BitCastOp::create(
448 rewriter, loc,
449 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
450 resultFlatVec);
451 rewriter.replaceOp(op, resultFlatVec);
452 }
453 }
454 } else {
455 // 1D tensor descriptor.
456 // `tdesc` represents base address as i64
457 // Offset in number of elements, need to multiply by element byte size.
458 // Compute byte offset.
459 // byteOffset = offset * elementByteSize
460 Value offset =
461 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
462 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
463 rewriter.getI64Type(), offset);
464 // Compute element byte size.
465 Value elemByteSize = arith::ConstantIntOp::create(
466 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
467 Value byteOffset =
468 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
469 // Final address = basePtr + byteOffset
470 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
471 loc, tdesc,
472 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
473 byteOffset));
474 // Convert base pointer (i64) to LLVM pointer type.
475 Value finalPtrLLVM =
476 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
477 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
478 Value src = adaptor.getValue();
479 // If store value is a scalar, get value from op instead of adaptor.
480 // Adaptor might have optimized away single element vector
481 if (src.getType().isIntOrFloat()) {
482 src = op.getValue();
483 }
484 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
485 if (!srcVecTy)
486 return rewriter.notifyMatchFailure(
487 op, "Expected store value to be a vector type.");
488 // Get flat vector type of integer type with matching element bit size.
489 VectorType newSrcVecTy =
490 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
491 if (srcVecTy != newSrcVecTy)
492 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
493 auto storeCacheControl =
494 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
496 op, finalPtrLLVM, src,
497 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
498 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
499 auto loadCacheControl =
500 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
501 VectorType resTy = cast<VectorType>(op.getValue().getType());
502 VectorType loadedTy =
503 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
504 Value load = xevm::BlockLoadOp::create(
505 rewriter, loc, loadedTy, finalPtrLLVM,
506 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
507 if (loadedTy != resTy)
508 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
509 rewriter.replaceOp(op, load);
510 } else {
511 return rewriter.notifyMatchFailure(
512 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
513 "descriptor rank == 1");
514 }
515 }
516 return success();
517 }
518};
519
520// Add a builder that creates
521// offset * elemByteSize + baseAddr
522static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
523 Location loc, Value baseAddr, Value offset,
524 int64_t elemByteSize) {
526 rewriter, loc, baseAddr.getType(), elemByteSize);
527 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
528 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
529 return newAddr;
530}
531
532template <typename OpType,
533 typename = std::enable_if_t<llvm::is_one_of<
534 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
535class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
536 using OpConversionPattern<OpType>::OpConversionPattern;
537 LogicalResult
538 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter) const override {
540 Value offset = adaptor.getOffsets();
541 if (!offset)
542 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
543 auto loc = op.getLoc();
544 auto ctxt = rewriter.getContext();
545 Value basePtrI64;
546 // Load result or Store valye Type can be vector or scalar.
547 Type valOrResTy;
548 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
549 valOrResTy =
550 this->getTypeConverter()->convertType(op.getResult().getType());
551 else
552 valOrResTy = adaptor.getValue().getType();
553 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
554 bool hasScalarVal = !valOrResVecTy;
555 int64_t elemBitWidth =
556 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
557 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
558 // Element type must be multiple of 8 bits.
559 if (elemBitWidth % 8 != 0)
560 return rewriter.notifyMatchFailure(
561 op, "Expected element type bit width to be multiple of 8.");
562 int64_t elemByteSize = elemBitWidth / 8;
563 // Default memory space is global.
564 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
566 // Base pointer can come from source (load) or dest (store).
567 // If they are memrefs, we use their memory space.
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
572 if (addrSpace != 0)
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
574 }
575 } else {
576 basePtrI64 = adaptor.getDest();
577 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
579 if (addrSpace != 0)
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
581 }
582 }
583 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
584 if (basePtrI64.getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
586 basePtrI64);
587 }
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.getType())) {
590 // Offset needs be scalar. Single element vector is converted to scalar
591 // by type converter.
592 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
593 } else {
594 // If offset is provided, we add them to the base pointer.
595 // Offset is in number of elements, we need to multiply by
596 // element byte size.
597 basePtrI64 =
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
599 }
600 // Convert base pointer (i64) to LLVM pointer type.
601 Value basePtrLLVM =
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
603
604 Value maskForLane;
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
606 if (maskVecTy) {
607 // Mask needs be scalar. Single element vector is converted to scalar by
608 // type converter.
609 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
610 } else
611 maskForLane = mask;
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane, true, true);
615 // If mask is true,- then clause - load from memory and yield.
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
617 if (!hasScalarVal)
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
620 Value loaded =
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
622 // Set cache control attribute on the load operation.
623 loaded.getDefiningOp()->setAttr(
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
629 // If mask is false - else clause -yield a vector of zeros.
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
631 TypedAttr eVal;
632 if (eTy.isFloat())
633 eVal = FloatAttr::get(eTy, 0.0);
634 else
635 eVal = IntegerAttr::get(eTy, 0);
636 if (hasScalarVal)
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
638 else
639 loaded = arith::ConstantOp::create(
640 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
641 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
643 } else {
644 // If mask is true, perform the store.
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
648 auto storeOp =
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
650 // Set cache control attribute on the store operation.
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
656 }
657 return success();
658 }
659};
660
661class CreateMemDescOpPattern final
662 : public OpConversionPattern<xegpu::CreateMemDescOp> {
663public:
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
665 LogicalResult
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668
669 rewriter.replaceOp(op, adaptor.getSource());
670 return success();
671 }
672};
673
674template <typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
679 LogicalResult
680 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter) const override {
682
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
684 if (offsets.empty())
685 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
686
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
691 // Load result or Store value Type can be vector or scalar.
692 Type dataTy;
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
695 // Some transforms may leave unit dimension in the 2D vector, adaptors do
696 // not catch it for results.
697 if (auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
703 }
704 dataTy = resType;
705 } else
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
708 if (!valOrResVecTy)
709 valOrResVecTy = VectorType::get(1, dataTy);
710
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
713 // Element type must be multiple of 8 bits.
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op, "Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
718
719 // Default memory space is SLM.
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
722
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
724
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
730
731 // convert base pointer (i32) to LLVM pointer type
732 Value basePtrLLVM =
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
734
735 if (op.getSubgroupBlockIoAttr()) {
736 // if the attribute 'subgroup_block_io' is set to true, it lowers to
737 // xevm.blockload
738
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
742
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
744 Value loadOp =
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
747 loadOp =
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
749 }
750 rewriter.replaceOp(op, loadOp);
751 } else {
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
754 dataToStore =
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
756 }
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
758 nullptr);
759 rewriter.eraseOp(op);
760 }
761 return success();
762 }
763
764 if (valOrResVecTy.getNumElements() >= 1) {
765 auto chipOpt = xegpu::getChipStr(op);
766 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
767 // the lowering for chunk load only works for pvc and bmg
768 return rewriter.notifyMatchFailure(
769 op, "The lowering is specific to pvc or bmg.");
770 }
771 }
772
773 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
774 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
775 // operation. LLVM load/store does not support vector of size 1, so we
776 // need to handle this case separately.
777 auto scalarTy = valOrResVecTy.getElementType();
778 LLVM::LoadOp loadOp;
779 if (valOrResVecTy.getNumElements() == 1)
780 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
781 else
782 loadOp =
783 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
784 rewriter.replaceOp(op, loadOp);
785 } else {
786 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
787 rewriter.eraseOp(op);
788 }
789 return success();
790 }
791};
792
793class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
794 using OpConversionPattern::OpConversionPattern;
795 LogicalResult
796 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
797 ConversionPatternRewriter &rewriter) const override {
798 auto loc = op.getLoc();
799 auto ctxt = rewriter.getContext();
800 Value basePtrI64 = adaptor.getSource();
801 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
802 if (basePtrI64.getType() != rewriter.getI64Type())
803 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
804 basePtrI64);
805 Value offsets = adaptor.getOffsets();
806 if (offsets) {
807 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
808 if (offsetsVecTy) {
809 // Offset needs be scalar.
810 return rewriter.notifyMatchFailure(op,
811 "Expected offsets to be a scalar.");
812 } else {
813 int64_t elemBitWidth{0};
814 int64_t elemByteSize;
815 // Element byte size can come from two sources:
816 if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
817 // If memref is available, we use its element type to
818 // determine element byte size.
819 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
820 } else {
821 // Otherwise, we use the provided offset byte alignment.
822 elemByteSize = *op.getOffsetAlignByte();
823 }
824 if (elemBitWidth != 0) {
825 if (elemBitWidth % 8 != 0)
826 return rewriter.notifyMatchFailure(
827 op, "Expected element type bit width to be multiple of 8.");
828 elemByteSize = elemBitWidth / 8;
829 }
830 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
831 elemByteSize);
832 }
833 }
834 // Default memory space is global.
835 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
836 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
837 // If source is a memref, we use its memory space.
838 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
839 auto addrSpace = memRefTy.getMemorySpaceAsInt();
840 if (addrSpace != 0)
841 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
842 }
843 // Convert base pointer (i64) to LLVM pointer type.
844 Value ptrLLVM =
845 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
846 // Create the prefetch op with cache control attribute.
847 xevm::PrefetchOp::create(
848 rewriter, loc, ptrLLVM,
849 xevm::LoadCacheControlAttr::get(
850 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
851 rewriter.eraseOp(op);
852 return success();
853 }
854};
855
856class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
857 using OpConversionPattern::OpConversionPattern;
858 LogicalResult
859 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
860 ConversionPatternRewriter &rewriter) const override {
861 auto loc = op.getLoc();
862 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
863 switch (op.getFenceScope()) {
864 case xegpu::FenceScope::Workgroup:
865 memScope = xevm::MemScope::WORKGROUP;
866 break;
867 case xegpu::FenceScope::GPU:
868 memScope = xevm::MemScope::DEVICE;
869 break;
870 }
871 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
872 switch (op.getMemoryKind()) {
873 case xegpu::MemorySpace::Global:
874 addrSpace = xevm::AddrSpace::GLOBAL;
875 break;
876 case xegpu::MemorySpace::SLM:
877 addrSpace = xevm::AddrSpace::SHARED;
878 break;
879 }
880 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
881 rewriter.eraseOp(op);
882 return success();
883 }
884};
885
886class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
887 using OpConversionPattern::OpConversionPattern;
888 LogicalResult
889 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
890 ConversionPatternRewriter &rewriter) const override {
891 auto loc = op.getLoc();
892 auto ctxt = rewriter.getContext();
893 auto aTy = cast<VectorType>(op.getLhs().getType());
894 auto bTy = cast<VectorType>(op.getRhs().getType());
895 auto resultType = cast<VectorType>(op.getResultType());
896
897 // get the correct dpasInst by getting info from chip
898 auto chipStr = xegpu::getChipStr(op);
899 if (!chipStr)
900 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
901
902 const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
903 if (!uArch)
904 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
905
906 auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
907 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
908 uArch->getInstruction(
909 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
910 if (!dpasInst)
911 return rewriter.notifyMatchFailure(op,
912 "DPAS not supported by target uArch");
913
914 auto checkSupportedTypes = [&](VectorType vecTy,
915 xegpu::uArch::MMAOpndKind kind) -> bool {
916 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
917 return llvm::find(supported, vecTy.getElementType()) != supported.end();
918 };
919
920 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
921 return rewriter.notifyMatchFailure(
922 op, "A-matrix element type not supported by target uArch");
923 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
924 return rewriter.notifyMatchFailure(
925 op, "B-matrix element type not supported by target uArch");
926 // NOTE: Supported types for MatrixC and MatrixD are identical
927 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
928 return rewriter.notifyMatchFailure(
929 op, "result/accumulator element type not supported by target uArch");
930
931 auto encodePrecision = [&](Type type) -> xevm::ElemType {
932 if (type == rewriter.getBF16Type())
933 return xevm::ElemType::BF16;
934 else if (type == rewriter.getF16Type())
935 return xevm::ElemType::F16;
936 else if (type == rewriter.getTF32Type())
937 return xevm::ElemType::TF32;
938 else if (type.isInteger(8)) {
939 if (type.isUnsignedInteger())
940 return xevm::ElemType::U8;
941 return xevm::ElemType::S8;
942 } else if (type == rewriter.getF32Type())
943 return xevm::ElemType::F32;
944 else if (type.isInteger(32))
945 return xevm::ElemType::S32;
946 llvm_unreachable("add more support for ElemType");
947 };
948 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
949 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
950 Value c = op.getAcc();
951 if (!c) {
952 auto elementTy = resultType.getElementType();
953 Attribute initValueAttr;
954 if (isa<FloatType>(elementTy))
955 initValueAttr = FloatAttr::get(elementTy, 0.0);
956 else
957 initValueAttr = IntegerAttr::get(elementTy, 0);
958 c = arith::ConstantOp::create(
959 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
960 }
961
962 Value aVec = op.getLhs();
963 Value bVec = op.getRhs();
964 auto cvecty = cast<VectorType>(c.getType());
965 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
966 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
967 VectorType cNty =
968 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
969 if (cvecty != cNty)
970 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
971 Value dpasRes = xevm::MMAOp::create(
972 rewriter, loc, cNty, aVec, bVec, c,
973 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
974 systolicDepth *
975 getNumOperandsPerDword(precATy)),
976 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
977 if (cvecty != cNty)
978 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
979 rewriter.replaceOp(op, dpasRes);
980 return success();
981 }
982
983private:
984 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
985 switch (pTy) {
986 case xevm::ElemType::TF32:
987 return 1;
988 case xevm::ElemType::BF16:
989 case xevm::ElemType::F16:
990 return 2;
991 case xevm::ElemType::U8:
992 case xevm::ElemType::S8:
993 return 4;
994 default:
995 llvm_unreachable("unsupported xevm::ElemType");
996 }
997 }
998};
999
1000static std::optional<LLVM::AtomicBinOp>
1001matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1002 switch (arithKind) {
1003 case arith::AtomicRMWKind::addf:
1004 return LLVM::AtomicBinOp::fadd;
1005 case arith::AtomicRMWKind::addi:
1006 return LLVM::AtomicBinOp::add;
1007 case arith::AtomicRMWKind::assign:
1008 return LLVM::AtomicBinOp::xchg;
1009 case arith::AtomicRMWKind::maximumf:
1010 return LLVM::AtomicBinOp::fmax;
1011 case arith::AtomicRMWKind::maxs:
1012 return LLVM::AtomicBinOp::max;
1013 case arith::AtomicRMWKind::maxu:
1014 return LLVM::AtomicBinOp::umax;
1015 case arith::AtomicRMWKind::minimumf:
1016 return LLVM::AtomicBinOp::fmin;
1017 case arith::AtomicRMWKind::mins:
1018 return LLVM::AtomicBinOp::min;
1019 case arith::AtomicRMWKind::minu:
1020 return LLVM::AtomicBinOp::umin;
1021 case arith::AtomicRMWKind::ori:
1022 return LLVM::AtomicBinOp::_or;
1023 case arith::AtomicRMWKind::andi:
1024 return LLVM::AtomicBinOp::_and;
1025 default:
1026 return std::nullopt;
1027 }
1028}
1029
1030class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1031 using OpConversionPattern::OpConversionPattern;
1032 LogicalResult
1033 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1034 ConversionPatternRewriter &rewriter) const override {
1035 auto loc = op.getLoc();
1036 auto ctxt = rewriter.getContext();
1037 auto tdesc = op.getTensorDesc().getType();
1038 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1039 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1040 Value basePtrI64 = arith::IndexCastOp::create(
1041 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1042 Value basePtrLLVM =
1043 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1044 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1045 VectorType srcOrDstFlatVecTy = VectorType::get(
1046 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1047 Value srcFlatVec = vector::ShapeCastOp::create(
1048 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1049 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1050 assert(atomicKind.has_value());
1051 Value resVec = srcFlatVec;
1052 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1053 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1054 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1055 rewriter.getIndexAttr(i));
1056 Value currPtr =
1057 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1058 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1059 Value newVal =
1060 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1061 val, LLVM::AtomicOrdering::seq_cst);
1062 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1063 }
1064 rewriter.replaceOp(op, resVec);
1065 return success();
1066 }
1067};
1068
1069//===----------------------------------------------------------------------===//
1070// Pass Definition
1071//===----------------------------------------------------------------------===//
1072
1073struct ConvertXeGPUToXeVMPass
1074 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1075 using Base::Base;
1076
1077 void runOnOperation() override {
1078 MLIRContext *context = &getContext();
1079
1080 // XeVM type converter is based on LLVM type converter with the
1081 // following customizations.
1082 // First, type conversion rules are added for xegpu custom types,
1083 // TensorDescType and MemDescType.
1084 // Second, MemRefType is lowered to single integer type
1085 // Third, VectorType of single element or 0D is converted to vector
1086 // element type. Otherwise, vector type is flatten to 1D.
1087 LowerToLLVMOptions options(context);
1088 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1089 LLVMTypeConverter typeConverter(context, options);
1090
1091 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1092 Type i32Type = IntegerType::get(context, 32);
1093 typeConverter.addConversion([&](VectorType type) -> Type {
1094 auto elemType = typeConverter.convertType(type.getElementType());
1095 // If the vector rank is 0 or has a single element, return the element
1096 unsigned rank = type.getRank();
1097 if (rank == 0 || type.getNumElements() == 1)
1098 return elemType;
1099 // Otherwise, convert the vector to a flat vector type.
1100 int64_t sum = llvm::product_of(type.getShape());
1101 return VectorType::get(sum, elemType);
1102 });
1103 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1104 if (type.getRank() == 1)
1105 return xevmIndexType;
1106 return VectorType::get(8, i32Type);
1107 });
1108 // SLM access related type conversions.
1109 // TODO: LLVM DLTI provides clean way of representing different pointer size
1110 // based on address space. Currently pointer size of SLM access is hard
1111 // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering
1112 // to use DLTI instead of use64bitIndex option used above.
1113
1114 // Convert MemDescType into i32 for SLM
1115 typeConverter.addConversion(
1116 [&](xegpu::MemDescType type) -> Type { return i32Type; });
1117
1118 typeConverter.addConversion([&](MemRefType type) -> Type {
1119 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1120 });
1121
1122 // LLVM type converter puts unrealized casts for the following cases:
1123 // add materialization casts to handle them.
1124
1125 // Materialization to convert memref to i64 or i32 depending on global/SLM
1126 // Applies only to target materialization.
1127 // Note: int type to memref materialization is not required as xegpu ops
1128 // currently do not produce memrefs as result.
1129 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1130 ValueRange inputs,
1131 Location loc) -> Value {
1132 if (inputs.size() != 1)
1133 return {};
1134 auto input = inputs.front();
1135 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1136 unsigned rank = memrefTy.getRank();
1137 Type indexType = builder.getIndexType();
1138
1139 int64_t intOffsets;
1140 SmallVector<int64_t> intStrides;
1141 Value addr;
1142 Value offset;
1143 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1144 ShapedType::isStatic(intOffsets)) {
1145 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1146 input);
1147 offset = arith::ConstantOp::create(builder, loc,
1148 builder.getIndexAttr(intOffsets));
1149 } else {
1150
1151 // Result types: [base_memref, offset, stride0, stride1, ...,
1152 // strideN-1, size0, size1, ..., sizeN-1]
1153 SmallVector<Type> resultTypes{
1154 MemRefType::get({}, memrefTy.getElementType(),
1155 MemRefLayoutAttrInterface(),
1156 memrefTy.getMemorySpace()),
1157 indexType};
1158 // strides + sizes
1159 resultTypes.append(2 * rank, indexType);
1160
1161 auto meta = memref::ExtractStridedMetadataOp::create(
1162 builder, loc, resultTypes, input);
1163
1164 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1165 builder, loc, meta.getBaseBuffer());
1166 offset = meta.getOffset();
1167 }
1168
1169 auto addrCasted =
1170 arith::IndexCastUIOp::create(builder, loc, type, addr);
1171 auto offsetCasted =
1172 arith::IndexCastUIOp::create(builder, loc, type, offset);
1173
1174 // Compute the final address: base address + byte offset
1175 auto byteSize = arith::ConstantOp::create(
1176 builder, loc, type,
1177 builder.getIntegerAttr(type,
1178 memrefTy.getElementTypeBitWidth() / 8));
1179 auto byteOffset =
1180 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1181 auto addrWithOffset =
1182 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1183
1184 return addrWithOffset.getResult();
1185 }
1186 return {};
1187 };
1188
1189 // Materialization to convert ui64 to i64
1190 // Applies only to target materialization.
1191 // Note: i64 to ui64 materialization is not required as xegpu ops
1192 // currently do not produce ui64 as result.
1193 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1194 ValueRange inputs,
1195 Location loc) -> Value {
1196 if (inputs.size() != 1)
1197 return {};
1198 auto input = inputs.front();
1199 if (input.getType() == builder.getIntegerType(64, false)) {
1200 Value cast =
1201 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1202 .getResult();
1203 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1204 .getResult();
1205 }
1206 return {};
1207 };
1208
1209 // Materialization to convert ui32 to i32
1210 // Applies only to target materialization.
1211 // Note: i32 to ui32 materialization is not required as xegpu ops
1212 // currently do not produce ui32 as result.
1213 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1214 ValueRange inputs,
1215 Location loc) -> Value {
1216 if (inputs.size() != 1)
1217 return {};
1218 auto input = inputs.front();
1219 if (input.getType() == builder.getIntegerType(32, false)) {
1220 Value cast =
1221 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1222 .getResult();
1223 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1224 .getResult();
1225 }
1226 return {};
1227 };
1228
1229 // Materialization to convert between vector types
1230 // - Add shape cast for different shapes
1231 // - Add bitcast for different element types
1232 // Applies to both source and target materialization.
1233 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1234 ValueRange inputs,
1235 Location loc) -> Value {
1236 if (inputs.size() != 1)
1237 return {};
1238 auto input = inputs.front();
1239 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1240 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1241 Value cast = input;
1242 // If the target type has a different shape, add a shape cast
1243 // If the target type has a different element type, add a bitcast
1244 if (targetVecTy.getShape() != vecTy.getShape()) {
1245 cast = vector::ShapeCastOp::create(
1246 builder, loc,
1247 VectorType::get(targetVecTy.getShape(),
1248 vecTy.getElementType()),
1249 cast)
1250 .getResult();
1251 }
1252 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1253 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1254 .getResult();
1255 }
1256 return cast;
1257 }
1258 }
1259 return {};
1260 };
1261
1262 // Materialization to convert
1263 // - single element vector to single element of vector element type
1264 // Applies only to target materialization.
1265 auto vectorToSingleElementMaterializationCast =
1266 [](OpBuilder &builder, Type type, ValueRange inputs,
1267 Location loc) -> Value {
1268 if (inputs.size() != 1)
1269 return {};
1270 auto input = inputs.front();
1271 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1272 // Source needs to be single element vector
1273 auto rank = vecTy.getRank();
1274 if (rank != 0 && vecTy.getNumElements() != 1)
1275 return {};
1276 auto inElemTy = vecTy.getElementType();
1277 // extract scalar
1278 Value cast = input;
1279 if (rank == 0) {
1280 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1281 } else {
1282 cast = vector::ExtractOp::create(builder, loc, cast,
1283 SmallVector<int64_t>(rank, 0))
1284 .getResult();
1285 }
1286 // Extracted element type may need conversion
1287 // Two cases
1288 // 1. Index type to integer type
1289 // 2. Other element type mismatch
1290 if (inElemTy.isIndex()) {
1291 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1292 .getResult();
1293 } else if (inElemTy != type) {
1294 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1295 }
1296 return cast;
1297 }
1298 return {};
1299 };
1300
1301 // Materialization to convert
1302 // - single element of vector element type to single element vector
1303 // If result type of original op is single element vector and lowered type
1304 // is scalar. This materialization cast creates a single element vector by
1305 // First convert element type if needed and then broadcast to single
1306 // element vector.
1307 // Applies only to source materialization.
1308 auto singleElementToVectorMaterializationCast =
1309 [](OpBuilder &builder, Type type, ValueRange inputs,
1310 Location loc) -> Value {
1311 if (inputs.size() != 1)
1312 return {};
1313 auto input = inputs.front();
1314 auto inTy = input.getType();
1315 if (!inTy.isIntOrFloat())
1316 return {};
1317 // If the target type is a vector of rank 0 or single element vector
1318 // of element type matching input type, broadcast input to target type.
1319 if (auto vecTy = dyn_cast<VectorType>(type)) {
1320 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1321 return {};
1322 auto outElemTy = vecTy.getElementType();
1323 Value cast = input;
1324 if (outElemTy.isIndex()) {
1325 cast = arith::IndexCastUIOp::create(builder, loc,
1326 builder.getIndexType(), cast)
1327 .getResult();
1328 } else if (inTy != outElemTy) {
1329 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1330 .getResult();
1331 }
1332 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1333 .getResult();
1334 }
1335 return {};
1336 };
1337 typeConverter.addSourceMaterialization(
1338 singleElementToVectorMaterializationCast);
1339 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1340 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1341 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1342 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1343 typeConverter.addTargetMaterialization(
1344 vectorToSingleElementMaterializationCast);
1345 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1346 ConversionTarget target(*context);
1347 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1348 vector::VectorDialect, arith::ArithDialect,
1349 memref::MemRefDialect, gpu::GPUDialect,
1350 index::IndexDialect>();
1351 target.addIllegalDialect<xegpu::XeGPUDialect>();
1352
1353 RewritePatternSet patterns(context);
1354 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1356 patterns, target);
1357 if (failed(applyPartialConversion(getOperation(), target,
1358 std::move(patterns))))
1359 signalPassFailure();
1360 }
1361};
1362} // namespace
1363
1364//===----------------------------------------------------------------------===//
1365// Pattern Population
1366//===----------------------------------------------------------------------===//
1368 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1369 patterns.add<CreateNdDescToXeVMPattern,
1370 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1371 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1372 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1373 typeConverter, patterns.getContext());
1374 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1375 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1376 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1377 typeConverter, patterns.getContext());
1378 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1379 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1380 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1381 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1382 patterns.getContext());
1383}
return success()
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)