MLIR 22.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12
26#include "mlir/Pass/Pass.h"
27#include "mlir/Support/LLVM.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
30
32#include "mlir/IR/Types.h"
33
34#include "llvm/ADT/TypeSwitch.h"
35
36#include <numeric>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45namespace {
46
47// TODO: Below are uArch dependent values, should move away from hardcoding
48static constexpr int32_t systolicDepth{8};
49static constexpr int32_t executionSize{16};
50
51// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
52enum class NdTdescOffset : uint32_t {
53 BasePtr = 0, // Base pointer (i64)
54 BaseShapeW = 2, // Base shape width (i32)
55 BaseShapeH = 3, // Base shape height (i32)
56 TensorOffsetW = 4, // Tensor offset W (i32)
57 TensorOffsetH = 5 // Tensor offset H (i32)
58};
59
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
66 }
67 llvm_unreachable("Unknown XeGPU memory space");
68}
69
70// Get same bitwidth flat vector type of new element type.
71static VectorType encodeVectorTypeTo(VectorType currentVecType,
72 Type toElemType) {
73 auto elemType = currentVecType.getElementType();
74 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
75 auto newBitWidth = toElemType.getIntOrFloatBitWidth();
76 const int size =
77 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
78 return VectorType::get(size, toElemType);
79}
80
81static xevm::LoadCacheControl
82translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
83 std::optional<xegpu::CachePolicy> L3hint) {
84 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
85 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
86 switch (L1hintVal) {
87 case xegpu::CachePolicy::CACHED:
88 if (L3hintVal == xegpu::CachePolicy::CACHED)
89 return xevm::LoadCacheControl::L1C_L2UC_L3C;
90 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
91 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
92 else
93 llvm_unreachable("Unsupported cache control.");
94 case xegpu::CachePolicy::UNCACHED:
95 if (L3hintVal == xegpu::CachePolicy::CACHED)
96 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
97 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
98 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
99 else
100 llvm_unreachable("Unsupported cache control.");
101 case xegpu::CachePolicy::STREAMING:
102 if (L3hintVal == xegpu::CachePolicy::CACHED)
103 return xevm::LoadCacheControl::L1S_L2UC_L3C;
104 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
105 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
106 else
107 llvm_unreachable("Unsupported cache control.");
108 case xegpu::CachePolicy::READ_INVALIDATE:
109 return xevm::LoadCacheControl::INVALIDATE_READ;
110 default:
111 llvm_unreachable("Unsupported cache control.");
112 }
113}
114
115static xevm::StoreCacheControl
116translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
117 std::optional<xegpu::CachePolicy> L3hint) {
118 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
119 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
120 switch (L1hintVal) {
121 case xegpu::CachePolicy::UNCACHED:
122 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
123 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
124 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
125 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
126 else
127 llvm_unreachable("Unsupported cache control.");
128 case xegpu::CachePolicy::STREAMING:
129 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
130 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
131 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
132 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
133 else
134 llvm_unreachable("Unsupported cache control.");
135 case xegpu::CachePolicy::WRITE_BACK:
136 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
137 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
138 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
139 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
140 else
141 llvm_unreachable("Unsupported cache control.");
142 case xegpu::CachePolicy::WRITE_THROUGH:
143 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
144 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
145 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
146 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
147 else
148 llvm_unreachable("Unsupported cache control.");
149 default:
150 llvm_unreachable("Unsupported cache control.");
151 }
152}
153
154class CreateNdDescToXeVMPattern
155 : public OpConversionPattern<xegpu::CreateNdDescOp> {
156 using OpConversionPattern::OpConversionPattern;
157 LogicalResult
158 matchAndRewrite(xegpu::CreateNdDescOp op,
159 xegpu::CreateNdDescOp::Adaptor adaptor,
160 ConversionPatternRewriter &rewriter) const override {
161 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
162 if (mixedOffsets.size() != 0)
163 return rewriter.notifyMatchFailure(op, "Offsets not supported.");
164 auto loc = op.getLoc();
165 auto source = op.getSource();
166 // Op is lowered to a code sequence that populates payload.
167 // Payload is a 8xi32 vector. Offset to individual fields are defined in
168 // NdTdescOffset enum.
169 Type payloadElemTy = rewriter.getI32Type();
170 VectorType payloadTy = VectorType::get(8, payloadElemTy);
171 Type i64Ty = rewriter.getI64Type();
172 // 4xi64 view is used for inserting the base pointer.
173 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
174 // Initialize payload to zero.
175 Value payload = arith::ConstantOp::create(
176 rewriter, loc,
177 DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
178
179 Value baseAddr;
180 Value baseShapeW;
181 Value baseShapeH;
182 Value offsetW;
183 Value offsetH;
184
185 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
186 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
187 // Descriptor shape is expected to be 2D.
188 int64_t rank = mixedSizes.size();
189 auto sourceTy = source.getType();
190 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
191 // If source is a memref, we need to extract the aligned pointer as index.
192 // Pointer type is passed as i32 or i64 by type converter.
193 if (sourceMemrefTy) {
194 if (!sourceMemrefTy.hasRank()) {
195 return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
196 }
197 // Access adaptor after failure check to avoid rolling back generated code
198 // for materialization cast.
199 baseAddr = adaptor.getSource();
200 } else {
201 baseAddr = adaptor.getSource();
202 if (baseAddr.getType() != i64Ty) {
203 // Pointer type may be i32. Cast to i64 if needed.
204 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
205 }
206 }
207 // 1D tensor descriptor is just the base address.
208 if (rank == 1) {
209 rewriter.replaceOp(op, baseAddr);
210 return success();
211 }
212 // Utility for creating offset values from op fold result.
213 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
214 unsigned idx) -> Value {
215 Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
216 val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
217 return val;
218 };
219 // Offsets are not supported (0 is used).
220 offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
221 offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
222 // Get shape values from op fold results.
223 baseShapeW = createOffset(mixedSizes, 1);
224 baseShapeH = createOffset(mixedSizes, 0);
225 // Populate payload.
226 Value payLoadAsI64 =
227 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
228 payLoadAsI64 =
229 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
230 static_cast<int>(NdTdescOffset::BasePtr));
231 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
232 payload =
233 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
234 static_cast<int>(NdTdescOffset::BaseShapeW));
235 payload =
236 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
237 static_cast<int>(NdTdescOffset::BaseShapeH));
238 payload = vector::InsertOp::create(
239 rewriter, loc, offsetW, payload,
240 static_cast<int>(NdTdescOffset::TensorOffsetW));
241 payload = vector::InsertOp::create(
242 rewriter, loc, offsetH, payload,
243 static_cast<int>(NdTdescOffset::TensorOffsetH));
244 rewriter.replaceOp(op, payload);
245 return success();
246 }
247};
248
249template <
250 typename OpType,
251 typename = std::enable_if_t<llvm::is_one_of<
252 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
253class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
254 using OpConversionPattern<OpType>::OpConversionPattern;
255 LogicalResult
256 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override {
258 auto mixedOffsets = op.getMixedOffsets();
259 int64_t opOffsetsSize = mixedOffsets.size();
260 auto loc = op.getLoc();
261 auto ctxt = rewriter.getContext();
262
263 auto tdesc = adaptor.getTensorDesc();
264 auto tdescTy = op.getTensorDescType();
265 auto tileRank = tdescTy.getRank();
266 if (opOffsetsSize != tileRank)
267 return rewriter.notifyMatchFailure(
268 op, "Expected offset rank to match descriptor rank.");
269 auto elemType = tdescTy.getElementType();
270 auto elemBitSize = elemType.getIntOrFloatBitWidth();
271 if (elemBitSize % 8 != 0)
272 return rewriter.notifyMatchFailure(
273 op, "Expected element type bit width to be multiple of 8.");
274
275 // Get address space from tensor descriptor memory space.
276 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
277 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
278 if (tileRank == 2) {
279 // Compute element byte size.
280 Value elemByteSize = arith::ConstantIntOp::create(
281 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
282 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
283 Value payLoadAsI64 =
284 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
285 Value basePtr =
286 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
287 static_cast<int>(NdTdescOffset::BasePtr));
288 Value baseShapeW = vector::ExtractOp::create(
289 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
290 Value baseShapeH = vector::ExtractOp::create(
291 rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
292 // Offsets are provided by the op.
293 // convert them to i32.
294 Value offsetW =
295 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
296 offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
297 rewriter.getI32Type(), offsetW);
298 Value offsetH =
299 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
300 offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
301 rewriter.getI32Type(), offsetH);
302 // Convert base pointer (i64) to LLVM pointer type.
303 Value basePtrLLVM =
304 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
305 // Compute width in bytes.
306 Value surfaceW =
307 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
308
309 // Get tile width from the tensor descriptor type.
310 auto tileW = tdescTy.getDimSize(tileRank - 1);
311 // Get tile height from the tensor descriptor type.
312 auto tileH = tdescTy.getDimSize(0);
313 // Get vblocks from the tensor descriptor type.
314 int32_t vblocks = tdescTy.getArrayLength();
315 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
316 Value src = adaptor.getValue();
317 // If store value is a scalar, get value from op instead of adaptor.
318 // Adaptor might have optimized away single element vector
319 if (src.getType().isIntOrFloat()) {
320 src = op.getValue();
321 }
322 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
323 if (!srcVecTy)
324 return rewriter.notifyMatchFailure(
325 op, "Expected store value to be a vector type.");
326 // Get flat vector type of integer type with matching element bit size.
327 VectorType newSrcVecTy =
328 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
329 if (srcVecTy != newSrcVecTy)
330 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
331 auto storeCacheControl =
332 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
333 xevm::BlockStore2dOp::create(
334 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
335 offsetH, elemBitSize, tileW, tileH, src,
336 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
337 rewriter.eraseOp(op);
338 } else {
339 auto loadCacheControl =
340 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
341 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
342 xevm::BlockPrefetch2dOp::create(
343 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
344 offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
345 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
346 rewriter.eraseOp(op);
347 } else {
348 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
349 const bool vnni = op.getPacked().value_or(false);
350 auto transposeValue = op.getTranspose();
351 bool transpose =
352 transposeValue.has_value() && transposeValue.value()[0] == 1;
353 VectorType loadedTy = encodeVectorTypeTo(
354 dstVecTy, vnni ? rewriter.getI32Type()
355 : rewriter.getIntegerType(elemBitSize));
356
357 Value resultFlatVec = xevm::BlockLoad2dOp::create(
358 rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
359 surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
360 transpose, vnni,
361 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
362 resultFlatVec = vector::BitCastOp::create(
363 rewriter, loc,
364 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
365 resultFlatVec);
366 rewriter.replaceOp(op, resultFlatVec);
367 }
368 }
369 } else {
370 // 1D tensor descriptor.
371 // `tdesc` represents base address as i64
372 // Offset in number of elements, need to multiply by element byte size.
373 // Compute byte offset.
374 // byteOffset = offset * elementByteSize
375 Value offset =
376 getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
377 offset = getValueOrCreateCastToIndexLike(rewriter, loc,
378 rewriter.getI64Type(), offset);
379 // Compute element byte size.
380 Value elemByteSize = arith::ConstantIntOp::create(
381 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
382 Value byteOffset =
383 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
384 // Final address = basePtr + byteOffset
385 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
386 loc, tdesc,
387 getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
388 byteOffset));
389 // Convert base pointer (i64) to LLVM pointer type.
390 Value finalPtrLLVM =
391 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
392 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
393 Value src = adaptor.getValue();
394 // If store value is a scalar, get value from op instead of adaptor.
395 // Adaptor might have optimized away single element vector
396 if (src.getType().isIntOrFloat()) {
397 src = op.getValue();
398 }
399 VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
400 if (!srcVecTy)
401 return rewriter.notifyMatchFailure(
402 op, "Expected store value to be a vector type.");
403 // Get flat vector type of integer type with matching element bit size.
404 VectorType newSrcVecTy =
405 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
406 if (srcVecTy != newSrcVecTy)
407 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
408 auto storeCacheControl =
409 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
410 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
411 op, finalPtrLLVM, src,
412 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
413 } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
414 auto loadCacheControl =
415 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
416 VectorType resTy = cast<VectorType>(op.getValue().getType());
417 VectorType loadedTy =
418 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
419 Value load = xevm::BlockLoadOp::create(
420 rewriter, loc, loadedTy, finalPtrLLVM,
421 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
422 if (loadedTy != resTy)
423 load = vector::BitCastOp::create(rewriter, loc, resTy, load);
424 rewriter.replaceOp(op, load);
425 } else {
426 return rewriter.notifyMatchFailure(
427 op, "Unsupported operation: xegpu.prefetch_nd with tensor "
428 "descriptor rank == 1");
429 }
430 }
431 return success();
432 }
433};
434
435// Add a builder that creates
436// offset * elemByteSize + baseAddr
437static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
438 Location loc, Value baseAddr, Value offset,
439 int64_t elemByteSize) {
441 rewriter, loc, baseAddr.getType(), elemByteSize);
442 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
443 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
444 return newAddr;
445}
446
447template <typename OpType,
448 typename = std::enable_if_t<llvm::is_one_of<
449 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
450class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
451 using OpConversionPattern<OpType>::OpConversionPattern;
452 LogicalResult
453 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override {
455 Value offset = adaptor.getOffsets();
456 if (!offset)
457 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
458 auto loc = op.getLoc();
459 auto ctxt = rewriter.getContext();
460 auto tdescTy = op.getTensorDescType();
461 Value basePtrI64;
462 // Load result or Store valye Type can be vector or scalar.
463 Type valOrResTy;
464 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
465 valOrResTy =
466 this->getTypeConverter()->convertType(op.getResult().getType());
467 else
468 valOrResTy = adaptor.getValue().getType();
469 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
470 bool hasScalarVal = !valOrResVecTy;
471 int64_t elemBitWidth =
472 hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
473 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
474 // Element type must be multiple of 8 bits.
475 if (elemBitWidth % 8 != 0)
476 return rewriter.notifyMatchFailure(
477 op, "Expected element type bit width to be multiple of 8.");
478 int64_t elemByteSize = elemBitWidth / 8;
479 // Default memory space is global.
480 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
481 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
482 // If tensor descriptor is available, we use its memory space.
483 if (tdescTy)
484 ptrTypeLLVM = LLVM::LLVMPointerType::get(
485 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
486 // Base pointer can come from source (load) or dest (store).
487 // If they are memrefs, we use their memory space.
488 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
489 basePtrI64 = adaptor.getSource();
490 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
491 auto addrSpace = memRefTy.getMemorySpaceAsInt();
492 if (addrSpace != 0)
493 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
494 }
495 } else {
496 basePtrI64 = adaptor.getDest();
497 if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
498 auto addrSpace = memRefTy.getMemorySpaceAsInt();
499 if (addrSpace != 0)
500 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
501 }
502 }
503 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
504 if (basePtrI64.getType() != rewriter.getI64Type()) {
505 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
506 basePtrI64);
507 }
508 Value mask = adaptor.getMask();
509 if (dyn_cast<VectorType>(offset.getType())) {
510 // Offset needs be scalar. Single element vector is converted to scalar
511 // by type converter.
512 return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
513 } else {
514 // If offset is provided, we add them to the base pointer.
515 // Offset is in number of elements, we need to multiply by
516 // element byte size.
517 basePtrI64 =
518 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
519 }
520 // Convert base pointer (i64) to LLVM pointer type.
521 Value basePtrLLVM =
522 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
523
524 Value maskForLane;
525 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
526 if (maskVecTy) {
527 // Mask needs be scalar. Single element vector is converted to scalar by
528 // type converter.
529 return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
530 } else
531 maskForLane = mask;
532 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
533 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
534 maskForLane, true, true);
535 // If mask is true,- then clause - load from memory and yield.
536 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
537 if (!hasScalarVal)
538 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
539 valOrResVecTy.getElementType());
540 Value loaded =
541 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
542 // Set cache control attribute on the load operation.
543 loaded.getDefiningOp()->setAttr(
544 "cache_control", xevm::LoadCacheControlAttr::get(
545 ctxt, translateLoadXeGPUCacheHint(
546 op.getL1Hint(), op.getL3Hint())));
547 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
548 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
549 // If mask is false - else clause -yield a vector of zeros.
550 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
551 TypedAttr eVal;
552 if (eTy.isFloat())
553 eVal = FloatAttr::get(eTy, 0.0);
554 else
555 eVal = IntegerAttr::get(eTy, 0);
556 if (hasScalarVal)
557 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
558 else
559 loaded = arith::ConstantOp::create(
560 rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
561 scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
562 rewriter.replaceOp(op, ifOp.getResult(0));
563 } else {
564 // If mask is true, perform the store.
565 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
566 auto body = ifOp.getBody();
567 rewriter.setInsertionPointToStart(body);
568 auto storeOp =
569 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
570 // Set cache control attribute on the store operation.
571 storeOp.getOperation()->setAttr(
572 "cache_control", xevm::StoreCacheControlAttr::get(
573 ctxt, translateStoreXeGPUCacheHint(
574 op.getL1Hint(), op.getL3Hint())));
575 rewriter.eraseOp(op);
576 }
577 return success();
578 }
579};
580
581class CreateMemDescOpPattern final
582 : public OpConversionPattern<xegpu::CreateMemDescOp> {
583public:
584 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
585 LogicalResult
586 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
587 ConversionPatternRewriter &rewriter) const override {
588
589 rewriter.replaceOp(op, adaptor.getSource());
590 return success();
591 }
592};
593
594template <typename OpType,
595 typename = std::enable_if_t<llvm::is_one_of<
596 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
597class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
598 using OpConversionPattern<OpType>::OpConversionPattern;
599 LogicalResult
600 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
601 ConversionPatternRewriter &rewriter) const override {
602
603 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
604 if (offsets.empty())
605 return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
606
607 auto loc = op.getLoc();
608 auto ctxt = rewriter.getContext();
609 Value baseAddr32 = adaptor.getMemDesc();
610 Value mdescVal = op.getMemDesc();
611 // Load result or Store value Type can be vector or scalar.
612 Value data;
613 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
614 data = op.getResult();
615 else
616 data = adaptor.getData();
617 VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
618 if (!valOrResVecTy)
619 valOrResVecTy = VectorType::get(1, data.getType());
620 if (valOrResVecTy.getShape().size() != 1)
621 return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
622
623 int64_t elemBitWidth =
624 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
625 // Element type must be multiple of 8 bits.
626 if (elemBitWidth % 8 != 0)
627 return rewriter.notifyMatchFailure(
628 op, "Expected element type bit width to be multiple of 8.");
629 int64_t elemByteSize = elemBitWidth / 8;
630
631 // Default memory space is SLM.
632 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
633 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
634
635 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
636
637 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
638 linearOffset = arith::IndexCastUIOp::create(
639 rewriter, loc, rewriter.getI32Type(), linearOffset);
640 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
641 linearOffset, elemByteSize);
642
643 // convert base pointer (i32) to LLVM pointer type
644 Value basePtrLLVM =
645 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
646
647 if (op.getSubgroupBlockIoAttr()) {
648 // if the attribute 'subgroup_block_io' is set to true, it lowers to
649 // xevm.blockload
650
651 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
652 VectorType intVecTy =
653 VectorType::get(valOrResVecTy.getShape(), intElemTy);
654
655 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
656 Value loadOp =
657 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
658 if (intVecTy != valOrResVecTy) {
659 loadOp =
660 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
661 }
662 rewriter.replaceOp(op, loadOp);
663 } else {
664 Value dataToStore = adaptor.getData();
665 if (valOrResVecTy != intVecTy) {
666 dataToStore =
667 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
668 }
669 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
670 nullptr);
671 rewriter.eraseOp(op);
672 }
673 return success();
674 }
675
676 if (valOrResVecTy.getNumElements() >= 1) {
677 auto chipOpt = xegpu::getChipStr(op);
678 if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
679 // the lowering for chunk load only works for pvc and bmg
680 return rewriter.notifyMatchFailure(
681 op, "The lowering is specific to pvc or bmg.");
682 }
683 }
684
685 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
686 // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
687 // operation. LLVM load/store does not support vector of size 1, so we
688 // need to handle this case separately.
689 auto scalarTy = valOrResVecTy.getElementType();
690 LLVM::LoadOp loadOp;
691 if (valOrResVecTy.getNumElements() == 1)
692 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
693 else
694 loadOp =
695 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
696 rewriter.replaceOp(op, loadOp);
697 } else {
698 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
699 rewriter.eraseOp(op);
700 }
701 return success();
702 }
703};
704
705class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
706 using OpConversionPattern::OpConversionPattern;
707 LogicalResult
708 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
709 ConversionPatternRewriter &rewriter) const override {
710 auto loc = op.getLoc();
711 auto ctxt = rewriter.getContext();
712 auto tdescTy = op.getTensorDescType();
713 Value basePtrI64 = adaptor.getSource();
714 // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
715 if (basePtrI64.getType() != rewriter.getI64Type())
716 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
717 basePtrI64);
718 Value offsets = adaptor.getOffsets();
719 if (offsets) {
720 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
721 if (offsetsVecTy) {
722 // Offset needs be scalar.
723 return rewriter.notifyMatchFailure(op,
724 "Expected offsets to be a scalar.");
725 } else {
726 int64_t elemBitWidth{0};
727 int64_t elemByteSize;
728 // Element byte size can come from three sources:
729 if (tdescTy) {
730 // If tensor descriptor is available, we use its element type to
731 // determine element byte size.
732 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
733 } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
734 // If memref is available, we use its element type to
735 // determine element byte size.
736 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
737 } else {
738 // Otherwise, we use the provided offset byte alignment.
739 elemByteSize = *op.getOffsetAlignByte();
740 }
741 if (elemBitWidth != 0) {
742 if (elemBitWidth % 8 != 0)
743 return rewriter.notifyMatchFailure(
744 op, "Expected element type bit width to be multiple of 8.");
745 elemByteSize = elemBitWidth / 8;
746 }
747 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
748 elemByteSize);
749 }
750 }
751 // Default memory space is global.
752 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
753 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
754 // If tensor descriptor is available, we use its memory space.
755 if (tdescTy)
756 ptrTypeLLVM = LLVM::LLVMPointerType::get(
757 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
758 // If source is a memref, we use its memory space.
759 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
760 auto addrSpace = memRefTy.getMemorySpaceAsInt();
761 if (addrSpace != 0)
762 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
763 }
764 // Convert base pointer (i64) to LLVM pointer type.
765 Value ptrLLVM =
766 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
767 // Create the prefetch op with cache control attribute.
768 xevm::PrefetchOp::create(
769 rewriter, loc, ptrLLVM,
770 xevm::LoadCacheControlAttr::get(
771 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
772 rewriter.eraseOp(op);
773 return success();
774 }
775};
776
777class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
778 using OpConversionPattern::OpConversionPattern;
779 LogicalResult
780 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
781 ConversionPatternRewriter &rewriter) const override {
782 auto loc = op.getLoc();
783 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
784 switch (op.getFenceScope()) {
785 case xegpu::FenceScope::Workgroup:
786 memScope = xevm::MemScope::WORKGROUP;
787 break;
788 case xegpu::FenceScope::GPU:
789 memScope = xevm::MemScope::DEVICE;
790 break;
791 }
792 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
793 switch (op.getMemoryKind()) {
794 case xegpu::MemorySpace::Global:
795 addrSpace = xevm::AddrSpace::GLOBAL;
796 break;
797 case xegpu::MemorySpace::SLM:
798 addrSpace = xevm::AddrSpace::SHARED;
799 break;
800 }
801 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
802 rewriter.eraseOp(op);
803 return success();
804 }
805};
806
807class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
808 using OpConversionPattern::OpConversionPattern;
809 LogicalResult
810 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
811 ConversionPatternRewriter &rewriter) const override {
812 auto loc = op.getLoc();
813 auto ctxt = rewriter.getContext();
814 auto aTy = cast<VectorType>(op.getLhs().getType());
815 auto bTy = cast<VectorType>(op.getRhs().getType());
816 auto resultType = cast<VectorType>(op.getResultType());
817
818 auto encodePrecision = [&](Type type) -> xevm::ElemType {
819 if (type == rewriter.getBF16Type())
820 return xevm::ElemType::BF16;
821 else if (type == rewriter.getF16Type())
822 return xevm::ElemType::F16;
823 else if (type == rewriter.getTF32Type())
824 return xevm::ElemType::TF32;
825 else if (type.isInteger(8)) {
826 if (type.isUnsignedInteger())
827 return xevm::ElemType::U8;
828 return xevm::ElemType::S8;
829 } else if (type == rewriter.getF32Type())
830 return xevm::ElemType::F32;
831 else if (type.isInteger(32))
832 return xevm::ElemType::S32;
833 llvm_unreachable("add more support for ElemType");
834 };
835 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
836 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
837 Value c = op.getAcc();
838 if (!c) {
839 auto elementTy = resultType.getElementType();
840 Attribute initValueAttr;
841 if (isa<FloatType>(elementTy))
842 initValueAttr = FloatAttr::get(elementTy, 0.0);
843 else
844 initValueAttr = IntegerAttr::get(elementTy, 0);
845 c = arith::ConstantOp::create(
846 rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
847 }
848
849 Value aVec = op.getLhs();
850 Value bVec = op.getRhs();
851 auto cvecty = cast<VectorType>(c.getType());
852 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
853 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
854 VectorType cNty =
855 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
856 if (cvecty != cNty)
857 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
858 Value dpasRes = xevm::MMAOp::create(
859 rewriter, loc, cNty, aVec, bVec, c,
860 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
861 systolicDepth *
862 getNumOperandsPerDword(precATy)),
863 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
864 if (cvecty != cNty)
865 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
866 rewriter.replaceOp(op, dpasRes);
867 return success();
868 }
869
870private:
871 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
872 switch (pTy) {
873 case xevm::ElemType::TF32:
874 return 1;
875 case xevm::ElemType::BF16:
876 case xevm::ElemType::F16:
877 return 2;
878 case xevm::ElemType::U8:
879 case xevm::ElemType::S8:
880 return 4;
881 default:
882 llvm_unreachable("unsupported xevm::ElemType");
883 }
884 }
885};
886
887static std::optional<LLVM::AtomicBinOp>
888matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
889 switch (arithKind) {
890 case arith::AtomicRMWKind::addf:
891 return LLVM::AtomicBinOp::fadd;
892 case arith::AtomicRMWKind::addi:
893 return LLVM::AtomicBinOp::add;
894 case arith::AtomicRMWKind::assign:
895 return LLVM::AtomicBinOp::xchg;
896 case arith::AtomicRMWKind::maximumf:
897 return LLVM::AtomicBinOp::fmax;
898 case arith::AtomicRMWKind::maxs:
899 return LLVM::AtomicBinOp::max;
900 case arith::AtomicRMWKind::maxu:
901 return LLVM::AtomicBinOp::umax;
902 case arith::AtomicRMWKind::minimumf:
903 return LLVM::AtomicBinOp::fmin;
904 case arith::AtomicRMWKind::mins:
905 return LLVM::AtomicBinOp::min;
906 case arith::AtomicRMWKind::minu:
907 return LLVM::AtomicBinOp::umin;
908 case arith::AtomicRMWKind::ori:
909 return LLVM::AtomicBinOp::_or;
910 case arith::AtomicRMWKind::andi:
911 return LLVM::AtomicBinOp::_and;
912 default:
913 return std::nullopt;
914 }
915}
916
917class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
918 using OpConversionPattern::OpConversionPattern;
919 LogicalResult
920 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
921 ConversionPatternRewriter &rewriter) const override {
922 auto loc = op.getLoc();
923 auto ctxt = rewriter.getContext();
924 auto tdesc = op.getTensorDesc().getType();
925 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
926 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
927 Value basePtrI64 = arith::IndexCastOp::create(
928 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
929 Value basePtrLLVM =
930 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
931 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
932 VectorType srcOrDstFlatVecTy = VectorType::get(
933 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
934 Value srcFlatVec = vector::ShapeCastOp::create(
935 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
936 auto atomicKind = matchSimpleAtomicOp(op.getKind());
937 assert(atomicKind.has_value());
938 Value resVec = srcFlatVec;
939 for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
940 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
941 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
942 rewriter.getIndexAttr(i));
943 Value currPtr =
944 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
945 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
946 Value newVal =
947 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
948 val, LLVM::AtomicOrdering::seq_cst);
949 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
950 }
951 rewriter.replaceOp(op, resVec);
952 return success();
953 }
954};
955
956//===----------------------------------------------------------------------===//
957// Pass Definition
958//===----------------------------------------------------------------------===//
959
960struct ConvertXeGPUToXeVMPass
961 : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
962 using Base::Base;
963
964 void runOnOperation() override {
965 LLVMTypeConverter typeConverter(&getContext());
966 typeConverter.addConversion([&](VectorType type) -> Type {
967 unsigned rank = type.getRank();
968 auto elemType = type.getElementType();
969 // If the element type is index, convert it to i64.
970 if (llvm::isa<IndexType>(elemType))
971 elemType = IntegerType::get(&getContext(), 64);
972 // If the vector is a scalar or has a single element, return the element
973 if (rank < 1 || type.getNumElements() == 1)
974 return elemType;
975 // Otherwise, convert the vector to a flat vector type.
976 int64_t sum = llvm::product_of(type.getShape());
977 return VectorType::get(sum, elemType);
978 });
979 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
980 // Scattered descriptors are not supported in XeVM lowering.
981 if (type.isScattered())
982 return {};
983 if (type.getRank() == 1)
984 return IntegerType::get(&getContext(), 64);
985 auto i32Type = IntegerType::get(&getContext(), 32);
986 return VectorType::get(8, i32Type);
987 });
988 // Convert MemDescType into i32 for SLM
989 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
990 return IntegerType::get(&getContext(), 32);
991 });
992
993 typeConverter.addConversion([&](MemRefType type) -> Type {
994 if (type.getMemorySpaceAsInt() == 3)
995 return IntegerType::get(&getContext(), 32);
996 return IntegerType::get(&getContext(), 64);
997 });
998
999 // LLVM type converter puts unrealized casts for the following cases:
1000 // add materialization casts to handle them.
1001
1002 // Materialization to convert memref to i64
1003 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
1004 ValueRange inputs,
1005 Location loc) -> Value {
1006 if (inputs.size() != 1)
1007 return {};
1008 auto input = inputs.front();
1009 if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1010
1011 Value addr =
1012 memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
1013 return arith::IndexCastUIOp::create(builder, loc, type, addr)
1014 .getResult();
1015 }
1016 return {};
1017 };
1018
1019 // Materialization to convert ui64 to i64
1020 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
1021 ValueRange inputs,
1022 Location loc) -> Value {
1023 if (inputs.size() != 1)
1024 return {};
1025 auto input = inputs.front();
1026 if (input.getType() == builder.getIntegerType(64, false)) {
1027 Value cast =
1028 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1029 .getResult();
1030 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1031 .getResult();
1032 }
1033 return {};
1034 };
1035
1036 // Materialization to convert ui32 to i32
1037 auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
1038 ValueRange inputs,
1039 Location loc) -> Value {
1040 if (inputs.size() != 1)
1041 return {};
1042 auto input = inputs.front();
1043 if (input.getType() == builder.getIntegerType(32, false)) {
1044 Value cast =
1045 index::CastUOp::create(builder, loc, builder.getIndexType(), input)
1046 .getResult();
1047 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1048 .getResult();
1049 }
1050 return {};
1051 };
1052
1053 // Materialization to convert
1054 // - single element 1D vector to scalar
1055 // - bitcast vector of same rank
1056 // - shape vector of different rank but same element type
1057 auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
1058 ValueRange inputs,
1059 Location loc) -> Value {
1060 if (inputs.size() != 1)
1061 return {};
1062 auto input = inputs.front();
1063 if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
1064 if (vecTy.getNumElements() == 1) {
1065 // If the vector has a single element, return the element type.
1066 Value cast =
1067 vector::ExtractOp::create(builder, loc, input, 0).getResult();
1068 if (vecTy.getElementType() == builder.getIndexType())
1069 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1070 .getResult();
1071 return cast;
1072 } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
1073 // If the target type is a vector of same rank,
1074 // bitcast to the target type.
1075 if (targetVecTy.getRank() == vecTy.getRank())
1076 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1077 .getResult();
1078 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1079 // If the target type is a vector of different rank but same element
1080 // type, reshape to the target type.
1081 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1082 .getResult();
1083 }
1084 }
1085 }
1086 return {};
1087 };
1088
1089 // If result type of original op is single element vector and lowered type
1090 // is scalar. This materialization cast creates a single element vector by
1091 // broadcasting the scalar value.
1092 auto singleElementVectorMaterializationCast =
1093 [](OpBuilder &builder, Type type, ValueRange inputs,
1094 Location loc) -> Value {
1095 if (inputs.size() != 1)
1096 return {};
1097 auto input = inputs.front();
1098 if (input.getType().isIntOrIndexOrFloat()) {
1099 // If the input is a scalar, and the target type is a vector of single
1100 // element, create a single element vector by broadcasting.
1101 if (auto vecTy = dyn_cast<VectorType>(type)) {
1102 if (vecTy.getNumElements() == 1) {
1103 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1104 .getResult();
1105 }
1106 }
1107 }
1108 return {};
1109 };
1110 typeConverter.addSourceMaterialization(
1111 singleElementVectorMaterializationCast);
1112 typeConverter.addTargetMaterialization(memrefMaterializationCast);
1113 typeConverter.addTargetMaterialization(ui32MaterializationCast);
1114 typeConverter.addTargetMaterialization(ui64MaterializationCast);
1115 typeConverter.addTargetMaterialization(vectorMaterializationCast);
1116 ConversionTarget target(getContext());
1117 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1118 vector::VectorDialect, arith::ArithDialect,
1119 memref::MemRefDialect, gpu::GPUDialect,
1120 index::IndexDialect>();
1121 target.addIllegalDialect<xegpu::XeGPUDialect>();
1122
1123 RewritePatternSet patterns(&getContext());
1126 patterns, target);
1127 if (failed(applyPartialConversion(getOperation(), target,
1128 std::move(patterns))))
1129 signalPassFailure();
1130 }
1131};
1132} // namespace
1133
1134//===----------------------------------------------------------------------===//
1135// Pattern Population
1136//===----------------------------------------------------------------------===//
1138 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1139 patterns.add<CreateNdDescToXeVMPattern,
1140 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1141 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1142 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1143 typeConverter, patterns.getContext());
1144 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1145 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1146 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1147 typeConverter, patterns.getContext());
1148 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1149 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1150 CreateMemDescOpPattern>(typeConverter, patterns.getContext());
1151 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1152 patterns.getContext());
1153}
return success()
b getContext())
auto load
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)