MLIR 23.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LLVM.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
31
33#include "mlir/IR/Types.h"
34
35#include "llvm/ADT/TypeSwitch.h"
36
37#include <numeric>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48// TODO: Below are uArch dependent values, should move away from hardcoding
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
51
52// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
53enum class NdTdescOffset : uint32_t {
54 BasePtr = 0, // Base pointer (i64)
55 BaseShapeW = 2, // Base shape width (i32)
56 BaseShapeH = 3, // Base shape height (i32)
57 BasePitch = 4, // Base pitch (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70/// Checks if the given MemRefType refers to shared memory.
71static bool isSharedMemRef(const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
73 if (!attr)
74 return false;
75 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
77 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
80}
81
82// Get same bitwidth flat vector type of new element type.
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 Type toElemType) {
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
87 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
88 const int size =
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
91}
92
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
97 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
98 switch (L1hintVal) {
99 case xegpu::CachePolicy::CACHED:
100 if (L3hintVal == xegpu::CachePolicy::CACHED)
101 return xevm::LoadCacheControl::L1C_L2UC_L3C;
102 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
103 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
104 else
105 llvm_unreachable("Unsupported cache control.");
106 case xegpu::CachePolicy::UNCACHED:
107 if (L3hintVal == xegpu::CachePolicy::CACHED)
108 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
109 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
110 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
111 else
112 llvm_unreachable("Unsupported cache control.");
113 case xegpu::CachePolicy::STREAMING:
114 if (L3hintVal == xegpu::CachePolicy::CACHED)
115 return xevm::LoadCacheControl::L1S_L2UC_L3C;
116 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
117 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
118 else
119 llvm_unreachable("Unsupported cache control.");
120 case xegpu::CachePolicy::READ_INVALIDATE:
121 return xevm::LoadCacheControl::INVALIDATE_READ;
122 default:
123 llvm_unreachable("Unsupported cache control.");
124 }
125}
126
127static xevm::StoreCacheControl
128translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
129 std::optional<xegpu::CachePolicy> L3hint) {
130 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
131 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
132 switch (L1hintVal) {
133 case xegpu::CachePolicy::UNCACHED:
134 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
135 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
136 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
137 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
138 else
139 llvm_unreachable("Unsupported cache control.");
140 case xegpu::CachePolicy::STREAMING:
141 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
142 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
143 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
144 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
145 else
146 llvm_unreachable("Unsupported cache control.");
147 case xegpu::CachePolicy::WRITE_BACK:
148 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
149 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
150 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
151 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
152 else
153 llvm_unreachable("Unsupported cache control.");
154 case xegpu::CachePolicy::WRITE_THROUGH:
155 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
156 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
157 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
158 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
159 else
160 llvm_unreachable("Unsupported cache control.");
161 default:
162 llvm_unreachable("Unsupported cache control.");
163 }
164}
165
166//
167// Note:
168// Block operations for tile of sub byte element types are handled by
169// emulating with larger element types.
170// Tensor descriptor are keep intact and only ops consuming them are
171// emulated
172//
173
174class CreateNdDescToXeVMPattern
175 : public OpConversionPattern<xegpu::CreateNdDescOp> {
176 using OpConversionPattern::OpConversionPattern;
177 LogicalResult
178 matchAndRewrite(xegpu::CreateNdDescOp op,
179 xegpu::CreateNdDescOp::Adaptor adaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
182 if (mixedOffsets.size() != 0)
183 return rewriter.notifyMatchFailure(op, "Offsets not supported.");
184 auto loc = op.getLoc();
185 auto source = op.getSource();
186 // Op is lowered to a code sequence that populates payload.
187 // Payload is a 8xi32 vector. Offset to individual fields are defined in
188 // NdTdescOffset enum.
189 Type payloadElemTy = rewriter.getI32Type();
190 VectorType payloadTy = VectorType::get(8, payloadElemTy);
191 Type i64Ty = rewriter.getI64Type();
192 // 4xi64 view is used for inserting the base pointer.
193 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
194 // Initialize payload to zero.
195 Value payload = arith::ConstantOp::create(
196 rewriter, loc,
197 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
198
199 Value baseAddr;
200 Value baseShapeW;
201 Value baseShapeH;
202
203 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
204 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
205 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
206 // Descriptor shape is expected to be 2D.
207 int64_t rank = mixedSizes.size();
208 auto sourceTy = source.getType();
209 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
210 // If source is a memref, we need to extract the aligned pointer as index.
211 // Pointer type is passed as i32 or i64 by type converter.
212 if (sourceMemrefTy) {
213 if (!sourceMemrefTy.hasRank()) {
214 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
215 }
216 // Access adaptor after failure check to avoid rolling back generated code
217 // for materialization cast.
218 baseAddr = adaptor.getSource();
219 } else {
220 baseAddr = adaptor.getSource();
221 if (baseAddr.getType() != i64Ty) {
222 // Pointer type may be i32. Cast to i64 if needed.
223 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
224 }
225 }
226 // 1D tensor descriptor is just the base address.
227 if (rank == 1) {
228 rewriter.replaceOp(op, baseAddr);
229 return success();
230 }
231 // Utility for creating offset values from op fold result.
232 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
233 unsigned idx) -> Value {
234 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
235 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
236 return val;
237 };
238 // Get shape values from op fold results.
239 baseShapeW = createOffset(mixedSizes, 1);
240 baseShapeH = createOffset(mixedSizes, 0);
241 // Get pitch value from op fold results.
242 Value basePitch = createOffset(mixedStrides, 0);
243 // Populate payload.
244 Value payLoadAsI64 =
245 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
246 payLoadAsI64 =
247 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
248 static_cast<int>(NdTdescOffset::BasePtr));
249 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
250 payload =
251 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
252 static_cast<int>(NdTdescOffset::BaseShapeW));
253 payload =
254 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
255 static_cast<int>(NdTdescOffset::BaseShapeH));
256 payload =
257 vector::InsertOp::create(rewriter, loc, basePitch, payload,
258 static_cast<int>(NdTdescOffset::BasePitch));
259 rewriter.replaceOp(op, payload);
260 return success();
261 }
262};
263
264template <
265 typename OpType,
266 typename = std::enable_if_t<llvm::is_one_of<
267 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
268class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
269 using OpConversionPattern<OpType>::OpConversionPattern;
270 LogicalResult
271 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
272 ConversionPatternRewriter &rewriter) const override {
273 auto mixedOffsets = op.getMixedOffsets();
274 int64_t opOffsetsSize = mixedOffsets.size();
275 auto loc = op.getLoc();
276 auto ctxt = rewriter.getContext();
277
278 auto tdesc = adaptor.getTensorDesc();
279 auto tdescTy = op.getTensorDescType();
280 auto tileRank = tdescTy.getRank();
281 if (opOffsetsSize != tileRank)
282 return rewriter.notifyMatchFailure(
283 op, "Expected offset rank to match descriptor rank.");
284 auto elemType = tdescTy.getElementType();
285 auto elemBitSize = elemType.getIntOrFloatBitWidth();
286 bool isSubByte = elemBitSize < 8;
287 uint64_t wScaleFactor = 1;
288
289 if (!isSubByte && (elemBitSize % 8 != 0))
290 return rewriter.notifyMatchFailure(
291 op, "Expected element type bit width to be multiple of 8.");
292 auto tileW = tdescTy.getDimSize(tileRank - 1);
293 // For sub byte types, only 4bits are currently supported.
294 if (isSubByte) {
295 if (elemBitSize != 4)
296 return rewriter.notifyMatchFailure(
297 op, "Only sub byte types of 4bits are supported.");
298 if (tileRank != 2)
299 return rewriter.notifyMatchFailure(
300 op, "Sub byte types are only supported for 2D tensor descriptors.");
301 auto subByteFactor = 8 / elemBitSize;
302 auto tileH = tdescTy.getDimSize(0);
303 // Handle special case for packed load.
304 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
305 if (op.getPacked().value_or(false)) {
306 // packed load is implemented as packed loads of 8bit elements.
307 if (tileH == systolicDepth * 4 &&
308 tileW == executionSize * subByteFactor) {
309 // Usage case for loading as Matrix B with pack request.
310 // source is assumed to pre-packed into 8bit elements
311 // Emulate with 8bit loads with pack request.
312 // scaled_tileW = executionSize
313 elemType = rewriter.getIntegerType(8);
314 tileW = executionSize;
315 wScaleFactor = subByteFactor;
316 }
317 }
318 }
319 // If not handled by packed load case above, handle other cases.
320 if (wScaleFactor == 1) {
321 auto sub16BitFactor = subByteFactor * 2;
322 if (tileW == executionSize * sub16BitFactor) {
323 // Usage case for loading as Matrix A operand
324 // Emulate with 16bit loads/stores.
325 // scaled_tileW = executionSize
326 elemType = rewriter.getIntegerType(16);
327 tileW = executionSize;
328 wScaleFactor = sub16BitFactor;
329 } else {
330 return rewriter.notifyMatchFailure(
331 op, "Unsupported tile shape for sub byte types.");
332 }
333 }
334 // recompute element bit size for emulation.
335 elemBitSize = elemType.getIntOrFloatBitWidth();
336 }
337
338 // Get address space from tensor descriptor memory space.
339 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
340 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
341 if (tileRank == 2) {
342 // Compute element byte size.
343 Value elemByteSize = arith::ConstantIntOp::create(
344 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
345 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
346 Value payLoadAsI64 =
347 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
348 Value basePtr =
349 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
350 static_cast<int>(NdTdescOffset::BasePtr));
351 Value baseShapeW = vector::ExtractOp::create(
352 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
353 Value baseShapeH = vector::ExtractOp::create(
354 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
355 Value basePitch = vector::ExtractOp::create(
356 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
357 // Offsets are provided by the op.
358 // convert them to i32.
359 Value offsetW =
360 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
361 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
362 rewriter.getI32Type(), offsetW);
363 Value offsetH =
364 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
365 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
366 rewriter.getI32Type(), offsetH);
367 // Convert base pointer (i64) to LLVM pointer type.
368 Value basePtrLLVM =
369 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
370 // FIXME: width or pitch is not the same as baseShapeW it should be the
371 // stride of the second to last dimension in row major layout.
372 // Compute width in bytes.
373 Value baseShapeWInBytes =
374 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
375 // Compute pitch in bytes.
376 Value basePitchBytes =
377 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
378
379 if (wScaleFactor > 1) {
380 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
381 // Note: tileW is already scaled above.
382 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
383 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
384 baseShapeWInBytes = arith::ShRSIOp::create(
385 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
386 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
387 wScaleFactorValLog2);
388 offsetW =
389 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
390 }
391 // Get tile height from the tensor descriptor type.
392 auto tileH = tdescTy.getDimSize(0);
393 // Get vblocks from the tensor descriptor type.
394 int32_t vblocks = tdescTy.getArrayLength();
395 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
396 Value src = adaptor.getValue();
397 // If store value is a scalar, get value from op instead of adaptor.
398 // Adaptor might have optimized away single element vector
399 if (src.getType().isIntOrFloat()) {
400 src = op.getValue();
401 }
402 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
403 if (!srcVecTy)
404 return rewriter.notifyMatchFailure(
405 op, "Expected store value to be a vector type.");
406 // Get flat vector type of integer type with matching element bit size.
407 VectorType newSrcVecTy =
408 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
409 if (srcVecTy != newSrcVecTy)
410 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
411 auto storeCacheControl =
412 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
413 xevm::BlockStore2dOp::create(
414 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
415 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
416 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
417 rewriter.eraseOp(op);
418 } else {
419 auto loadCacheControl =
420 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
422 xevm::BlockPrefetch2dOp::create(
423 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
424 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
425 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
426 rewriter.eraseOp(op);
427 } else {
428 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
429 const bool vnni = op.getPacked().value_or(false);
430 auto transposeValue = op.getTranspose();
431 bool transpose =
432 transposeValue.has_value() && transposeValue.value()[0] == 1;
433 VectorType loadedTy = encodeVectorTypeTo(
434 dstVecTy, vnni ? rewriter.getI32Type()
435 : rewriter.getIntegerType(elemBitSize));
436
437 Value resultFlatVec = xevm::BlockLoad2dOp::create(
438 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
439 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
440 tileH, vblocks, transpose, vnni,
441 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
442 resultFlatVec = vector::BitCastOp::create(
443 rewriter, loc,
444 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
445 resultFlatVec);
446 rewriter.replaceOp(op, resultFlatVec);
447 }
448 }
449 } else {
450 // 1D tensor descriptor.
451 // `tdesc` represents base address as i64
452 // Offset in number of elements, need to multiply by element byte size.
453 // Compute byte offset.
454 // byteOffset = offset * elementByteSize
455 Value offset =
456 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
457 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
458 rewriter.getI64Type(), offset);
459 // Compute element byte size.
460 Value elemByteSize = arith::ConstantIntOp::create(
461 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
462 Value byteOffset =
463 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
464 // Final address = basePtr + byteOffset
465 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
466 loc, tdesc,
467 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
468 byteOffset));
469 // Convert base pointer (i64) to LLVM pointer type.
470 Value finalPtrLLVM =
471 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
472 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
473 Value src = adaptor.getValue();
474 // If store value is a scalar, get value from op instead of adaptor.
475 // Adaptor might have optimized away single element vector
476 if (src.getType().isIntOrFloat()) {
477 src = op.getValue();
478 }
479 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
480 if (!srcVecTy)
481 return rewriter.notifyMatchFailure(
482 op, "Expected store value to be a vector type.");
483 // Get flat vector type of integer type with matching element bit size.
484 VectorType newSrcVecTy =
485 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
486 if (srcVecTy != newSrcVecTy)
487 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
488 auto storeCacheControl =
489 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
490 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
491 op, finalPtrLLVM, src,
492 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
493 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
494 auto loadCacheControl =
495 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
496 VectorType resTy = cast<VectorType>(op.getValue().getType());
497 VectorType loadedTy =
498 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
499 Value load = xevm::BlockLoadOp::create(
500 rewriter, loc, loadedTy, finalPtrLLVM,
501 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
502 if (loadedTy != resTy)
503 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
504 rewriter.replaceOp(op, load);
505 } else {
506 return rewriter.notifyMatchFailure(
507 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
508 "descriptor rank == 1");
509 }
510 }
511 return success();
512 }
513};
514
515// Add a builder that creates
516// offset * elemByteSize + baseAddr
517static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
518 Location loc, Value baseAddr, Value offset,
519 int64_t elemByteSize) {
521 rewriter, loc, baseAddr.getType(), elemByteSize);
522 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
523 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
524 return newAddr;
525}
526
527template <typename OpType,
528 typename = std::enable_if_t<llvm::is_one_of<
529 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
530class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
531 using OpConversionPattern<OpType>::OpConversionPattern;
532 LogicalResult
533 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
534 ConversionPatternRewriter &rewriter) const override {
535 Value offset = adaptor.getOffsets();
536 if (!offset)
537 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
538 auto loc = op.getLoc();
539 auto ctxt = rewriter.getContext();
540 auto tdescTy = op.getTensorDescType();
541 Value basePtrI64;
542 // Load result or Store valye Type can be vector or scalar.
543 Type valOrResTy;
544 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
545 valOrResTy =
546 this->getTypeConverter()->convertType(op.getResult().getType());
547 else
548 valOrResTy = adaptor.getValue().getType();
549 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
550 bool hasScalarVal = !valOrResVecTy;
551 int64_t elemBitWidth =
552 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
553 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
554 // Element type must be multiple of 8 bits.
555 if (elemBitWidth % 8 != 0)
556 return rewriter.notifyMatchFailure(
557 op, "Expected element type bit width to be multiple of 8.");
558 int64_t elemByteSize = elemBitWidth / 8;
559 // Default memory space is global.
560 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
561 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
562 // If tensor descriptor is available, we use its memory space.
563 if (tdescTy)
564 ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
566 // Base pointer can come from source (load) or dest (store).
567 // If they are memrefs, we use their memory space.
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
572 if (addrSpace != 0)
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
574 }
575 } else {
576 basePtrI64 = adaptor.getDest();
577 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
579 if (addrSpace != 0)
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
581 }
582 }
583 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
584 if (basePtrI64.getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
586 basePtrI64);
587 }
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.getType())) {
590 // Offset needs be scalar. Single element vector is converted to scalar
591 // by type converter.
592 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
593 } else {
594 // If offset is provided, we add them to the base pointer.
595 // Offset is in number of elements, we need to multiply by
596 // element byte size.
597 basePtrI64 =
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
599 }
600 // Convert base pointer (i64) to LLVM pointer type.
601 Value basePtrLLVM =
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
603
604 Value maskForLane;
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
606 if (maskVecTy) {
607 // Mask needs be scalar. Single element vector is converted to scalar by
608 // type converter.
609 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
610 } else
611 maskForLane = mask;
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane, true, true);
615 // If mask is true,- then clause - load from memory and yield.
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
617 if (!hasScalarVal)
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
620 Value loaded =
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
622 // Set cache control attribute on the load operation.
623 loaded.getDefiningOp()->setAttr(
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
629 // If mask is false - else clause -yield a vector of zeros.
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
631 TypedAttr eVal;
632 if (eTy.isFloat())
633 eVal = FloatAttr::get(eTy, 0.0);
634 else
635 eVal = IntegerAttr::get(eTy, 0);
636 if (hasScalarVal)
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
638 else
639 loaded = arith::ConstantOp::create(
640 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
641 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
643 } else {
644 // If mask is true, perform the store.
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
648 auto storeOp =
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
650 // Set cache control attribute on the store operation.
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
656 }
657 return success();
658 }
659};
660
661class CreateMemDescOpPattern final
662 : public OpConversionPattern<xegpu::CreateMemDescOp> {
663public:
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
665 LogicalResult
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668
669 rewriter.replaceOp(op, adaptor.getSource());
670 return success();
671 }
672};
673
674template <typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
679 LogicalResult
680 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter) const override {
682
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
684 if (offsets.empty())
685 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
686
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
691 // Load result or Store value Type can be vector or scalar.
692 Type dataTy;
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
695 // Some transforms may leave unit dimension in the 2D vector, adaptors do
696 // not catch it for results.
697 if (auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
703 }
704 dataTy = resType;
705 } else
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
708 if (!valOrResVecTy)
709 valOrResVecTy = VectorType::get(1, dataTy);
710
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
713 // Element type must be multiple of 8 bits.
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op, "Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
718
719 // Default memory space is SLM.
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
722
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
724
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
730
731 // convert base pointer (i32) to LLVM pointer type
732 Value basePtrLLVM =
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
734
735 if (op.getSubgroupBlockIoAttr()) {
736 // if the attribute 'subgroup_block_io' is set to true, it lowers to
737 // xevm.blockload
738
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
742
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
744 Value loadOp =
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
747 loadOp =
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
749 }
750 rewriter.replaceOp(op, loadOp);
751 } else {
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
754 dataToStore =
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
756 }
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
758 nullptr);
759 rewriter.eraseOp(op);
760 }
761 return success();
762 }
763
764 if (valOrResVecTy.getNumElements() >= 1) {
765 auto chipOpt = xegpu::getChipStr(op);
766 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
767 // the lowering for chunk load only works for pvc and bmg
768 return rewriter.notifyMatchFailure(
769 op, "The lowering is specific to pvc or bmg.");
770 }
771 }
772
773 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
774 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
775 // operation. LLVM load/store does not support vector of size 1, so we
776 // need to handle this case separately.
777 auto scalarTy = valOrResVecTy.getElementType();
778 LLVM::LoadOp loadOp;
779 if (valOrResVecTy.getNumElements() == 1)
780 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
781 else
782 loadOp =
783 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
784 rewriter.replaceOp(op, loadOp);
785 } else {
786 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
787 rewriter.eraseOp(op);
788 }
789 return success();
790 }
791};
792
793class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
794 using OpConversionPattern::OpConversionPattern;
795 LogicalResult
796 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
797 ConversionPatternRewriter &rewriter) const override {
798 auto loc = op.getLoc();
799 auto ctxt = rewriter.getContext();
800 auto tdescTy = op.getTensorDescType();
801 Value basePtrI64 = adaptor.getSource();
802 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
803 if (basePtrI64.getType() != rewriter.getI64Type())
804 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
805 basePtrI64);
806 Value offsets = adaptor.getOffsets();
807 if (offsets) {
808 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
809 if (offsetsVecTy) {
810 // Offset needs be scalar.
811 return rewriter.notifyMatchFailure(op,
812 "Expected offsets to be a scalar.");
813 } else {
814 int64_t elemBitWidth{0};
815 int64_t elemByteSize;
816 // Element byte size can come from three sources:
817 if (tdescTy) {
818 // If tensor descriptor is available, we use its element type to
819 // determine element byte size.
820 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
821 } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
822 // If memref is available, we use its element type to
823 // determine element byte size.
824 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
825 } else {
826 // Otherwise, we use the provided offset byte alignment.
827 elemByteSize = *op.getOffsetAlignByte();
828 }
829 if (elemBitWidth != 0) {
830 if (elemBitWidth % 8 != 0)
831 return rewriter.notifyMatchFailure(
832 op, "Expected element type bit width to be multiple of 8.");
833 elemByteSize = elemBitWidth / 8;
834 }
835 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
836 elemByteSize);
837 }
838 }
839 // Default memory space is global.
840 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
841 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
842 // If tensor descriptor is available, we use its memory space.
843 if (tdescTy)
844 ptrTypeLLVM = LLVM::LLVMPointerType::get(
845 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
846 // If source is a memref, we use its memory space.
847 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
848 auto addrSpace = memRefTy.getMemorySpaceAsInt();
849 if (addrSpace != 0)
850 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
851 }
852 // Convert base pointer (i64) to LLVM pointer type.
853 Value ptrLLVM =
854 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
855 // Create the prefetch op with cache control attribute.
856 xevm::PrefetchOp::create(
857 rewriter, loc, ptrLLVM,
858 xevm::LoadCacheControlAttr::get(
859 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
860 rewriter.eraseOp(op);
861 return success();
862 }
863};
864
865class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
866 using OpConversionPattern::OpConversionPattern;
867 LogicalResult
868 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
869 ConversionPatternRewriter &rewriter) const override {
870 auto loc = op.getLoc();
871 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
872 switch (op.getFenceScope()) {
873 case xegpu::FenceScope::Workgroup:
874 memScope = xevm::MemScope::WORKGROUP;
875 break;
876 case xegpu::FenceScope::GPU:
877 memScope = xevm::MemScope::DEVICE;
878 break;
879 }
880 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
881 switch (op.getMemoryKind()) {
882 case xegpu::MemorySpace::Global:
883 addrSpace = xevm::AddrSpace::GLOBAL;
884 break;
885 case xegpu::MemorySpace::SLM:
886 addrSpace = xevm::AddrSpace::SHARED;
887 break;
888 }
889 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
890 rewriter.eraseOp(op);
891 return success();
892 }
893};
894
895class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
896 using OpConversionPattern::OpConversionPattern;
897 LogicalResult
898 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
899 ConversionPatternRewriter &rewriter) const override {
900 auto loc = op.getLoc();
901 auto ctxt = rewriter.getContext();
902 auto aTy = cast<VectorType>(op.getLhs().getType());
903 auto bTy = cast<VectorType>(op.getRhs().getType());
904 auto resultType = cast<VectorType>(op.getResultType());
905
906 // get the correct dpasInst by getting info from chip
907 auto chipStr = xegpu::getChipStr(op);
908 if (!chipStr)
909 return rewriter.notifyMatchFailure(op, "cannot determine target chip");
910
911 const auto *uArch = mlir::xegpu::uArch::getUArch(*chipStr);
912 if (!uArch)
913 return rewriter.notifyMatchFailure(op, "unsupported target uArch");
914
915 auto *dpasInst = const_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc *>(
916 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
917 uArch->getInstruction(
918 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
919 if (!dpasInst)
920 return rewriter.notifyMatchFailure(op,
921 "DPAS not supported by target uArch");
922
923 auto checkSupportedTypes = [&](VectorType vecTy,
924 xegpu::uArch::MMAOpndKind kind) -> bool {
925 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
926 return llvm::find(supported, vecTy.getElementType()) != supported.end();
927 };
928
929 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
930 return rewriter.notifyMatchFailure(
931 op, "A-matrix element type not supported by target uArch");
932 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
933 return rewriter.notifyMatchFailure(
934 op, "B-matrix element type not supported by target uArch");
935 // NOTE: Supported types for MatrixC and MatrixD are identical
936 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
937 return rewriter.notifyMatchFailure(
938 op, "result/accumulator element type not supported by target uArch");
939
940 auto encodePrecision = [&](Type type) -> xevm::ElemType {
941 if (type == rewriter.getBF16Type())
942 return xevm::ElemType::BF16;
943 else if (type == rewriter.getF16Type())
944 return xevm::ElemType::F16;
945 else if (type == rewriter.getTF32Type())
946 return xevm::ElemType::TF32;
947 else if (type.isInteger(8)) {
948 if (type.isUnsignedInteger())
949 return xevm::ElemType::U8;
950 return xevm::ElemType::S8;
951 } else if (type == rewriter.getF32Type())
952 return xevm::ElemType::F32;
953 else if (type.isInteger(32))
954 return xevm::ElemType::S32;
955 llvm_unreachable("add more support for ElemType");
956 };
957 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
958 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
959 Value c = op.getAcc();
960 if (!c) {
961 auto elementTy = resultType.getElementType();
962 Attribute initValueAttr;
963 if (isa<FloatType>(elementTy))
964 initValueAttr = FloatAttr::get(elementTy, 0.0);
965 else
966 initValueAttr = IntegerAttr::get(elementTy, 0);
967 c = arith::ConstantOp::create(
968 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
969 }
970
971 Value aVec = op.getLhs();
972 Value bVec = op.getRhs();
973 auto cvecty = cast<VectorType>(c.getType());
974 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
975 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
976 VectorType cNty =
977 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
978 if (cvecty != cNty)
979 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
980 Value dpasRes = xevm::MMAOp::create(
981 rewriter, loc, cNty, aVec, bVec, c,
982 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
983 systolicDepth *
984 getNumOperandsPerDword(precATy)),
985 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
986 if (cvecty != cNty)
987 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
988 rewriter.replaceOp(op, dpasRes);
989 return success();
990 }
991
992private:
993 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
994 switch (pTy) {
995 case xevm::ElemType::TF32:
996 return 1;
997 case xevm::ElemType::BF16:
998 case xevm::ElemType::F16:
999 return 2;
1000 case xevm::ElemType::U8:
1001 case xevm::ElemType::S8:
1002 return 4;
1003 default:
1004 llvm_unreachable("unsupported xevm::ElemType");
1005 }
1006 }
1007};
1008
1009static std::optional<LLVM::AtomicBinOp>
1010matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1011 switch (arithKind) {
1012 case arith::AtomicRMWKind::addf:
1013 return LLVM::AtomicBinOp::fadd;
1014 case arith::AtomicRMWKind::addi:
1015 return LLVM::AtomicBinOp::add;
1016 case arith::AtomicRMWKind::assign:
1017 return LLVM::AtomicBinOp::xchg;
1018 case arith::AtomicRMWKind::maximumf:
1019 return LLVM::AtomicBinOp::fmax;
1020 case arith::AtomicRMWKind::maxs:
1021 return LLVM::AtomicBinOp::max;
1022 case arith::AtomicRMWKind::maxu:
1023 return LLVM::AtomicBinOp::umax;
1024 case arith::AtomicRMWKind::minimumf:
1025 return LLVM::AtomicBinOp::fmin;
1026 case arith::AtomicRMWKind::mins:
1027 return LLVM::AtomicBinOp::min;
1028 case arith::AtomicRMWKind::minu:
1029 return LLVM::AtomicBinOp::umin;
1030 case arith::AtomicRMWKind::ori:
1031 return LLVM::AtomicBinOp::_or;
1032 case arith::AtomicRMWKind::andi:
1033 return LLVM::AtomicBinOp::_and;
1034 default:
1035 return std::nullopt;
1036 }
1037}
1038
1039class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1040 using OpConversionPattern::OpConversionPattern;
1041 LogicalResult
1042 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1043 ConversionPatternRewriter &rewriter) const override {
1044 auto loc = op.getLoc();
1045 auto ctxt = rewriter.getContext();
1046 auto tdesc = op.getTensorDesc().getType();
1047 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1048 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1049 Value basePtrI64 = arith::IndexCastOp::create(
1050 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1051 Value basePtrLLVM =
1052 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1053 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1054 VectorType srcOrDstFlatVecTy = VectorType::get(
1055 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1056 Value srcFlatVec = vector::ShapeCastOp::create(
1057 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1058 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1059 assert(atomicKind.has_value());
1060 Value resVec = srcFlatVec;
1061 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1062 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1063 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1064 rewriter.getIndexAttr(i));
1065 Value currPtr =
1066 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1067 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1068 Value newVal =
1069 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1070 val, LLVM::AtomicOrdering::seq_cst);
1071 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1072 }
1073 rewriter.replaceOp(op, resVec);
1074 return success();
1075 }
1076};
1077
1078//===----------------------------------------------------------------------===//
1079// Pass Definition
1080//===----------------------------------------------------------------------===//
1081
1082struct ConvertXeGPUToXeVMPass
1083 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1084 using Base::Base;
1085
1086 void runOnOperation() override {
1087 LLVMTypeConverter typeConverter(&getContext());
1088 typeConverter.addConversion([&](VectorType type) -> Type {
1089 unsigned rank = type.getRank();
1090 auto elemType = type.getElementType();
1091 // If the element type is index, convert it to i64.
1092 if (llvm::isa<IndexType>(elemType))
1093 elemType = IntegerType::get(&getContext(), 64);
1094 // If the vector rank is 0 or has a single element, return the element
1095 if (rank == 0 || type.getNumElements() == 1)
1096 return elemType;
1097 // Otherwise, convert the vector to a flat vector type.
1098 int64_t sum = llvm::product_of(type.getShape());
1099 return VectorType::get(sum, elemType);
1100 });
1101 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1102 // Scattered descriptors are not supported in XeVM lowering.
1103 if (type.isScattered())
1104 return {};
1105 if (type.getRank() == 1)
1106 return IntegerType::get(&getContext(), 64);
1107 auto i32Type = IntegerType::get(&getContext(), 32);
1108 return VectorType::get(8, i32Type);
1109 });
1110 // Convert MemDescType into i32 for SLM
1111 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1112 return IntegerType::get(&getContext(), 32);
1113 });
1114
1115 typeConverter.addConversion([&](MemRefType type) -> Type {
1116 return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
1117 });
1118
1119 // LLVM type converter puts unrealized casts for the following cases:
1120 // add materialization casts to handle them.
1121
1122 // Materialization to convert memref to i64 or i32 depending on global/SLM
1123 // Applies only to target materialization.
1124 // Note: int type to memref materialization is not required as xegpu ops
1125 // currently do not produce memrefs as result.
1126 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1127 ValueRange inputs,
1128 Location loc) -> Value {
1129 if (inputs.size() != 1)
1130 return {};
1131 auto input = inputs.front();
1132 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1133 unsigned rank = memrefTy.getRank();
1134 Type indexType = builder.getIndexType();
1135
1136 int64_t intOffsets;
1137 SmallVector<int64_t> intStrides;
1138 Value addr;
1139 Value offset;
1140 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1141 ShapedType::isStatic(intOffsets)) {
1142 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1143 input);
1144 offset = arith::ConstantOp::create(builder, loc,
1145 builder.getIndexAttr(intOffsets));
1146 } else {
1147
1148 // Result types: [base_memref, offset, stride0, stride1, ...,
1149 // strideN-1, size0, size1, ..., sizeN-1]
1150 SmallVector<Type> resultTypes{
1151 MemRefType::get({}, memrefTy.getElementType(),
1152 MemRefLayoutAttrInterface(),
1153 memrefTy.getMemorySpace()),
1154 indexType};
1155 // strides + sizes
1156 resultTypes.append(2 * rank, indexType);
1157
1158 auto meta = memref::ExtractStridedMetadataOp::create(
1159 builder, loc, resultTypes, input);
1160
1161 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1162 builder, loc, meta.getBaseBuffer());
1163 offset = meta.getOffset();
1164 }
1165
1166 auto addrCasted =
1167 arith::IndexCastUIOp::create(builder, loc, type, addr);
1168 auto offsetCasted =
1169 arith::IndexCastUIOp::create(builder, loc, type, offset);
1170
1171 // Compute the final address: base address + byte offset
1172 auto byteSize = arith::ConstantOp::create(
1173 builder, loc, type,
1174 builder.getIntegerAttr(type,
1175 memrefTy.getElementTypeBitWidth() / 8));
1176 auto byteOffset =
1177 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1178 auto addrWithOffset =
1179 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1180
1181 return addrWithOffset.getResult();
1182 }
1183 return {};
1184 };
1185
1186 // Materialization to convert ui64 to i64
1187 // Applies only to target materialization.
1188 // Note: i64 to ui64 materialization is not required as xegpu ops
1189 // currently do not produce ui64 as result.
1190 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1191 ValueRange inputs,
1192 Location loc) -> Value {
1193 if (inputs.size() != 1)
1194 return {};
1195 auto input = inputs.front();
1196 if (input.getType() == builder.getIntegerType(64, false)) {
1197 Value cast =
1198 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1199 .getResult();
1200 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1201 .getResult();
1202 }
1203 return {};
1204 };
1205
1206 // Materialization to convert ui32 to i32
1207 // Applies only to target materialization.
1208 // Note: i32 to ui32 materialization is not required as xegpu ops
1209 // currently do not produce ui32 as result.
1210 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1211 ValueRange inputs,
1212 Location loc) -> Value {
1213 if (inputs.size() != 1)
1214 return {};
1215 auto input = inputs.front();
1216 if (input.getType() == builder.getIntegerType(32, false)) {
1217 Value cast =
1218 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1219 .getResult();
1220 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1221 .getResult();
1222 }
1223 return {};
1224 };
1225
1226 // Materialization to convert
1227 // - bitcast vector of same rank
1228 // - shape vector of different rank but same element type
1229 // Applies to both source and target materialization.
1230 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1231 ValueRange inputs,
1232 Location loc) -> Value {
1233 if (inputs.size() != 1)
1234 return {};
1235 auto input = inputs.front();
1236 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1237 if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1238 // If the target type is a vector of same rank,
1239 // bitcast to the target type.
1240 if (targetVecTy.getRank() == vecTy.getRank())
1241 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1242 .getResult();
1243 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1244 // If the target type is a vector of different rank but same element
1245 // type, reshape to the target type.
1246 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1247 .getResult();
1248 }
1249 }
1250 }
1251 return {};
1252 };
1253
1254 // Materialization to convert
1255 // - single element vector to single element of vector element type
1256 // Applies only to target materialization.
1257 auto vectorToSingleElementMaterializationCast =
1258 [](OpBuilder &builder, Type type, ValueRange inputs,
1259 Location loc) -> Value {
1260 if (inputs.size() != 1)
1261 return {};
1262 auto input = inputs.front();
1263 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1264 if (type == vecTy.getElementType() ||
1265 ((vecTy.getElementType() == builder.getIndexType()) &&
1266 type.isInteger())) {
1267 // If the vector rank is 0 or has a single element,
1268 // extract scalar of target type.
1269 auto rank = vecTy.getRank();
1270 Value cast;
1271 if (rank == 0) {
1272 cast =
1273 vector::ExtractOp::create(builder, loc, input, {}).getResult();
1274 } else {
1275 cast = vector::ExtractOp::create(builder, loc, input,
1276 SmallVector<int64_t>(rank, 0))
1277 .getResult();
1278 }
1279 if (type != vecTy.getElementType())
1280 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1281 .getResult();
1282 return cast;
1283 }
1284 }
1285 return {};
1286 };
1287
1288 // Materialization to convert
1289 // - single element of vector element type to single element vector
1290 // If result type of original op is single element vector and lowered type
1291 // is scalar. This materialization cast creates a single element vector by
1292 // broadcasting the scalar value.
1293 // Applies only to source materialization.
1294 auto singleElementToVectorMaterializationCast =
1295 [](OpBuilder &builder, Type type, ValueRange inputs,
1296 Location loc) -> Value {
1297 if (inputs.size() != 1)
1298 return {};
1299 auto input = inputs.front();
1300 // If the target type is a vector of rank 0 or single element vector
1301 // of element type matching input type, broadcast input to target type.
1302 if (auto vecTy = dyn_cast<VectorType>(type)) {
1303 if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
1304 if (input.getType() == vecTy.getElementType()) {
1305 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1306 .getResult();
1307 } else if (vecTy.getElementType() == builder.getIndexType()) {
1308 Value cast = arith::IndexCastUIOp::create(
1309 builder, loc, builder.getIndexType(), input)
1310 .getResult();
1311 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1312 .getResult();
1313 }
1314 }
1315 }
1316 return {};
1317 };
1318 typeConverter.addSourceMaterialization(
1319 singleElementToVectorMaterializationCast);
1320 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1321 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1322 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1323 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1324 typeConverter.addTargetMaterialization(
1325 vectorToSingleElementMaterializationCast);
1326 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1327 ConversionTarget target(getContext());
1328 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1329 vector::VectorDialect, arith::ArithDialect,
1330 memref::MemRefDialect, gpu::GPUDialect,
1331 index::IndexDialect>();
1332 target.addIllegalDialect<xegpu::XeGPUDialect>();
1333
1334 RewritePatternSet patterns(&getContext());
1335 populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
1337 patterns, target);
1338 if (failed(applyPartialConversion(getOperation(), target,
1339 std::move(patterns))))
1340 signalPassFailure();
1341 }
1342};
1343} // namespace
1344
1345//===----------------------------------------------------------------------===//
1346// Pattern Population
1347//===----------------------------------------------------------------------===//
1349 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1350 patterns.add<CreateNdDescToXeVMPattern,
1351 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1352 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1353 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1354 typeConverter, patterns.getContext());
1355 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1356 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1357 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1358 typeConverter, patterns.getContext());
1359 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1360 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1361 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1362 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1363 patterns.getContext());
1364}
return success()
b getContext())
auto load
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:611
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)