MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
33#include "mlir/IR/Types.h"
34
35#include "llvm/ADT/TypeSwitch.h"
36
37#include <numeric>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48// TODO: Below are uArch dependent values, should move away from hardcoding
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
51
52// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
53enum class NdTdescOffset : uint32_t {
54 BasePtr = 0, // Base pointer (i64)
55 BaseShapeW = 2, // Base shape width (i32)
56 BaseShapeH = 3, // Base shape height (i32)
57 BasePitch = 4, // Base pitch (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70/// Checks if the given MemRefType refers to shared memory.
71static bool isSharedMemRef(const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
73 if (!attr)
74 return false;
75 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
77 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
80}
81
82// Get same bitwidth flat vector type of new element type.
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 Type toElemType) {
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
87 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
88 const int size =
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
91}
92
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 // If no hints are provided, use the default cache control.
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
99 // If only one of the hints is provided, use the default for the other level.
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
102 switch (L1hintVal) {
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
108 else
109 llvm_unreachable("Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
115 else
116 llvm_unreachable("Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
122 else
123 llvm_unreachable("Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
126 default:
127 llvm_unreachable("Unsupported cache control.");
128 }
129}
130
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
134 // If no hints are provided, use the default cache control.
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
137 // If only one of the hints is provided, use the default for the other level.
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
140 switch (L1hintVal) {
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
146 else
147 llvm_unreachable("Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
153 else
154 llvm_unreachable("Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
160 else
161 llvm_unreachable("Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
167 else
168 llvm_unreachable("Unsupported cache control.");
169 default:
170 llvm_unreachable("Unsupported cache control.");
171 }
172}
173
174//
175// Note:
176// Block operations for tile of sub byte element types are handled by
177// emulating with larger element types.
178// Tensor descriptor are keep intact and only ops consuming them are
179// emulated
180//
181
182class CreateNdDescToXeVMPattern
183 : public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
185 LogicalResult
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
191 // Op is lowered to a code sequence that populates payload.
192 // Payload is a 8xi32 vector. Offset to individual fields are defined in
193 // NdTdescOffset enum.
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
197 // 4xi64 view is used for inserting the base pointer.
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
199 // Initialize payload to zero.
200 Value payload = arith::ConstantOp::create(
201 rewriter, loc,
202 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
203
204 Value baseAddr;
205 Value baseShapeW;
206 Value baseShapeH;
207
208 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
211 // Descriptor shape is expected to be 2D.
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
215 // If source is a memref, we need to extract the aligned pointer as index.
216 // Pointer type is passed as i32 or i64 by type converter.
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
220 }
221 // Access adaptor after failure check to avoid rolling back generated code
222 // for materialization cast.
223 baseAddr = adaptor.getSource();
224 } else {
225 baseAddr = adaptor.getSource();
226 if (baseAddr.getType() != i64Ty) {
227 // Pointer type may be i32. Cast to i64 if needed.
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
229 }
230 }
231 // 1D tensor descriptor is just the base address.
232 if (rank == 1) {
233 rewriter.replaceOp(op, baseAddr);
234 return success();
235 }
236 // Utility for creating offset values from op fold result.
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
239 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
240 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
241 return val;
242 };
243 // Get shape values from op fold results.
244 baseShapeW = createOffset(mixedSizes, 1);
245 baseShapeH = createOffset(mixedSizes, 0);
246 // Get pitch value from op fold results.
247 Value basePitch = createOffset(mixedStrides, 0);
248 // Populate payload.
249 Value payLoadAsI64 =
250 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
251 payLoadAsI64 =
252 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
253 static_cast<int>(NdTdescOffset::BasePtr));
254 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
255 payload =
256 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
257 static_cast<int>(NdTdescOffset::BaseShapeW));
258 payload =
259 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeH));
261 payload =
262 vector::InsertOp::create(rewriter, loc, basePitch, payload,
263 static_cast<int>(NdTdescOffset::BasePitch));
264 rewriter.replaceOp(op, payload);
265 return success();
266 }
267};
268
269template <
270 typename OpType,
271 typename = std::enable_if_t<llvm::is_one_of<
272 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
273class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
274 using OpConversionPattern<OpType>::OpConversionPattern;
275 LogicalResult
276 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 auto mixedOffsets = op.getMixedOffsets();
279 int64_t opOffsetsSize = mixedOffsets.size();
280 auto loc = op.getLoc();
281 auto ctxt = rewriter.getContext();
282
283 auto tdesc = adaptor.getTensorDesc();
284 auto tdescTy = op.getTensorDescType();
285 auto tileRank = tdescTy.getRank();
286 if (opOffsetsSize != tileRank)
287 return rewriter.notifyMatchFailure(
288 op, "Expected offset rank to match descriptor rank.");
289 auto elemType = tdescTy.getElementType();
290 auto elemBitSize = elemType.getIntOrFloatBitWidth();
291 bool isSubByte = elemBitSize < 8;
292 uint64_t wScaleFactor = 1;
293
294 if (!isSubByte && (elemBitSize % 8 != 0))
295 return rewriter.notifyMatchFailure(
296 op, "Expected element type bit width to be multiple of 8.");
297 auto tileW = tdescTy.getDimSize(tileRank - 1);
298 // For sub byte types, only 4bits are currently supported.
299 if (isSubByte) {
300 if (elemBitSize != 4)
301 return rewriter.notifyMatchFailure(
302 op, "Only sub byte types of 4bits are supported.");
303 if (tileRank != 2)
304 return rewriter.notifyMatchFailure(
305 op, "Sub byte types are only supported for 2D tensor descriptors.");
306 auto subByteFactor = 8 / elemBitSize;
307 auto tileH = tdescTy.getDimSize(0);
308 // Handle special case for packed load.
309 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
310 if (op.getPacked().value_or(false)) {
311 // packed load is implemented as packed loads of 8bit elements.
312 if (tileH == systolicDepth * 4 &&
313 tileW == executionSize * subByteFactor) {
314 // Usage case for loading as Matrix B with pack request.
315 // source is assumed to pre-packed into 8bit elements
316 // Emulate with 8bit loads with pack request.
317 // scaled_tileW = executionSize
318 elemType = rewriter.getIntegerType(8);
319 tileW = executionSize;
320 wScaleFactor = subByteFactor;
321 }
322 }
323 }
324 // If not handled by packed load case above, handle other cases.
325 if (wScaleFactor == 1) {
326 auto sub16BitFactor = subByteFactor * 2;
327 if (tileW == executionSize * sub16BitFactor) {
328 // Usage case for loading as Matrix A operand
329 // Emulate with 16bit loads/stores.
330 // scaled_tileW = executionSize
331 elemType = rewriter.getIntegerType(16);
332 tileW = executionSize;
333 wScaleFactor = sub16BitFactor;
334 } else {
335 return rewriter.notifyMatchFailure(
336 op, "Unsupported tile shape for sub byte types.");
337 }
338 }
339 // recompute element bit size for emulation.
340 elemBitSize = elemType.getIntOrFloatBitWidth();
341 }
342
343 // Get address space from tensor descriptor memory space.
344 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
345 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
346 if (tileRank == 2) {
347 // Compute element byte size.
348 Value elemByteSize = arith::ConstantIntOp::create(
349 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
350 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
351 Value payLoadAsI64 =
352 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
353 Value basePtr =
354 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
355 static_cast<int>(NdTdescOffset::BasePtr));
356 Value baseShapeW = vector::ExtractOp::create(
357 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
358 Value baseShapeH = vector::ExtractOp::create(
359 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
360 Value basePitch = vector::ExtractOp::create(
361 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
362 // Offsets are provided by the op.
363 // convert them to i32.
364 Value offsetW =
365 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
366 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
367 rewriter.getI32Type(), offsetW);
368 Value offsetH =
369 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
370 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
371 rewriter.getI32Type(), offsetH);
372 // Convert base pointer (i64) to LLVM pointer type.
373 Value basePtrLLVM =
374 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
375 // FIXME: width or pitch is not the same as baseShapeW it should be the
376 // stride of the second to last dimension in row major layout.
377 // Compute width in bytes.
378 Value baseShapeWInBytes =
379 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
380 // Compute pitch in bytes.
381 Value basePitchBytes =
382 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
383
384 if (wScaleFactor > 1) {
385 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
386 // Note: tileW is already scaled above.
387 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
388 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
389 baseShapeWInBytes = arith::ShRSIOp::create(
390 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
391 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
392 wScaleFactorValLog2);
393 offsetW =
394 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
395 }
396 // Get tile height from the tensor descriptor type.
397 auto tileH = tdescTy.getDimSize(0);
398 // Get vblocks from the tensor descriptor type.
399 int32_t vblocks = tdescTy.getArrayLength();
400 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
401 Value src = adaptor.getValue();
402 // If store value is a scalar, get value from op instead of adaptor.
403 // Adaptor might have optimized away single element vector
404 if (src.getType().isIntOrFloat()) {
405 src = op.getValue();
406 }
407 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
408 if (!srcVecTy)
409 return rewriter.notifyMatchFailure(
410 op, "Expected store value to be a vector type.");
411 // Get flat vector type of integer type with matching element bit size.
412 VectorType newSrcVecTy =
413 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
414 if (srcVecTy != newSrcVecTy)
415 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
416 auto storeCacheControl =
417 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
418 xevm::BlockStore2dOp::create(
419 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
420 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
421 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
422 rewriter.eraseOp(op);
423 } else {
424 auto loadCacheControl =
425 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
426 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
427 xevm::BlockPrefetch2dOp::create(
428 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
429 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
430 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
431 rewriter.eraseOp(op);
432 } else {
433 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
434 const bool vnni = op.getPacked().value_or(false);
435 auto transposeValue = op.getTranspose();
436 bool transpose =
437 transposeValue.has_value() && transposeValue.value()[0] == 1;
438 VectorType loadedTy = encodeVectorTypeTo(
439 dstVecTy, vnni ? rewriter.getI32Type()
440 : rewriter.getIntegerType(elemBitSize));
441
442 Value resultFlatVec = xevm::BlockLoad2dOp::create(
443 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
444 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
445 tileH, vblocks, transpose, vnni,
446 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
447 resultFlatVec = vector::BitCastOp::create(
448 rewriter, loc,
449 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
450 resultFlatVec);
451 rewriter.replaceOp(op, resultFlatVec);
452 }
453 }
454 } else {
455 // 1D tensor descriptor.
456 // `tdesc` represents base address as i64
457 // Offset in number of elements, need to multiply by element byte size.
458 // Compute byte offset.
459 // byteOffset = offset * elementByteSize
460 Value offset =
461 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
462 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
463 rewriter.getI64Type(), offset);
464 // Compute element byte size.
465 Value elemByteSize = arith::ConstantIntOp::create(
466 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
467 Value byteOffset =
468 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
469 // Final address = basePtr + byteOffset
470 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
471 loc, tdesc,
472 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
473 byteOffset));
474 // Convert base pointer (i64) to LLVM pointer type.
475 Value finalPtrLLVM =
476 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
477 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
478 Value src = adaptor.getValue();
479 // If store value is a scalar, get value from op instead of adaptor.
480 // Adaptor might have optimized away single element vector
481 if (src.getType().isIntOrFloat()) {
482 src = op.getValue();
483 }
484 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
485 if (!srcVecTy)
486 return rewriter.notifyMatchFailure(
487 op, "Expected store value to be a vector type.");
488 // Get flat vector type of integer type with matching element bit size.
489 VectorType newSrcVecTy =
490 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
491 if (srcVecTy != newSrcVecTy)
492 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
493 auto storeCacheControl =
494 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
496 op, finalPtrLLVM, src,
497 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
498 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
499 auto loadCacheControl =
500 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
501 VectorType resTy = cast<VectorType>(op.getValue().getType());
502 VectorType loadedTy =
503 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
504 Value load = xevm::BlockLoadOp::create(
505 rewriter, loc, loadedTy, finalPtrLLVM,
506 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
507 if (loadedTy != resTy)
508 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
509 rewriter.replaceOp(op, load);
510 } else {
511 return rewriter.notifyMatchFailure(
512 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
513 "descriptor rank == 1");
514 }
515 }
516 return success();
517 }
518};
519
520// Add a builder that creates
521// offset * elemByteSize + baseAddr
522static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
523 Location loc, Value baseAddr, Value offset,
524 int64_t elemByteSize) {
526 rewriter, loc, baseAddr.getType(), elemByteSize);
527 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
528 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
529 return newAddr;
530}
531
532template <typename OpType,
533 typename = std::enable_if_t<llvm::is_one_of<
534 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
535class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
536 using OpConversionPattern<OpType>::OpConversionPattern;
537 LogicalResult
538 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter) const override {
540 Value offset = adaptor.getOffsets();
541 if (!offset)
542 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
543 auto loc = op.getLoc();
544 auto ctxt = rewriter.getContext();
545 Value basePtrI64;
546 // Load result or Store valye Type can be vector or scalar.
547 Type valOrResTy;
548 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
549 valOrResTy =
550 this->getTypeConverter()->convertType(op.getResult().getType());
551 else
552 valOrResTy = adaptor.getValue().getType();
553 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
554 bool hasScalarVal = !valOrResVecTy;
555 int64_t elemBitWidth =
556 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
557 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
558 // Element type must be multiple of 8 bits.
559 if (elemBitWidth % 8 != 0)
560 return rewriter.notifyMatchFailure(
561 op, "Expected element type bit width to be multiple of 8.");
562 int64_t elemByteSize = elemBitWidth / 8;
563 // Default memory space is global.
564 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
566 // Base pointer can come from source (load) or dest (store).
567 // If they are memrefs, we use their memory space.
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
572 if (addrSpace != 0)
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
574 }
575 } else {
576 basePtrI64 = adaptor.getDest();
577 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
579 if (addrSpace != 0)
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
581 }
582 }
583 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
584 if (basePtrI64.getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
586 basePtrI64);
587 }
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.getType())) {
590 // Offset needs be scalar. Single element vector is converted to scalar
591 // by type converter.
592 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
593 } else {
594 // If offset is provided, we add them to the base pointer.
595 // Offset is in number of elements, we need to multiply by
596 // element byte size.
597 basePtrI64 =
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
599 }
600 // Convert base pointer (i64) to LLVM pointer type.
601 Value basePtrLLVM =
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
603
604 Value maskForLane;
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
606 if (maskVecTy) {
607 // Mask needs be scalar. Single element vector is converted to scalar by
608 // type converter.
609 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
610 } else
611 maskForLane = mask;
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane, true, true);
615 // If mask is true,- then clause - load from memory and yield.
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
617 if (!hasScalarVal)
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
620 Value loaded =
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
622 // Set cache control attribute on the load operation.
623 loaded.getDefiningOp()->setAttr(
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
629 // If mask is false - else clause -yield a vector of zeros.
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
631 TypedAttr eVal;
632 if (eTy.isFloat())
633 eVal = FloatAttr::get(eTy, 0.0);
634 else
635 eVal = IntegerAttr::get(eTy, 0);
636 if (hasScalarVal)
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
638 else
639 loaded = arith::ConstantOp::create(
640 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
641 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
643 } else {
644 // If mask is true, perform the store.
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
648 auto storeOp =
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
650 // Set cache control attribute on the store operation.
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
656 }
657 return success();
658 }
659};
660
661class CreateMemDescOpPattern final
662 : public OpConversionPattern<xegpu::CreateMemDescOp> {
663public:
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
665 LogicalResult
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668
669 rewriter.replaceOp(op, adaptor.getSource());
670 return success();
671 }
672};
673
674template <typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
679 LogicalResult
680 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter) const override {
682
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
684 if (offsets.empty())
685 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
686
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
691 // Load result or Store value Type can be vector or scalar.
692 Type dataTy;
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
695 // Some transforms may leave unit dimension in the 2D vector, adaptors do
696 // not catch it for results.
697 if (auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
703 }
704 dataTy = resType;
705 } else
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
708 if (!valOrResVecTy)
709 valOrResVecTy = VectorType::get(1, dataTy);
710
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
713 // Element type must be multiple of 8 bits.
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op, "Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
718
719 // Default memory space is SLM.
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
722
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
724
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
730
731 // convert base pointer (i32) to LLVM pointer type
732 Value basePtrLLVM =
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
734
735 if (op.getSubgroupBlockIoAttr()) {
736 // if the attribute 'subgroup_block_io' is set to true, it lowers to
737 // xevm.blockload
738
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
742
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
744 Value loadOp =
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
747 loadOp =
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
749 }
750 rewriter.replaceOp(op, loadOp);
751 } else {
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
754 dataToStore =
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
756 }
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
758 nullptr);
759 rewriter.eraseOp(op);
760 }
761 return success();
762 }
763
764 if (valOrResVecTy.getNumElements() >= 1) {
765 auto chipOpt = xegpu::getChipStr(op);
766 if (!chipOpt ||
767 (*chipOpt != "pvc" && *chipOpt != "bmg" && *chipOpt != "cri")) {
768 // the lowering for chunk load only works for pvc, bmg or cri
769 return rewriter.notifyMatchFailure(
770 op, "The lowering is specific to pvc, bmg or cri.");
771 }
772 }
773
774 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
775 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
776 // operation. LLVM load/store does not support vector of size 1, so we
777 // need to handle this case separately.
778 auto scalarTy = valOrResVecTy.getElementType();
779 LLVM::LoadOp loadOp;
780 if (valOrResVecTy.getNumElements() == 1)
781 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
782 else
783 loadOp =
784 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
785 rewriter.replaceOp(op, loadOp);
786 } else {
787 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
788 rewriter.eraseOp(op);
789 }
790 return success();
791 }
792};
793
794class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
795 using OpConversionPattern::OpConversionPattern;
796 LogicalResult
797 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
798 ConversionPatternRewriter &rewriter) const override {
799 auto loc = op.getLoc();
800 auto ctxt = rewriter.getContext();
801 Value basePtrI64 = adaptor.getSource();
802 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
803 if (basePtrI64.getType() != rewriter.getI64Type())
804 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
805 basePtrI64);
806 Value offsets = adaptor.getOffsets();
807 if (offsets) {
808 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
809 if (offsetsVecTy) {
810 // Offset needs be scalar.
811 return rewriter.notifyMatchFailure(op,
812 "Expected offsets to be a scalar.");
813 } else {
814 int64_t elemBitWidth{0};
815 int64_t elemByteSize;
816 // Element byte size can come from two sources:
817 if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
818 // If memref is available, we use its element type to
819 // determine element byte size.
820 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
821 } else {
822 // Otherwise, we use the provided offset byte alignment.
823 elemByteSize = *op.getOffsetAlignByte();
824 }
825 if (elemBitWidth != 0) {
826 if (elemBitWidth % 8 != 0)
827 return rewriter.notifyMatchFailure(
828 op, "Expected element type bit width to be multiple of 8.");
829 elemByteSize = elemBitWidth / 8;
830 }
831 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
832 elemByteSize);
833 }
834 }
835 // Default memory space is global.
836 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
837 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
838 // If source is a memref, we use its memory space.
839 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
840 auto addrSpace = memRefTy.getMemorySpaceAsInt();
841 if (addrSpace != 0)
842 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
843 }
844 // Convert base pointer (i64) to LLVM pointer type.
845 Value ptrLLVM =
846 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
847 // Create the prefetch op with cache control attribute.
848 xevm::PrefetchOp::create(
849 rewriter, loc, ptrLLVM,
850 xevm::LoadCacheControlAttr::get(
851 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
852 rewriter.eraseOp(op);
853 return success();
854 }
855};
856
857class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
858 using OpConversionPattern::OpConversionPattern;
859 LogicalResult
860 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 auto loc = op.getLoc();
863 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
864 switch (op.getFenceScope()) {
865 case xegpu::FenceScope::Workgroup:
866 memScope = xevm::MemScope::WORKGROUP;
867 break;
868 case xegpu::FenceScope::GPU:
869 memScope = xevm::MemScope::DEVICE;
870 break;
871 }
872 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
873 switch (op.getMemoryKind()) {
874 case xegpu::MemorySpace::Global:
875 addrSpace = xevm::AddrSpace::GLOBAL;
876 break;
877 case xegpu::MemorySpace::SLM:
878 addrSpace = xevm::AddrSpace::SHARED;
879 break;
880 }
881 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
882 rewriter.eraseOp(op);
883 return success();
884 }
885};
886
887static auto encodePrecision = [](Type type) -> xevm::ElemType {
888 if (type.isBF16())
889 return xevm::ElemType::BF16;
890 else if (type.isF16())
891 return xevm::ElemType::F16;
892 else if (type.isTF32())
893 return xevm::ElemType::TF32;
894 else if (type.isInteger(8)) {
895 if (type.isUnsignedInteger())
896 return xevm::ElemType::U8;
897 return xevm::ElemType::S8;
898 } else if (type.isF32())
899 return xevm::ElemType::F32;
900 else if (type.isInteger(32))
901 return xevm::ElemType::S32;
902 else if (type.isF8E5M2())
903 return xevm::ElemType::BF8;
904 else if (type.isF8E4M3FN())
905 return xevm::ElemType::F8;
906 else if (mlir::isa<Float4E2M1FNType>(type))
907 return xevm::ElemType::E2M1;
908 llvm_unreachable("add more support for ElemType");
909};
910
911static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
912 switch (pTy) {
913 case xevm::ElemType::TF32:
914 return 1;
915 case xevm::ElemType::BF16:
916 case xevm::ElemType::F16:
917 return 2;
918 case xevm::ElemType::U8:
919 case xevm::ElemType::S8:
920 case xevm::ElemType::F8:
921 case xevm::ElemType::BF8:
922 return 4;
923 case xevm::ElemType::E2M1:
924 return 8;
925 default:
926 llvm_unreachable("unsupported xevm::ElemType");
927 }
928}
929
930class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
931 using OpConversionPattern::OpConversionPattern;
932 LogicalResult
933 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
934 ConversionPatternRewriter &rewriter) const override {
935 auto loc = op.getLoc();
936 auto ctxt = rewriter.getContext();
937 auto aTy = cast<VectorType>(op.getLhs().getType());
938 auto bTy = cast<VectorType>(op.getRhs().getType());
939 auto resultType = cast<VectorType>(op.getResultType());
940
941 // get the correct dpasInst by getting info from chip
942 auto chipStr = xegpu::getChipStr(op);
943 if (!chipStr)
944 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
945
946 const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
947 if (!uArch)
948 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
949
950 auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
951 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
952 uArch->getInstruction(
953 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
954 if (!dpasInst)
955 return rewriter.notifyMatchFailure(op,
956 "DPAS not supported by target uArch");
957
958 auto checkSupportedTypes = [&](VectorType vecTy,
959 xegpu::uArch::MMAOpndKind kind) -> bool {
960 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
961 return llvm::find(supported, vecTy.getElementType()) != supported.end();
962 };
963
964 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
965 return rewriter.notifyMatchFailure(
966 op, "A-matrix element type not supported by target uArch");
967 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
968 return rewriter.notifyMatchFailure(
969 op, "B-matrix element type not supported by target uArch");
970 // NOTE: Supported types for MatrixC and MatrixD are identical
971 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
972 return rewriter.notifyMatchFailure(
973 op, "result/accumulator element type not supported by target uArch");
974
975 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
976 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
977 Value c = op.getAcc();
978 if (!c) {
979 auto elementTy = resultType.getElementType();
980 Attribute initValueAttr;
981 if (isa<FloatType>(elementTy))
982 initValueAttr = FloatAttr::get(elementTy, 0.0);
983 else
984 initValueAttr = IntegerAttr::get(elementTy, 0);
985 c = arith::ConstantOp::create(
986 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
987 }
988
989 Value aVec = op.getLhs();
990 Value bVec = op.getRhs();
991 auto cvecty = cast<VectorType>(c.getType());
992 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
993 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
994 VectorType cNty =
995 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
996 if (cvecty != cNty)
997 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
998 Value dpasRes = xevm::MMAOp::create(
999 rewriter, loc, cNty, aVec, bVec, c,
1000 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
1001 systolicDepth *
1002 getNumOperandsPerDword(precATy)),
1003 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1004 if (cvecty != cNty)
1005 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
1006 rewriter.replaceOp(op, dpasRes);
1007 return success();
1008 }
1009};
1010
1011static std::optional<LLVM::AtomicBinOp>
1012matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1013 switch (arithKind) {
1014 case arith::AtomicRMWKind::addf:
1015 return LLVM::AtomicBinOp::fadd;
1016 case arith::AtomicRMWKind::addi:
1017 return LLVM::AtomicBinOp::add;
1018 case arith::AtomicRMWKind::assign:
1019 return LLVM::AtomicBinOp::xchg;
1020 case arith::AtomicRMWKind::maximumf:
1021 return LLVM::AtomicBinOp::fmax;
1022 case arith::AtomicRMWKind::maxs:
1023 return LLVM::AtomicBinOp::max;
1024 case arith::AtomicRMWKind::maxu:
1025 return LLVM::AtomicBinOp::umax;
1026 case arith::AtomicRMWKind::minimumf:
1027 return LLVM::AtomicBinOp::fmin;
1028 case arith::AtomicRMWKind::mins:
1029 return LLVM::AtomicBinOp::min;
1030 case arith::AtomicRMWKind::minu:
1031 return LLVM::AtomicBinOp::umin;
1032 case arith::AtomicRMWKind::ori:
1033 return LLVM::AtomicBinOp::_or;
1034 case arith::AtomicRMWKind::andi:
1035 return LLVM::AtomicBinOp::_and;
1036 default:
1037 return std::nullopt;
1038 }
1039}
1040
1041class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1042 using OpConversionPattern::OpConversionPattern;
1043 LogicalResult
1044 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1045 ConversionPatternRewriter &rewriter) const override {
1046 auto loc = op.getLoc();
1047 auto ctxt = rewriter.getContext();
1048 auto tdesc = op.getTensorDesc().getType();
1049 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1050 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1051 Value basePtrI64 = arith::IndexCastOp::create(
1052 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1053 Value basePtrLLVM =
1054 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1055 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1056 VectorType srcOrDstFlatVecTy = VectorType::get(
1057 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1058 Value srcFlatVec = vector::ShapeCastOp::create(
1059 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1060 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1061 assert(atomicKind.has_value());
1062 Value resVec = srcFlatVec;
1063 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1064 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1065 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1066 rewriter.getIndexAttr(i));
1067 Value currPtr =
1068 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1069 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1070 Value newVal =
1071 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1072 val, LLVM::AtomicOrdering::seq_cst);
1073 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1074 }
1075 rewriter.replaceOp(op, resVec);
1076 return success();
1077 }
1078};
1079
1080class DpasMxToXeVMPattern : public OpConversionPattern<xegpu::DpasMxOp> {
1081 using OpConversionPattern::OpConversionPattern;
1082 LogicalResult
1083 matchAndRewrite(xegpu::DpasMxOp op, xegpu::DpasMxOp::Adaptor adaptor,
1084 ConversionPatternRewriter &rewriter) const override {
1085 auto loc = op.getLoc();
1086 auto ctxt = rewriter.getContext();
1087 auto aTy = op.getA().getType();
1088 auto bTy = op.getB().getType();
1089 auto resVecTy =
1090 cast<VectorType>(getTypeConverter()->convertType(op.getType()));
1091
1092 auto chipStr = xegpu::getChipStr(op);
1093 if (!chipStr)
1094 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
1095
1096 const auto *uArch = xegpu::uArch::getUArch(*chipStr);
1097 if (!uArch)
1098 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
1099
1100 // TODO: Add supported shape check
1101
1102 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1103 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1104 Value c = adaptor.getAcc();
1105 if (!c) {
1106 auto elementTy = resVecTy.getElementType();
1107 Attribute initValueAttr;
1108 if (isa<FloatType>(elementTy))
1109 initValueAttr = FloatAttr::get(elementTy, 0.0);
1110 else
1111 initValueAttr = IntegerAttr::get(elementTy, 0);
1112 c = arith::ConstantOp::create(
1113 rewriter, loc, DenseElementsAttr::get(resVecTy, initValueAttr));
1114 }
1115
1116 Value aVec = adaptor.getA();
1117 Value bVec = adaptor.getB();
1118 auto aVecTy = cast<VectorType>(aVec.getType());
1119 auto bVecTy = cast<VectorType>(bVec.getType());
1120 if (aVecTy.getElementTypeBitWidth() == 4)
1121 aVec = vector::BitCastOp::create(
1122 rewriter, loc,
1123 VectorType::get(aVecTy.getNumElements() / 2, rewriter.getI8Type()),
1124 aVec);
1125 if (bVecTy.getElementTypeBitWidth() == 4)
1126 bVec = vector::BitCastOp::create(
1127 rewriter, loc,
1128 VectorType::get(bVecTy.getNumElements() / 2, rewriter.getI8Type()),
1129 bVec);
1130 auto cVecTy = cast<VectorType>(c.getType());
1131 xevm::ElemType precCTy = encodePrecision(cVecTy.getElementType());
1132 xevm::ElemType precDTy = encodePrecision(resVecTy.getElementType());
1133 Value scaleA = adaptor.getScaleA();
1134 Value scaleB = adaptor.getScaleB();
1135 Value dpasMxRes = xevm::MMAMxOp::create(
1136 rewriter, loc, resVecTy, aVec, bVec, scaleA, scaleB, c,
1137 xevm::MMAShapeAttr::get(ctxt, cVecTy.getNumElements(), executionSize,
1138 systolicDepth *
1139 getNumOperandsPerDword(precATy)),
1140 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1141 rewriter.replaceOp(op, dpasMxRes);
1142 return success();
1143 }
1144};
1145
1146//===----------------------------------------------------------------------===//
1147// Pass Definition
1148//===----------------------------------------------------------------------===//
1149
1150struct ConvertXeGPUToXeVMPass
1151 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1152 using Base::Base;
1153
1154 void runOnOperation() override {
1155 MLIRContext *context = &getContext();
1156
1157 // XeVM type converter is based on LLVM type converter with the
1158 // following customizations.
1159 // First, type conversion rules are added for xegpu custom types,
1160 // TensorDescType and MemDescType.
1161 // Second, MemRefType is lowered to single integer type
1162 // Third, VectorType of single element or 0D is converted to vector
1163 // element type. Otherwise, vector type is flatten to 1D.
1164 LowerToLLVMOptions options(context);
1165 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1166 LLVMTypeConverter typeConverter(context, options);
1167
1168 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1169 Type i32Type = IntegerType::get(context, 32);
1170 typeConverter.addConversion([&](VectorType type) -> Type {
1171 auto elemType = typeConverter.convertType(type.getElementType());
1172 // If the vector rank is 0 or has a single element, return the element
1173 unsigned rank = type.getRank();
1174 if (rank == 0 || type.getNumElements() == 1)
1175 return elemType;
1176 // Otherwise, convert the vector to a flat vector type.
1177 int64_t sum = llvm::product_of(type.getShape());
1178 return VectorType::get(sum, elemType);
1179 });
1180 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1181 if (type.getRank() == 1)
1182 return xevmIndexType;
1183 return VectorType::get(8, i32Type);
1184 });
1185 // SLM access related type conversions.
1186 // TODO: LLVM DLTI provides clean way of representing different pointer size
1187 // based on address space. Currently pointer size of SLM access is hard
1188 // coded to 32bit. Update to use DLTI when switching overall XeGPU lowering
1189 // to use DLTI instead of use64bitIndex option used above.
1190
1191 // Convert MemDescType into i32 for SLM
1192 typeConverter.addConversion(
1193 [&](xegpu::MemDescType type) -> Type { return i32Type; });
1194
1195 typeConverter.addConversion([&](MemRefType type) -> Type {
1196 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1197 });
1198
1199 // LLVM type converter puts unrealized casts for the following cases:
1200 // add materialization casts to handle them.
1201
1202 // Materialization to convert memref to i64 or i32 depending on global/SLM
1203 // Applies only to target materialization.
1204 // Note: int type to memref materialization is not required as xegpu ops
1205 // currently do not produce memrefs as result.
1206 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1207 ValueRange inputs,
1208 Location loc) -> Value {
1209 if (inputs.size() != 1)
1210 return {};
1211 auto input = inputs.front();
1212 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1213 unsigned rank = memrefTy.getRank();
1214 Type indexType = builder.getIndexType();
1215
1216 int64_t intOffsets;
1217 SmallVector<int64_t> intStrides;
1218 Value addr;
1219 Value offset;
1220 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1221 ShapedType::isStatic(intOffsets)) {
1222 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1223 input);
1224 offset = arith::ConstantOp::create(builder, loc,
1225 builder.getIndexAttr(intOffsets));
1226 } else {
1227
1228 // Result types: [base_memref, offset, stride0, stride1, ...,
1229 // strideN-1, size0, size1, ..., sizeN-1]
1230 SmallVector<Type> resultTypes{
1231 MemRefType::get({}, memrefTy.getElementType(),
1232 MemRefLayoutAttrInterface(),
1233 memrefTy.getMemorySpace()),
1234 indexType};
1235 // strides + sizes
1236 resultTypes.append(2 * rank, indexType);
1237
1238 auto meta = memref::ExtractStridedMetadataOp::create(
1239 builder, loc, resultTypes, input);
1240
1241 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1242 builder, loc, meta.getBaseBuffer());
1243 offset = meta.getOffset();
1244 }
1245
1246 auto addrCasted =
1247 arith::IndexCastUIOp::create(builder, loc, type, addr);
1248 auto offsetCasted =
1249 arith::IndexCastUIOp::create(builder, loc, type, offset);
1250
1251 // Compute the final address: base address + byte offset
1252 auto byteSize = arith::ConstantOp::create(
1253 builder, loc, type,
1254 builder.getIntegerAttr(type,
1255 memrefTy.getElementTypeBitWidth() / 8));
1256 auto byteOffset =
1257 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1258 auto addrWithOffset =
1259 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1260
1261 return addrWithOffset.getResult();
1262 }
1263 return {};
1264 };
1265
1266 // Materialization to convert ui64 to i64
1267 // Applies only to target materialization.
1268 // Note: i64 to ui64 materialization is not required as xegpu ops
1269 // currently do not produce ui64 as result.
1270 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1271 ValueRange inputs,
1272 Location loc) -> Value {
1273 if (inputs.size() != 1)
1274 return {};
1275 auto input = inputs.front();
1276 if (input.getType() == builder.getIntegerType(64, false)) {
1277 Value cast =
1278 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1279 .getResult();
1280 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1281 .getResult();
1282 }
1283 return {};
1284 };
1285
1286 // Materialization to convert ui32 to i32
1287 // Applies only to target materialization.
1288 // Note: i32 to ui32 materialization is not required as xegpu ops
1289 // currently do not produce ui32 as result.
1290 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1291 ValueRange inputs,
1292 Location loc) -> Value {
1293 if (inputs.size() != 1)
1294 return {};
1295 auto input = inputs.front();
1296 if (input.getType() == builder.getIntegerType(32, false)) {
1297 Value cast =
1298 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1299 .getResult();
1300 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1301 .getResult();
1302 }
1303 return {};
1304 };
1305
1306 // Materialization to convert between vector types
1307 // - Add shape cast for different shapes
1308 // - Add bitcast for different element types
1309 // Applies to both source and target materialization.
1310 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1311 ValueRange inputs,
1312 Location loc) -> Value {
1313 if (inputs.size() != 1)
1314 return {};
1315 auto input = inputs.front();
1316 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1317 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1318 Value cast = input;
1319 // If the target type has a different shape, add a shape cast
1320 // If the target type has a different element type, add a bitcast
1321 if (targetVecTy.getShape() != vecTy.getShape()) {
1322 cast = vector::ShapeCastOp::create(
1323 builder, loc,
1324 VectorType::get(targetVecTy.getShape(),
1325 vecTy.getElementType()),
1326 cast)
1327 .getResult();
1328 }
1329 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1330 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1331 .getResult();
1332 }
1333 return cast;
1334 }
1335 }
1336 return {};
1337 };
1338
1339 // Materialization to convert
1340 // - single element vector to single element of vector element type
1341 // Applies only to target materialization.
1342 auto vectorToSingleElementMaterializationCast =
1343 [](OpBuilder &builder, Type type, ValueRange inputs,
1344 Location loc) -> Value {
1345 if (inputs.size() != 1)
1346 return {};
1347 auto input = inputs.front();
1348 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1349 // Source needs to be single element vector
1350 auto rank = vecTy.getRank();
1351 if (rank != 0 && vecTy.getNumElements() != 1)
1352 return {};
1353 auto inElemTy = vecTy.getElementType();
1354 // extract scalar
1355 Value cast = input;
1356 if (rank == 0) {
1357 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1358 } else {
1359 cast = vector::ExtractOp::create(builder, loc, cast,
1360 SmallVector<int64_t>(rank, 0))
1361 .getResult();
1362 }
1363 // Extracted element type may need conversion
1364 // Two cases
1365 // 1. Index type to integer type
1366 // 2. Other element type mismatch
1367 if (inElemTy.isIndex()) {
1368 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1369 .getResult();
1370 } else if (inElemTy != type) {
1371 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1372 }
1373 return cast;
1374 }
1375 return {};
1376 };
1377
1378 // Materialization to convert
1379 // - single element of vector element type to single element vector
1380 // If result type of original op is single element vector and lowered type
1381 // is scalar. This materialization cast creates a single element vector by
1382 // First convert element type if needed and then broadcast to single
1383 // element vector.
1384 // Applies only to source materialization.
1385 auto singleElementToVectorMaterializationCast =
1386 [](OpBuilder &builder, Type type, ValueRange inputs,
1387 Location loc) -> Value {
1388 if (inputs.size() != 1)
1389 return {};
1390 auto input = inputs.front();
1391 auto inTy = input.getType();
1392 if (!inTy.isIntOrFloat())
1393 return {};
1394 // If the target type is a vector of rank 0 or single element vector
1395 // of element type matching input type, broadcast input to target type.
1396 if (auto vecTy = dyn_cast<VectorType>(type)) {
1397 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1398 return {};
1399 auto outElemTy = vecTy.getElementType();
1400 Value cast = input;
1401 if (outElemTy.isIndex()) {
1402 cast = arith::IndexCastUIOp::create(builder, loc,
1403 builder.getIndexType(), cast)
1404 .getResult();
1405 } else if (inTy != outElemTy) {
1406 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1407 .getResult();
1408 }
1409 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1410 .getResult();
1411 }
1412 return {};
1413 };
1414 typeConverter.addSourceMaterialization(
1415 singleElementToVectorMaterializationCast);
1416 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1417 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1418 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1419 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1420 typeConverter.addTargetMaterialization(
1421 vectorToSingleElementMaterializationCast);
1422 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1423 ConversionTarget target(*context);
1424 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1425 vector::VectorDialect, arith::ArithDialect,
1426 memref::MemRefDialect, gpu::GPUDialect,
1427 index::IndexDialect>();
1428 target.addIllegalDialect<xegpu::XeGPUDialect>();
1429
1430 RewritePatternSet patterns(context);
1431 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1433 patterns, target);
1434 if (failed(applyPartialConversion(getOperation(), target,
1435 std::move(patterns))))
1436 signalPassFailure();
1437 }
1438};
1439} // namespace
1440
1441//===----------------------------------------------------------------------===//
1442// Pattern Population
1443//===----------------------------------------------------------------------===//
1445 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1446 patterns.add<CreateNdDescToXeVMPattern,
1447 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1448 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1449 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1450 typeConverter, patterns.getContext());
1451 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1452 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1453 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1454 typeConverter, patterns.getContext());
1455 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1456 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1457 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1458 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1459 patterns.getContext());
1460 patterns.add<DpasMxToXeVMPattern>(typeConverter, patterns.getContext());
1461}
return success()
b getContext())
auto load
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:607
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:283
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)