MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
33#include "mlir/IR/Types.h"
34
35#include "llvm/ADT/TypeSwitch.h"
36
37#include <numeric>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48// TODO: Below are uArch dependent values, should move away from hardcoding
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
51
52// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
53enum class NdTdescOffset : uint32_t {
54 BasePtr = 0, // Base pointer (i64)
55 BaseShapeW = 2, // Base shape width (i32)
56 BaseShapeH = 3, // Base shape height (i32)
57 BasePitch = 4, // Base pitch (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70/// Checks if the given MemRefType refers to shared memory.
71static bool isSharedMemRef(const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
73 if (!attr)
74 return false;
75 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
77 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
80}
81
82// Get same bitwidth flat vector type of new element type.
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 Type toElemType) {
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
87 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
88 const int size =
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
91}
92
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 // If no hints are provided, use the default cache control.
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
99 // If only one of the hints is provided, use the default for the other level.
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
102 switch (L1hintVal) {
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
108 else
109 llvm_unreachable("Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
115 else
116 llvm_unreachable("Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
122 else
123 llvm_unreachable("Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
126 default:
127 llvm_unreachable("Unsupported cache control.");
128 }
129}
130
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
134 // If no hints are provided, use the default cache control.
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
137 // If only one of the hints is provided, use the default for the other level.
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
140 switch (L1hintVal) {
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
146 else
147 llvm_unreachable("Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
153 else
154 llvm_unreachable("Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
160 else
161 llvm_unreachable("Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
167 else
168 llvm_unreachable("Unsupported cache control.");
169 default:
170 llvm_unreachable("Unsupported cache control.");
171 }
172}
173
174//
175// Note:
176// Block operations for tile of sub byte element types are handled by
177// emulating with larger element types.
178// Tensor descriptor are keep intact and only ops consuming them are
179// emulated
180//
181
182class CreateNdDescToXeVMPattern
183 : public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
185 LogicalResult
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
190 if (mixedOffsets.size() != 0)
191 return rewriter.notifyMatchFailure(op, "Offsets not supported.");
192 auto loc = op.getLoc();
193 auto source = op.getSource();
194 // Op is lowered to a code sequence that populates payload.
195 // Payload is a 8xi32 vector. Offset to individual fields are defined in
196 // NdTdescOffset enum.
197 Type payloadElemTy = rewriter.getI32Type();
198 VectorType payloadTy = VectorType::get(8, payloadElemTy);
199 Type i64Ty = rewriter.getI64Type();
200 // 4xi64 view is used for inserting the base pointer.
201 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
202 // Initialize payload to zero.
203 Value payload = arith::ConstantOp::create(
204 rewriter, loc,
205 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
206
207 Value baseAddr;
208 Value baseShapeW;
209 Value baseShapeH;
210
211 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
212 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
213 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
214 // Descriptor shape is expected to be 2D.
215 int64_t rank = mixedSizes.size();
216 auto sourceTy = source.getType();
217 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
218 // If source is a memref, we need to extract the aligned pointer as index.
219 // Pointer type is passed as i32 or i64 by type converter.
220 if (sourceMemrefTy) {
221 if (!sourceMemrefTy.hasRank()) {
222 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
223 }
224 // Access adaptor after failure check to avoid rolling back generated code
225 // for materialization cast.
226 baseAddr = adaptor.getSource();
227 } else {
228 baseAddr = adaptor.getSource();
229 if (baseAddr.getType() != i64Ty) {
230 // Pointer type may be i32. Cast to i64 if needed.
231 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
232 }
233 }
234 // 1D tensor descriptor is just the base address.
235 if (rank == 1) {
236 rewriter.replaceOp(op, baseAddr);
237 return success();
238 }
239 // Utility for creating offset values from op fold result.
240 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
241 unsigned idx) -> Value {
242 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
243 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
244 return val;
245 };
246 // Get shape values from op fold results.
247 baseShapeW = createOffset(mixedSizes, 1);
248 baseShapeH = createOffset(mixedSizes, 0);
249 // Get pitch value from op fold results.
250 Value basePitch = createOffset(mixedStrides, 0);
251 // Populate payload.
252 Value payLoadAsI64 =
253 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
254 payLoadAsI64 =
255 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
256 static_cast<int>(NdTdescOffset::BasePtr));
257 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
258 payload =
259 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeW));
261 payload =
262 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
263 static_cast<int>(NdTdescOffset::BaseShapeH));
264 payload =
265 vector::InsertOp::create(rewriter, loc, basePitch, payload,
266 static_cast<int>(NdTdescOffset::BasePitch));
267 rewriter.replaceOp(op, payload);
268 return success();
269 }
270};
271
272template <
273 typename OpType,
274 typename = std::enable_if_t<llvm::is_one_of<
275 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
276class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
277 using OpConversionPattern<OpType>::OpConversionPattern;
278 LogicalResult
279 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 auto mixedOffsets = op.getMixedOffsets();
282 int64_t opOffsetsSize = mixedOffsets.size();
283 auto loc = op.getLoc();
284 auto ctxt = rewriter.getContext();
285
286 auto tdesc = adaptor.getTensorDesc();
287 auto tdescTy = op.getTensorDescType();
288 auto tileRank = tdescTy.getRank();
289 if (opOffsetsSize != tileRank)
290 return rewriter.notifyMatchFailure(
291 op, "Expected offset rank to match descriptor rank.");
292 auto elemType = tdescTy.getElementType();
293 auto elemBitSize = elemType.getIntOrFloatBitWidth();
294 bool isSubByte = elemBitSize < 8;
295 uint64_t wScaleFactor = 1;
296
297 if (!isSubByte && (elemBitSize % 8 != 0))
298 return rewriter.notifyMatchFailure(
299 op, "Expected element type bit width to be multiple of 8.");
300 auto tileW = tdescTy.getDimSize(tileRank - 1);
301 // For sub byte types, only 4bits are currently supported.
302 if (isSubByte) {
303 if (elemBitSize != 4)
304 return rewriter.notifyMatchFailure(
305 op, "Only sub byte types of 4bits are supported.");
306 if (tileRank != 2)
307 return rewriter.notifyMatchFailure(
308 op, "Sub byte types are only supported for 2D tensor descriptors.");
309 auto subByteFactor = 8 / elemBitSize;
310 auto tileH = tdescTy.getDimSize(0);
311 // Handle special case for packed load.
312 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
313 if (op.getPacked().value_or(false)) {
314 // packed load is implemented as packed loads of 8bit elements.
315 if (tileH == systolicDepth * 4 &&
316 tileW == executionSize * subByteFactor) {
317 // Usage case for loading as Matrix B with pack request.
318 // source is assumed to pre-packed into 8bit elements
319 // Emulate with 8bit loads with pack request.
320 // scaled_tileW = executionSize
321 elemType = rewriter.getIntegerType(8);
322 tileW = executionSize;
323 wScaleFactor = subByteFactor;
324 }
325 }
326 }
327 // If not handled by packed load case above, handle other cases.
328 if (wScaleFactor == 1) {
329 auto sub16BitFactor = subByteFactor * 2;
330 if (tileW == executionSize * sub16BitFactor) {
331 // Usage case for loading as Matrix A operand
332 // Emulate with 16bit loads/stores.
333 // scaled_tileW = executionSize
334 elemType = rewriter.getIntegerType(16);
335 tileW = executionSize;
336 wScaleFactor = sub16BitFactor;
337 } else {
338 return rewriter.notifyMatchFailure(
339 op, "Unsupported tile shape for sub byte types.");
340 }
341 }
342 // recompute element bit size for emulation.
343 elemBitSize = elemType.getIntOrFloatBitWidth();
344 }
345
346 // Get address space from tensor descriptor memory space.
347 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
348 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
349 if (tileRank == 2) {
350 // Compute element byte size.
351 Value elemByteSize = arith::ConstantIntOp::create(
352 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
353 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
354 Value payLoadAsI64 =
355 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
356 Value basePtr =
357 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
358 static_cast<int>(NdTdescOffset::BasePtr));
359 Value baseShapeW = vector::ExtractOp::create(
360 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
361 Value baseShapeH = vector::ExtractOp::create(
362 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
363 Value basePitch = vector::ExtractOp::create(
364 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
365 // Offsets are provided by the op.
366 // convert them to i32.
367 Value offsetW =
368 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
369 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
370 rewriter.getI32Type(), offsetW);
371 Value offsetH =
372 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
373 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
374 rewriter.getI32Type(), offsetH);
375 // Convert base pointer (i64) to LLVM pointer type.
376 Value basePtrLLVM =
377 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
378 // FIXME: width or pitch is not the same as baseShapeW it should be the
379 // stride of the second to last dimension in row major layout.
380 // Compute width in bytes.
381 Value baseShapeWInBytes =
382 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
383 // Compute pitch in bytes.
384 Value basePitchBytes =
385 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
386
387 if (wScaleFactor > 1) {
388 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
389 // Note: tileW is already scaled above.
390 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
391 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
392 baseShapeWInBytes = arith::ShRSIOp::create(
393 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
394 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
395 wScaleFactorValLog2);
396 offsetW =
397 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
398 }
399 // Get tile height from the tensor descriptor type.
400 auto tileH = tdescTy.getDimSize(0);
401 // Get vblocks from the tensor descriptor type.
402 int32_t vblocks = tdescTy.getArrayLength();
403 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
404 Value src = adaptor.getValue();
405 // If store value is a scalar, get value from op instead of adaptor.
406 // Adaptor might have optimized away single element vector
407 if (src.getType().isIntOrFloat()) {
408 src = op.getValue();
409 }
410 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
411 if (!srcVecTy)
412 return rewriter.notifyMatchFailure(
413 op, "Expected store value to be a vector type.");
414 // Get flat vector type of integer type with matching element bit size.
415 VectorType newSrcVecTy =
416 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
417 if (srcVecTy != newSrcVecTy)
418 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
419 auto storeCacheControl =
420 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 xevm::BlockStore2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
424 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
425 rewriter.eraseOp(op);
426 } else {
427 auto loadCacheControl =
428 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
429 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
430 xevm::BlockPrefetch2dOp::create(
431 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
432 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
433 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
434 rewriter.eraseOp(op);
435 } else {
436 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
437 const bool vnni = op.getPacked().value_or(false);
438 auto transposeValue = op.getTranspose();
439 bool transpose =
440 transposeValue.has_value() && transposeValue.value()[0] == 1;
441 VectorType loadedTy = encodeVectorTypeTo(
442 dstVecTy, vnni ? rewriter.getI32Type()
443 : rewriter.getIntegerType(elemBitSize));
444
445 Value resultFlatVec = xevm::BlockLoad2dOp::create(
446 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
447 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
448 tileH, vblocks, transpose, vnni,
449 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
450 resultFlatVec = vector::BitCastOp::create(
451 rewriter, loc,
452 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
453 resultFlatVec);
454 rewriter.replaceOp(op, resultFlatVec);
455 }
456 }
457 } else {
458 // 1D tensor descriptor.
459 // `tdesc` represents base address as i64
460 // Offset in number of elements, need to multiply by element byte size.
461 // Compute byte offset.
462 // byteOffset = offset * elementByteSize
463 Value offset =
464 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
465 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
466 rewriter.getI64Type(), offset);
467 // Compute element byte size.
468 Value elemByteSize = arith::ConstantIntOp::create(
469 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
470 Value byteOffset =
471 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
472 // Final address = basePtr + byteOffset
473 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
474 loc, tdesc,
475 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
476 byteOffset));
477 // Convert base pointer (i64) to LLVM pointer type.
478 Value finalPtrLLVM =
479 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
480 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
481 Value src = adaptor.getValue();
482 // If store value is a scalar, get value from op instead of adaptor.
483 // Adaptor might have optimized away single element vector
484 if (src.getType().isIntOrFloat()) {
485 src = op.getValue();
486 }
487 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
488 if (!srcVecTy)
489 return rewriter.notifyMatchFailure(
490 op, "Expected store value to be a vector type.");
491 // Get flat vector type of integer type with matching element bit size.
492 VectorType newSrcVecTy =
493 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
494 if (srcVecTy != newSrcVecTy)
495 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
496 auto storeCacheControl =
497 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
498 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
499 op, finalPtrLLVM, src,
500 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
501 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
502 auto loadCacheControl =
503 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
504 VectorType resTy = cast<VectorType>(op.getValue().getType());
505 VectorType loadedTy =
506 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
507 Value load = xevm::BlockLoadOp::create(
508 rewriter, loc, loadedTy, finalPtrLLVM,
509 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
510 if (loadedTy != resTy)
511 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
512 rewriter.replaceOp(op, load);
513 } else {
514 return rewriter.notifyMatchFailure(
515 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
516 "descriptor rank == 1");
517 }
518 }
519 return success();
520 }
521};
522
523// Add a builder that creates
524// offset * elemByteSize + baseAddr
525static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
526 Location loc, Value baseAddr, Value offset,
527 int64_t elemByteSize) {
529 rewriter, loc, baseAddr.getType(), elemByteSize);
530 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
531 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
532 return newAddr;
533}
534
535template <typename OpType,
536 typename = std::enable_if_t<llvm::is_one_of<
537 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
538class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
539 using OpConversionPattern<OpType>::OpConversionPattern;
540 LogicalResult
541 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
542 ConversionPatternRewriter &rewriter) const override {
543 Value offset = adaptor.getOffsets();
544 if (!offset)
545 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
546 auto loc = op.getLoc();
547 auto ctxt = rewriter.getContext();
548 auto tdescTy = op.getTensorDescType();
549 Value basePtrI64;
550 // Load result or Store valye Type can be vector or scalar.
551 Type valOrResTy;
552 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
553 valOrResTy =
554 this->getTypeConverter()->convertType(op.getResult().getType());
555 else
556 valOrResTy = adaptor.getValue().getType();
557 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
558 bool hasScalarVal = !valOrResVecTy;
559 int64_t elemBitWidth =
560 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
561 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
562 // Element type must be multiple of 8 bits.
563 if (elemBitWidth % 8 != 0)
564 return rewriter.notifyMatchFailure(
565 op, "Expected element type bit width to be multiple of 8.");
566 int64_t elemByteSize = elemBitWidth / 8;
567 // Default memory space is global.
568 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
569 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
570 // If tensor descriptor is available, we use its memory space.
571 if (tdescTy)
572 ptrTypeLLVM = LLVM::LLVMPointerType::get(
573 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
574 // Base pointer can come from source (load) or dest (store).
575 // If they are memrefs, we use their memory space.
576 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
577 basePtrI64 = adaptor.getSource();
578 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
579 auto addrSpace = memRefTy.getMemorySpaceAsInt();
580 if (addrSpace != 0)
581 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
582 }
583 } else {
584 basePtrI64 = adaptor.getDest();
585 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
586 auto addrSpace = memRefTy.getMemorySpaceAsInt();
587 if (addrSpace != 0)
588 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
589 }
590 }
591 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
592 if (basePtrI64.getType() != rewriter.getI64Type()) {
593 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
594 basePtrI64);
595 }
596 Value mask = adaptor.getMask();
597 if (dyn_cast<VectorType>(offset.getType())) {
598 // Offset needs be scalar. Single element vector is converted to scalar
599 // by type converter.
600 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
601 } else {
602 // If offset is provided, we add them to the base pointer.
603 // Offset is in number of elements, we need to multiply by
604 // element byte size.
605 basePtrI64 =
606 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
607 }
608 // Convert base pointer (i64) to LLVM pointer type.
609 Value basePtrLLVM =
610 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
611
612 Value maskForLane;
613 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
614 if (maskVecTy) {
615 // Mask needs be scalar. Single element vector is converted to scalar by
616 // type converter.
617 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
618 } else
619 maskForLane = mask;
620 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
621 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
622 maskForLane, true, true);
623 // If mask is true,- then clause - load from memory and yield.
624 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
625 if (!hasScalarVal)
626 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
627 valOrResVecTy.getElementType());
628 Value loaded =
629 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
630 // Set cache control attribute on the load operation.
631 loaded.getDefiningOp()->setAttr(
632 "cache_control", xevm::LoadCacheControlAttr::get(
633 ctxt, translateLoadXeGPUCacheHint(
634 op.getL1Hint(), op.getL3Hint())));
635 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
636 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
637 // If mask is false - else clause -yield a vector of zeros.
638 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
639 TypedAttr eVal;
640 if (eTy.isFloat())
641 eVal = FloatAttr::get(eTy, 0.0);
642 else
643 eVal = IntegerAttr::get(eTy, 0);
644 if (hasScalarVal)
645 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
646 else
647 loaded = arith::ConstantOp::create(
648 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
649 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
650 rewriter.replaceOp(op, ifOp.getResult(0));
651 } else {
652 // If mask is true, perform the store.
653 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
654 auto body = ifOp.getBody();
655 rewriter.setInsertionPointToStart(body);
656 auto storeOp =
657 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
658 // Set cache control attribute on the store operation.
659 storeOp.getOperation()->setAttr(
660 "cache_control", xevm::StoreCacheControlAttr::get(
661 ctxt, translateStoreXeGPUCacheHint(
662 op.getL1Hint(), op.getL3Hint())));
663 rewriter.eraseOp(op);
664 }
665 return success();
666 }
667};
668
669class CreateMemDescOpPattern final
670 : public OpConversionPattern<xegpu::CreateMemDescOp> {
671public:
672 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
673 LogicalResult
674 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter) const override {
676
677 rewriter.replaceOp(op, adaptor.getSource());
678 return success();
679 }
680};
681
682template <typename OpType,
683 typename = std::enable_if_t<llvm::is_one_of<
684 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
685class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
686 using OpConversionPattern<OpType>::OpConversionPattern;
687 LogicalResult
688 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
689 ConversionPatternRewriter &rewriter) const override {
690
691 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
692 if (offsets.empty())
693 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
694
695 auto loc = op.getLoc();
696 auto ctxt = rewriter.getContext();
697 Value baseAddr32 = adaptor.getMemDesc();
698 Value mdescVal = op.getMemDesc();
699 // Load result or Store value Type can be vector or scalar.
700 Type dataTy;
701 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
702 Type resType = op.getResult().getType();
703 // Some transforms may leave unit dimension in the 2D vector, adaptors do
704 // not catch it for results.
705 if (auto vecType = dyn_cast<VectorType>(resType)) {
706 assert(llvm::count_if(vecType.getShape(),
707 [](int64_t d) { return d != 1; }) <= 1 &&
708 "Expected either 1D vector or nD with unit dimensions");
709 resType = VectorType::get({vecType.getNumElements()},
710 vecType.getElementType());
711 }
712 dataTy = resType;
713 } else
714 dataTy = adaptor.getData().getType();
715 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
716 if (!valOrResVecTy)
717 valOrResVecTy = VectorType::get(1, dataTy);
718
719 int64_t elemBitWidth =
720 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
721 // Element type must be multiple of 8 bits.
722 if (elemBitWidth % 8 != 0)
723 return rewriter.notifyMatchFailure(
724 op, "Expected element type bit width to be multiple of 8.");
725 int64_t elemByteSize = elemBitWidth / 8;
726
727 // Default memory space is SLM.
728 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
729 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
730
731 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
732
733 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
734 linearOffset = arith::IndexCastUIOp::create(
735 rewriter, loc, rewriter.getI32Type(), linearOffset);
736 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
737 linearOffset, elemByteSize);
738
739 // convert base pointer (i32) to LLVM pointer type
740 Value basePtrLLVM =
741 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
742
743 if (op.getSubgroupBlockIoAttr()) {
744 // if the attribute 'subgroup_block_io' is set to true, it lowers to
745 // xevm.blockload
746
747 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
748 VectorType intVecTy =
749 VectorType::get(valOrResVecTy.getShape(), intElemTy);
750
751 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
752 Value loadOp =
753 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
754 if (intVecTy != valOrResVecTy) {
755 loadOp =
756 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
757 }
758 rewriter.replaceOp(op, loadOp);
759 } else {
760 Value dataToStore = adaptor.getData();
761 if (valOrResVecTy != intVecTy) {
762 dataToStore =
763 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
764 }
765 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
766 nullptr);
767 rewriter.eraseOp(op);
768 }
769 return success();
770 }
771
772 if (valOrResVecTy.getNumElements() >= 1) {
773 auto chipOpt = xegpu::getChipStr(op);
774 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
775 // the lowering for chunk load only works for pvc and bmg
776 return rewriter.notifyMatchFailure(
777 op, "The lowering is specific to pvc or bmg.");
778 }
779 }
780
781 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
782 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
783 // operation. LLVM load/store does not support vector of size 1, so we
784 // need to handle this case separately.
785 auto scalarTy = valOrResVecTy.getElementType();
786 LLVM::LoadOp loadOp;
787 if (valOrResVecTy.getNumElements() == 1)
788 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
789 else
790 loadOp =
791 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
792 rewriter.replaceOp(op, loadOp);
793 } else {
794 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
795 rewriter.eraseOp(op);
796 }
797 return success();
798 }
799};
800
801class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
802 using OpConversionPattern::OpConversionPattern;
803 LogicalResult
804 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
805 ConversionPatternRewriter &rewriter) const override {
806 auto loc = op.getLoc();
807 auto ctxt = rewriter.getContext();
808 auto tdescTy = op.getTensorDescType();
809 Value basePtrI64 = adaptor.getSource();
810 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
811 if (basePtrI64.getType() != rewriter.getI64Type())
812 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
813 basePtrI64);
814 Value offsets = adaptor.getOffsets();
815 if (offsets) {
816 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
817 if (offsetsVecTy) {
818 // Offset needs be scalar.
819 return rewriter.notifyMatchFailure(op,
820 "Expected offsets to be a scalar.");
821 } else {
822 int64_t elemBitWidth{0};
823 int64_t elemByteSize;
824 // Element byte size can come from three sources:
825 if (tdescTy) {
826 // If tensor descriptor is available, we use its element type to
827 // determine element byte size.
828 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
829 } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
830 // If memref is available, we use its element type to
831 // determine element byte size.
832 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
833 } else {
834 // Otherwise, we use the provided offset byte alignment.
835 elemByteSize = *op.getOffsetAlignByte();
836 }
837 if (elemBitWidth != 0) {
838 if (elemBitWidth % 8 != 0)
839 return rewriter.notifyMatchFailure(
840 op, "Expected element type bit width to be multiple of 8.");
841 elemByteSize = elemBitWidth / 8;
842 }
843 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
844 elemByteSize);
845 }
846 }
847 // Default memory space is global.
848 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
849 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
850 // If tensor descriptor is available, we use its memory space.
851 if (tdescTy)
852 ptrTypeLLVM = LLVM::LLVMPointerType::get(
853 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
854 // If source is a memref, we use its memory space.
855 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
856 auto addrSpace = memRefTy.getMemorySpaceAsInt();
857 if (addrSpace != 0)
858 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
859 }
860 // Convert base pointer (i64) to LLVM pointer type.
861 Value ptrLLVM =
862 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
863 // Create the prefetch op with cache control attribute.
864 xevm::PrefetchOp::create(
865 rewriter, loc, ptrLLVM,
866 xevm::LoadCacheControlAttr::get(
867 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
868 rewriter.eraseOp(op);
869 return success();
870 }
871};
872
873class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
874 using OpConversionPattern::OpConversionPattern;
875 LogicalResult
876 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
877 ConversionPatternRewriter &rewriter) const override {
878 auto loc = op.getLoc();
879 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
880 switch (op.getFenceScope()) {
881 case xegpu::FenceScope::Workgroup:
882 memScope = xevm::MemScope::WORKGROUP;
883 break;
884 case xegpu::FenceScope::GPU:
885 memScope = xevm::MemScope::DEVICE;
886 break;
887 }
888 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
889 switch (op.getMemoryKind()) {
890 case xegpu::MemorySpace::Global:
891 addrSpace = xevm::AddrSpace::GLOBAL;
892 break;
893 case xegpu::MemorySpace::SLM:
894 addrSpace = xevm::AddrSpace::SHARED;
895 break;
896 }
897 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
898 rewriter.eraseOp(op);
899 return success();
900 }
901};
902
903class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
904 using OpConversionPattern::OpConversionPattern;
905 LogicalResult
906 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
907 ConversionPatternRewriter &rewriter) const override {
908 auto loc = op.getLoc();
909 auto ctxt = rewriter.getContext();
910 auto aTy = cast<VectorType>(op.getLhs().getType());
911 auto bTy = cast<VectorType>(op.getRhs().getType());
912 auto resultType = cast<VectorType>(op.getResultType());
913
914 // get the correct dpasInst by getting info from chip
915 auto chipStr = xegpu::getChipStr(op);
916 if (!chipStr)
917 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
918
919 const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
920 if (!uArch)
921 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
922
923 auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
924 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
925 uArch->getInstruction(
926 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
927 if (!dpasInst)
928 return rewriter.notifyMatchFailure(op,
929 "DPAS not supported by target uArch");
930
931 auto checkSupportedTypes = [&](VectorType vecTy,
932 xegpu::uArch::MMAOpndKind kind) -> bool {
933 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
934 return llvm::find(supported, vecTy.getElementType()) != supported.end();
935 };
936
937 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
938 return rewriter.notifyMatchFailure(
939 op, "A-matrix element type not supported by target uArch");
940 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
941 return rewriter.notifyMatchFailure(
942 op, "B-matrix element type not supported by target uArch");
943 // NOTE: Supported types for MatrixC and MatrixD are identical
944 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
945 return rewriter.notifyMatchFailure(
946 op, "result/accumulator element type not supported by target uArch");
947
948 auto encodePrecision = [&](Type type) -> xevm::ElemType {
949 if (type == rewriter.getBF16Type())
950 return xevm::ElemType::BF16;
951 else if (type == rewriter.getF16Type())
952 return xevm::ElemType::F16;
953 else if (type == rewriter.getTF32Type())
954 return xevm::ElemType::TF32;
955 else if (type.isInteger(8)) {
956 if (type.isUnsignedInteger())
957 return xevm::ElemType::U8;
958 return xevm::ElemType::S8;
959 } else if (type == rewriter.getF32Type())
960 return xevm::ElemType::F32;
961 else if (type.isInteger(32))
962 return xevm::ElemType::S32;
963 llvm_unreachable("add more support for ElemType");
964 };
965 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
966 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
967 Value c = op.getAcc();
968 if (!c) {
969 auto elementTy = resultType.getElementType();
970 Attribute initValueAttr;
971 if (isa<FloatType>(elementTy))
972 initValueAttr = FloatAttr::get(elementTy, 0.0);
973 else
974 initValueAttr = IntegerAttr::get(elementTy, 0);
975 c = arith::ConstantOp::create(
976 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
977 }
978
979 Value aVec = op.getLhs();
980 Value bVec = op.getRhs();
981 auto cvecty = cast<VectorType>(c.getType());
982 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
983 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
984 VectorType cNty =
985 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
986 if (cvecty != cNty)
987 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
988 Value dpasRes = xevm::MMAOp::create(
989 rewriter, loc, cNty, aVec, bVec, c,
990 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
991 systolicDepth *
992 getNumOperandsPerDword(precATy)),
993 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
994 if (cvecty != cNty)
995 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
996 rewriter.replaceOp(op, dpasRes);
997 return success();
998 }
999
1000private:
1001 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
1002 switch (pTy) {
1003 case xevm::ElemType::TF32:
1004 return 1;
1005 case xevm::ElemType::BF16:
1006 case xevm::ElemType::F16:
1007 return 2;
1008 case xevm::ElemType::U8:
1009 case xevm::ElemType::S8:
1010 return 4;
1011 default:
1012 llvm_unreachable("unsupported xevm::ElemType");
1013 }
1014 }
1015};
1016
1017static std::optional<LLVM::AtomicBinOp>
1018matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1019 switch (arithKind) {
1020 case arith::AtomicRMWKind::addf:
1021 return LLVM::AtomicBinOp::fadd;
1022 case arith::AtomicRMWKind::addi:
1023 return LLVM::AtomicBinOp::add;
1024 case arith::AtomicRMWKind::assign:
1025 return LLVM::AtomicBinOp::xchg;
1026 case arith::AtomicRMWKind::maximumf:
1027 return LLVM::AtomicBinOp::fmax;
1028 case arith::AtomicRMWKind::maxs:
1029 return LLVM::AtomicBinOp::max;
1030 case arith::AtomicRMWKind::maxu:
1031 return LLVM::AtomicBinOp::umax;
1032 case arith::AtomicRMWKind::minimumf:
1033 return LLVM::AtomicBinOp::fmin;
1034 case arith::AtomicRMWKind::mins:
1035 return LLVM::AtomicBinOp::min;
1036 case arith::AtomicRMWKind::minu:
1037 return LLVM::AtomicBinOp::umin;
1038 case arith::AtomicRMWKind::ori:
1039 return LLVM::AtomicBinOp::_or;
1040 case arith::AtomicRMWKind::andi:
1041 return LLVM::AtomicBinOp::_and;
1042 default:
1043 return std::nullopt;
1044 }
1045}
1046
1047class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1048 using OpConversionPattern::OpConversionPattern;
1049 LogicalResult
1050 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1051 ConversionPatternRewriter &rewriter) const override {
1052 auto loc = op.getLoc();
1053 auto ctxt = rewriter.getContext();
1054 auto tdesc = op.getTensorDesc().getType();
1055 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1056 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1057 Value basePtrI64 = arith::IndexCastOp::create(
1058 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1059 Value basePtrLLVM =
1060 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1061 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1062 VectorType srcOrDstFlatVecTy = VectorType::get(
1063 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1064 Value srcFlatVec = vector::ShapeCastOp::create(
1065 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1066 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1067 assert(atomicKind.has_value());
1068 Value resVec = srcFlatVec;
1069 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1070 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1071 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1072 rewriter.getIndexAttr(i));
1073 Value currPtr =
1074 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1075 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1076 Value newVal =
1077 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1078 val, LLVM::AtomicOrdering::seq_cst);
1079 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1080 }
1081 rewriter.replaceOp(op, resVec);
1082 return success();
1083 }
1084};
1085
1086//===----------------------------------------------------------------------===//
1087// Pass Definition
1088//===----------------------------------------------------------------------===//
1089
1090struct ConvertXeGPUToXeVMPass
1091 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1092 using Base::Base;
1093
1094 void runOnOperation() override {
1095 MLIRContext *context = &getContext();
1096
1097 // XeVM type converter is based on LLVM type converter with the
1098 // following customizations.
1099 // First, type conversion rules are added for xegpu custom types,
1100 // TensorDescType and MemDescType.
1101 // Second, MemRefType is lowered to single integer type
1102 // Third, VectorType of single element or 0D is converted to vector
1103 // element type. Otherwise, vector type is flatten to 1D.
1104 LowerToLLVMOptions options(context);
1105 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1106 LLVMTypeConverter typeConverter(context, options);
1107
1108 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1109 Type i32Type = IntegerType::get(context, 32);
1110 typeConverter.addConversion([&](VectorType type) -> Type {
1111 auto elemType = typeConverter.convertType(type.getElementType());
1112 // If the vector rank is 0 or has a single element, return the element
1113 unsigned rank = type.getRank();
1114 if (rank == 0 || type.getNumElements() == 1)
1115 return elemType;
1116 // Otherwise, convert the vector to a flat vector type.
1117 int64_t sum = llvm::product_of(type.getShape());
1118 return VectorType::get(sum, elemType);
1119 });
1120 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1121 // Scattered descriptors are not supported in XeVM lowering.
1122 if (type.isScattered())
1123 return {};
1124 if (type.getRank() == 1)
1125 return xevmIndexType;
1126 return VectorType::get(8, i32Type);
1127 });
1128 // SLM access related type conversions.
1129 // TODO: LLVM DLTI provides clean way of representing different pointer size
1130 // based on address space. Currently pointer size of SLM access is hard
1131 // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering
1132 // to use DLTI instead of use64bitIndex option used above.
1133
1134 // Convert MemDescType into i32 for SLM
1135 typeConverter.addConversion(
1136 [&](xegpu::MemDescType type) -> Type { return i32Type; });
1137
1138 typeConverter.addConversion([&](MemRefType type) -> Type {
1139 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1140 });
1141
1142 // LLVM type converter puts unrealized casts for the following cases:
1143 // add materialization casts to handle them.
1144
1145 // Materialization to convert memref to i64 or i32 depending on global/SLM
1146 // Applies only to target materialization.
1147 // Note: int type to memref materialization is not required as xegpu ops
1148 // currently do not produce memrefs as result.
1149 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1150 ValueRange inputs,
1151 Location loc) -> Value {
1152 if (inputs.size() != 1)
1153 return {};
1154 auto input = inputs.front();
1155 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1156 unsigned rank = memrefTy.getRank();
1157 Type indexType = builder.getIndexType();
1158
1159 int64_t intOffsets;
1160 SmallVector<int64_t> intStrides;
1161 Value addr;
1162 Value offset;
1163 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1164 ShapedType::isStatic(intOffsets)) {
1165 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1166 input);
1167 offset = arith::ConstantOp::create(builder, loc,
1168 builder.getIndexAttr(intOffsets));
1169 } else {
1170
1171 // Result types: [base_memref, offset, stride0, stride1, ...,
1172 // strideN-1, size0, size1, ..., sizeN-1]
1173 SmallVector<Type> resultTypes{
1174 MemRefType::get({}, memrefTy.getElementType(),
1175 MemRefLayoutAttrInterface(),
1176 memrefTy.getMemorySpace()),
1177 indexType};
1178 // strides + sizes
1179 resultTypes.append(2 * rank, indexType);
1180
1181 auto meta = memref::ExtractStridedMetadataOp::create(
1182 builder, loc, resultTypes, input);
1183
1184 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1185 builder, loc, meta.getBaseBuffer());
1186 offset = meta.getOffset();
1187 }
1188
1189 auto addrCasted =
1190 arith::IndexCastUIOp::create(builder, loc, type, addr);
1191 auto offsetCasted =
1192 arith::IndexCastUIOp::create(builder, loc, type, offset);
1193
1194 // Compute the final address: base address + byte offset
1195 auto byteSize = arith::ConstantOp::create(
1196 builder, loc, type,
1197 builder.getIntegerAttr(type,
1198 memrefTy.getElementTypeBitWidth() / 8));
1199 auto byteOffset =
1200 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1201 auto addrWithOffset =
1202 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1203
1204 return addrWithOffset.getResult();
1205 }
1206 return {};
1207 };
1208
1209 // Materialization to convert ui64 to i64
1210 // Applies only to target materialization.
1211 // Note: i64 to ui64 materialization is not required as xegpu ops
1212 // currently do not produce ui64 as result.
1213 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1214 ValueRange inputs,
1215 Location loc) -> Value {
1216 if (inputs.size() != 1)
1217 return {};
1218 auto input = inputs.front();
1219 if (input.getType() == builder.getIntegerType(64, false)) {
1220 Value cast =
1221 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1222 .getResult();
1223 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1224 .getResult();
1225 }
1226 return {};
1227 };
1228
1229 // Materialization to convert ui32 to i32
1230 // Applies only to target materialization.
1231 // Note: i32 to ui32 materialization is not required as xegpu ops
1232 // currently do not produce ui32 as result.
1233 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1234 ValueRange inputs,
1235 Location loc) -> Value {
1236 if (inputs.size() != 1)
1237 return {};
1238 auto input = inputs.front();
1239 if (input.getType() == builder.getIntegerType(32, false)) {
1240 Value cast =
1241 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1242 .getResult();
1243 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1244 .getResult();
1245 }
1246 return {};
1247 };
1248
1249 // Materialization to convert between vector types
1250 // - Add shape cast for different shapes
1251 // - Add bitcast for different element types
1252 // Applies to both source and target materialization.
1253 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1254 ValueRange inputs,
1255 Location loc) -> Value {
1256 if (inputs.size() != 1)
1257 return {};
1258 auto input = inputs.front();
1259 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1260 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1261 Value cast = input;
1262 // If the target type has a different shape, add a shape cast
1263 // If the target type has a different element type, add a bitcast
1264 if (targetVecTy.getShape() != vecTy.getShape()) {
1265 cast = vector::ShapeCastOp::create(
1266 builder, loc,
1267 VectorType::get(targetVecTy.getShape(),
1268 vecTy.getElementType()),
1269 cast)
1270 .getResult();
1271 }
1272 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1273 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1274 .getResult();
1275 }
1276 return cast;
1277 }
1278 }
1279 return {};
1280 };
1281
1282 // Materialization to convert
1283 // - single element vector to single element of vector element type
1284 // Applies only to target materialization.
1285 auto vectorToSingleElementMaterializationCast =
1286 [](OpBuilder &builder, Type type, ValueRange inputs,
1287 Location loc) -> Value {
1288 if (inputs.size() != 1)
1289 return {};
1290 auto input = inputs.front();
1291 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1292 // Source needs to be single element vector
1293 auto rank = vecTy.getRank();
1294 if (rank != 0 && vecTy.getNumElements() != 1)
1295 return {};
1296 auto inElemTy = vecTy.getElementType();
1297 // extract scalar
1298 Value cast = input;
1299 if (rank == 0) {
1300 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1301 } else {
1302 cast = vector::ExtractOp::create(builder, loc, cast,
1303 SmallVector<int64_t>(rank, 0))
1304 .getResult();
1305 }
1306 // Extracted element type may need conversion
1307 // Two cases
1308 // 1. Index type to integer type
1309 // 2. Other element type mismatch
1310 if (inElemTy.isIndex()) {
1311 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1312 .getResult();
1313 } else if (inElemTy != type) {
1314 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1315 }
1316 return cast;
1317 }
1318 return {};
1319 };
1320
1321 // Materialization to convert
1322 // - single element of vector element type to single element vector
1323 // If result type of original op is single element vector and lowered type
1324 // is scalar. This materialization cast creates a single element vector by
1325 // First convert element type if needed and then broadcast to single
1326 // element vector.
1327 // Applies only to source materialization.
1328 auto singleElementToVectorMaterializationCast =
1329 [](OpBuilder &builder, Type type, ValueRange inputs,
1330 Location loc) -> Value {
1331 if (inputs.size() != 1)
1332 return {};
1333 auto input = inputs.front();
1334 auto inTy = input.getType();
1335 if (!inTy.isIntOrFloat())
1336 return {};
1337 // If the target type is a vector of rank 0 or single element vector
1338 // of element type matching input type, broadcast input to target type.
1339 if (auto vecTy = dyn_cast<VectorType>(type)) {
1340 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1341 return {};
1342 auto outElemTy = vecTy.getElementType();
1343 Value cast = input;
1344 if (outElemTy.isIndex()) {
1345 cast = arith::IndexCastUIOp::create(builder, loc,
1346 builder.getIndexType(), cast)
1347 .getResult();
1348 } else if (inTy != outElemTy) {
1349 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1350 .getResult();
1351 }
1352 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1353 .getResult();
1354 }
1355 return {};
1356 };
1357 typeConverter.addSourceMaterialization(
1358 singleElementToVectorMaterializationCast);
1359 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1360 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1361 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1362 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1363 typeConverter.addTargetMaterialization(
1364 vectorToSingleElementMaterializationCast);
1365 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1366 ConversionTarget target(*context);
1367 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1368 vector::VectorDialect, arith::ArithDialect,
1369 memref::MemRefDialect, gpu::GPUDialect,
1370 index::IndexDialect>();
1371 target.addIllegalDialect<xegpu::XeGPUDialect>();
1372
1373 RewritePatternSet patterns(context);
1374 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1376 patterns, target);
1377 if (failed(applyPartialConversion(getOperation(), target,
1378 std::move(patterns))))
1379 signalPassFailure();
1380 }
1381};
1382} // namespace
1383
1384//===----------------------------------------------------------------------===//
1385// Pattern Population
1386//===----------------------------------------------------------------------===//
1388 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1389 patterns.add<CreateNdDescToXeVMPattern,
1390 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1391 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1392 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1393 typeConverter, patterns.getContext());
1394 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1395 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1396 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1397 typeConverter, patterns.getContext());
1398 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1399 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1400 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1401 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1402 patterns.getContext());
1403}
return success()
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)