MLIR 22.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
26#include "mlir/Pass/Pass.h"
27#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
30
32#include "mlir/IR/Types.h"
33
34#include "llvm/ADT/TypeSwitch.h"
35
36#include <numeric>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45namespace {
46
47// TODO: Below are uArch dependent values, should move away from hardcoding
48static constexpr int32_t systolicDepth{8};
49static constexpr int32_t executionSize{16};
50
51// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
52enum class NdTdescOffset : uint32_t {
53 BasePtr = 0, // Base pointer (i64)
54 BaseShapeW = 2, // Base shape width (i32)
55 BaseShapeH = 3, // Base shape height (i32)
56 BasePitch = 4, // Base pitch (i32)
57};
58
59static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
60 switch (xeGpuMemspace) {
61 case xegpu::MemorySpace::Global:
62 return static_cast<int>(xevm::AddrSpace::GLOBAL);
63 case xegpu::MemorySpace::SLM:
64 return static_cast<int>(xevm::AddrSpace::SHARED);
65 }
66 llvm_unreachable("Unknown XeGPU memory space");
67}
68
69/// Checks if the given MemRefType refers to shared memory.
70static bool isSharedMemRef(const MemRefType &memrefTy) {
71 Attribute attr = memrefTy.getMemorySpace();
72 if (!attr)
73 return false;
74 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
75 return intAttr.getInt() == static_cast<int>(xevm::AddrSpace::SHARED);
76 if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
77 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
78 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
79}
80
81// Get same bitwidth flat vector type of new element type.
82static VectorType encodeVectorTypeTo(VectorType currentVecType,
83 Type toElemType) {
84 auto elemType = currentVecType.getElementType();
85 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
86 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
87 const int size =
88 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
89 return VectorType::get(size, toElemType);
90}
91
92static xevm::LoadCacheControl
93translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
94 std::optional<xegpu::CachePolicy> L3hint) {
95 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
96 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
97 switch (L1hintVal) {
98 case xegpu::CachePolicy::CACHED:
99 if (L3hintVal == xegpu::CachePolicy::CACHED)
100 return xevm::LoadCacheControl::L1C_L2UC_L3C;
101 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
102 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
103 else
104 llvm_unreachable("Unsupported cache control.");
105 case xegpu::CachePolicy::UNCACHED:
106 if (L3hintVal == xegpu::CachePolicy::CACHED)
107 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
108 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
109 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
110 else
111 llvm_unreachable("Unsupported cache control.");
112 case xegpu::CachePolicy::STREAMING:
113 if (L3hintVal == xegpu::CachePolicy::CACHED)
114 return xevm::LoadCacheControl::L1S_L2UC_L3C;
115 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
116 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
117 else
118 llvm_unreachable("Unsupported cache control.");
119 case xegpu::CachePolicy::READ_INVALIDATE:
120 return xevm::LoadCacheControl::INVALIDATE_READ;
121 default:
122 llvm_unreachable("Unsupported cache control.");
123 }
124}
125
126static xevm::StoreCacheControl
127translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
128 std::optional<xegpu::CachePolicy> L3hint) {
129 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
130 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
131 switch (L1hintVal) {
132 case xegpu::CachePolicy::UNCACHED:
133 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
134 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
135 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
136 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
137 else
138 llvm_unreachable("Unsupported cache control.");
139 case xegpu::CachePolicy::STREAMING:
140 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
141 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
142 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
143 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
144 else
145 llvm_unreachable("Unsupported cache control.");
146 case xegpu::CachePolicy::WRITE_BACK:
147 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
148 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
149 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
150 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
151 else
152 llvm_unreachable("Unsupported cache control.");
153 case xegpu::CachePolicy::WRITE_THROUGH:
154 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
155 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
156 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
157 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
158 else
159 llvm_unreachable("Unsupported cache control.");
160 default:
161 llvm_unreachable("Unsupported cache control.");
162 }
163}
164
165//
166// Note:
167// Block operations for tile of sub byte element types are handled by
168// emulating with larger element types.
169// Tensor descriptor are keep intact and only ops consuming them are
170// emulated
171//
172
173class CreateNdDescToXeVMPattern
174 : public OpConversionPattern<xegpu::CreateNdDescOp> {
175 using OpConversionPattern::OpConversionPattern;
176 LogicalResult
177 matchAndRewrite(xegpu::CreateNdDescOp op,
178 xegpu::CreateNdDescOp::Adaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181 if (mixedOffsets.size() != 0)
182 return rewriter.notifyMatchFailure(op, "Offsets not supported.");
183 auto loc = op.getLoc();
184 auto source = op.getSource();
185 // Op is lowered to a code sequence that populates payload.
186 // Payload is a 8xi32 vector. Offset to individual fields are defined in
187 // NdTdescOffset enum.
188 Type payloadElemTy = rewriter.getI32Type();
189 VectorType payloadTy = VectorType::get(8, payloadElemTy);
190 Type i64Ty = rewriter.getI64Type();
191 // 4xi64 view is used for inserting the base pointer.
192 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
193 // Initialize payload to zero.
194 Value payload = arith::ConstantOp::create(
195 rewriter, loc,
196 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
197
198 Value baseAddr;
199 Value baseShapeW;
200 Value baseShapeH;
201
202 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
203 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
204 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
205 // Descriptor shape is expected to be 2D.
206 int64_t rank = mixedSizes.size();
207 auto sourceTy = source.getType();
208 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
209 // If source is a memref, we need to extract the aligned pointer as index.
210 // Pointer type is passed as i32 or i64 by type converter.
211 if (sourceMemrefTy) {
212 if (!sourceMemrefTy.hasRank()) {
213 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
214 }
215 // Access adaptor after failure check to avoid rolling back generated code
216 // for materialization cast.
217 baseAddr = adaptor.getSource();
218 } else {
219 baseAddr = adaptor.getSource();
220 if (baseAddr.getType() != i64Ty) {
221 // Pointer type may be i32. Cast to i64 if needed.
222 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
223 }
224 }
225 // 1D tensor descriptor is just the base address.
226 if (rank == 1) {
227 rewriter.replaceOp(op, baseAddr);
228 return success();
229 }
230 // Utility for creating offset values from op fold result.
231 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
232 unsigned idx) -> Value {
233 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
234 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
235 return val;
236 };
237 // Get shape values from op fold results.
238 baseShapeW = createOffset(mixedSizes, 1);
239 baseShapeH = createOffset(mixedSizes, 0);
240 // Get pitch value from op fold results.
241 Value basePitch = createOffset(mixedStrides, 0);
242 // Populate payload.
243 Value payLoadAsI64 =
244 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
245 payLoadAsI64 =
246 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
247 static_cast<int>(NdTdescOffset::BasePtr));
248 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
249 payload =
250 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
251 static_cast<int>(NdTdescOffset::BaseShapeW));
252 payload =
253 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
254 static_cast<int>(NdTdescOffset::BaseShapeH));
255 payload =
256 vector::InsertOp::create(rewriter, loc, basePitch, payload,
257 static_cast<int>(NdTdescOffset::BasePitch));
258 rewriter.replaceOp(op, payload);
259 return success();
260 }
261};
262
263template <
264 typename OpType,
265 typename = std::enable_if_t<llvm::is_one_of<
266 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
267class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
268 using OpConversionPattern<OpType>::OpConversionPattern;
269 LogicalResult
270 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
271 ConversionPatternRewriter &rewriter) const override {
272 auto mixedOffsets = op.getMixedOffsets();
273 int64_t opOffsetsSize = mixedOffsets.size();
274 auto loc = op.getLoc();
275 auto ctxt = rewriter.getContext();
276
277 auto tdesc = adaptor.getTensorDesc();
278 auto tdescTy = op.getTensorDescType();
279 auto tileRank = tdescTy.getRank();
280 if (opOffsetsSize != tileRank)
281 return rewriter.notifyMatchFailure(
282 op, "Expected offset rank to match descriptor rank.");
283 auto elemType = tdescTy.getElementType();
284 auto elemBitSize = elemType.getIntOrFloatBitWidth();
285 bool isSubByte = elemBitSize < 8;
286 uint64_t wScaleFactor = 1;
287
288 if (!isSubByte && (elemBitSize % 8 != 0))
289 return rewriter.notifyMatchFailure(
290 op, "Expected element type bit width to be multiple of 8.");
291 auto tileW = tdescTy.getDimSize(tileRank - 1);
292 // For sub byte types, only 4bits are currently supported.
293 if (isSubByte) {
294 if (elemBitSize != 4)
295 return rewriter.notifyMatchFailure(
296 op, "Only sub byte types of 4bits are supported.");
297 if (tileRank != 2)
298 return rewriter.notifyMatchFailure(
299 op, "Sub byte types are only supported for 2D tensor descriptors.");
300 auto subByteFactor = 8 / elemBitSize;
301 auto tileH = tdescTy.getDimSize(0);
302 // Handle special case for packed load.
303 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
304 if (op.getPacked().value_or(false)) {
305 // packed load is implemented as packed loads of 8bit elements.
306 if (tileH == systolicDepth * 4 &&
307 tileW == executionSize * subByteFactor) {
308 // Usage case for loading as Matrix B with pack request.
309 // source is assumed to pre-packed into 8bit elements
310 // Emulate with 8bit loads with pack request.
311 // scaled_tileW = executionSize
312 elemType = rewriter.getIntegerType(8);
313 tileW = executionSize;
314 wScaleFactor = subByteFactor;
315 }
316 }
317 }
318 // If not handled by packed load case above, handle other cases.
319 if (wScaleFactor == 1) {
320 auto sub16BitFactor = subByteFactor * 2;
321 if (tileW == executionSize * sub16BitFactor) {
322 // Usage case for loading as Matrix A operand
323 // Emulate with 16bit loads/stores.
324 // scaled_tileW = executionSize
325 elemType = rewriter.getIntegerType(16);
326 tileW = executionSize;
327 wScaleFactor = sub16BitFactor;
328 } else {
329 return rewriter.notifyMatchFailure(
330 op, "Unsupported tile shape for sub byte types.");
331 }
332 }
333 // recompute element bit size for emulation.
334 elemBitSize = elemType.getIntOrFloatBitWidth();
335 }
336
337 // Get address space from tensor descriptor memory space.
338 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
339 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
340 if (tileRank == 2) {
341 // Compute element byte size.
342 Value elemByteSize = arith::ConstantIntOp::create(
343 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
344 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
345 Value payLoadAsI64 =
346 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
347 Value basePtr =
348 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
349 static_cast<int>(NdTdescOffset::BasePtr));
350 Value baseShapeW = vector::ExtractOp::create(
351 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
352 Value baseShapeH = vector::ExtractOp::create(
353 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
354 Value basePitch = vector::ExtractOp::create(
355 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
356 // Offsets are provided by the op.
357 // convert them to i32.
358 Value offsetW =
359 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
360 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
361 rewriter.getI32Type(), offsetW);
362 Value offsetH =
363 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
364 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
365 rewriter.getI32Type(), offsetH);
366 // Convert base pointer (i64) to LLVM pointer type.
367 Value basePtrLLVM =
368 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
369 // FIXME: width or pitch is not the same as baseShapeW it should be the
370 // stride of the second to last dimension in row major layout.
371 // Compute width in bytes.
372 Value baseShapeWInBytes =
373 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
374 // Compute pitch in bytes.
375 Value basePitchBytes =
376 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
377
378 if (wScaleFactor > 1) {
379 // Scale offsetW, baseShapeWInBytes for sub byte emulation.
380 // Note: tileW is already scaled above.
381 Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
382 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
383 baseShapeWInBytes = arith::ShRSIOp::create(
384 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
385 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
386 wScaleFactorValLog2);
387 offsetW =
388 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
389 }
390 // Get tile height from the tensor descriptor type.
391 auto tileH = tdescTy.getDimSize(0);
392 // Get vblocks from the tensor descriptor type.
393 int32_t vblocks = tdescTy.getArrayLength();
394 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
395 Value src = adaptor.getValue();
396 // If store value is a scalar, get value from op instead of adaptor.
397 // Adaptor might have optimized away single element vector
398 if (src.getType().isIntOrFloat()) {
399 src = op.getValue();
400 }
401 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
402 if (!srcVecTy)
403 return rewriter.notifyMatchFailure(
404 op, "Expected store value to be a vector type.");
405 // Get flat vector type of integer type with matching element bit size.
406 VectorType newSrcVecTy =
407 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
408 if (srcVecTy != newSrcVecTy)
409 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
410 auto storeCacheControl =
411 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
412 xevm::BlockStore2dOp::create(
413 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
414 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
415 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
416 rewriter.eraseOp(op);
417 } else {
418 auto loadCacheControl =
419 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
420 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
421 xevm::BlockPrefetch2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
424 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
425 rewriter.eraseOp(op);
426 } else {
427 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
428 const bool vnni = op.getPacked().value_or(false);
429 auto transposeValue = op.getTranspose();
430 bool transpose =
431 transposeValue.has_value() && transposeValue.value()[0] == 1;
432 VectorType loadedTy = encodeVectorTypeTo(
433 dstVecTy, vnni ? rewriter.getI32Type()
434 : rewriter.getIntegerType(elemBitSize));
435
436 Value resultFlatVec = xevm::BlockLoad2dOp::create(
437 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
438 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
439 tileH, vblocks, transpose, vnni,
440 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
441 resultFlatVec = vector::BitCastOp::create(
442 rewriter, loc,
443 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
444 resultFlatVec);
445 rewriter.replaceOp(op, resultFlatVec);
446 }
447 }
448 } else {
449 // 1D tensor descriptor.
450 // `tdesc` represents base address as i64
451 // Offset in number of elements, need to multiply by element byte size.
452 // Compute byte offset.
453 // byteOffset = offset * elementByteSize
454 Value offset =
455 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
456 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
457 rewriter.getI64Type(), offset);
458 // Compute element byte size.
459 Value elemByteSize = arith::ConstantIntOp::create(
460 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
461 Value byteOffset =
462 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
463 // Final address = basePtr + byteOffset
464 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
465 loc, tdesc,
466 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
467 byteOffset));
468 // Convert base pointer (i64) to LLVM pointer type.
469 Value finalPtrLLVM =
470 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
471 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
472 Value src = adaptor.getValue();
473 // If store value is a scalar, get value from op instead of adaptor.
474 // Adaptor might have optimized away single element vector
475 if (src.getType().isIntOrFloat()) {
476 src = op.getValue();
477 }
478 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
479 if (!srcVecTy)
480 return rewriter.notifyMatchFailure(
481 op, "Expected store value to be a vector type.");
482 // Get flat vector type of integer type with matching element bit size.
483 VectorType newSrcVecTy =
484 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
485 if (srcVecTy != newSrcVecTy)
486 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
487 auto storeCacheControl =
488 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
489 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
490 op, finalPtrLLVM, src,
491 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
492 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
493 auto loadCacheControl =
494 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 VectorType resTy = cast<VectorType>(op.getValue().getType());
496 VectorType loadedTy =
497 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
498 Value load = xevm::BlockLoadOp::create(
499 rewriter, loc, loadedTy, finalPtrLLVM,
500 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
501 if (loadedTy != resTy)
502 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
503 rewriter.replaceOp(op, load);
504 } else {
505 return rewriter.notifyMatchFailure(
506 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
507 "descriptor rank == 1");
508 }
509 }
510 return success();
511 }
512};
513
514// Add a builder that creates
515// offset * elemByteSize + baseAddr
516static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
517 Location loc, Value baseAddr, Value offset,
518 int64_t elemByteSize) {
520 rewriter, loc, baseAddr.getType(), elemByteSize);
521 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
522 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
523 return newAddr;
524}
525
526template <typename OpType,
527 typename = std::enable_if_t<llvm::is_one_of<
528 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
529class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
530 using OpConversionPattern<OpType>::OpConversionPattern;
531 LogicalResult
532 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
533 ConversionPatternRewriter &rewriter) const override {
534 Value offset = adaptor.getOffsets();
535 if (!offset)
536 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
537 auto loc = op.getLoc();
538 auto ctxt = rewriter.getContext();
539 auto tdescTy = op.getTensorDescType();
540 Value basePtrI64;
541 // Load result or Store valye Type can be vector or scalar.
542 Type valOrResTy;
543 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
544 valOrResTy =
545 this->getTypeConverter()->convertType(op.getResult().getType());
546 else
547 valOrResTy = adaptor.getValue().getType();
548 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
549 bool hasScalarVal = !valOrResVecTy;
550 int64_t elemBitWidth =
551 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
552 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
553 // Element type must be multiple of 8 bits.
554 if (elemBitWidth % 8 != 0)
555 return rewriter.notifyMatchFailure(
556 op, "Expected element type bit width to be multiple of 8.");
557 int64_t elemByteSize = elemBitWidth / 8;
558 // Default memory space is global.
559 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
560 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
561 // If tensor descriptor is available, we use its memory space.
562 if (tdescTy)
563 ptrTypeLLVM = LLVM::LLVMPointerType::get(
564 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
565 // Base pointer can come from source (load) or dest (store).
566 // If they are memrefs, we use their memory space.
567 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
568 basePtrI64 = adaptor.getSource();
569 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
570 auto addrSpace = memRefTy.getMemorySpaceAsInt();
571 if (addrSpace != 0)
572 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
573 }
574 } else {
575 basePtrI64 = adaptor.getDest();
576 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
577 auto addrSpace = memRefTy.getMemorySpaceAsInt();
578 if (addrSpace != 0)
579 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
580 }
581 }
582 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
583 if (basePtrI64.getType() != rewriter.getI64Type()) {
584 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
585 basePtrI64);
586 }
587 Value mask = adaptor.getMask();
588 if (dyn_cast<VectorType>(offset.getType())) {
589 // Offset needs be scalar. Single element vector is converted to scalar
590 // by type converter.
591 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
592 } else {
593 // If offset is provided, we add them to the base pointer.
594 // Offset is in number of elements, we need to multiply by
595 // element byte size.
596 basePtrI64 =
597 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
598 }
599 // Convert base pointer (i64) to LLVM pointer type.
600 Value basePtrLLVM =
601 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
602
603 Value maskForLane;
604 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
605 if (maskVecTy) {
606 // Mask needs be scalar. Single element vector is converted to scalar by
607 // type converter.
608 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
609 } else
610 maskForLane = mask;
611 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
612 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
613 maskForLane, true, true);
614 // If mask is true,- then clause - load from memory and yield.
615 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
616 if (!hasScalarVal)
617 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
618 valOrResVecTy.getElementType());
619 Value loaded =
620 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
621 // Set cache control attribute on the load operation.
622 loaded.getDefiningOp()->setAttr(
623 "cache_control", xevm::LoadCacheControlAttr::get(
624 ctxt, translateLoadXeGPUCacheHint(
625 op.getL1Hint(), op.getL3Hint())));
626 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
627 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
628 // If mask is false - else clause -yield a vector of zeros.
629 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
630 TypedAttr eVal;
631 if (eTy.isFloat())
632 eVal = FloatAttr::get(eTy, 0.0);
633 else
634 eVal = IntegerAttr::get(eTy, 0);
635 if (hasScalarVal)
636 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
637 else
638 loaded = arith::ConstantOp::create(
639 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
640 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
641 rewriter.replaceOp(op, ifOp.getResult(0));
642 } else {
643 // If mask is true, perform the store.
644 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
645 auto body = ifOp.getBody();
646 rewriter.setInsertionPointToStart(body);
647 auto storeOp =
648 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
649 // Set cache control attribute on the store operation.
650 storeOp.getOperation()->setAttr(
651 "cache_control", xevm::StoreCacheControlAttr::get(
652 ctxt, translateStoreXeGPUCacheHint(
653 op.getL1Hint(), op.getL3Hint())));
654 rewriter.eraseOp(op);
655 }
656 return success();
657 }
658};
659
660class CreateMemDescOpPattern final
661 : public OpConversionPattern<xegpu::CreateMemDescOp> {
662public:
663 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
664 LogicalResult
665 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const override {
667
668 rewriter.replaceOp(op, adaptor.getSource());
669 return success();
670 }
671};
672
673template <typename OpType,
674 typename = std::enable_if_t<llvm::is_one_of<
675 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
676class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
677 using OpConversionPattern<OpType>::OpConversionPattern;
678 LogicalResult
679 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
680 ConversionPatternRewriter &rewriter) const override {
681
682 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
683 if (offsets.empty())
684 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
685
686 auto loc = op.getLoc();
687 auto ctxt = rewriter.getContext();
688 Value baseAddr32 = adaptor.getMemDesc();
689 Value mdescVal = op.getMemDesc();
690 // Load result or Store value Type can be vector or scalar.
691 Type dataTy;
692 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
693 Type resType = op.getResult().getType();
694 // Some transforms may leave unit dimension in the 2D vector, adaptors do
695 // not catch it for results.
696 if (auto vecType = dyn_cast<VectorType>(resType)) {
697 assert(llvm::count_if(vecType.getShape(),
698 [](int64_t d) { return d != 1; }) <= 1 &&
699 "Expected either 1D vector or nD with unit dimensions");
700 resType = VectorType::get({vecType.getNumElements()},
701 vecType.getElementType());
702 }
703 dataTy = resType;
704 } else
705 dataTy = adaptor.getData().getType();
706 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
707 if (!valOrResVecTy)
708 valOrResVecTy = VectorType::get(1, dataTy);
709
710 int64_t elemBitWidth =
711 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
712 // Element type must be multiple of 8 bits.
713 if (elemBitWidth % 8 != 0)
714 return rewriter.notifyMatchFailure(
715 op, "Expected element type bit width to be multiple of 8.");
716 int64_t elemByteSize = elemBitWidth / 8;
717
718 // Default memory space is SLM.
719 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
720 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
721
722 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
723
724 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
725 linearOffset = arith::IndexCastUIOp::create(
726 rewriter, loc, rewriter.getI32Type(), linearOffset);
727 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
728 linearOffset, elemByteSize);
729
730 // convert base pointer (i32) to LLVM pointer type
731 Value basePtrLLVM =
732 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
733
734 if (op.getSubgroupBlockIoAttr()) {
735 // if the attribute 'subgroup_block_io' is set to true, it lowers to
736 // xevm.blockload
737
738 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
739 VectorType intVecTy =
740 VectorType::get(valOrResVecTy.getShape(), intElemTy);
741
742 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
743 Value loadOp =
744 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
745 if (intVecTy != valOrResVecTy) {
746 loadOp =
747 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
748 }
749 rewriter.replaceOp(op, loadOp);
750 } else {
751 Value dataToStore = adaptor.getData();
752 if (valOrResVecTy != intVecTy) {
753 dataToStore =
754 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
755 }
756 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
757 nullptr);
758 rewriter.eraseOp(op);
759 }
760 return success();
761 }
762
763 if (valOrResVecTy.getNumElements() >= 1) {
764 auto chipOpt = xegpu::getChipStr(op);
765 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
766 // the lowering for chunk load only works for pvc and bmg
767 return rewriter.notifyMatchFailure(
768 op, "The lowering is specific to pvc or bmg.");
769 }
770 }
771
772 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
773 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
774 // operation. LLVM load/store does not support vector of size 1, so we
775 // need to handle this case separately.
776 auto scalarTy = valOrResVecTy.getElementType();
777 LLVM::LoadOp loadOp;
778 if (valOrResVecTy.getNumElements() == 1)
779 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
780 else
781 loadOp =
782 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
783 rewriter.replaceOp(op, loadOp);
784 } else {
785 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
786 rewriter.eraseOp(op);
787 }
788 return success();
789 }
790};
791
792class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
793 using OpConversionPattern::OpConversionPattern;
794 LogicalResult
795 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter) const override {
797 auto loc = op.getLoc();
798 auto ctxt = rewriter.getContext();
799 auto tdescTy = op.getTensorDescType();
800 Value basePtrI64 = adaptor.getSource();
801 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
802 if (basePtrI64.getType() != rewriter.getI64Type())
803 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
804 basePtrI64);
805 Value offsets = adaptor.getOffsets();
806 if (offsets) {
807 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
808 if (offsetsVecTy) {
809 // Offset needs be scalar.
810 return rewriter.notifyMatchFailure(op,
811 "Expected offsets to be a scalar.");
812 } else {
813 int64_t elemBitWidth{0};
814 int64_t elemByteSize;
815 // Element byte size can come from three sources:
816 if (tdescTy) {
817 // If tensor descriptor is available, we use its element type to
818 // determine element byte size.
819 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
820 } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
821 // If memref is available, we use its element type to
822 // determine element byte size.
823 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
824 } else {
825 // Otherwise, we use the provided offset byte alignment.
826 elemByteSize = *op.getOffsetAlignByte();
827 }
828 if (elemBitWidth != 0) {
829 if (elemBitWidth % 8 != 0)
830 return rewriter.notifyMatchFailure(
831 op, "Expected element type bit width to be multiple of 8.");
832 elemByteSize = elemBitWidth / 8;
833 }
834 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
835 elemByteSize);
836 }
837 }
838 // Default memory space is global.
839 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
840 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
841 // If tensor descriptor is available, we use its memory space.
842 if (tdescTy)
843 ptrTypeLLVM = LLVM::LLVMPointerType::get(
844 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
845 // If source is a memref, we use its memory space.
846 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
847 auto addrSpace = memRefTy.getMemorySpaceAsInt();
848 if (addrSpace != 0)
849 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
850 }
851 // Convert base pointer (i64) to LLVM pointer type.
852 Value ptrLLVM =
853 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
854 // Create the prefetch op with cache control attribute.
855 xevm::PrefetchOp::create(
856 rewriter, loc, ptrLLVM,
857 xevm::LoadCacheControlAttr::get(
858 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
859 rewriter.eraseOp(op);
860 return success();
861 }
862};
863
864class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
865 using OpConversionPattern::OpConversionPattern;
866 LogicalResult
867 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
868 ConversionPatternRewriter &rewriter) const override {
869 auto loc = op.getLoc();
870 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
871 switch (op.getFenceScope()) {
872 case xegpu::FenceScope::Workgroup:
873 memScope = xevm::MemScope::WORKGROUP;
874 break;
875 case xegpu::FenceScope::GPU:
876 memScope = xevm::MemScope::DEVICE;
877 break;
878 }
879 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
880 switch (op.getMemoryKind()) {
881 case xegpu::MemorySpace::Global:
882 addrSpace = xevm::AddrSpace::GLOBAL;
883 break;
884 case xegpu::MemorySpace::SLM:
885 addrSpace = xevm::AddrSpace::SHARED;
886 break;
887 }
888 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
889 rewriter.eraseOp(op);
890 return success();
891 }
892};
893
894class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
895 using OpConversionPattern::OpConversionPattern;
896 LogicalResult
897 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
898 ConversionPatternRewriter &rewriter) const override {
899 auto loc = op.getLoc();
900 auto ctxt = rewriter.getContext();
901 auto aTy = cast<VectorType>(op.getLhs().getType());
902 auto bTy = cast<VectorType>(op.getRhs().getType());
903 auto resultType = cast<VectorType>(op.getResultType());
904
905 auto encodePrecision = [&](Type type) -> xevm::ElemType {
906 if (type == rewriter.getBF16Type())
907 return xevm::ElemType::BF16;
908 else if (type == rewriter.getF16Type())
909 return xevm::ElemType::F16;
910 else if (type == rewriter.getTF32Type())
911 return xevm::ElemType::TF32;
912 else if (type.isInteger(8)) {
913 if (type.isUnsignedInteger())
914 return xevm::ElemType::U8;
915 return xevm::ElemType::S8;
916 } else if (type == rewriter.getF32Type())
917 return xevm::ElemType::F32;
918 else if (type.isInteger(32))
919 return xevm::ElemType::S32;
920 llvm_unreachable("add more support for ElemType");
921 };
922 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
923 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
924 Value c = op.getAcc();
925 if (!c) {
926 auto elementTy = resultType.getElementType();
927 Attribute initValueAttr;
928 if (isa<FloatType>(elementTy))
929 initValueAttr = FloatAttr::get(elementTy, 0.0);
930 else
931 initValueAttr = IntegerAttr::get(elementTy, 0);
932 c = arith::ConstantOp::create(
933 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
934 }
935
936 Value aVec = op.getLhs();
937 Value bVec = op.getRhs();
938 auto cvecty = cast<VectorType>(c.getType());
939 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
940 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
941 VectorType cNty =
942 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
943 if (cvecty != cNty)
944 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
945 Value dpasRes = xevm::MMAOp::create(
946 rewriter, loc, cNty, aVec, bVec, c,
947 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
948 systolicDepth *
949 getNumOperandsPerDword(precATy)),
950 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
951 if (cvecty != cNty)
952 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
953 rewriter.replaceOp(op, dpasRes);
954 return success();
955 }
956
957private:
958 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
959 switch (pTy) {
960 case xevm::ElemType::TF32:
961 return 1;
962 case xevm::ElemType::BF16:
963 case xevm::ElemType::F16:
964 return 2;
965 case xevm::ElemType::U8:
966 case xevm::ElemType::S8:
967 return 4;
968 default:
969 llvm_unreachable("unsupported xevm::ElemType");
970 }
971 }
972};
973
974static std::optional<LLVM::AtomicBinOp>
975matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
976 switch (arithKind) {
977 case arith::AtomicRMWKind::addf:
978 return LLVM::AtomicBinOp::fadd;
979 case arith::AtomicRMWKind::addi:
980 return LLVM::AtomicBinOp::add;
981 case arith::AtomicRMWKind::assign:
982 return LLVM::AtomicBinOp::xchg;
983 case arith::AtomicRMWKind::maximumf:
984 return LLVM::AtomicBinOp::fmax;
985 case arith::AtomicRMWKind::maxs:
986 return LLVM::AtomicBinOp::max;
987 case arith::AtomicRMWKind::maxu:
988 return LLVM::AtomicBinOp::umax;
989 case arith::AtomicRMWKind::minimumf:
990 return LLVM::AtomicBinOp::fmin;
991 case arith::AtomicRMWKind::mins:
992 return LLVM::AtomicBinOp::min;
993 case arith::AtomicRMWKind::minu:
994 return LLVM::AtomicBinOp::umin;
995 case arith::AtomicRMWKind::ori:
996 return LLVM::AtomicBinOp::_or;
997 case arith::AtomicRMWKind::andi:
998 return LLVM::AtomicBinOp::_and;
999 default:
1000 return std::nullopt;
1001 }
1002}
1003
1004class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
1005 using OpConversionPattern::OpConversionPattern;
1006 LogicalResult
1007 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1008 ConversionPatternRewriter &rewriter) const override {
1009 auto loc = op.getLoc();
1010 auto ctxt = rewriter.getContext();
1011 auto tdesc = op.getTensorDesc().getType();
1012 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1013 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1014 Value basePtrI64 = arith::IndexCastOp::create(
1015 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1016 Value basePtrLLVM =
1017 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1018 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1019 VectorType srcOrDstFlatVecTy = VectorType::get(
1020 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1021 Value srcFlatVec = vector::ShapeCastOp::create(
1022 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1023 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1024 assert(atomicKind.has_value());
1025 Value resVec = srcFlatVec;
1026 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1027 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1028 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1029 rewriter.getIndexAttr(i));
1030 Value currPtr =
1031 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1032 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1033 Value newVal =
1034 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1035 val, LLVM::AtomicOrdering::seq_cst);
1036 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1037 }
1038 rewriter.replaceOp(op, resVec);
1039 return success();
1040 }
1041};
1042
1043//===----------------------------------------------------------------------===//
1044// Pass Definition
1045//===----------------------------------------------------------------------===//
1046
1047struct ConvertXeGPUToXeVMPass
1048 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1049 using Base::Base;
1050
1051 void runOnOperation() override {
1052 LLVMTypeConverter typeConverter(&getContext());
1053 typeConverter.addConversion([&](VectorType type) -> Type {
1054 unsigned rank = type.getRank();
1055 auto elemType = type.getElementType();
1056 // If the element type is index, convert it to i64.
1057 if (llvm::isa<IndexType>(elemType))
1058 elemType = IntegerType::get(&getContext(), 64);
1059 // If the vector is a scalar or has a single element, return the element
1060 if (rank < 1 || type.getNumElements() == 1)
1061 return elemType;
1062 // Otherwise, convert the vector to a flat vector type.
1063 int64_t sum = llvm::product_of(type.getShape());
1064 return VectorType::get(sum, elemType);
1065 });
1066 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1067 // Scattered descriptors are not supported in XeVM lowering.
1068 if (type.isScattered())
1069 return {};
1070 if (type.getRank() == 1)
1071 return IntegerType::get(&getContext(), 64);
1072 auto i32Type = IntegerType::get(&getContext(), 32);
1073 return VectorType::get(8, i32Type);
1074 });
1075 // Convert MemDescType into i32 for SLM
1076 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1077 return IntegerType::get(&getContext(), 32);
1078 });
1079
1080 typeConverter.addConversion([&](MemRefType type) -> Type {
1081 return IntegerType::get(&getContext(), (isSharedMemRef(type) ? 32 : 64));
1082 });
1083
1084 // LLVM type converter puts unrealized casts for the following cases:
1085 // add materialization casts to handle them.
1086
1087 // Materialization to convert memref to i64 or i32 depending on global/SLM
1088 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
1089 ValueRange inputs,
1090 Location loc) -> Value {
1091 if (inputs.size() != 1)
1092 return {};
1093 auto input = inputs.front();
1094 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1095 unsigned rank = memrefTy.getRank();
1096 Type indexType = builder.getIndexType();
1097
1098 int64_t intOffsets;
1099 SmallVector<int64_t> intStrides;
1100 Value addr;
1101 Value offset;
1102 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1103 ShapedType::isStatic(intOffsets)) {
1104 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1105 input);
1106 offset = arith::ConstantOp::create(builder, loc,
1107 builder.getIndexAttr(intOffsets));
1108 } else {
1109
1110 // Result types: [base_memref, offset, stride0, stride1, ...,
1111 // strideN-1, size0, size1, ..., sizeN-1]
1112 SmallVector<Type> resultTypes{
1113 MemRefType::get({}, memrefTy.getElementType(),
1114 MemRefLayoutAttrInterface(),
1115 memrefTy.getMemorySpace()),
1116 indexType};
1117 // strides + sizes
1118 resultTypes.append(2 * rank, indexType);
1119
1120 auto meta = memref::ExtractStridedMetadataOp::create(
1121 builder, loc, resultTypes, input);
1122
1123 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1124 builder, loc, meta.getBaseBuffer());
1125 offset = meta.getOffset();
1126 }
1127
1128 auto addrCasted =
1129 arith::IndexCastUIOp::create(builder, loc, type, addr);
1130 auto offsetCasted =
1131 arith::IndexCastUIOp::create(builder, loc, type, offset);
1132
1133 // Compute the final address: base address + byte offset
1134 auto byteSize = arith::ConstantOp::create(
1135 builder, loc, type,
1136 builder.getIntegerAttr(type,
1137 memrefTy.getElementTypeBitWidth() / 8));
1138 auto byteOffset =
1139 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1140 auto addrWithOffset =
1141 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1142
1143 return addrWithOffset.getResult();
1144 }
1145 return {};
1146 };
1147
1148 // Materialization to convert ui64 to i64
1149 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
1150 ValueRange inputs,
1151 Location loc) -> Value {
1152 if (inputs.size() != 1)
1153 return {};
1154 auto input = inputs.front();
1155 if (input.getType() == builder.getIntegerType(64, false)) {
1156 Value cast =
1157 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1158 .getResult();
1159 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1160 .getResult();
1161 }
1162 return {};
1163 };
1164
1165 // Materialization to convert ui32 to i32
1166 auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
1167 ValueRange inputs,
1168 Location loc) -> Value {
1169 if (inputs.size() != 1)
1170 return {};
1171 auto input = inputs.front();
1172 if (input.getType() == builder.getIntegerType(32, false)) {
1173 Value cast =
1174 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1175 .getResult();
1176 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1177 .getResult();
1178 }
1179 return {};
1180 };
1181
1182 // Materialization to convert
1183 // - single element 1D vector to scalar
1184 // - bitcast vector of same rank
1185 // - shape vector of different rank but same element type
1186 auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
1187 ValueRange inputs,
1188 Location loc) -> Value {
1189 if (inputs.size() != 1)
1190 return {};
1191 auto input = inputs.front();
1192 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1193 if (vecTy.getNumElements() == 1) {
1194 // If the vector has a single element, return the element type.
1195 Value cast =
1196 vector::ExtractOp::create(builder, loc, input, 0).getResult();
1197 if (vecTy.getElementType() == builder.getIndexType())
1198 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1199 .getResult();
1200 return cast;
1201 } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1202 // If the target type is a vector of same rank,
1203 // bitcast to the target type.
1204 if (targetVecTy.getRank() == vecTy.getRank())
1205 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1206 .getResult();
1207 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1208 // If the target type is a vector of different rank but same element
1209 // type, reshape to the target type.
1210 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1211 .getResult();
1212 }
1213 }
1214 }
1215 return {};
1216 };
1217
1218 // If result type of original op is single element vector and lowered type
1219 // is scalar. This materialization cast creates a single element vector by
1220 // broadcasting the scalar value.
1221 auto singleElementVectorMaterializationCast =
1222 [](OpBuilder &builder, Type type, ValueRange inputs,
1223 Location loc) -> Value {
1224 if (inputs.size() != 1)
1225 return {};
1226 auto input = inputs.front();
1227 if (input.getType().isIntOrIndexOrFloat()) {
1228 // If the input is a scalar, and the target type is a vector of single
1229 // element, create a single element vector by broadcasting.
1230 if (auto vecTy = dyn_cast<VectorType>(type)) {
1231 if (vecTy.getNumElements() == 1) {
1232 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1233 .getResult();
1234 }
1235 }
1236 }
1237 return {};
1238 };
1239 typeConverter.addSourceMaterialization(
1240 singleElementVectorMaterializationCast);
1241 typeConverter.addSourceMaterialization(vectorMaterializationCast);
1242 typeConverter.addTargetMaterialization(memrefMaterializationCast);
1243 typeConverter.addTargetMaterialization(ui32MaterializationCast);
1244 typeConverter.addTargetMaterialization(ui64MaterializationCast);
1245 typeConverter.addTargetMaterialization(vectorMaterializationCast);
1246 ConversionTarget target(getContext());
1247 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1248 vector::VectorDialect, arith::ArithDialect,
1249 memref::MemRefDialect, gpu::GPUDialect,
1250 index::IndexDialect>();
1251 target.addIllegalDialect<xegpu::XeGPUDialect>();
1252
1253 RewritePatternSet patterns(&getContext());
1256 patterns, target);
1257 if (failed(applyPartialConversion(getOperation(), target,
1258 std::move(patterns))))
1259 signalPassFailure();
1260 }
1261};
1262} // namespace
1263
1264//===----------------------------------------------------------------------===//
1265// Pattern Population
1266//===----------------------------------------------------------------------===//
1268 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1269 patterns.add<CreateNdDescToXeVMPattern,
1270 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1271 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1272 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1273 typeConverter, patterns.getContext());
1274 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1275 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1276 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1277 typeConverter, patterns.getContext());
1278 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1279 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1280 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1281 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1282 patterns.getContext());
1283}
return success()
b getContext())
auto load
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)